diff --git a/CMakeLists.txt b/CMakeLists.txt index aca1c3c12ad7808a99ef03b096f32e5d3ce8d3de..ddad9ce707da45cb6becbdd5a68fc16a3117c30e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -7,6 +7,7 @@ project (rays CXX) #------------------------------------------------------------------------------- option (USE_PCH "Enable the use of precompiled headers" ON) option (USE_STATIC "Limits the dyamics for testing." OFF) +option (SAVE_KERNEL_SOURCE "Writes the kernel source code to a file." OFF) #------------------------------------------------------------------------------- # Set the cmake module path. @@ -60,9 +61,12 @@ else () find_package (CUDAToolkit REQUIRED) + option (USE_CUDA_TEXTURES "Enable the use of cuda textures" OFF) + target_compile_definitions (cuda_lib INTERFACE USE_CUDA + $<$:USE_CUDA_TEXTURES> CUDA_INCLUDE="${CUDAToolkit_INCLUDE_DIRS}" ) target_link_libraries (cuda_lib @@ -73,12 +77,18 @@ else () endif () endif () +option (USE_INPUT_CACHE "Cache the values kernel input values." OFF) + add_library (gpu_lib INTERFACE) target_link_libraries (gpu_lib INTERFACE $<$:metal_lib> $<$:cuda_lib> ) +target_compile_definitions (gpu_lib + INTERFACE + $<$:USE_INPUT_CACHE> +) #------------------------------------------------------------------------------- # Sanitizer options @@ -234,6 +244,9 @@ add_dependencies (cuda-resource-headers pull_llvm) add_dependencies (scan-build-py pull_llvm) add_dependencies (x86-resource-headers pull_llvm) add_dependencies (obj.clangSupport pull_llvm) +add_dependencies (arm-common-resource-headers pull_llvm) +add_dependencies (arm-resource-headers pull_llvm) +add_dependencies (aarch64-resource-headers pull_llvm) add_library (llvm_dep INTERFACE) target_include_directories (llvm_dep @@ -259,6 +272,8 @@ target_link_libraries (llvm_dep clangCodeGen LLVM${LLVM_NATIVE_ARCH}CodeGen LLVMOrcJIT + LLVMOrcDebugging + LLVMOrcTargetProcess ) #------------------------------------------------------------------------------- diff --git a/graph_benchmark/xrays_bench.cpp b/graph_benchmark/xrays_bench.cpp index 9496977d688155a3c748ce87a90b131e14cd8769..0af790dc2543aeb2307b54286866d6d19c009f7e 100644 --- a/graph_benchmark/xrays_bench.cpp +++ b/graph_benchmark/xrays_bench.cpp @@ -38,10 +38,14 @@ void bench_runner() { const size_t batch = NUM_RAYS/threads.size(); const size_t extra = NUM_RAYS%threads.size(); - timeing::measure_diagnostic_threaded timing; + timeing::measure_diagnostic_threaded time_setup("Setup Time"); + timeing::measure_diagnostic_threaded time_init("Init Time"); + timeing::measure_diagnostic_threaded time_compile("Compile Time"); + timeing::measure_diagnostic_threaded time_steps("Time Steps"); for (size_t i = 0, ie = threads.size(); i < ie; i++) { - threads[i] = std::thread([&timing, batch, extra] (const size_t thread_number) -> void { + threads[i] = std::thread([&time_setup, &time_init, &time_compile, &time_steps, batch, extra] (const size_t thread_number) -> void { + time_setup.start_time(thread_number); const size_t local_num_rays = batch + (extra > thread_number ? 1 : 0); @@ -78,25 +82,33 @@ void bench_runner() { eq, "", local_num_rays, thread_number); + time_setup.end_time(thread_number); + time_init.start_time(thread_number); solve.init(kx); + time_init.end_time(thread_number); + time_compile.start_time(thread_number); solve.compile(); + time_compile.end_time(thread_number); - timing.start_time(thread_number); + time_steps.start_time(thread_number); for (size_t j = 0; j < num_steps; j++) { for (size_t k = 0; k < SUB_STEPS; k++) { solve.step(); } } solve.sync_host(); - timing.end_time(thread_number); + time_steps.end_time(thread_number); }, i); } for (std::thread &t : threads) { t.join(); } - timing.print(); + time_setup.print(); + time_init.print(); + time_compile.print(); + time_steps.print(); std::cout << "--------------------------------------------------------------------------------" << std::endl << std::endl; diff --git a/graph_framework.xcodeproj/project.pbxproj b/graph_framework.xcodeproj/project.pbxproj index 4817ca9acd2b5a8d07f749b11d9d5838c1dc70bd..855cfac47a7e3e27fb2992dd3d8d34686eb5ae6e 100644 --- a/graph_framework.xcodeproj/project.pbxproj +++ b/graph_framework.xcodeproj/project.pbxproj @@ -886,7 +886,7 @@ isa = PBXProject; attributes = { BuildIndependentTargetsInParallel = YES; - LastUpgradeCheck = 1530; + LastUpgradeCheck = 1540; ORGANIZATIONNAME = "Cianciosa, Mark R."; TargetAttributes = { C73690302A38C498001733B0 = { @@ -1282,7 +1282,6 @@ "VMEC_FILE=\\\"/Users/m4c/Projects/graph_framework/graph_tests/vmec.nc\\\"", "EFIT_FILE=\\\"/Users/m4c/Projects/graph_framework/graph_tests/efit.nc\\\"", USE_METAL, - "CXX=\\\"c++\\ -I/Users/m4c/Projects/graph_framework/graph_framework\\ -std=gnu++2a\\\"", "$(inherited)", ); MACOSX_DEPLOYMENT_TARGET = 13.3; @@ -1338,8 +1337,7 @@ "EFIT_FILE=\\\"/Users/m4c/Projects/graph_framework/graph_tests/efit.nc\\\"", "VMEC_FILE=\\\"/Users/m4c/Projects/graph_framework/graph_tests/vmec.nc\\\"", USE_METAL, - "CXX_FLAGS=\\\"-g\\\"", - "\"CXX_ARGS=\\\"-I/Users/m4c/Projects/graph_framework/graph_framework -std=gnu++2a\\\"\"", + "\"CXX_ARGS=\\\"-I/Users/m4c/Projects/graph_framework/graph_framework -I/usr/local/include -I/Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk/usr/include/c++/v1 -I/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/lib/clang/15.0.0/include -I/Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk/usr/include -I/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/include -I/Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk/System/Library/Frameworks -fgnuc-version=4.2.1 -std=gnu++2a\\\"\"", STATIC, "DEBUG=1", "$(inherited)", @@ -1366,9 +1364,69 @@ OTHER_LDFLAGS = ( "-lnetcdf", "-ld_classic", - "-rpath", - /usr/local/lib, - "-lLLVM", + "-L/Users/m4c/Projects/graph_framework/build/_deps/llvm-build/lib", + "-lz", + "-lLLVMCoverage", + "-lLLVMSupport", + "-lLLVMDebugInfoCodeView", + "-lLLVMRemarks", + "-lLLVMJITLink", + "-lLLVMLinker", + "-lLLVMTextAPI", + "-lLLVMRuntimeDyld", + "-lLLVMOrcShared", + "-lLLVMOrcDebugging", + "-lLLVMOrcTargetProcess", + "-lLLVMOrcJIT", + "-lLLVMHipStdPar", + "-lLLVMAggressiveInstCombine", + "-lLLVMVectorize", + "-lLLVMAsmParser", + "-lLLVMOption", + "-lLLVMLTO", + "-lLLVMObject", + "-lLLVMWindowsDriver", + "-lLLVMDemangle", + "-lLLVMIRReader", + "-lLLVMIRPrinter", + "-lLLVMInstCombine", + "-lLLVMBinaryFormat", + "-lLLVMCoroutines", + "-lLLVMBitstreamReader", + "-lLLVMBitReader", + "-lLLVMBitWriter", + "-lLLVMDebugInfoDWARF", + "-lLLVMInstrumentation", + "-lLLVMCFGuard", + "-lLLVMObjCARCOpts", + "-lLLVMipo", + "-lLLVMGlobalISel", + "-lLLVMExecutionEngine", + "-lLLVMFrontendDriver", + "-lLLVMFrontendHLSL", + "-lLLVMFrontendOpenMP", + "-lLLVMFrontendOffloading", + "-lLLVMSelectionDAG", + "-lLLVMProfileData", + "-lLLVMAnalysis", + "-lLLVMScalarOpts", + "-lLLVMCodeGenTypes", + "-lLLVMCodeGen", + "-lLLVMTargetParser", + "-lLLVMScalarOpts", + "-lLLVMTarget", + "-lLLVMTransformUtils", + "-lLLVMPasses", + "-lLLVMSupport", + "-lLLVMMCParser", + "-lLLVMMC", + "-lLLVMCore", + "-lLLVMAsmPrinter", + "-lLLVMAArch64Utils", + "-lLLVMAArch64Info", + "-lLLVMAArch64Desc", + "-lLLVMAArch64AsmParser", + "-lLLVMAArch64CodeGen", "-lclangFrontend", "-lclangBasic", "-lclangEdit", @@ -1383,6 +1441,8 @@ "-lclangParse", "-lclangAPINotes", "-lclangCodeGen", + "-rpath", + /usr/local/lib, ); SDKROOT = macosx; SYSTEM_HEADER_SEARCH_PATHS = ""; @@ -1441,7 +1501,7 @@ "EFIT_FILE=\\\"/Users/m4c/Projects/graph_framework/graph_tests/efit.nc\\\"", "VMEC_FILE=\\\"/Users/m4c/Projects/graph_framework/graph_tests/vmec.nc\\\"", USE_METAL, - "\"CXX_ARGS=\\\"-I/Users/m4c/Projects/graph_framework/graph_framework -std=gnu++2a\\\"\"", + "\"CXX_ARGS=\\\"-I/Users/m4c/Projects/graph_framework/graph_framework -I/usr/local/include -I/Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk/usr/include/c++/v1 -I/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/lib/clang/15.0.0/include -I/Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk/usr/include -I/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/include -I/Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk/System/Library/Frameworks -fgnuc-version=4.2.1 -std=gnu++2a\\\"\"", "$(inherited)", ); GCC_WARN_64_TO_32_BIT_CONVERSION = YES; @@ -1466,9 +1526,69 @@ OTHER_LDFLAGS = ( "-lnetcdf", "-ld_classic", - "-rpath", - /usr/local/lib, - "-lLLVM", + "-L/Users/m4c/Projects/graph_framework/build/_deps/llvm-build/lib", + "-lz", + "-lLLVMCoverage", + "-lLLVMSupport", + "-lLLVMDebugInfoCodeView", + "-lLLVMRemarks", + "-lLLVMJITLink", + "-lLLVMLinker", + "-lLLVMTextAPI", + "-lLLVMRuntimeDyld", + "-lLLVMOrcShared", + "-lLLVMOrcDebugging", + "-lLLVMOrcTargetProcess", + "-lLLVMOrcJIT", + "-lLLVMHipStdPar", + "-lLLVMAggressiveInstCombine", + "-lLLVMVectorize", + "-lLLVMAsmParser", + "-lLLVMOption", + "-lLLVMLTO", + "-lLLVMObject", + "-lLLVMWindowsDriver", + "-lLLVMDemangle", + "-lLLVMIRReader", + "-lLLVMIRPrinter", + "-lLLVMInstCombine", + "-lLLVMBinaryFormat", + "-lLLVMCoroutines", + "-lLLVMBitstreamReader", + "-lLLVMBitReader", + "-lLLVMBitWriter", + "-lLLVMDebugInfoDWARF", + "-lLLVMInstrumentation", + "-lLLVMCFGuard", + "-lLLVMObjCARCOpts", + "-lLLVMipo", + "-lLLVMGlobalISel", + "-lLLVMExecutionEngine", + "-lLLVMFrontendDriver", + "-lLLVMFrontendHLSL", + "-lLLVMFrontendOpenMP", + "-lLLVMFrontendOffloading", + "-lLLVMSelectionDAG", + "-lLLVMProfileData", + "-lLLVMAnalysis", + "-lLLVMScalarOpts", + "-lLLVMCodeGenTypes", + "-lLLVMCodeGen", + "-lLLVMTargetParser", + "-lLLVMScalarOpts", + "-lLLVMTarget", + "-lLLVMTransformUtils", + "-lLLVMPasses", + "-lLLVMSupport", + "-lLLVMMCParser", + "-lLLVMMC", + "-lLLVMCore", + "-lLLVMAsmPrinter", + "-lLLVMAArch64Utils", + "-lLLVMAArch64Info", + "-lLLVMAArch64Desc", + "-lLLVMAArch64AsmParser", + "-lLLVMAArch64CodeGen", "-lclangFrontend", "-lclangBasic", "-lclangEdit", @@ -1483,6 +1603,8 @@ "-lclangParse", "-lclangAPINotes", "-lclangCodeGen", + "-rpath", + /usr/local/lib, ); SDKROOT = macosx; SYSTEM_HEADER_SEARCH_PATHS = ""; @@ -1747,11 +1869,94 @@ GCC_PREPROCESSOR_DEFINITIONS = ( "EFIT_FILE=\\\"/Users/m4c/Projects/graph_framework/graph_tests/efit.nc\\\"", USE_METAL, - "CXX=\\\"c++\\\"", "DEBUG=1", "$(inherited)", ); MACOSX_DEPLOYMENT_TARGET = 13.3; + OTHER_LDFLAGS = ( + "-lnetcdf", + "-ld_classic", + "-L/Users/m4c/Projects/graph_framework/build/_deps/llvm-build/lib", + "-lz", + "-lLLVMCoverage", + "-lLLVMSupport", + "-lLLVMDebugInfoCodeView", + "-lLLVMRemarks", + "-lLLVMJITLink", + "-lLLVMLinker", + "-lLLVMTextAPI", + "-lLLVMRuntimeDyld", + "-lLLVMOrcShared", + "-lLLVMOrcTargetProcess", + "-lLLVMOrcDebugging", + "-lLLVMOrcTargetProcess", + "-lLLVMOrcJIT", + "-lLLVMHipStdPar", + "-lLLVMAggressiveInstCombine", + "-lLLVMVectorize", + "-lLLVMAsmParser", + "-lLLVMOption", + "-lLLVMLTO", + "-lLLVMObject", + "-lLLVMWindowsDriver", + "-lLLVMDemangle", + "-lLLVMIRReader", + "-lLLVMIRPrinter", + "-lLLVMInstCombine", + "-lLLVMBinaryFormat", + "-lLLVMCoroutines", + "-lLLVMBitstreamReader", + "-lLLVMBitReader", + "-lLLVMBitWriter", + "-lLLVMDebugInfoDWARF", + "-lLLVMInstrumentation", + "-lLLVMCFGuard", + "-lLLVMObjCARCOpts", + "-lLLVMipo", + "-lLLVMGlobalISel", + "-lLLVMExecutionEngine", + "-lLLVMFrontendDriver", + "-lLLVMFrontendHLSL", + "-lLLVMFrontendOpenMP", + "-lLLVMFrontendOffloading", + "-lLLVMSelectionDAG", + "-lLLVMProfileData", + "-lLLVMAnalysis", + "-lLLVMScalarOpts", + "-lLLVMCodeGenTypes", + "-lLLVMCodeGen", + "-lLLVMTargetParser", + "-lLLVMScalarOpts", + "-lLLVMTarget", + "-lLLVMTransformUtils", + "-lLLVMPasses", + "-lLLVMSupport", + "-lLLVMMCParser", + "-lLLVMMC", + "-lLLVMCore", + "-lLLVMAsmPrinter", + "-lLLVMAArch64Utils", + "-lLLVMAArch64Info", + "-lLLVMAArch64Desc", + "-lLLVMAArch64AsmParser", + "-lLLVMAArch64CodeGen", + "-lclangFrontend", + "-lclangBasic", + "-lclangEdit", + "-lclangLex", + "-lclangDriver", + "-lclangSerialization", + "-lclangAST", + "-lclangSema", + "-lclangAnalysis", + "-lclangASTMatchers", + "-lclangSupport", + "-lclangParse", + "-lclangAPINotes", + "-lclangCodeGen", + "-rpath", + /usr/local/lib, + ); PRODUCT_NAME = "$(TARGET_NAME)"; }; name = Debug; @@ -1766,10 +1971,93 @@ GCC_PREPROCESSOR_DEFINITIONS = ( "EFIT_FILE=\\\"/Users/m4c/Projects/graph_framework/graph_tests/efit.nc\\\"", USE_METAL, - "CXX=\\\"c++\\\"", "$(inherited)", ); MACOSX_DEPLOYMENT_TARGET = 13.3; + OTHER_LDFLAGS = ( + "-lnetcdf", + "-ld_classic", + "-L/Users/m4c/Projects/graph_framework/build/_deps/llvm-build/lib", + "-lz", + "-lLLVMCoverage", + "-lLLVMSupport", + "-lLLVMDebugInfoCodeView", + "-lLLVMRemarks", + "-lLLVMJITLink", + "-lLLVMLinker", + "-lLLVMTextAPI", + "-lLLVMRuntimeDyld", + "-lLLVMOrcShared", + "-lLLVMOrcTargetProcess", + "-lLLVMOrcDebugging", + "-lLLVMOrcTargetProcess", + "-lLLVMOrcJIT", + "-lLLVMHipStdPar", + "-lLLVMAggressiveInstCombine", + "-lLLVMVectorize", + "-lLLVMAsmParser", + "-lLLVMOption", + "-lLLVMLTO", + "-lLLVMObject", + "-lLLVMWindowsDriver", + "-lLLVMDemangle", + "-lLLVMIRReader", + "-lLLVMIRPrinter", + "-lLLVMInstCombine", + "-lLLVMBinaryFormat", + "-lLLVMCoroutines", + "-lLLVMBitstreamReader", + "-lLLVMBitReader", + "-lLLVMBitWriter", + "-lLLVMDebugInfoDWARF", + "-lLLVMInstrumentation", + "-lLLVMCFGuard", + "-lLLVMObjCARCOpts", + "-lLLVMipo", + "-lLLVMGlobalISel", + "-lLLVMExecutionEngine", + "-lLLVMFrontendDriver", + "-lLLVMFrontendHLSL", + "-lLLVMFrontendOpenMP", + "-lLLVMFrontendOffloading", + "-lLLVMSelectionDAG", + "-lLLVMProfileData", + "-lLLVMAnalysis", + "-lLLVMScalarOpts", + "-lLLVMCodeGenTypes", + "-lLLVMCodeGen", + "-lLLVMTargetParser", + "-lLLVMScalarOpts", + "-lLLVMTarget", + "-lLLVMTransformUtils", + "-lLLVMPasses", + "-lLLVMSupport", + "-lLLVMMCParser", + "-lLLVMMC", + "-lLLVMCore", + "-lLLVMAsmPrinter", + "-lLLVMAArch64Utils", + "-lLLVMAArch64Info", + "-lLLVMAArch64Desc", + "-lLLVMAArch64AsmParser", + "-lLLVMAArch64CodeGen", + "-lclangFrontend", + "-lclangBasic", + "-lclangEdit", + "-lclangLex", + "-lclangDriver", + "-lclangSerialization", + "-lclangAST", + "-lclangSema", + "-lclangAnalysis", + "-lclangASTMatchers", + "-lclangSupport", + "-lclangParse", + "-lclangAPINotes", + "-lclangCodeGen", + "-rpath", + /usr/local/lib, + ); PRODUCT_NAME = "$(TARGET_NAME)"; }; name = Release; diff --git a/graph_framework.xcodeproj/xcshareddata/xcschemes/arithmetic_test.xcscheme b/graph_framework.xcodeproj/xcshareddata/xcschemes/arithmetic_test.xcscheme index 6ff6a04d549c6998a02e75846b6f2128d99676ac..42cdf873640301daae963a2257bf63901bc07bd2 100644 --- a/graph_framework.xcodeproj/xcshareddata/xcschemes/arithmetic_test.xcscheme +++ b/graph_framework.xcodeproj/xcshareddata/xcschemes/arithmetic_test.xcscheme @@ -1,6 +1,6 @@ :HEADER_DIR="$"> $<$:STATIC> + $<$:SAVE_KERNEL_SOURCE> ) target_include_directories (rays diff --git a/graph_framework/arithmetic.hpp b/graph_framework/arithmetic.hpp index a7d5322c3c6b6ae015c1a69185c1997a93e61824..ad30bd831ebf2b4430c6a4bf0b3186b04be2827e 100644 --- a/graph_framework/arithmetic.hpp +++ b/graph_framework/arithmetic.hpp @@ -11,6 +11,93 @@ #include "trigonometry.hpp" namespace graph { +//------------------------------------------------------------------------------ +/// @brief Check if nodes are constant combineable. +/// +/// @tparam T Base type of the nodes. +/// @tparam SAFE_MATH Use safe math operations. +/// +/// @params[in] a Opperand A +/// @params[in] b Opperand B +/// @returns True if a and b are combinable. +//------------------------------------------------------------------------------ + template + bool is_constant_combineable(shared_leaf a, + shared_leaf b) { + if (a->is_constant() && b->is_constant()) { + auto a1 = piecewise_1D_cast(a); + auto a2 = piecewise_2D_cast(a); + auto b2 = piecewise_2D_cast(b); + + return constant_cast(a).get() || + constant_cast(b).get() || + (a1.get() && a1->is_arg_match(b)) || + (a2.get() && a2->is_arg_match(b)) || + (a2.get() && (a2->is_row_match(b) || a2->is_col_match(b))) || + (b2.get() && (b2->is_row_match(a) || b2->is_col_match(a))); + } + return false; + } + +//------------------------------------------------------------------------------ +/// @brief Check if the constants are promotable. +/// +/// @tparam T Base type of the nodes. +/// @tparam SAFE_MATH Use safe math operations. +/// +/// @params[in] a Opperand A +/// @params[in] b Opperand B +/// @returns True if a is promoteable over b. +//------------------------------------------------------------------------------ + template + bool is_constant_promotable(shared_leaf a, + shared_leaf b) { + + auto b1 = piecewise_1D_cast(b); + auto b2 = piecewise_2D_cast(b); + + return a->is_constant() && + (!b->is_constant() || + (constant_cast(a).get() && (b1.get() || b2.get())) || + (piecewise_1D_cast(a).get() && b2.get())); + } + +//------------------------------------------------------------------------------ +/// @brief Check if the variable is combinable. +/// +/// @tparam T Base type of the nodes. +/// @tparam SAFE_MATH Use safe math operations. +/// +/// @params[in] a Opperand A +/// @params[in] b Opperand B +/// @returns True if a and b are combinable. +//------------------------------------------------------------------------------ + template + bool is_variable_combinable(shared_leaf a, + shared_leaf b) { + return a->get_power_base()->is_match(b->get_power_base()); + } + +//------------------------------------------------------------------------------ +/// @brief Check if the exponent is greater than the other. +/// +/// @tparam T Base type of the nodes. +/// @tparam SAFE_MATH Use safe math operations. +/// +/// @params[in] a Opperand A +/// @params[in] b Opperand B +/// @returns True if a and b are combinable. +//------------------------------------------------------------------------------ + template + bool is_greater_exponent(shared_leaf a, + shared_leaf b) { + auto ae = constant_cast(a->get_power_exponent()); + auto be = constant_cast(b->get_power_exponent()); + + return ae.get() && be.get() && + std::abs(ae->evaluate().at(0)) > std::abs(be->evaluate().at(0)); + } + //****************************************************************************** // Add node. //****************************************************************************** @@ -107,6 +194,37 @@ namespace graph { pr2->get_right()); } +// Combine 2D and 1D piecewise constants if a row or column matches. + if (pr2.get() && pr2->is_row_match(this->left)) { + backend::buffer result = pl1->evaluate(); + result.add_row(pr2->evaluate()); + return piecewise_2D(result, + pr2->get_num_columns(), + pr2->get_left(), + pr2->get_right()); + } else if (pr2.get() && pr2->is_col_match(this->left)) { + backend::buffer result = pl1->evaluate(); + result.add_col(pr2->evaluate()); + return piecewise_2D(result, + pr2->get_num_columns(), + pr2->get_left(), + pr2->get_right()); + } else if (pl2.get() && pl2->is_row_match(this->right)) { + backend::buffer result = pl2->evaluate(); + result.add_row(pr1->evaluate()); + return piecewise_2D(result, + pl2->get_num_columns(), + pl2->get_left(), + pl2->get_right()); + } else if (pl2.get() && pl2->is_col_match(this->right)) { + backend::buffer result = pl2->evaluate(); + result.add_col(pr1->evaluate()); + return piecewise_2D(result, + pl2->get_num_columns(), + pl2->get_left(), + pl2->get_right()); + } + // Idenity reductions. if (this->left->is_match(this->right)) { return two ()*this->left; @@ -199,14 +317,40 @@ namespace graph { auto rfma = fma_cast(this->right); if (lfma.get()) { // fma(c,d,e) + a -> fma(c,d,e + a) - return fma(lfma->get_left(), lfma->get_middle(), + return fma(lfma->get_left(), + lfma->get_middle(), lfma->get_right() + this->right); } else if (rfma.get()) { // a + fma(c,d,e) -> fma(c,d,a + e) - return fma(rfma->get_left(), rfma->get_middle(), + return fma(rfma->get_left(), + rfma->get_middle(), this->left + rfma->get_right()); } +// fma(b,a,d) + fma(c,a,e) -> fma(a,b + c, d + e) +// fma(a,b,d) + fma(c,a,e) -> fma(a,b + c, d + e) +// fma(b,a,d) + fma(a,c,e) -> fma(a,b + c, d + e) +// fma(a,b,d) + fma(a,c,e) -> fma(a,b + c, d + e) + if (lfma.get() && rfma.get()) { + if (lfma->get_middle()->is_match(rfma->get_middle())) { + return fma(lfma->get_middle(), + lfma->get_left() + rfma->get_left(), + lfma->get_right() + rfma->get_right()); + } else if (lfma->get_left()->is_match(rfma->get_middle())) { + return fma(lfma->get_left(), + lfma->get_middle() + rfma->get_left(), + lfma->get_right() + rfma->get_right()); + } else if (lfma->get_middle()->is_match(rfma->get_left())) { + return fma(lfma->get_middle(), + lfma->get_left() + rfma->get_middle(), + lfma->get_right() + rfma->get_right()); + } else if (lfma->get_left()->is_match(rfma->get_left())) { + return fma(lfma->get_left(), + lfma->get_middle() + rfma->get_middle(), + lfma->get_right() + rfma->get_right()); + } + } + // Handle cases like: // (a/y)^e + b/y^e -> (a^2 + b)/(y^e) // b/y^e + (a/y)^e -> (b + a^2)/(y^e) @@ -277,22 +421,28 @@ namespace graph { /// /// @params[in,out] stream String buffer stream. /// @params[in,out] registers List of defined registers. +/// @params[in] usage List of register usage count. /// @returns The current node. //------------------------------------------------------------------------------ virtual shared_leaf compile(std::ostringstream &stream, - jit::register_map ®isters) { + jit::register_map ®isters, + const jit::register_usage &usage) { if (registers.find(this) == registers.end()) { - shared_leaf l = this->left->compile(stream, registers); - shared_leaf r = this->right->compile(stream, registers); + shared_leaf l = this->left->compile(stream, + registers, + usage); + shared_leaf r = this->right->compile(stream, + registers, + usage); registers[this] = jit::to_string('r', this); stream << " const "; jit::add_type (stream); stream << " " << registers[this] << " = " << registers[l.get()] << " + " - << registers[r.get()] << ";" - << std::endl; + << registers[r.get()] << "; // used " + << usage.at(this) << std::endl; } return this->shared_from_this(); @@ -558,6 +708,37 @@ namespace graph { pr2->get_right()); } +// Combine 2D and 1D piecewise constants if a row or column matches. + if (pr2.get() && pr2->is_row_match(this->left)) { + backend::buffer result = pl1->evaluate(); + result.subtract_row(pr2->evaluate()); + return piecewise_2D(result, + pr2->get_num_columns(), + pr2->get_left(), + pr2->get_right()); + } else if (pr2.get() && pr2->is_col_match(this->left)) { + backend::buffer result = pl1->evaluate(); + result.subtract_col(pr2->evaluate()); + return piecewise_2D(result, + pr2->get_num_columns(), + pr2->get_left(), + pr2->get_right()); + } else if (pl2.get() && pl2->is_row_match(this->right)) { + backend::buffer result = pl2->evaluate(); + result.subtract_row(pr1->evaluate()); + return piecewise_2D(result, + pl2->get_num_columns(), + pl2->get_left(), + pl2->get_right()); + } else if (pl2.get() && pl2->is_col_match(this->right)) { + backend::buffer result = pl2->evaluate(); + result.subtract_col(pr1->evaluate()); + return piecewise_2D(result, + pl2->get_num_columns(), + pl2->get_left(), + pl2->get_right()); + } + // Common factor reduction. If the left and right are both muliply nodes check // for a common factor. So you can change a*b - a*c -> a*(b - c). auto lm = multiply_cast(this->left); @@ -568,16 +749,23 @@ namespace graph { // v1 - -c*v2 -> v1 + c*v2 if (rm.get()) { auto rmc = constant_cast(rm->get_left()); - if (rmc.get() && rmc->evaluate().is_none()) { + if (rmc.get() && rmc->is(-1)) { return this->left + rm->get_right(); } else if (rmc.get() && rmc->evaluate().is_negative()) { return this->left + (none ()*rm->get_left())*rm->get_right(); } } + if (lm.get()) { +// Assume constants are on the left. +// -a - b -> -(a + b) + auto lmc = constant_cast(lm->get_left()); + if (lmc.get() && lmc->is(-1)) { + return lm->get_left()*(lm->get_right() + this->right); + } + // a*v - v = (a - 1)*v // v*a - v = (a - 1)*v - if (lm.get()) { if (this->right->is_match(lm->get_right())) { return (lm->get_left() - one ())*this->right; } else if (this->right->is_match(lm->get_left())) { @@ -609,10 +797,11 @@ namespace graph { return lm->get_right()*(lm->get_left() - rm->get_left()); } -// Change cases like c1*a - c2*b -> c1*(a - c2*b) - auto lmc = constant_cast(lm->get_left()); - auto rmc = constant_cast(rm->get_left()); - if (lmc.get() && rmc.get()) { +// Change cases like c1*a - c2*b -> c1*(a - c2/c1*b) +// Note need to make sure c1 doesn't contain any zeros. + if (lm->get_left()->is_constant() && + rm->get_left()->is_constant() && + !lm->get_left()->has_constant_zero()) { return lm->get_left()*(lm->get_right() - (rm->get_left()/lm->get_left())*rm->get_right()); } @@ -667,16 +856,20 @@ namespace graph { auto rmrd = divide_cast(rm->get_right()); if (lmld.get() && rmld.get() && lmld->get_right()->is_match(rmld->get_right())) { - return (lmld->get_left()*lm->get_right() - rmld->get_left()*rm->get_right())/lmld->get_right(); + return (lmld->get_left()*lm->get_right() - + rmld->get_left()*rm->get_right())/lmld->get_right(); } else if (lmld.get() && rmrd.get() && lmld->get_right()->is_match(rmrd->get_right())) { - return (lmld->get_left()*lm->get_right() - rmrd->get_left()*rm->get_left())/lmld->get_right(); + return (lmld->get_left()*lm->get_right() - + rmrd->get_left()*rm->get_left())/lmld->get_right(); } else if (lmrd.get() && rmld.get() && lmrd->get_right()->is_match(rmld->get_right())) { - return (lmrd->get_left()*lm->get_left() - rmld->get_left()*rm->get_right())/lmrd->get_right(); + return (lmrd->get_left()*lm->get_left() - + rmld->get_left()*rm->get_right())/lmrd->get_right(); } else if (lmrd.get() && rmrd.get() && lmrd->get_right()->is_match(rmrd->get_right())) { - return (lmrd->get_left()*lm->get_left() - rmrd->get_left()*rm->get_left())/lmrd->get_right(); + return (lmrd->get_left()*lm->get_left() - + rmrd->get_left()*rm->get_left())/lmrd->get_right(); } } @@ -688,19 +881,23 @@ namespace graph { if (lrm->get_left()->is_match(rm->get_left())) { // (a - c*b) - c*d -> a - (b + d)*c return ls->get_left() - - (lrm->get_right() + rm->get_right())*rm->get_left(); + (lrm->get_right() + + rm->get_right())*rm->get_left(); } else if (lrm->get_left()->is_match(rm->get_right())) { // (a - c*b) - d*c -> a - (b + d)*c return ls->get_left() - - (lrm->get_right() + rm->get_left())*rm->get_right(); + (lrm->get_right() + + rm->get_left())*rm->get_right(); } else if (lrm->get_right()->is_match(rm->get_left())) { // (a - c*b) - c*d -> a - (b + d)*c return ls->get_left() - - (lrm->get_left() + rm->get_right())*rm->get_left(); + (lrm->get_left() + + rm->get_right())*rm->get_left(); } else if (lrm->get_right()->is_match(rm->get_right())) { // (a - c*b) - d*c -> a - (b + d)*c return ls->get_left() - - (lrm->get_left() + rm->get_left())*rm->get_right(); + (lrm->get_left() + + rm->get_left())*rm->get_right(); } } } @@ -780,6 +977,13 @@ namespace graph { } } +// fma(c,d,e) - a -> fma(c,d,e - a) + if (lfma.get() && !this->right->is_all_variables()) { + return fma(lfma->get_left(), + lfma->get_middle(), + lfma->get_right() - this->right); + } + // Reduce cases chained subtract multiply divide. if (ls.get()) { // (a - b*c) - d*e -> a - (b*c + d*e) @@ -820,22 +1024,28 @@ namespace graph { /// /// @params[in,out] stream String buffer stream. /// @params[in,out] registers List of defined registers. +/// @params[in] usage List of register usage count. /// @returns The current node. //------------------------------------------------------------------------------ virtual shared_leaf compile(std::ostringstream &stream, - jit::register_map ®isters) { + jit::register_map ®isters, + const jit::register_usage &usage) { if (registers.find(this) == registers.end()) { - shared_leaf l = this->left->compile(stream, registers); - shared_leaf r = this->right->compile(stream, registers); + shared_leaf l = this->left->compile(stream, + registers, + usage); + shared_leaf r = this->right->compile(stream, + registers, + usage); registers[this] = jit::to_string('r', this); stream << " const "; jit::add_type (stream); stream << " " << registers[this] << " = " << registers[l.get()] << " - " - << registers[r.get()] << ";" - << std::endl; + << registers[r.get()] << "; // used " + << usage.at(this) << std::endl; } return this->shared_from_this(); @@ -1092,20 +1302,44 @@ namespace graph { pr2->get_right()); } -// Move constants to the left. - if (r.get() && !l.get()) { - return this->right*this->left; +// Combine 2D and 1D piecewise constants if a row or column matches. + if (pr2.get() && pr2->is_row_match(this->left)) { + backend::buffer result = pl1->evaluate(); + result.multiply_row(pr2->evaluate()); + return piecewise_2D(result, + pr2->get_num_columns(), + pr2->get_left(), + pr2->get_right()); + } else if (pr2.get() && pr2->is_col_match(this->left)) { + backend::buffer result = pl1->evaluate(); + result.multiply_col(pr2->evaluate()); + return piecewise_2D(result, + pr2->get_num_columns(), + pr2->get_left(), + pr2->get_right()); + } else if (pl2.get() && pl2->is_row_match(this->right)) { + backend::buffer result = pl2->evaluate(); + result.multiply_row(pr1->evaluate()); + return piecewise_2D(result, + pl2->get_num_columns(), + pl2->get_left(), + pl2->get_right()); + } else if (pl2.get() && pl2->is_col_match(this->right)) { + backend::buffer result = pl2->evaluate(); + result.multiply_col(pr1->evaluate()); + return piecewise_2D(result, + pl2->get_num_columns(), + pl2->get_left(), + pl2->get_right()); } -// Move piecewise constants to the left. - if ((pr1.get() || pr2.get()) && - (!pl1.get() && !pl2.get() && !l.get())) { +// Move constants to the left. + if (is_constant_promotable(this->right, this->left)) { return this->right*this->left; } // Move constant like to the left. - if (this->right->is_constant_like() && - !this->left->is_constant_like()) { + if (is_constant_promotable(this->right, this->left)) { return this->right*this->left; } @@ -1113,7 +1347,7 @@ namespace graph { // Disable if the left is a constant like to avoid an infinite loop. if (this->left->is_power_like() && !this->right->is_power_like() && - !this->left->is_constant_like()) { + !this->left->is_constant()) { return this->right*this->left; } @@ -1156,7 +1390,8 @@ namespace graph { // Promote constants before variables. // (c*v1)*v2 -> c*(v1*v2) - if (lm->get_left()->is_constant_like()) { + if (is_constant_promotable(lm->get_left(), + lm->get_right())) { return lm->get_left()*(lm->get_right()*this->right); } @@ -1179,8 +1414,12 @@ namespace graph { if (rm.get()) { // Assume constants are on the left. // c1*(c2*v) -> c3*v - if (constant_cast(rm->get_left()).get() && l.get()) { - return (this->left*rm->get_left())*rm->get_right(); + if (is_constant_combineable(this->left, + rm->get_left())) { + auto temp = this->left*rm->get_left(); + if (temp->is_normal()) { + return temp*rm->get_right(); + } } if (this->left->is_match(rm->get_left())) { @@ -1191,7 +1430,8 @@ namespace graph { } // v1*(c*v2) -> c*(v1*v2) - if (rm.get() && constant_cast(rm->get_left()).get()) { + if (rm.get() && + is_constant_promotable(rm->get_left(), this->left)) { return rm->get_left()*(this->left*rm->get_right()); } @@ -1209,29 +1449,37 @@ namespace graph { } else if (rm.get() && (sin_cast(rm->get_right()).get() || cos_cast(rm->get_right()).get()) && - !this->left->is_constant_like()) { + !this->left->is_constant()) { return (this->left*rm->get_left())*rm->get_right(); } // Factor out common constants c*b*c*d -> c*c*b*d. c*c will get reduced to c on // the second pass. if (lm.get() && rm.get()) { - if (constant_cast(lm->get_left()).get() && - constant_cast(rm->get_left()).get()) { - return (lm->get_left()*rm->get_left()) * - (lm->get_right()*rm->get_right()); - } else if (constant_cast(lm->get_left()).get() && - constant_cast(rm->get_right()).get()) { - return (lm->get_left()*rm->get_right()) * - (lm->get_right()*rm->get_left()); - } else if (constant_cast(lm->get_right()).get() && - constant_cast(rm->get_left()).get()) { - return (lm->get_right()*rm->get_left()) * - (lm->get_left()*rm->get_right()); - } else if (constant_cast(lm->get_right()).get() && - constant_cast(rm->get_right()).get()) { - return (lm->get_right()*rm->get_right()) * - (lm->get_left()*rm->get_left()); + if (is_constant_combineable(lm->get_left(), + rm->get_left())) { + auto temp = lm->get_left()*rm->get_left(); + if (temp->is_normal()) { + return temp*(lm->get_right()*rm->get_right()); + } + } else if (is_constant_combineable(lm->get_left(), + rm->get_right())) { + auto temp = lm->get_left()*rm->get_right(); + if (temp->is_normal()) { + return temp*(lm->get_right()*rm->get_left()); + } + } else if (is_constant_combineable(lm->get_right(), + rm->get_left())) { + auto temp = lm->get_right()*rm->get_left(); + if (temp->is_normal()) { + return temp*(lm->get_left()*rm->get_right()); + } + } else if (is_constant_combineable(lm->get_right(), + rm->get_right())) { + auto temp = lm->get_right()*rm->get_right(); + if (temp->is_normal()) { + return temp*(lm->get_left()*rm->get_left()); + } } // Gather common terms. This will help reduce sqrt(a)*sqrt(a). @@ -1256,26 +1504,22 @@ namespace graph { if (ld.get()) { // (c/v1)*v2 -> c*(v2/v1) - if (constant_cast(ld->get_left()).get() || - piecewise_1D_cast(ld->get_left()).get() || - piecewise_2D_cast(ld->get_left()).get()) { + if (ld->get_left()->is_constant()) { return ld->get_left()*(this->right/ld->get_right()); } } // c1*(c2/v) -> c3/v -// c1*(v/c2) -> v/c3 - if (rd.get() && l.get()) { - if (constant_cast(rd->get_left()).get()) { - return (this->left*rd->get_left())/rd->get_right(); - } else if (constant_cast(rd->get_right()).get()) { - return rd->get_left()/(this->left*rd->get_right()); - } + if (rd.get() && this->left->is_constant() && + rd->get_left()->is_constant()) { + return (this->left*rd->get_left())/rd->get_right(); } +// (a/b)*(c/a) -> c/b +// (b/a)*(a/c) -> c/b if (ld.get() && rd.get()) { if (ld->get_left()->is_match(rd->get_right())) { - return ld->get_right()/rd->get_left(); + return rd->get_left()/ld->get_right(); } else if (ld->get_right()->is_match(rd->get_left())) { return ld->get_left()/rd->get_right(); } @@ -1543,14 +1787,20 @@ namespace graph { /// /// @params[in,out] stream String buffer stream. /// @params[in,out] registers List of defined registers. +/// @params[in] usage List of register usage count. /// @returns The current node. //------------------------------------------------------------------------------ virtual shared_leaf compile(std::ostringstream &stream, - jit::register_map ®isters) { + jit::register_map ®isters, + const jit::register_usage &usage) { if (registers.find(this) == registers.end()) { - shared_leaf l = this->left->compile(stream, registers); - shared_leaf r = this->right->compile(stream, registers); + shared_leaf l = this->left->compile(stream, + registers, + usage); + shared_leaf r = this->right->compile(stream, + registers, + usage); registers[this] = jit::to_string('r', this); stream << " const "; @@ -1581,8 +1831,8 @@ namespace graph { stream << " : "; } stream << registers[l.get()] << "*" - << registers[r.get()] << ";" - << std::endl; + << registers[r.get()] << "; // used " + << usage.at(this) << std::endl; } return this->shared_from_this(); @@ -1835,18 +2085,44 @@ namespace graph { pr2->get_right()); } - if (this->left->is_match(this->right)) { - if (l.get() && l->is(1)) { - return this->left; - } +// Combine 2D and 1D piecewise constants if a row or column matches. + if (pr2.get() && pr2->is_row_match(this->left)) { + backend::buffer result = pl1->evaluate(); + result.divide_row(pr2->evaluate()); + return piecewise_2D(result, + pr2->get_num_columns(), + pr2->get_left(), + pr2->get_right()); + } else if (pr2.get() && pr2->is_col_match(this->left)) { + backend::buffer result = pl1->evaluate(); + result.divide_col(pr2->evaluate()); + return piecewise_2D(result, + pr2->get_num_columns(), + pr2->get_left(), + pr2->get_right()); + } else if (pl2.get() && pl2->is_row_match(this->right)) { + backend::buffer result = pl2->evaluate(); + result.divide_row(pr1->evaluate()); + return piecewise_2D(result, + pl2->get_num_columns(), + pl2->get_left(), + pl2->get_right()); + } else if (pl2.get() && pl2->is_col_match(this->right)) { + backend::buffer result = pl2->evaluate(); + result.divide_col(pr1->evaluate()); + return piecewise_2D(result, + pl2->get_num_columns(), + pl2->get_left(), + pl2->get_right()); + } + if (this->left->is_match(this->right)) { return one (); } // Reduce cases of a/c1 -> c2*a - if (r.get()) { - return (one ()/this->right) * - this->left; + if (this->right->is_constant()) { + return (one ()/this->right)*this->left; } // fma(a,d,c*d)/d -> a + c @@ -1877,35 +2153,64 @@ namespace graph { auto lm = multiply_cast(this->left); auto rm = multiply_cast(this->right); -// Assume constants are always on the left. // c1/(c2*v) -> c3/v -// (c1*v)/c2 -> c3*v - if (rm.get() && l.get()) { - if (constant_cast(rm->get_left()).get()) { - return (this->left/rm->get_left())/rm->get_right(); +// c1/(c2*c3) -> c4/c3 + if (rm.get()) { + if (is_constant_combineable(rm->get_left(), + this->left)) { + auto temp = this->left/rm->get_left(); + if (temp->is_normal()) { + return temp/rm->get_right(); + } } - } else if (lm.get() && r.get()) { - if (constant_cast(lm->get_left()).get()) { - return (lm->get_left()/this->right)*lm->get_right(); + if (is_constant_combineable(rm->get_right(), + this->left)) { + auto temp = this->left/rm->get_right(); + if (temp->is_normal()) { + return temp/rm->get_left(); + } } } if (lm.get() && rm.get()) { // Test for constants that can be reduced out. - if (constant_cast(lm->get_left()).get() && - constant_cast(rm->get_left()).get()) { - return (lm->get_left()/rm->get_left())*(lm->get_right()/rm->get_right()); - } else if (constant_cast(lm->get_left()).get() && - constant_cast(rm->get_right()).get()) { - return (lm->get_left()/rm->get_right())*(lm->get_right()/rm->get_left()); - } else if (constant_cast(lm->get_right()).get() && - constant_cast(rm->get_left()).get()) { - return (lm->get_right()/rm->get_left())*(lm->get_left()/rm->get_right()); - } else if (constant_cast(lm->get_right()).get() && - constant_cast(rm->get_right()).get()) { - return (lm->get_right()/rm->get_right())*(lm->get_left()/rm->get_left()); - } - +// (c1*a)/(c2*b) -> c3*a/b +// (a*c1)/(c2*b) -> c3*a/b +// (c1*a)/(b*c2) -> c3*a/b +// (a*c1)/(b*c2) -> c3*a/b + if (is_constant_combineable(lm->get_left(), + rm->get_left())) { + auto temp = lm->get_left()/rm->get_left(); + if (temp->is_normal()) { + return temp*lm->get_right()/rm->get_right(); + } + } + if (is_constant_combineable(lm->get_left(), + rm->get_right())) { + auto temp = lm->get_left()/rm->get_right(); + if (temp->is_normal()) { + return temp*lm->get_right()/rm->get_left(); + } + } + if (is_constant_combineable(lm->get_right(), + rm->get_left())) { + auto temp = lm->get_right()/rm->get_left(); + if (temp->is_normal()) { + return temp*lm->get_left()/rm->get_right(); + } + } + if (is_constant_combineable(lm->get_right(), + rm->get_right())) { + auto temp = lm->get_right()/rm->get_right(); + if (temp->is_normal()) { + return temp*lm->get_left()/rm->get_left(); + } + } + +// (a*b)/(a*c) -> b/c +// (b*a)/(a*c) -> b/c +// (a*b)/(c*a) -> b/c +// (b*a)/(c*a) -> b/c if (lm->get_left()->is_match(rm->get_left())) { return lm->get_right()/rm->get_right(); } else if (lm->get_left()->is_match(rm->get_right())) { @@ -1950,11 +2255,8 @@ namespace graph { } // (c*v1)/v2 -> c*(v1/v2) - if (lm.get() && constant_cast(lm->get_left()).get()) { - return lm->get_left()*(lm->get_right()/this->right); - } - - if (lm.get() && lm->get_left()->is_constant_like()) { + if (lm.get() && lm->get_left()->is_constant() && + !lm->get_right()->is_constant()) { return lm->get_left()*(lm->get_right()/this->right); } @@ -2093,14 +2395,20 @@ namespace graph { /// /// @params[in,out] stream String buffer stream. /// @params[in,out] registers List of defined registers. +/// @params[in] usage List of register usage count. /// @returns The current node. //------------------------------------------------------------------------------ virtual shared_leaf compile(std::ostringstream &stream, - jit::register_map ®isters) { + jit::register_map ®isters, + const jit::register_usage &usage) { if (registers.find(this) == registers.end()) { - shared_leaf l = this->left->compile(stream, registers); - shared_leaf r = this->right->compile(stream, registers); + shared_leaf l = this->left->compile(stream, + registers, + usage); + shared_leaf r = this->right->compile(stream, + registers, + usage); registers[this] = jit::to_string('r', this); stream << " const "; @@ -2124,8 +2432,8 @@ namespace graph { stream << " : "; } stream << registers[l.get()] << "/" - << registers[r.get()] << ";" - << std::endl; + << registers[r.get()] << "; // used " + << usage.at(this) << std::endl; } return this->shared_from_this(); } @@ -2344,22 +2652,35 @@ namespace graph { return constant (this->evaluate()); } else if (l.get() && m.get()) { return this->left*this->middle + this->right; - } else if (l.get() && l->evaluate().is_none()) { + } else if (l.get() && l->is(-1)) { return this->right - this->middle; - } else if (m.get() && m->evaluate().is_none()) { + } else if (m.get() && m->is(-1)) { return this->right - this->left; + } else if (l.get() && l->is(1)) { + return this->middle + this->right; + } else if (m.get() && m->is(1)) { + return this->left + this->right; } - auto pl1 = piecewise_1D_cast(this->left); - auto pm1 = piecewise_1D_cast(this->middle); - auto pl2 = piecewise_2D_cast(this->left); - auto pm2 = piecewise_2D_cast(this->middle); +// Check if the left and middle are combinable. This will be constant merged in +// multiply reduction. + if (is_constant_combineable(this->left, this->middle) || + is_variable_combinable(this->left, this->middle)) { + return (this->left*this->middle) + this->right; + } + +// fma(c2,c1,a) -> fma(c1,c2,a) + if (is_constant_promotable(this->middle, + this->left)) { + return fma(this->middle, this->left, this->right); + } - if ((pl1.get() && (m.get() || pl1->is_arg_match(this->middle))) || - (pm1.get() && (l.get() || pm1->is_arg_match(this->left))) || - (pl2.get() && (m.get() || pl2->is_arg_match(this->middle))) || - (pm2.get() && (l.get() || pm2->is_arg_match(this->left)))) { - return (this->left*this->middle) + this->right; +// fma(a,b,a) -> a*(1 + b) +// fma(b,a,a) -> a*(1 + b) + if (this->left->is_match(this->right)) { + return this->left*(one () + this->middle); + } else if (this->middle->is_match(this->right)) { + return this->middle*(one () + this->left); } // Common factor reduction. If the left and right are both multiply nodes check @@ -2378,12 +2699,50 @@ namespace graph { return this->middle*(this->left + rm->get_left()); } -// Change cases like c1*a + c2*b -> c1*(c3*b + a) - auto rmc = constant_cast(rm->get_left()); - if (rmc.get() && l.get()) { - return this->left*fma(rm->get_left()/this->left, - rm->get_right(), - this->middle); +// Change cases like +// fma(c1,a,c2*b) -> c1*fma(c3,b,a) +// fma(a,c1,c2*b) -> c1*fma(c3,b,a) +// fma(c1,a,b*c2) -> c1*fma(c3,b,a) +// fma(a,c1,b*c2) -> c1*fma(c3,b,a) + if (is_constant_combineable(this->left, + rm->get_left()) && + !this->left->has_constant_zero()) { + auto temp = rm->get_left()/this->left; + if (temp->is_normal()) { + return this->left*fma(temp, + rm->get_right(), + this->middle); + } + } + if (is_constant_combineable(this->middle, + rm->get_left()) && + !this->middle->has_constant_zero()) { + auto temp = rm->get_left()/this->middle; + if (temp->is_normal()) { + return this->middle*fma(temp, + rm->get_right(), + this->left); + } + } + if (is_constant_combineable(this->left, + rm->get_right()) && + !this->left->has_constant_zero()) { + auto temp = rm->get_right()/this->left; + if (temp->is_normal()) { + return this->left*fma(temp, + rm->get_left(), + this->middle); + } + } + if (is_constant_combineable(this->middle, + rm->get_right()) && + !this->middle->has_constant_zero()) { + auto temp = rm->get_right()/this->middle; + if (temp->is_normal()) { + return this->middle*fma(temp, + rm->get_left(), + this->left); + } } // Convert fma(a*b,c,d*e) -> fma(d,e,a*b*c) @@ -2398,48 +2757,112 @@ namespace graph { // Handle cases like. // fma(c1*a,b,c2*d) -> c1*(a*b + c2/c1*d) +// fma(a*c1,b,c2*d) -> c1*(a*b + c2/c1*d) +// fma(c1*a,b,d*c2*d) -> c1*(a*b + c2/c1*d) +// fma(a*c1,b,d*c2*d) -> c1*(a*b + c2/c1*d) if (lm.get() && rm.get()) { - auto rmc = constant_cast(rm->get_left()); - if (rmc.get()) { - return lm->get_left()*fma(lm->get_right(), - this->middle, - (rm->get_left()/lm->get_left())*rm->get_right()); + if (is_constant_combineable(rm->get_left(), + lm->get_left()) && + !lm->get_left()->has_constant_zero()) { + auto temp = rm->get_left()/lm->get_left(); + if (temp->is_normal()){ + return lm->get_left()*fma(lm->get_right(), + this->middle, + temp*rm->get_right()); + } + } + if (is_constant_combineable(rm->get_left(), + lm->get_right()) && + !lm->get_right()->has_constant_zero()) { + auto temp = rm->get_left()/lm->get_right(); + if (temp->is_normal()){ + return lm->get_right()*fma(lm->get_left(), + this->middle, + temp*rm->get_right()); + } + } + if (is_constant_combineable(rm->get_right(), + lm->get_left()) && + !lm->get_left()->has_constant_zero()) { + auto temp = rm->get_right()/lm->get_left(); + if (temp->is_normal()) { + return lm->get_left()*fma(lm->get_right(), + this->middle, + temp*rm->get_left()); + } + } + if (is_constant_combineable(rm->get_right(), + lm->get_right()) && + !lm->get_right()->has_constant_zero()) { + auto temp = rm->get_right()/lm->get_right(); + if (temp->is_normal()) { + return lm->get_right()*fma(lm->get_left(), + this->middle, + temp*rm->get_left()); + } } } // Move constant multiplies to the left. if (lm.get()) { - auto lmc = constant_cast(lm->get_left()); - if (lmc.get()) { +// fma(c1*a,b,c) -> fma(c1,a*b,c) + if (is_constant_promotable(lm->get_left(), + lm->get_right())) { return fma(lm->get_left(), lm->get_right()*this->middle, this->right); } } else if (mm.get()) { - auto mmc = constant_cast(mm->get_left()); - auto mmpw1c = piecewise_1D_cast(mm->get_left()); - auto mmpw2c = piecewise_2D_cast(mm->get_left()); - if (mmc.get() || mmpw1c.get() || mmpw2c.get()) { - if (l.get() || pl1.get() || pl2.get()) { - return fma(this->left*mm->get_left(), +// fma(c1,c2*a,b) -> fma(c3,a,b) +// fma(c1,a*c2,b) -> fma(c3,a,b) +// fma(a,c1*b,c) -> fma(c1,a*b,c) + if (is_constant_combineable(this->left, + mm->get_left())) { + auto temp = this->left*mm->get_left(); + if (temp->is_normal()) { + return fma(temp, mm->get_right(), this->right); - } else { - return fma(mm->get_left(), - this->left*mm->get_right(), + } + } + if (is_constant_combineable(this->left, + mm->get_right())) { + auto temp = this->left*mm->get_right(); + if (temp->is_normal()) { + return fma(temp, + mm->get_left(), this->right); } } + if (is_constant_promotable(mm->get_left(), + this->left)) { + return fma(mm->get_left(), + this->left*mm->get_right(), + this->right); + } } -// fma(c1,a,c2/b) -> c1*(a + c1/(c2*b)) -// fma(c1,a,b/c2) -> c1*(a + b/(c1*c2)) +// fma(c1,a,c2/b) -> c1*(a + c3/b) +// fma(a,c1,c2/b) -> c1*(a + c3/b) auto rd = divide_cast(this->right); - if (l.get() && rd.get()) { - if (constant_cast(rd->get_left()).get() || - constant_cast(rd->get_right()).get()) { - return this->left*(this->middle + - rd->get_left()/(this->left*rd->get_right())); + if (rd.get()) { + if (is_constant_combineable(this->left, + rd->get_left()) && + !this->left->has_constant_zero()) { + auto temp = rd->get_left()/this->left; + if (temp->is_normal()) { + return this->left*(this->middle + + temp/rd->get_right()); + } + } + if (is_constant_combineable(this->middle, + rd->get_left()) && + !this->middle->has_constant_zero()) { + auto temp = rd->get_left()/this->middle; + if (temp->is_normal()) { + return this->middle*(this->left + + temp/rd->get_right()); + } } } @@ -2476,6 +2899,342 @@ namespace graph { // Chained fma reductions. auto rfma = fma_cast(this->right); if (rfma.get()) { +// fma(a, b, fma(c, b, d)) -> fma(b, a + c, d) +// fma(b, a, fma(c, b, d)) -> fma(b, a + c, d) +// fma(a, b, fma(b, c, d)) -> fma(b, a + c, d) +// fma(b, a, fma(b, c, d)) -> fma(b, a + c, d) + if (this->middle->is_match(rfma->get_middle())) { + return fma(this->middle, + this->left + rfma->get_left(), + rfma->get_right()); + } else if (this->left->is_match(rfma->get_middle())) { + return fma(this->left, + this->middle + rfma->get_left(), + rfma->get_right()); + } else if (this->middle->is_match(rfma->get_left())) { + return fma(this->middle, + this->left + rfma->get_middle(), + rfma->get_right()); + } else if (this->left->is_match(rfma->get_left())) { + return fma(this->left, + this->middle + rfma->get_middle(), + rfma->get_right()); + } + + if (mm.get()) { +// fma(a, e*b, fma(c, b, d)) -> fma(b, fma(a, e, c), d) +// fma(a, b*e, fma(c, b, d)) -> fma(b, fma(a, e, c), d) +// fma(a, e*b, fma(b, c, d)) -> fma(b, fma(a, e, c), d) +// fma(a, b*e, fma(b, c, d)) -> fma(b, fma(a, e, c), d) + if (mm->get_right()->is_match(rfma->get_middle())) { + return fma(mm->get_right(), + fma(this->left, + mm->get_left(), + rfma->get_left()), + rfma->get_right()); + } else if (mm->get_left()->is_match(rfma->get_middle())) { + return fma(mm->get_left(), + fma(this->left, + mm->get_right(), + rfma->get_left()), + rfma->get_right()); + } else if (mm->get_right()->is_match(rfma->get_left())) { + return fma(mm->get_right(), + fma(this->left, + mm->get_left(), + rfma->get_middle()), + rfma->get_right()); + } else if (mm->get_left()->is_match(rfma->get_left())) { + return fma(mm->get_left(), + fma(this->left, + mm->get_right(), + rfma->get_middle()), + rfma->get_right()); + } + } else if (lm.get()) { +// fma(e*b, a, fma(c, b, d)) -> fma(b, fma(a, e, c), d) +// fma(b*e, a, fma(c, b, d)) -> fma(b, fma(a, e, c), d) +// fma(e*b, a, fma(b, c, d)) -> fma(b, fma(a, e, c), d) +// fma(e*d, a, fma(b, c, d)) -> fma(b, fma(a, e, c), d) + if (lm->get_right()->is_match(rfma->get_middle())) { + return fma(lm->get_right(), + fma(this->middle, + lm->get_left(), + rfma->get_left()), + rfma->get_right()); + } else if (lm->get_left()->is_match(rfma->get_middle())) { + return fma(lm->get_left(), + fma(this->middle, + lm->get_right(), + rfma->get_left()), + rfma->get_right()); + } else if (lm->get_right()->is_match(rfma->get_left())) { + return fma(lm->get_right(), + fma(this->middle, + lm->get_left(), + rfma->get_middle()), + rfma->get_right()); + } else if (lm->get_left()->is_match(rfma->get_left())) { + return fma(lm->get_left(), + fma(this->middle, + lm->get_right(), + rfma->get_middle()), + rfma->get_right()); + } + } + + auto rfmamm = multiply_cast(rfma->get_middle()); + auto rfmalm = multiply_cast(rfma->get_left()); + if (rfmamm.get()) { +// fma(a, b, fma(c, e*b, d)) -> fma(b, fma(c, e, a), d) +// fma(b, a, fma(c, e*b, d)) -> fma(b, fma(c, e, a), d) +// fma(a, b, fma(c, b*e, d)) -> fma(b, fma(c, e, a), d) +// fma(b, a, fma(c, b*e, d)) -> fma(b, fma(c, e, a), d) + if (rfmamm->get_right()->is_match(this->middle)) { + return fma(this->middle, + fma(rfma->get_left(), + rfmamm->get_left(), + this->left), + rfma->get_right()); + } else if (rfmamm->get_right()->is_match(this->left)) { + return fma(this->left, + fma(rfma->get_left(), + rfmamm->get_left(), + this->middle), + rfma->get_right()); + } else if (rfmamm->get_left()->is_match(this->middle)) { + return fma(this->middle, + fma(rfma->get_left(), + rfmamm->get_right(), + this->left), + rfma->get_right()); + } else if (rfmamm->get_left()->is_match(this->left)) { + return fma(this->left, + fma(rfma->get_left(), + rfmamm->get_right(), + this->middle), + rfma->get_right()); + } + } else if (rfmalm.get()) { +// fma(a, b, fma(e*b, c, d)) -> fma(b, fma(c, e, a), d) +// fma(b, a, fma(e*b, c, d)) -> fma(b, fma(c, e, a), d) +// fma(a, b, fma(b*e, c, d)) -> fma(b, fma(c, e, a), d) +// fma(b, a, fma(b*e, c, d)) -> fma(b, fma(c, e, a), d) + if (rfmalm->get_right()->is_match(this->middle)) { + return fma(this->middle, + fma(rfma->get_middle(), + rfmalm->get_left(), + this->left), + rfma->get_right()); + } else if (rfmalm->get_right()->is_match(this->left)) { + return fma(this->left, + fma(rfma->get_middle(), + rfmalm->get_left(), + this->middle), + rfma->get_right()); + } else if (rfmalm->get_left()->is_match(this->middle)) { + return fma(this->middle, + fma(rfma->get_middle(), + rfmalm->get_right(), + this->left), + rfma->get_right()); + } else if (rfmalm->get_left()->is_match(this->left)) { + return fma(this->left, + fma(rfma->get_middle(), + rfmalm->get_right(), + this->middle), + rfma->get_right()); + } + } + + if (mm.get() && rfmamm.get()) { +// fma(a, f*b, fma(c, e*b, d)) -> fma(b, fma(a, f, c*e), d) +// fma(a, b*f, fma(c, e*b, d)) -> fma(b, fma(a, f, c*e), d) +// fma(a, f*b, fma(c, b*e, d)) -> fma(b, fma(a, f, c*e), d) +// fma(a, b*f, fma(c, b*e, d)) -> fma(b, fma(a, f, c*e), d) + if (mm->get_right()->is_match(rfmamm->get_right())) { + return fma(mm->get_right(), + fma(this->left, + mm->get_left(), + rfma->get_left()*rfmamm->get_left()), + rfma->get_right()); + } else if (mm->get_left()->is_match(rfmamm->get_right())) { + return fma(mm->get_left(), + fma(this->left, + mm->get_right(), + rfma->get_left()*rfmamm->get_left()), + rfma->get_right()); + } else if (mm->get_right()->is_match(rfmamm->get_left())) { + return fma(mm->get_right(), + fma(this->left, + mm->get_left(), + rfma->get_left()*rfmamm->get_right()), + rfma->get_right()); + } else if (mm->get_left()->is_match(rfmamm->get_left())) { + return fma(mm->get_left(), + fma(this->left, + mm->get_right(), + rfma->get_left()*rfmamm->get_right()), + rfma->get_right()); + } + } else if (lm.get() && rfmamm.get()) { +// fma(f*b, a, fma(c, e*b, d)) -> fma(b, fma(a, f, c*e), d) +// fma(b*f, a, fma(c, e*b, d)) -> fma(b, fma(a, f, c*e), d) +// fma(f*b, a, fma(c, b*e, d)) -> fma(b, fma(a, f, c*e), d) +// fma(b*f, a, fma(c, b*e, d)) -> fma(b, fma(a, f, c*e), d) + if (lm->get_right()->is_match(rfmamm->get_right())) { + return fma(lm->get_right(), + fma(this->middle, + lm->get_left(), + rfma->get_left()*rfmamm->get_left()), + rfma->get_right()); + } else if (lm->get_left()->is_match(rfmamm->get_right())) { + return fma(lm->get_left(), + fma(this->middle, + lm->get_right(), + rfma->get_left()*rfmamm->get_left()), + rfma->get_right()); + } else if (lm->get_right()->is_match(rfmamm->get_left())) { + return fma(lm->get_right(), + fma(this->middle, + lm->get_left(), + rfma->get_left()*rfmamm->get_right()), + rfma->get_right()); + } else if (lm->get_left()->is_match(rfmamm->get_left())) { + return fma(lm->get_left(), + fma(this->middle, + lm->get_right(), + rfma->get_left()*rfmamm->get_right()), + rfma->get_right()); + } + } else if (mm.get() && rfmalm.get()) { +// fma(a, f*b, fma(e*b, c, d)) -> fma(b, fma(a, f, c*e), d) +// fma(a, b*f, fma(e*b, c, d)) -> fma(b, fma(a, f, c*e), d) +// fma(a, f*b, fma(b*e, c, d)) -> fma(b, fma(a, f, c*e), d) +// fma(a, b*f, fma(b*e, c, d)) -> fma(b, fma(a, f, c*e), d) + if (mm->get_right()->is_match(rfmalm->get_right())) { + return fma(mm->get_right(), + fma(this->left, + mm->get_left(), + rfma->get_middle()*rfmalm->get_left()), + rfma->get_right()); + } else if (mm->get_left()->is_match(rfmalm->get_right())) { + return fma(mm->get_left(), + fma(this->left, + mm->get_right(), + rfma->get_middle()*rfmalm->get_left()), + rfma->get_right()); + } else if (mm->get_right()->is_match(rfmalm->get_left())) { + return fma(mm->get_right(), + fma(this->left, + mm->get_left(), + rfma->get_middle()*rfmalm->get_right()), + rfma->get_right()); + } else if (mm->get_left()->is_match(rfmalm->get_left())) { + return fma(mm->get_left(), + fma(this->left, + mm->get_right(), + rfma->get_middle()*rfmalm->get_right()), + rfma->get_right()); + } + } else if (lm.get() && rfmalm.get()) { +// fma(f*b, a, fma(e*b, c, d)) -> fma(b, fma(a, f, c*e), d) +// fma(b*f, a, fma(e*b, c, d)) -> fma(b, fma(a, f, c*e), d) +// fma(f*b, a, fma(b*e, c, d)) -> fma(b, fma(a, f, c*e), d) +// fma(b*f, a, fma(b*e, c, d)) -> fma(b, fma(a, f, c*e), d) + if (lm->get_right()->is_match(rfmalm->get_right())) { + return fma(lm->get_right(), + fma(this->middle, + lm->get_left(), + rfma->get_middle()*rfmalm->get_left()), + rfma->get_right()); + } else if (lm->get_left()->is_match(rfmalm->get_right())) { + return fma(lm->get_left(), + fma(this->middle, + lm->get_right(), + rfma->get_middle()*rfmalm->get_left()), + rfma->get_right()); + } else if (lm->get_right()->is_match(rfmalm->get_left())) { + return fma(lm->get_right(), + fma(this->middle, + lm->get_left(), + rfma->get_middle()*rfmalm->get_right()), + rfma->get_right()); + } else if (lm->get_left()->is_match(rfmalm->get_left())) { + return fma(lm->get_left(), + fma(this->middle, + lm->get_right(), + rfma->get_middle()*rfmalm->get_right()), + rfma->get_right()); + } + } + + if (is_variable_combinable(this->middle, rfma->get_middle())) { + if (is_greater_exponent(this->middle, rfma->get_middle())) { +// fma(a,x^b,fma(c,x^d,e)) -> fma(x^d,fma(x^(d-b),a,c),e) if b > d + return fma(rfma->get_middle(), + fma(this->middle/rfma->get_middle(), + this->left, + rfma->get_left()), + rfma->get_right()); + } else { +// fma(a,x^b,fma(c,x^d,e)) -> fma(x^b,fma(x^(d-b),c,a),e) if d > b + return fma(this->middle, + fma(rfma->get_middle()/this->middle, + rfma->get_left(), + this->left), + rfma->get_right()); + } + } else if (is_variable_combinable(this->left, rfma->get_middle())) { + if (is_greater_exponent(this->left, rfma->get_middle())) { +// fma(x^b,a,fma(c,x^d,e)) -> fma(x^d,fma(x^(d-b),a,c),e) if b > d + return fma(rfma->get_middle(), + fma(this->left/rfma->get_middle(), + this->middle, + rfma->get_left()), + rfma->get_right()); + } else { +// fma(x^b,a,fma(c,x^d,e)) -> fma(x^b,fma(x^(d-b),c,a),e) if d > b + return fma(this->left, + fma(rfma->get_middle()/this->left, + rfma->get_left(), + this->middle), + rfma->get_right()); + } + } else if (is_variable_combinable(this->middle, rfma->get_left())) { + if (is_greater_exponent(this->middle, rfma->get_left())) { +// fma(a,x^b,fma(x^d,c,e)) -> fma(x^d,fma(x^(d-b),a,c),e) if b > d + return fma(rfma->get_left(), + fma(this->middle/rfma->get_left(), + this->left, + rfma->get_middle()), + rfma->get_right()); + } else { +// fma(a,x^b,fma(x^d,c,e)) -> fma(x^b,fma(x^(d-b),c,a),e) if d > b + return fma(this->middle, + fma(rfma->get_left()/this->middle, + rfma->get_middle(), + this->left), + rfma->get_right()); + } + } else if (is_variable_combinable(this->left, rfma->get_left())) { + if (is_greater_exponent(this->left, rfma->get_left())) { +// fma(x^b,a,fma(x^d,c,e)) -> fma(x^d,fma(x^(d-b),a,c),e) if b > d + return fma(rfma->get_left(), + fma(this->left/rfma->get_left(), + this->middle, + rfma->get_middle()), + rfma->get_right()); + } else { +// fma(x^b,a,fma(x^d,c,e)) -> fma(x^b,fma(x^(d-b),c,a),e) if d > b + return fma(this->left, + fma(rfma->get_left()/this->left, + rfma->get_middle(), + this->middle), + rfma->get_right()); + } + } + // fma(a,b,fma(a,b,c)) -> fma(2*a,b,c) // fma(a,b,fma(b,a,c)) -> fma(2*a,b,c) if (this->left->is_match(rfma->get_left()) && @@ -2518,7 +3277,7 @@ namespace graph { } // Check to see if it is worth moving nodes out of a fma nodes. These should be -// restricted to variable like nodes. Only do this reduction is the complexity +// restricted to variable like nodes. Only do this reduction if the complexity // reduces. if (this->left->is_all_variables()) { auto rdl = this->right/this->left; @@ -2535,12 +3294,15 @@ namespace graph { } // Promote constants out to the left. - if (l.get() && r.get()) { - return this->left*(this->middle + this->right/this->left); + if (is_constant_combineable(this->left, this->right) && + !this->left->has_constant_zero()) { + auto temp = this->right/this->left; + if (temp->is_normal()) { + return this->left*(this->middle + temp); + } } - -// Change negative eponents to divide so that can be factored out. +// Change negative exponents to divide so that can be factored out. // fma(a,b^-c,d) = a/b^c + d // fma(b^-c,a,d) = a/b^c + d auto lp = pow_cast(this->left); @@ -2920,15 +3682,23 @@ namespace graph { /// /// @params[in,out] stream String buffer stream. /// @params[in,out] registers List of defined registers. +/// @params[in] usage List of register usage count. /// @returns The current node. //------------------------------------------------------------------------------ virtual shared_leaf compile(std::ostringstream &stream, - jit::register_map ®isters) { + jit::register_map ®isters, + const jit::register_usage &usage) { if (registers.find(this) == registers.end()) { - shared_leaf l = this->left->compile(stream, registers); - shared_leaf m = this->middle->compile(stream, registers); - shared_leaf r = this->right->compile(stream, registers); + shared_leaf l = this->left->compile(stream, + registers, + usage); + shared_leaf m = this->middle->compile(stream, + registers, + usage); + shared_leaf r = this->right->compile(stream, + registers, + usage); registers[this] = jit::to_string('r', this); stream << " const "; @@ -2954,15 +3724,14 @@ namespace graph { if constexpr (jit::is_complex ()) { stream << registers[l.get()] << "*" << registers[m.get()] << " + " - << registers[r.get()] << ";" - << std::endl; + << registers[r.get()] << ";"; } else { stream << "fma(" << registers[l.get()] << ", " << registers[m.get()] << ", " - << registers[r.get()] << ");" - << std::endl; + << registers[r.get()] << ");"; } + stream << " // used " << usage.at(this) << std::endl; } return this->shared_from_this(); diff --git a/graph_framework/backend.hpp b/graph_framework/backend.hpp index e90dcda9c441fa2aa96cbce0ca4251cadab4daa9..254f73a8659573743327ce41ba304b8d0185b825 100644 --- a/graph_framework/backend.hpp +++ b/graph_framework/backend.hpp @@ -10,6 +10,7 @@ #include #include +#include #include "special_functions.hpp" #include "register.hpp" @@ -153,7 +154,7 @@ namespace backend { /// @returns Returns true if every element is zero. //------------------------------------------------------------------------------ bool is_zero() const { - for (T d : memory) { + for (const T &d : memory) { if (d != static_cast (0.0)) { return false; } @@ -161,14 +162,29 @@ namespace backend { return true; } + +//------------------------------------------------------------------------------ +/// @brief Is every element zero. +/// +/// @returns Returns true if every element is zero. +//------------------------------------------------------------------------------ + bool has_zero() const { + for (const T &d : memory) { + if (d == static_cast (0.0)) { + return true; + } + } + return false; + } + //------------------------------------------------------------------------------ /// @brief Is every element negative. /// /// @returns Returns true if every element is negative. //------------------------------------------------------------------------------ bool is_negative() const { - for (T d : memory) { + for (const T &d : memory) { if (std::real(d) > std::real(static_cast (0.0))) { return false; } @@ -183,7 +199,7 @@ namespace backend { /// @returns Returns true if every element is negative one. //------------------------------------------------------------------------------ bool is_none() const { - for (T d : memory) { + for (const T &d : memory) { if (d != static_cast (-1.0)) { return false; } @@ -256,6 +272,475 @@ namespace backend { return memory.data(); } +//------------------------------------------------------------------------------ +/// @brief Check for normal values. +/// +/// @returns False if any NaN or Inf is found. +//------------------------------------------------------------------------------ + bool is_normal() const { + for (const T &x : memory) { + if constexpr (jit::is_complex ()) { + if (std::isnan(std::real(x)) || std::isinf(std::real(x)) || + std::isnan(std::imag(x)) || std::isinf(std::imag(x))) { + return false; + } + } else { + if (std::isnan(x) || std::isinf(x)) { + return false; + } + } + } + return true; + } + +//------------------------------------------------------------------------------ +/// @brief Add row operation. +/// +/// Adds m_ij + v_i or v_i + m_ij. This will resize the buffer if it needs to +/// be. +/// +/// @params[in] x The right operand. +//------------------------------------------------------------------------------ + void add_row(const buffer &x) { + if (size() > x.size()) { + assert(size()%x.size() == 0 && + "Vector operand size is not a multiple of matrix operand size"); + + const size_t num_colmns = size()/x.size(); + const size_t num_rows = x.size(); + for (size_t i = 0; i < num_rows; i++) { + for (size_t j = 0; j < num_colmns; j++) { + memory[i*num_rows + j] += x[i]; + } + } + } else { + assert(x.size()%size() == 0 && + "Vector operand size is not a multiple of matrix operand size"); + + std::vector m(x.size()); + const size_t num_colmns = x.size()/size(); + const size_t num_rows = size(); + for (size_t i = 0; i < num_rows; i++) { + for (size_t j = 0; j < num_colmns; j++) { + m[i*num_colmns + j] = memory[i] + x[i*num_colmns + j]; + } + } + memory = m; + } + } + +//------------------------------------------------------------------------------ +/// @brief Add col operation. +/// +/// Adds m_ij + v_j or v_j + m_ij. This will resize the buffer if it needs to +/// be. +/// +/// @params[in] x The other operand. +//------------------------------------------------------------------------------ + void add_col(const buffer &x) { + if (size() > x.size()) { + assert(size()%x.size() == 0 && + "Vector operand size is not a multiple of matrix operand size"); + + const size_t num_colmns = size()/x.size(); + const size_t num_rows = x.size(); + for (size_t i = 0; i < num_rows; i++) { + for (size_t j = 0; j < num_colmns; j++) { + memory[i*num_colmns + j] += x[j]; + } + } + } else { + assert(x.size()%size() == 0 && + "Vector operand size is not a multiple of matrix operand size"); + + std::vector m(x.size()); + const size_t num_colmns = x.size()/size(); + const size_t num_rows = size(); + for (size_t i = 0; i < num_rows; i++) { + for (size_t j = 0; j < num_colmns; j++) { + m[i*num_colmns + j] = memory[j] + x[i*num_colmns + j]; + } + } + memory = m; + } + } + +//------------------------------------------------------------------------------ +/// @brief Subtract row operation. +/// +/// Sunbtracts m_ij - v_i or v_i - m_ij. This will resize the buffer if it +/// needs to be. +/// +/// @params[in] x The right operand. +//------------------------------------------------------------------------------ + void subtract_row(const buffer &x) { + if (size() > x.size()) { + assert(size()%x.size() == 0 && + "Vector operand size is not a multiple of matrix operand size"); + + const size_t num_colmns = size()/x.size(); + const size_t num_rows = x.size(); + for (size_t i = 0; i < num_rows; i++) { + for (size_t j = 0; j < num_colmns; j++) { + memory[i*num_colmns + j] -= x[i]; + } + } + } else { + assert(x.size()%size() == 0 && + "Vector operand size is not a multiple of matrix operand size"); + + std::vector m(x.size()); + const size_t num_colmns = x.size()/size(); + const size_t num_rows = size(); + for (size_t i = 0; i < num_colmns; i++) { + for (size_t j = 0; j < num_rows; j++) { + m[i*num_colmns + j] = memory[i] - x[i*num_colmns + j]; + } + } + memory = m; + } + } + +//------------------------------------------------------------------------------ +/// @brief Subtract col operation. +/// +/// Sunbtracts m_ij - v_j or v_j - m_ij. This will resize the buffer if it +/// needs to be. +/// +/// @params[in] x The other operand. +//------------------------------------------------------------------------------ + void subtract_col(const buffer &x) { + if (size() > x.size()) { + assert(size()%x.size() == 0 && + "Vector operand size is not a multiple of matrix operand size"); + + const size_t num_colmns = size()/x.size(); + const size_t num_rows = x.size(); + for (size_t i = 0; i < num_rows; i++) { + for (size_t j = 0; j < num_colmns; j++) { + memory[i*num_colmns + j] -= x[j]; + } + } + } else { + assert(x.size()%size() == 0 && + "Vector operand size is not a multiple of matrix operand size"); + + std::vector m(x.size()); + const size_t num_colmns = x.size()/size(); + const size_t num_rows = size(); + for (size_t i = 0; i < num_rows; i++) { + for (size_t j = 0; j < num_colmns; j++) { + m[i*num_colmns + j] = memory[j] - x[i*num_colmns + j]; + } + } + memory = m; + } + } + +//------------------------------------------------------------------------------ +/// @brief Multiply row operation. +/// +/// Multiplies m_ij * v_i or v_i * m_ij. This will resize the buffer if it +/// needs to be. +/// +/// @params[in] x The right operand. +//------------------------------------------------------------------------------ + void multiply_row(const buffer &x) { + if (size() > x.size()) { + assert(size()%x.size() == 0 && + "Vector operand size is not a multiple of matrix operand size"); + + const size_t num_colmns = size()/x.size(); + const size_t num_rows = x.size(); + for (size_t i = 0; i < num_rows; i++) { + for (size_t j = 0; j < num_colmns; j++) { + memory[i*num_colmns + j] *= x[i]; + } + } + } else { + assert(x.size()%size() == 0 && + "Vector operand size is not a multiple of matrix operand size"); + + std::vector m(x.size()); + const size_t num_colmns = x.size()/size(); + const size_t num_rows = size(); + for (size_t i = 0; i < num_rows; i++) { + for (size_t j = 0; j < num_colmns; j++) { + m[i*num_colmns + j] = memory[i]*x[i*num_colmns + j]; + } + } + memory = m; + } + } + +//------------------------------------------------------------------------------ +/// @brief Multiply col operation. +/// +/// Multiplies m_ij * v_j or v_j * m_ij. This will resize the buffer if it +/// needs to be. +/// +/// @params[in] x The other operand. +//------------------------------------------------------------------------------ + void multiply_col(const buffer &x) { + if (size() > x.size()) { + assert(size()%x.size() == 0 && + "Vector operand size is not a multiple of matrix operand size"); + + const size_t num_colmns = size()/x.size(); + const size_t num_rows = x.size(); + for (size_t i = 0; i < num_rows; i++) { + for (size_t j = 0; j < num_colmns; j++) { + memory[i*num_colmns + j] *= x[j]; + } + } + } else { + assert(x.size()%size() == 0 && + "Vector operand size is not a multiple of matrix operand size"); + + std::vector m(x.size()); + const size_t num_colmns = x.size()/size(); + const size_t num_rows = size(); + for (size_t i = 0; i < num_rows; i++) { + for (size_t j = 0; j < num_colmns; j++) { + m[i*num_colmns + j] = memory[j]*x[i*num_colmns + j]; + } + } + memory = m; + } + } + +//------------------------------------------------------------------------------ +/// @brief Divide row operation. +/// +/// Divides m_ij / v_i or v_i / m_ij. This will resize the buffer if it needs +/// to be. +/// +/// @params[in] x The right operand. +//------------------------------------------------------------------------------ + void divide_row(const buffer &x) { + if (size() > x.size()) { + assert(size()%x.size() == 0 && + "Vector operand size is not a multiple of matrix operand size"); + + const size_t num_colmns = size()/x.size(); + const size_t num_rows = x.size(); + for (size_t i = 0; i < num_rows; i++) { + for (size_t j = 0; j < num_colmns; j++) { + memory[i*num_colmns + j] /= x[i]; + } + } + } else { + assert(x.size()%size() == 0 && + "Vector operand size is not a multiple of matrix operand size"); + + std::vector m(x.size()); + const size_t num_colmns = x.size()/size(); + const size_t num_rows = size(); + for (size_t i = 0; i < num_rows; i++) { + for (size_t j = 0; j < num_colmns; j++) { + m[i*num_colmns + j] = memory[i]/x[i*num_colmns + j]; + } + } + memory = m; + } + } + +//------------------------------------------------------------------------------ +/// @brief Divide col operation. +/// +/// Divides m_ij / v_j or v_j / m_ij. This will resize the buffer if it needs +/// to be. +/// +/// @params[in] x The other operand. +//------------------------------------------------------------------------------ + void divide_col(const buffer &x) { + if (size() > x.size()) { + assert(size()%x.size() == 0 && + "Vector operand size is not a multiple of matrix operand size"); + + const size_t num_colmns = size()/x.size(); + const size_t num_rows = x.size(); + for (size_t i = 0; i < num_rows; i++) { + for (size_t j = 0; j < num_colmns; j++) { + memory[i*num_colmns + j] /= x[j]; + } + } + } else { + assert(x.size()%size() == 0 && + "Vector operand size is not a multiple of matrix operand size"); + + std::vector m(x.size()); + const size_t num_colmns = x.size()/size(); + const size_t num_rows = size(); + for (size_t i = 0; i < num_rows; i++) { + for (size_t j = 0; j < num_colmns; j++) { + m[i*num_colmns + j] = memory[j]/x[i*num_colmns + j]; + } + } + memory = m; + } + } + +//------------------------------------------------------------------------------ +/// @brief Atan row operation. +/// +/// Computes atan(m_ij, v_i) or atan(v_i, m_ij). This will resize the buffer if +/// it needs to be. +/// +/// @params[in] x The right operand. +//------------------------------------------------------------------------------ + void atan_row(const buffer &x) { + if (size() > x.size()) { + assert(size()%x.size() == 0 && + "Vector operand size is not a multiple of matrix operand size"); + + const size_t num_colmns = size()/x.size(); + const size_t num_rows = x.size(); + for (size_t i = 0; i < num_rows; i++) { + for (size_t j = 0; j < num_colmns; j++) { + if constexpr (jit::is_complex ()) { + memory[i*num_colmns + j] = std::atan(x[i]/memory[i*num_colmns + j]); + } else { + memory[i*num_colmns + j] = std::atan2(x[i], memory[i*num_colmns + j]); + } + } + } + } else { + assert(x.size()%size() == 0 && + "Vector operand size is not a multiple of matrix operand size"); + + std::vector m(x.size()); + const size_t num_colmns = x.size()/size(); + const size_t num_rows = size(); + for (size_t i = 0; i < num_rows; i++) { + for (size_t j = 0; j < num_colmns; j++) { + if constexpr (jit::is_complex ()) { + m[i*num_colmns + j] = std::atan(x[i*num_colmns + j]/memory[i]); + } else { + m[i*num_colmns + j] = std::atan2(x[i*num_colmns + j], memory[i]); + } + } + } + memory = m; + } + } + +//------------------------------------------------------------------------------ +/// @brief Atan col operation. +/// +/// Computes atan(m_ij, v_j) or atan(v_j, m_ij). This will resize the buffer if +/// it needs to be. +/// +/// @params[in] x The other operand. +//------------------------------------------------------------------------------ + void atan_col(const buffer &x) { + if (size() > x.size()) { + assert(size()%x.size() == 0 && + "Vector operand size is not a multiple of matrix operand size"); + + const size_t num_colmns = size()/x.size(); + const size_t num_rows = x.size(); + for (size_t i = 0; i < num_colmns; i++) { + for (size_t j = 0; j < num_rows; j++) { + if constexpr (jit::is_complex ()) { + memory[i*num_colmns + j] = std::atan(x[j]/memory[i*num_colmns + j]); + } else { + memory[i*num_colmns + j] = std::atan2(x[j], memory[i*num_colmns + j]); + } + } + } + } else { + assert(x.size()%size() == 0 && + "Vector operand size is not a multiple of matrix operand size"); + + std::vector m(x.size()); + const size_t num_colmns = x.size()/size(); + const size_t num_rows = size(); + for (size_t i = 0; i < num_rows; i++) { + for (size_t j = 0; j < num_colmns; j++) { + if constexpr (jit::is_complex ()) { + m[i*num_colmns + j] = std::atan(x[i*num_colmns + j]/memory[j]); + } else { + m[i*num_colmns + j] = std::atan2(x[i*num_colmns + j], memory[j]); + } + } + } + memory = m; + } + } + +//------------------------------------------------------------------------------ +/// @brief Pow row operation. +/// +/// Computes pow(m_ij, v_i) or pow(v_i, m_ij). This will resize the buffer if +/// it needs to be. +/// +/// @params[in] x The right operand. +//------------------------------------------------------------------------------ + void pow_row(const buffer &x) { + if (size() > x.size()) { + assert(size()%x.size() == 0 && + "Vector operand size is not a multiple of matrix operand size"); + + const size_t num_colmns = size()/x.size(); + const size_t num_rows = x.size(); + for (size_t i = 0; i < num_rows; i++) { + for (size_t j = 0; j < num_colmns; j++) { + memory[i*num_colmns + j] = std::pow(memory[i*num_colmns + j], x[i]); + } + } + } else { + assert(x.size()%size() == 0 && + "Vector operand size is not a multiple of matrix operand size"); + + std::vector m(x.size()); + const size_t num_colmns = x.size()/size(); + const size_t num_rows = size(); + for (size_t i = 0; i < num_colmns; i++) { + for (size_t j = 0; j < num_rows; j++) { + m[i*num_colmns + j] = std::pow(memory[i], x[i*num_colmns + j]); + } + } + memory = m; + } + } + +//------------------------------------------------------------------------------ +/// @brief Pow col operation. +/// +/// Computes pow(m_ij, v_j) or pow(v_j, m_ij). This will resize the buffer if +/// it needs to be. +/// +/// @params[in] x The other operand. +//------------------------------------------------------------------------------ + void pow_col(const buffer &x) { + if (size() > x.size()) { + assert(size()%x.size() == 0 && + "Vector operand size is not a multiple of matrix operand size"); + + const size_t num_colmns = size()/x.size(); + const size_t num_rows = x.size(); + for (size_t i = 0; i < num_rows; i++) { + for (size_t j = 0; j < num_colmns; j++) { + memory[i*num_colmns + j] = std::pow(memory[i*num_colmns + j], x[j]); + } + } + } else { + assert(x.size()%size() == 0 && + "Vector operand size is not a multiple of matrix operand size"); + + std::vector m(x.size()); + const size_t num_colmns = x.size()/size(); + const size_t num_rows = size(); + for (size_t i = 0; i < num_rows; i++) { + for (size_t j = 0; j < num_colmns; j++) { + m[i*num_colmns + j] = std::pow(memory[j], x[i*num_colmns + j]); + } + } + memory = m; + } + } + /// Type def to retrieve the backend T type. typedef T base; }; diff --git a/graph_framework/cpu_context.hpp b/graph_framework/cpu_context.hpp index ca5a92bd7453f3a6bb9af46dc648606edddba566..a41d4849c65871dc480359a11799deea37bc0be2 100644 --- a/graph_framework/cpu_context.hpp +++ b/graph_framework/cpu_context.hpp @@ -24,6 +24,10 @@ #include "clang/Lex/PreprocessorOptions.h" #include "llvm/Support/TargetSelect.h" #include "llvm/ExecutionEngine/Orc/LLJIT.h" +#ifndef NDEBUG +#include "llvm/ExecutionEngine/Orc/Debugging/DebuggerSupport.h" +#include "llvm/ExecutionEngine/Orc/TargetProcess/JITLoaderGDB.h" +#endif #include "llvm/Support/raw_ostream.h" #include "llvm/ADT/IntrusiveRefCntPtr.h" #include "llvm/ADT/SmallVector.h" @@ -32,6 +36,16 @@ #include "node.hpp" +#ifndef NDEBUG +//------------------------------------------------------------------------------ +/// @brief This just exposes the functions so the debugger links. +//------------------------------------------------------------------------------ +LLVM_ATTRIBUTE_USED void linkComponents() { + llvm::errs() << (void *)&llvm_orc_registerJITLoaderGDBWrapper + << (void *)&llvm_orc_registerJITLoaderGDBAllocAction; +} +#endif + namespace gpu { //------------------------------------------------------------------------------ /// @brief Split a string by the space delimiter. @@ -75,6 +89,9 @@ namespace gpu { std::map *, size_t> arg_index; public: +/// Remaining constant memory in bytes. NOT USED. + int remaining_const_memory; + //------------------------------------------------------------------------------ /// @brief Get the maximum number of concurrent instances. /// @@ -132,6 +149,8 @@ namespace gpu { args.push_back(filename.c_str()); #ifdef NDEBUG args.push_back("-O3"); +#else + args.push_back("-debug-info-kind=standalone"); #endif if (jit::verbose) { for (auto &arg : args) { @@ -173,7 +192,13 @@ namespace gpu { auto ir_module = action.takeModule(); auto context = std::unique_ptr (action.takeLLVMContext()); - auto jit_try = llvm::orc::LLJITBuilder().create(); + auto jit_try = llvm::orc::LLJITBuilder() +#ifndef NDEBUG + .setPrePlatformSetup([](llvm::orc::LLJIT &J) { + return llvm::orc::enableDebuggerSupport(J); + }) +#endif + .create(); if (auto jiterror = jit_try.takeError()) { std::cerr << "Failed to build JIT : " << toString(std::move(jiterror)) << std::endl; exit(-1); @@ -187,16 +212,20 @@ namespace gpu { //------------------------------------------------------------------------------ /// @brief Create a kernel calling function. /// -/// @params[in] kernel_name Name of the kernel for later reference. -/// @params[in] inputs Input nodes of the kernel. -/// @params[in] outputs Output nodes of the kernel. -/// @params[in] num_rays Number of rays to trace. +/// @params[in] kernel_name Name of the kernel for later reference. +/// @params[in] inputs Input nodes of the kernel. +/// @params[in] outputs Output nodes of the kernel. +/// @params[in] num_rays Number of rays to trace. +/// @params[in] tex1d_list List of 1D textures. +/// @params[in] tex2d_list List of 1D textures. /// @returns A lambda function to run the kernel. //------------------------------------------------------------------------------ std::function create_kernel_call(const std::string kernel_name, graph::input_nodes inputs, graph::output_nodes outputs, - const size_t num_rays) { + const size_t num_rays, + const jit::texture1d_list &tex1d_list, + const jit::texture2d_list &tex2d_list) { auto entry = std::move(jit->lookup(kernel_name)).get(); auto kernel = entry.toPtr &)> (); @@ -353,14 +382,22 @@ namespace gpu { /// @params[in] inputs Input variables of the kernel. /// @params[in] outputs Output nodes of the graph to compute. /// @params[in] size Size of the input buffer. +/// @params[in] is_constant Flags if the input is read only. /// @params[in,out] registers Map of used registers. +/// @params[in] usage List of register usage count. +/// @params[in] textures1d List of 1D kernel textures. +/// @params[in] textures2d List of 2D kernel textures. //------------------------------------------------------------------------------ void create_kernel_prefix(std::ostringstream &source_buffer, const std::string name, graph::input_nodes &inputs, graph::output_nodes &outputs, - const size_t size, - jit::register_map ®isters) { + const size_t size, + const std::vector &is_constant, + jit::register_map ®isters, + const jit::register_usage &usage, + jit::texture1d_list &textures1d, + jit::texture2d_list &textures2d) { source_buffer << std::endl; source_buffer << "extern \"C\" void " << name << "(" << std::endl; @@ -368,11 +405,14 @@ namespace gpu { jit::add_type (source_buffer); source_buffer << " *> &args) {" << std::endl; - for (auto &input : inputs) { + for (size_t i = 0, ie = inputs.size(); i < ie; i++) { source_buffer << " "; + if (is_constant[i]) { + source_buffer << "const "; + } jit::add_type (source_buffer); - source_buffer << " *" << jit::to_string('v', input.get()) - << " = args[" << reinterpret_cast (input.get()) + source_buffer << " *" << jit::to_string('v', inputs[i].get()) + << " = args[" << reinterpret_cast (inputs[i].get()) << "];" << std::endl; } for (auto &output : outputs) { @@ -391,7 +431,8 @@ namespace gpu { jit::add_type (source_buffer); source_buffer << " " << registers[input.get()] << " = " << jit::to_string('v', input.get()) - << "[i]; //" << input->get_symbol() << std::endl; + << "[i]; // " << input->get_symbol() + << " used " << usage.at(input.get()) << std::endl; } } @@ -402,13 +443,17 @@ namespace gpu { /// @params[in] outputs Output nodes of the graph to compute. /// @params[in] setters Map outputs back to input values. /// @params[in,out] registers Map of used registers. +/// @params[in] usage List of register usage count. //------------------------------------------------------------------------------ void create_kernel_postfix(std::ostringstream &source_buffer, graph::output_nodes &outputs, graph::map_nodes &setters, - jit::register_map ®isters) { + jit::register_map ®isters, + const jit::register_usage &usage) { for (auto &[out, in] : setters) { - graph::shared_leaf a = out->compile(source_buffer, registers); + graph::shared_leaf a = out->compile(source_buffer, + registers, + usage); source_buffer << " " << jit::to_string('v', in.get()); source_buffer << "[i] = "; if constexpr (SAFE_MATH) { @@ -431,7 +476,9 @@ namespace gpu { } } for (auto &out : outputs) { - graph::shared_leaf a = out->compile(source_buffer, registers); + graph::shared_leaf a = out->compile(source_buffer, + registers, + usage); source_buffer << " " << jit::to_string('o', out.get()); source_buffer << "[i] = "; if constexpr (SAFE_MATH) { diff --git a/graph_framework/cuda_context.hpp b/graph_framework/cuda_context.hpp index ea0ce7fb842b1784196dcf1c7aef57e6436a981d..e0912b01f31adfdca185654318a52ffd2eff0c92 100644 --- a/graph_framework/cuda_context.hpp +++ b/graph_framework/cuda_context.hpp @@ -9,12 +9,16 @@ #define cuda_context_h #include +#include #include #include #include "node.hpp" +#define MAX_REG 128 +#define MAX_CONSTANT_MEMORY + namespace gpu { //------------------------------------------------------------------------------ /// @brief Check results of realtime compile. @@ -40,6 +44,9 @@ namespace gpu { #ifndef NDEBUG const char *error; cuGetErrorString(result, &error); + if (result != CUDA_SUCCESS) { + std::cerr << name << " " << std::string(error) << std::endl; + } assert(result == CUDA_SUCCESS && error); #endif } @@ -72,6 +79,10 @@ namespace gpu { CUmodule module; /// Argument map. std::map *, CUdeviceptr> kernel_arguments; +#ifdef USE_CUDA_TEXTURES +/// Textures. + std::map texture_arguments; +#endif /// Result buffer. CUdeviceptr result_buffer; /// Cuda stream. @@ -93,6 +104,9 @@ namespace gpu { } public: +/// Remaining constant memory in bytes. + int remaining_const_memory; + //------------------------------------------------------------------------------ /// @brief Get the maximum number of concurrent instances. /// @@ -120,7 +134,11 @@ namespace gpu { check_error(cuDeviceGet(&device, index), "cuDeviceGet"); check_error(cuDevicePrimaryCtxRetain(&context, device), "cuDevicePrimaryCtxRetain"); check_error(cuCtxSetCurrent(context), "cuCtxSetCurrent"); + check_error(cuCtxSetCacheConfig(CU_FUNC_CACHE_PREFER_L1), "cuCtxSetCacheConfig"); check_error(cuStreamCreate(&stream, CU_STREAM_DEFAULT), "cuStreamCreate"); + check_error(cuDeviceGetAttribute(&remaining_const_memory, + CU_DEVICE_ATTRIBUTE_TOTAL_CONSTANT_MEMORY, + device), "cuDeviceGetAttribute"); } //------------------------------------------------------------------------------ @@ -136,6 +154,17 @@ namespace gpu { check_error(cuMemFree(value), "cuMemFree"); } +#ifdef USE_CUDA_TEXTURES + for (auto &[key, value] : texture_arguments) { + CUDA_RESOURCE_DESC resource; + check_error(cuTexObjectGetResourceDesc(&resource, value), + "cuTexObjectGetResourceDesc"); + + check_error(cuArrayDestroy(resource.res.array.hArray), "cuArrayDestroy"); + check_error(cuTexObjectDestroy(value), "cuTexObjectDestroy"); + } +#endif + if (result_buffer) { check_error(cuMemFree(result_buffer), "cuMemFree"); result_buffer = 0; @@ -203,13 +232,14 @@ namespace gpu { } const std::string temp = arch.str(); - std::array options({ + std::array options({ temp.c_str(), "--std=c++17", + "--relocatable-device-code=false", "--include-path=" CUDA_INCLUDE, "--include-path=" HEADER_DIR, "--extra-device-vectorization", - "--device-as-default-execution-space" + "--device-as-default-execution-space" }); if (nvrtcCompileProgram(kernel_program, options.size(), options.data())) { @@ -242,7 +272,20 @@ namespace gpu { check_nvrtc_error(nvrtcDestroyProgram(&kernel_program), "nvrtcDestroyProgram"); - check_error(cuModuleLoadDataEx(&module, ptx, 0, NULL, NULL), "cuModuleLoadDataEx"); + std::array module_options = { + CU_JIT_MAX_REGISTERS, + CU_JIT_LTO, + CU_JIT_POSITION_INDEPENDENT_CODE + }; + std::array module_values = { + reinterpret_cast (MAX_REG), + reinterpret_cast (1), + reinterpret_cast (0) + }; + + check_error(cuModuleLoadDataEx(&module, ptx, 1, + module_options.data(), + module_values.data()), "cuModuleLoadDataEx"); free(ptx); } @@ -253,13 +296,17 @@ namespace gpu { /// @params[in] kernel_name Name of the kernel for later reference. /// @params[in] inputs Input nodes of the kernel. /// @params[in] outputs Output nodes of the kernel. -/// @params[in] num_rays Number of rays to trace. +/// @params[in] num_rays Number of rays to trace.' +/// @params[in] tex1d_list List of 1D textures. +/// @params[in] tex2d_list List of 1D textures. /// @returns A lambda function to run the kernel. //------------------------------------------------------------------------------ std::function create_kernel_call(const std::string kernel_name, graph::input_nodes inputs, graph::output_nodes outputs, - const size_t num_rays) { + const size_t num_rays, + const jit::texture1d_list &tex1d_list, + const jit::texture2d_list &tex2d_list) { CUfunction function; check_error(cuModuleGetFunction(&function, module, kernel_name.c_str()), "cuModuleGetFunction"); @@ -292,16 +339,129 @@ namespace gpu { buffers.push_back(reinterpret_cast (&kernel_arguments[output.get()])); } +#ifdef USE_CUDA_TEXTURES + for (auto &[data, size] : tex1d_list) { + if (!texture_arguments.contains(data)) { + texture_arguments.try_emplace(data); + CUDA_RESOURCE_DESC resource_desc; + CUDA_TEXTURE_DESC texture_desc; + CUDA_ARRAY_DESCRIPTOR array_desc; + + array_desc.Width = size; + array_desc.Height = 1; + + memset(&resource_desc, 0, sizeof(CUDA_RESOURCE_DESC)); + memset(&texture_desc, 0, sizeof(CUDA_TEXTURE_DESC)); + + resource_desc.resType = CU_RESOURCE_TYPE_ARRAY; + texture_desc.addressMode[0] = CU_TR_ADDRESS_MODE_BORDER; + texture_desc.addressMode[1] = CU_TR_ADDRESS_MODE_BORDER; + texture_desc.addressMode[2] = CU_TR_ADDRESS_MODE_BORDER; + if constexpr (jit::is_float ()) { + array_desc.Format = CU_AD_FORMAT_FLOAT; + if constexpr (jit::is_complex ()) { + array_desc.NumChannels = 2; + } else { + array_desc.NumChannels = 1; + } + } else { + array_desc.Format = CU_AD_FORMAT_UNSIGNED_INT32; + if constexpr (jit::is_complex ()) { + array_desc.NumChannels = 4; + } else { + array_desc.NumChannels = 2; + } + } + check_error(cuArrayCreate(&resource_desc.res.array.hArray, &array_desc), + "cuArrayCreate"); + check_error(cuMemcpyHtoA(resource_desc.res.array.hArray, 0, data, + size*sizeof(float)*array_desc.NumChannels), + "cuMemcpyHtoA"); + + check_error(cuTexObjectCreate(&texture_arguments[data], + &resource_desc, &texture_desc, + NULL), + "cuTexObjectCreate"); + } + buffers.push_back(reinterpret_cast (&texture_arguments[data])); + } + for (auto &[data, size] : tex2d_list) { + if (!texture_arguments.contains(data)) { + texture_arguments.try_emplace(data); + CUDA_RESOURCE_DESC resource_desc; + CUDA_TEXTURE_DESC texture_desc; + CUDA_ARRAY_DESCRIPTOR array_desc; + + array_desc.Width = size[0]; + array_desc.Height = size[1]; + + memset(&resource_desc, 0, sizeof(CUDA_RESOURCE_DESC)); + memset(&texture_desc, 0, sizeof(CUDA_TEXTURE_DESC)); + + resource_desc.resType = CU_RESOURCE_TYPE_ARRAY; + texture_desc.addressMode[0] = CU_TR_ADDRESS_MODE_BORDER; + texture_desc.addressMode[1] = CU_TR_ADDRESS_MODE_BORDER; + texture_desc.addressMode[2] = CU_TR_ADDRESS_MODE_BORDER; + const size_t total = size[0]*size[1]; + if constexpr (jit::is_float ()) { + array_desc.Format = CU_AD_FORMAT_FLOAT; + if constexpr (jit::is_complex ()) { + array_desc.NumChannels = 2; + } else { + array_desc.NumChannels = 1; + } + } else { + array_desc.Format = CU_AD_FORMAT_UNSIGNED_INT32; + if constexpr (jit::is_complex ()) { + array_desc.NumChannels = 4; + } else { + array_desc.NumChannels = 2; + } + } + check_error(cuArrayCreate(&resource_desc.res.array.hArray, &array_desc), + "cuArrayCreate"); + + CUDA_MEMCPY2D copy_desc; + memset(©_desc, 0, sizeof(copy_desc)); + + copy_desc.srcPitch = size[0]*sizeof(float)*array_desc.NumChannels; + copy_desc.srcMemoryType = CU_MEMORYTYPE_HOST; + copy_desc.srcHost = data; + + copy_desc.dstMemoryType = CU_MEMORYTYPE_ARRAY; + copy_desc.dstArray = resource_desc.res.array.hArray; + + copy_desc.WidthInBytes = copy_desc.srcPitch; + copy_desc.Height = size[0]; + + check_error(cuMemcpy2D(©_desc), "cuMemcpy2D"); + + check_error(cuTexObjectCreate(&texture_arguments[data], + &resource_desc, &texture_desc, + NULL), + "cuTexObjectCreate"); + } + buffers.push_back(reinterpret_cast (&texture_arguments[data])); + } +#endif + int value; check_error(cuFuncGetAttribute(&value, CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK, function), "cuFuncGetAttribute"); unsigned int threads_per_group = value; unsigned int thread_groups = num_rays/threads_per_group + (num_rays%threads_per_group ? 1 : 0); + + int min_grid; + check_error(cuOccupancyMaxPotentialBlockSize(&min_grid, &value, function, 0, 0, 0), + "cuOccupancyMaxPotentialBlockSize"); + if (jit::verbose) { std::cout << " Kernel name : " << kernel_name << std::endl; - std::cout << " Threads per group : " << threads_per_group << std::endl; - std::cout << " Number of groups : " << thread_groups << std::endl; - std::cout << " Total problem size : " << threads_per_group*thread_groups << std::endl; + std::cout << " Threads per group : " << threads_per_group << std::endl; + std::cout << " Number of groups : " << thread_groups << std::endl; + std::cout << " Total problem size : " << threads_per_group*thread_groups << std::endl; + std::cout << " Min grid size : " << min_grid << std::endl; + std::cout << " Suggested Block size : " << value << std::endl; } return [this, function, thread_groups, threads_per_group, buffers] () mutable { @@ -334,8 +494,15 @@ namespace gpu { check_error(cuModuleGetFunction(&function, module, "max_reduction"), "cuModuleGetFunction"); + int value; + int min_grid; + check_error(cuOccupancyMaxPotentialBlockSize(&min_grid, &value, function, 0, 0, 0), + "cuOccupancyMaxPotentialBlockSize"); + if (jit::verbose) { std::cout << " Kernel name : max_reduction" << std::endl; + std::cout << " Min grid size : " << min_grid << std::endl; + std::cout << " Suggested Block size : " << value << std::endl; } return [this, function, run, buffers] () mutable { @@ -425,10 +592,36 @@ namespace gpu { void create_header(std::ostringstream &source_buffer) { if constexpr (jit::is_complex ()) { source_buffer << "#define CUDA_DEVICE_CODE" << std::endl; - source_buffer << "#define M_PI " << M_PI << std::endl; + source_buffer << "#define M_PI " << M_PI << std::endl; source_buffer << "#include " << std::endl; source_buffer << "#include " << std::endl; +#ifdef USE_CUDA_TEXTURES + if constexpr (jit::is_float ()) { + source_buffer << "static __inline__ __device__ complex to_cmp_float(float2 p) {" + << std::endl + << " return "; + jit::add_type (source_buffer); + source_buffer << " (p.x, p.y);" << std::endl + << "}" << std::endl; + } else { + source_buffer << "static __inline__ __device__ complex to_cmp_double(uint4 p) {" + << std::endl + << " return "; + jit::add_type (source_buffer); + source_buffer << " (__hiloint2double(p.y, p.x), __hiloint2double(p.w, p.z));" + << std::endl + << "}" << std::endl; + } + } else if constexpr (jit::is_double ()) { + source_buffer << "static __inline__ __device__ double to_double(uint2 p) {" + << std::endl + << " return __hiloint2double(p.y, p.x);" + << std::endl + << "}" << std::endl; + } +#else } +#endif } //------------------------------------------------------------------------------ @@ -439,34 +632,76 @@ namespace gpu { /// @params[in] inputs Input variables of the kernel. /// @params[in] outputs Output nodes of the graph to compute. /// @params[in] size Size of the input buffer. +/// @params[in] is_constant Flags if the input is read only. /// @params[in,out] registers Map of used registers. +/// @params[in] usage List of register usage count. +/// @params[in] textures1d List of 1D kernel textures. +/// @params[in] textures2d List of 2D kernel textures. //------------------------------------------------------------------------------ void create_kernel_prefix(std::ostringstream &source_buffer, const std::string name, graph::input_nodes &inputs, graph::output_nodes &outputs, const size_t size, - jit::register_map ®isters) { + const std::vector &is_constant, + jit::register_map ®isters, + const jit::register_usage &usage, + jit::texture1d_list &textures1d, + jit::texture2d_list &textures2d) { source_buffer << std::endl; - source_buffer << "extern \"C\" __global__ void " << name << "(" - << std::endl; + source_buffer << "extern \"C\" __global__ void " + << name << "(" << std::endl; source_buffer << " "; + if (is_constant[0]) { + source_buffer << "const "; + } jit::add_type (source_buffer); - source_buffer << " *" << jit::to_string('v', inputs[0].get()); + source_buffer << " * __restrict__ " + << jit::to_string('v', inputs[0].get()); for (size_t i = 1, ie = inputs.size(); i < ie; i++) { - source_buffer << "," << std::endl; + source_buffer << ", // " << inputs[i - 1]->get_symbol() +#ifndef USE_INPUT_CACHE + << " used " << usage.at(inputs[i - 1].get()) +#endif + << std::endl; source_buffer << " "; + if (is_constant[i]) { + source_buffer << "const "; + } jit::add_type (source_buffer); - source_buffer << " *" << jit::to_string('v', inputs[i].get()); + source_buffer << " * __restrict__ " + << jit::to_string('v', inputs[i].get()); } - for (size_t i = 0, ie = outputs.size(); i < ie; i++) { - source_buffer << "," << std::endl; + source_buffer << ","; + if (i == 0) { + source_buffer << " // " + << inputs[inputs.size() - 1]->get_symbol(); +#ifndef USE_INPUT_CACHE + source_buffer << " used " + << usage.at(inputs[inputs.size() - 1].get()); +#endif + } + + source_buffer << std::endl; source_buffer << " "; jit::add_type (source_buffer); - source_buffer << " *" << jit::to_string('o', outputs[i].get()); + source_buffer << " * __restrict__ " + << jit::to_string('o', outputs[i].get()); + } +#ifdef USE_CUDA_TEXTURES + for (auto &[key, value] : textures1d) { + source_buffer << "," << std::endl; + source_buffer << " cudaTextureObject_t " + << jit::to_string('a', key); } + for (auto &[key, value] : textures2d) { + source_buffer << "," << std::endl; + source_buffer << " cudaTextureObject_t " + << jit::to_string('a', key); + } +#endif source_buffer << ") {" << std::endl; source_buffer << " const int index = blockIdx.x*blockDim.x + threadIdx.x;" @@ -474,12 +709,19 @@ namespace gpu { source_buffer << " if (index < " << size << ") {" << std::endl; for (auto &input : inputs) { - registers[input.get()] = jit::to_string('r', input.get()); - source_buffer << " const "; - jit::add_type (source_buffer); - source_buffer << " " << registers[input.get()] << " = " - << jit::to_string('v', input.get()) << "[index];" - << std::endl; +#ifdef USE_INPUT_CACHE + if (usage.at(input.get())) { + registers[input.get()] = jit::to_string('r', input.get()); + source_buffer << " const "; + jit::add_type (source_buffer); + source_buffer << " " << registers[input.get()] << " = " + << jit::to_string('v', input.get()) + << "[index]; // " << input->get_symbol() + << " used " << usage.at(input.get()) << std::endl; + } +#else + registers[input.get()] = jit::to_string('v', input.get()) + "[index]"; +#endif } } @@ -490,14 +732,17 @@ namespace gpu { /// @params[in] outputs Output nodes of the graph to compute. /// @params[in] setters Map outputs back to input values. /// @params[in,out] registers Map of used registers. - +/// @params[in] usage List of register usage count. //------------------------------------------------------------------------------ void create_kernel_postfix(std::ostringstream &source_buffer, graph::output_nodes &outputs, graph::map_nodes &setters, - jit::register_map ®isters) { + jit::register_map ®isters, + const jit::register_usage &usage) { for (auto &[out, in] : setters) { - graph::shared_leaf a = out->compile(source_buffer, registers); + graph::shared_leaf a = out->compile(source_buffer, + registers, + usage); source_buffer << " " << jit::to_string('v', in.get()) << "[index] = "; if constexpr (SAFE_MATH) { @@ -521,7 +766,9 @@ namespace gpu { } for (auto &out : outputs) { - graph::shared_leaf a = out->compile(source_buffer, registers); + graph::shared_leaf a = out->compile(source_buffer, + registers, + usage); source_buffer << " " << jit::to_string('o', out.get()) << "[index] = "; if constexpr (SAFE_MATH) { @@ -557,12 +804,12 @@ namespace gpu { const size_t size) { source_buffer << std::endl; source_buffer << "extern \"C\" __global__ void max_reduction(" << std::endl; - source_buffer << " "; + source_buffer << " const "; jit::add_type (source_buffer); - source_buffer << " *input," << std::endl; + source_buffer << " * __restrict__ input," << std::endl; source_buffer << " "; jit::add_type (source_buffer); - source_buffer << " *result) {" << std::endl; + source_buffer << " * __restrict__ result) {" << std::endl; source_buffer << " const unsigned int i = threadIdx.x;" << std::endl; source_buffer << " const unsigned int j = threadIdx.x/32;" << std::endl; source_buffer << " const unsigned int k = threadIdx.x%32;" << std::endl; @@ -596,47 +843,6 @@ namespace gpu { source_buffer << "}" << std::endl << std::endl; } -//------------------------------------------------------------------------------ -/// @brief Create a preamble. -/// -/// @params[in,out] source_buffer Source buffer stream. -//------------------------------------------------------------------------------ - void create_preamble(std::ostringstream &source_buffer) { - source_buffer << "extern \"C\" __global__ "; - } - -//------------------------------------------------------------------------------ -/// @brief Create arg prefix. -/// -/// @params[in,out] source_buffer Source buffer stream. -//------------------------------------------------------------------------------ - void create_argument_prefix(std::ostringstream &source_buffer) {} - -//------------------------------------------------------------------------------ -/// @brief Create arg postfix. -/// -/// @params[in,out] source_buffer Source buffer stream. -/// @params[in] index Argument index. -//------------------------------------------------------------------------------ - void create_argument_postfix(std::ostringstream &source_buffer, - const size_t index) {} - -//------------------------------------------------------------------------------ -/// @brief Create index argument. -/// -/// @params[in,out] source_buffer Source buffer stream. -//------------------------------------------------------------------------------ - void create_index_argument(std::ostringstream &source_buffer) {} - -//------------------------------------------------------------------------------ -/// @brief Create index. -/// -/// @params[in,out] source_buffer Source buffer stream. -//------------------------------------------------------------------------------ - void create_index(std::ostringstream &source_buffer) { - source_buffer << "blockIdx.x*blockDim.x + threadIdx.x;"; - } - //------------------------------------------------------------------------------ /// @brief Get the buffer for a node. /// diff --git a/graph_framework/jit.hpp b/graph_framework/jit.hpp index cfb3e29ae3cea8fd5ef37a76796dbfc485be376f..ba092c569a8150998e37356d75e5edd734eec006 100644 --- a/graph_framework/jit.hpp +++ b/graph_framework/jit.hpp @@ -8,6 +8,10 @@ #ifndef jit_h #define jit_h +#include +#include +#include + #ifdef USE_METAL #include "metal_context.hpp" #elif defined(USE_CUDA) @@ -39,6 +43,10 @@ namespace jit { register_map registers; /// Kernel names. std::vector kernel_names; +/// Kernel textures. + std::map kernel_1dtextures; +/// Kernel textures. + std::map kernel_2dtextures; /// Type for the GPU context. using gpu_context_type = typename std::conditional (), @@ -94,39 +102,66 @@ namespace jit { graph::output_nodes outputs, graph::map_nodes setters) { kernel_names.push_back(name); - + const size_t size = inputs[0]->size(); + std::vector is_constant(inputs.size(), true); visiter_map visited; + register_usage usage; + kernel_1dtextures[name] = texture1d_list(); + kernel_2dtextures[name] = texture2d_list(); for (auto &[out, in] : setters) { - out->compile_preamble(source_buffer, registers, visited); + auto found = std::distance(inputs.begin(), + std::find(inputs.begin(), + inputs.end(), in)); + if (found < is_constant.size()) { + is_constant[found] = false; + } + out->compile_preamble(source_buffer, registers, + visited, usage, + kernel_1dtextures[name], + kernel_2dtextures[name], + gpu_context.remaining_const_memory); } for (auto &out : outputs) { - out->compile_preamble(source_buffer, registers, visited); + out->compile_preamble(source_buffer, registers, + visited, usage, + kernel_1dtextures[name], + kernel_2dtextures[name], + gpu_context.remaining_const_memory); + } + + for (auto &in : inputs) { + if (usage.find(in.get()) == usage.end()) { + usage[in.get()] == 0; + } } gpu_context.create_kernel_prefix(source_buffer, - name, inputs, outputs, size, - registers); + name, inputs, outputs, + size, is_constant, + registers, usage, + kernel_1dtextures[name], + kernel_2dtextures[name]); for (auto &[out, in] : setters) { - out->compile(source_buffer, registers); + out->compile(source_buffer, registers, usage); } for (auto &out : outputs) { - out->compile(source_buffer, registers); + out->compile(source_buffer, registers, usage); } gpu_context.create_kernel_postfix(source_buffer, outputs, - setters, registers); + setters, registers, usage); -// Delete the registers so that can be used again in other kernels. +// Delete the registers so that they can be used again in other kernels. std::vector removed_elements; for (auto &[key, value] : registers) { if (value[0] == 'r') { removed_elements.push_back(key); } } - + for (auto &key : removed_elements) { registers.erase(key); } @@ -148,6 +183,26 @@ namespace jit { std::cout << std::endl << source_buffer.str() << std::endl; } +//------------------------------------------------------------------------------ +/// @brief Save the kernel source code. +//------------------------------------------------------------------------------ + void save_source() { + std::string source = source_buffer.str(); + std::ostringstream filename; + filename << std::hash {} (source) + << std::hash{}(std::this_thread::get_id()); + if constexpr (use_cuda()) { + filename << ".cu"; + } else if constexpr (use_metal ()) { + filename << ".metal"; + } else { + filename << ".cpp"; + } + + std::ofstream outFile(filename.str()); + outFile << source; + } + //------------------------------------------------------------------------------ /// @brief Compile the kernel. /// @@ -155,6 +210,9 @@ namespace jit { /// kernel. //------------------------------------------------------------------------------ void compile(const bool add_reduction=false) { +#ifdef SAVE_KERNEL_SOURCE + save_source(); +#endif gpu_context.compile(source_buffer.str(), kernel_names, add_reduction); @@ -173,8 +231,9 @@ namespace jit { graph::input_nodes inputs, graph::output_nodes outputs, const size_t num_rays) { - return gpu_context.create_kernel_call(kernel_name, inputs, outputs, - num_rays); + return gpu_context.create_kernel_call(kernel_name, inputs, outputs, num_rays, + kernel_1dtextures[kernel_name], + kernel_2dtextures[kernel_name]); } //------------------------------------------------------------------------------ diff --git a/graph_framework/math.hpp b/graph_framework/math.hpp index 60456ec0a87bf455573aa42e9803fdcd0b79236b..dca53cbaaee04a023295c316264f28e2d4e2b94d 100644 --- a/graph_framework/math.hpp +++ b/graph_framework/math.hpp @@ -3,7 +3,6 @@ /// @brief Defined basic math functions. //------------------------------------------------------------------------------ - #ifndef math_h #define math_h @@ -99,14 +98,10 @@ namespace graph { // sqrt((x^a)*y). auto am = multiply_cast(this->arg); if (am.get()) { - if (pow_cast(am->get_left()).get() || - constant_cast(am->get_left()).get() || - piecewise_1D_cast(am->get_left()).get() || - piecewise_2D_cast(am->get_left()).get() || - pow_cast(am->get_right()).get() || - constant_cast(am->get_right()).get() || - piecewise_1D_cast(am->get_right()).get() || - piecewise_2D_cast(am->get_right()).get()) { + if (pow_cast(am->get_left()).get() || + am->get_left()->is_constant() || + pow_cast(am->get_right()).get() || + am->get_right()->is_constant()) { return sqrt(am->get_left()) * sqrt(am->get_right()); } @@ -116,14 +111,10 @@ namespace graph { // where c is a constant. auto ad = divide_cast(this->arg); if (ad.get()) { - if (pow_cast(ad->get_left()).get() || - constant_cast(ad->get_left()).get() || - piecewise_1D_cast(ad->get_left()).get() || - piecewise_2D_cast(ad->get_left()).get() || - pow_cast(ad->get_right()).get() || - constant_cast(ad->get_right()).get() || - piecewise_1D_cast(ad->get_right()).get() || - piecewise_2D_cast(ad->get_right()).get()) { + if (pow_cast(ad->get_left()).get() || + ad->get_left()->is_constant() || + pow_cast(ad->get_right()).get() || + ad->get_right()->is_constant()) { return sqrt(ad->get_left()) / sqrt(ad->get_right()); } @@ -158,21 +149,24 @@ namespace graph { /// /// @params[in,out] stream String buffer stream. /// @params[in,out] registers List of defined registers. +/// @params[in] usage List of register usage count. /// @returns The current node. //------------------------------------------------------------------------------ virtual shared_leaf compile(std::ostringstream &stream, - jit::register_map ®isters) { + jit::register_map ®isters, + const jit::register_usage &usage) { if (registers.find(this) == registers.end()) { shared_leaf a = this->arg->compile(stream, - registers); + registers, + usage); registers[this] = jit::to_string('r', this); stream << " const "; jit::add_type (stream); stream << " " << registers[this] << " = sqrt(" - << registers[a.get()] << ");" - << std::endl; + << registers[a.get()] << "); // used " + << usage.at(this) << std::endl; } return this->shared_from_this(); @@ -416,13 +410,17 @@ namespace graph { /// /// @params[in,out] stream String buffer stream. /// @params[in,out] registers List of defined registers. +/// @params[in] usage List of register usage count. /// @returns The current node. //------------------------------------------------------------------------------ virtual shared_leaf compile(std::ostringstream &stream, - jit::register_map ®isters) { + jit::register_map ®isters, + const jit::register_usage &usage) { if (registers.find(this) == registers.end()) { - shared_leaf a = this->arg->compile(stream, registers); + shared_leaf a = this->arg->compile(stream, + registers, + usage); registers[this] = jit::to_string('r', this); stream << " const "; @@ -454,7 +452,7 @@ namespace graph { stream << ")"; } } - stream << ";" << std::endl; + stream << "; // used " << usage.at(this) << std::endl; } return this->shared_from_this(); @@ -671,20 +669,24 @@ namespace graph { /// /// @params[in,out] stream String buffer stream. /// @params[in,out] registers List of defined registers. +/// @params[in] usage List of register usage count. /// @returns The current node. //------------------------------------------------------------------------------ virtual shared_leaf compile(std::ostringstream &stream, - jit::register_map ®isters) { + jit::register_map ®isters, + const jit::register_usage &usage) { if (registers.find(this) == registers.end()) { - shared_leaf a = this->arg->compile(stream, registers); + shared_leaf a = this->arg->compile(stream, + registers, + usage); registers[this] = jit::to_string('r', this); stream << " const "; jit::add_type (stream); stream << " " << registers[this] << " = log(" - << registers[a.get()] << ");" - << std::endl; + << registers[a.get()] << "); // used " + << usage.at(this) << std::endl; } return this->shared_from_this(); @@ -854,39 +856,77 @@ namespace graph { /// @returns A reduced power node. //------------------------------------------------------------------------------ virtual shared_leaf reduce() { + auto lc = constant_cast(this->left); auto rc = constant_cast(this->right); - if (rc.get()) { - if (rc->is(0)) { - return one (); - } else if (rc->is(1)) { - return this->left; - } else if (rc->is(0.5)) { - return sqrt(this->left); - } else if (rc->is(2)){ - auto sq = sqrt_cast(this->left); - if (sq.get()) { - return sq->get_arg(); - } - } - - if (constant_cast(this->left).get()) { - return constant (this->evaluate()); + if (rc.get() && rc->is(0)) { + return one (); + } else if (rc.get() && rc->is(1)) { + return this->left; + } else if (rc.get() && rc->is(0.5)) { + return sqrt(this->left); + } else if (rc.get() && rc->is(2)){ + auto sq = sqrt_cast(this->left); + if (sq.get()) { + return sq->get_arg(); } + } - auto pl1 = piecewise_1D_cast(this->left); - if (pl1.get()) { - return piecewise_1D(this->evaluate(), - pl1->get_arg()); - } + if (lc.get() && rc.get()) { + return constant (this->evaluate()); + } - auto pl2 = piecewise_2D_cast(this->left); - if (pl2.get()) { - return piecewise_2D(this->evaluate(), - pl2->get_num_columns(), - pl2->get_left(), - pl2->get_right()); - } + auto pl1 = piecewise_1D_cast(this->left); + auto pr1 = piecewise_1D_cast(this->right); + if (pl1.get() && (rc.get() || pl1->is_arg_match(this->right))) { + return piecewise_1D(this->evaluate(), pl1->get_arg()); + } else if (pr1.get() && (lc.get() || pr1->is_arg_match(this->left))) { + return piecewise_1D(this->evaluate(), pr1->get_arg()); + } + + auto pl2 = piecewise_2D_cast(this->left); + auto pr2 = piecewise_2D_cast(this->right); + if (pl2.get() && (rc.get() || pl2->is_arg_match(this->right))) { + return piecewise_2D(this->evaluate(), + pl2->get_num_columns(), + pl2->get_left(), + pl2->get_right()); + } else if (pr2.get() && (lc.get() || pr2->is_arg_match(this->left))) { + return piecewise_2D(this->evaluate(), + pr2->get_num_columns(), + pr2->get_left(), + pr2->get_right()); + } + +// Combine 2D and 1D piecewise constants if a row or column matches. + if (pr2.get() && pr2->is_row_match(this->left)) { + backend::buffer result = pl1->evaluate(); + result.pow_row(pr2->evaluate()); + return piecewise_2D(result, + pr2->get_num_columns(), + pr2->get_left(), + pr2->get_right()); + } else if (pr2.get() && pr2->is_col_match(this->left)) { + backend::buffer result = pl1->evaluate(); + result.pow_col(pr2->evaluate()); + return piecewise_2D(result, + pr2->get_num_columns(), + pr2->get_left(), + pr2->get_right()); + } else if (pl2.get() && pl2->is_row_match(this->right)) { + backend::buffer result = pl2->evaluate(); + result.pow_row(pr1->evaluate()); + return piecewise_2D(result, + pl2->get_num_columns(), + pl2->get_left(), + pl2->get_right()); + } else if (pl2.get() && pl2->is_col_match(this->right)) { + backend::buffer result = pl2->evaluate(); + result.pow_col(pr1->evaluate()); + return piecewise_2D(result, + pl2->get_num_columns(), + pl2->get_left(), + pl2->get_right()); } auto lp = pow_cast(this->left); @@ -898,15 +938,11 @@ namespace graph { // Handle cases where (c*x)^a, (x*c)^a, (a*sqrt(b))^c and (a*b^c)^2. auto lm = multiply_cast(this->left); if (lm.get()) { - if (constant_cast(lm->get_left()).get() || - constant_cast(lm->get_right()).get() || - piecewise_1D_cast(lm->get_left()).get() || - piecewise_1D_cast(lm->get_right()).get() || - piecewise_2D_cast(lm->get_left()).get() || - piecewise_2D_cast(lm->get_right()).get() || - sqrt_cast(lm->get_left()).get() || - sqrt_cast(lm->get_right()).get() || - pow_cast(lm->get_left()).get() || + if (lm->get_left()->is_constant() || + lm->get_right()->is_constant() || + sqrt_cast(lm->get_left()).get() || + sqrt_cast(lm->get_right()).get() || + pow_cast(lm->get_left()).get() || pow_cast(lm->get_right()).get()) { return pow(lm->get_left(), this->right) * pow(lm->get_right(), this->right); @@ -916,15 +952,11 @@ namespace graph { // Handle cases where (c/x)^a, (x/c)^a, (a/sqrt(b))^c and (a/b^c)^2. auto ld = divide_cast(this->left); if (ld.get()) { - if (constant_cast(ld->get_left()).get() || - constant_cast(ld->get_right()).get() || - piecewise_1D_cast(ld->get_left()).get() || - piecewise_1D_cast(ld->get_right()).get() || - piecewise_2D_cast(ld->get_left()).get() || - piecewise_2D_cast(ld->get_right()).get() || - sqrt_cast(ld->get_left()).get() || - sqrt_cast(ld->get_right()).get() || - pow_cast(ld->get_left()).get() || + if (ld->get_left()->is_constant() || + ld->get_right()->is_constant() || + sqrt_cast(ld->get_left()).get() || + sqrt_cast(ld->get_right()).get() || + pow_cast(ld->get_left()).get() || pow_cast(ld->get_right()).get()) { return pow(ld->get_left(), this->right) / pow(ld->get_right(), this->right); @@ -975,17 +1007,21 @@ namespace graph { /// /// @params[in,out] stream String buffer stream. /// @params[in,out] registers List of defined registers. +/// @params[in] usage List of register usage count. /// @returns The current node. //------------------------------------------------------------------------------ virtual shared_leaf compile(std::ostringstream &stream, - jit::register_map ®isters) { + jit::register_map ®isters, + const jit::register_usage &usage) { if (registers.find(this) == registers.end()) { - shared_leaf l = this->left->compile(stream, registers); + shared_leaf l = this->left->compile(stream, + registers, + usage); shared_leaf r; auto temp = constant_cast(this->right); if (!temp.get() || !temp->is_integer()) { - r = this->right->compile(stream, registers); + r = this->right->compile(stream, registers, usage); } registers[this] = jit::to_string('r', this); @@ -1004,7 +1040,7 @@ namespace graph { << registers[l.get()] << ", " << registers[r.get()] << ");"; } - stream << std::endl; + stream << " // used " << usage.at(this) << std::endl; } return this->shared_from_this(); @@ -1266,20 +1302,24 @@ namespace graph { /// /// @params[in,out] stream String buffer stream. /// @params[in,out] registers List of defined registers. +/// @params[in] usage List of register usage count. /// @returns The current node. //------------------------------------------------------------------------------ virtual shared_leaf compile(std::ostringstream &stream, - jit::register_map ®isters) { + jit::register_map ®isters, + const jit::register_usage &usage) { if (registers.find(this) == registers.end()) { - shared_leaf a = this->arg->compile(stream, registers); + shared_leaf a = this->arg->compile(stream, + registers, + usage); registers[this] = jit::to_string('r', this); stream << " const "; jit::add_type (stream); stream << " " << registers[this] << " = special::erfi(" - << registers[a.get()] << ");" - << std::endl; + << registers[a.get()] << "); // used " + << usage.at(this) << std::endl; } return this->shared_from_this(); diff --git a/graph_framework/metal_context.hpp b/graph_framework/metal_context.hpp index 73d649e61f8f0030fef8a18c4b77c80e3ef2fd9c..f46f2be9ce25f1bceaeb5fe5af02b89e8f489550 100644 --- a/graph_framework/metal_context.hpp +++ b/graph_framework/metal_context.hpp @@ -27,14 +27,21 @@ namespace gpu { id queue; /// Argument map. std::map *, id> kernel_arguments; +/// Textures. + std::map> texture_arguments; /// Max Buffer. id result; /// Metal command buffer. id command_buffer; /// Metal library. id library; +/// Buffer mutability discriptor. + std::map> bufferMutability; public: +/// Remaining constant memory in bytes. NOT USED. + int remaining_const_memory; + //------------------------------------------------------------------------------ /// @brief Get the maximum number of concurrent instances. /// @@ -75,7 +82,7 @@ namespace gpu { encoding:NSUTF8StringEncoding] options:compile_options() error:&error]; - + if (error) { NSLog(@"%@", error); } @@ -88,16 +95,20 @@ namespace gpu { //------------------------------------------------------------------------------ /// @brief Create a kernel calling function. /// -/// @params[in] kernel_name Name of the kernel for later reference. -/// @params[in] inputs Input nodes of the kernel. -/// @params[in] outputs Output nodes of the kernel. -/// @params[in] num_rays Number of rays to trace. +/// @params[in] kernel_name Name of the kernel for later reference. +/// @params[in] inputs Input nodes of the kernel. +/// @params[in] outputs Output nodes of the kernel. +/// @params[in] num_rays Number of rays to trace. +/// @params[in] tex1d_list List of 1D textures. +/// @params[in] tex2d_list List of 1D textures. /// @returns A lambda function to run the kernel. //------------------------------------------------------------------------------ std::function create_kernel_call(const std::string kernel_name, graph::input_nodes inputs, graph::output_nodes outputs, - const size_t num_rays) { + const size_t num_rays, + const jit::texture1d_list &tex1d_list, + const jit::texture2d_list &tex2d_list) { NSError *error; id function = [library newFunctionWithName:[NSString stringWithCString:kernel_name.c_str() @@ -106,6 +117,10 @@ namespace gpu { MTLComputePipelineDescriptor *compute = [MTLComputePipelineDescriptor new]; compute.threadGroupSizeIsMultipleOfThreadExecutionWidth = YES; compute.computeFunction = function; + compute.maxTotalThreadsPerThreadgroup = 1024; + for (size_t i = 0, ie = bufferMutability[kernel_name].size(); i < ie; i++) { + compute.buffers[i].mutability = bufferMutability[kernel_name][i]; + } id state = [device newComputePipelineStateWithDescriptor:compute options:MTLPipelineOptionNone @@ -136,20 +151,70 @@ namespace gpu { buffers.push_back(kernel_arguments[output.get()]); } + std::vector> textures; + command_buffer = [queue commandBuffer]; + id encoder = [command_buffer blitCommandEncoder]; + for (auto &[data, size] : tex1d_list) { + if (!texture_arguments.contains(data)) { + MTLTextureDescriptor *discriptor = [MTLTextureDescriptor new]; + discriptor.textureType = MTLTextureType1D; + discriptor.pixelFormat = MTLPixelFormatR32Float; + discriptor.width = size; + discriptor.storageMode = MTLStorageModeManaged; + discriptor.cpuCacheMode = MTLCPUCacheModeWriteCombined; + discriptor.hazardTrackingMode = MTLHazardTrackingModeUntracked; + discriptor.usage = MTLTextureUsageShaderRead; + texture_arguments[data] = [device newTextureWithDescriptor:discriptor]; + [texture_arguments[data] replaceRegion:MTLRegionMake1D(0, size) + mipmapLevel:0 + withBytes:reinterpret_cast (data) + bytesPerRow:4*size]; + + [encoder optimizeContentsForGPUAccess:texture_arguments[data]]; + } + textures.push_back(texture_arguments[data]); + } + for (auto &[data, size] : tex2d_list) { + if (!texture_arguments.contains(data)) { + MTLTextureDescriptor *discriptor = [MTLTextureDescriptor new]; + discriptor.textureType = MTLTextureType2D; + discriptor.pixelFormat = MTLPixelFormatR32Float; + discriptor.width = size[1]; + discriptor.height = size[0]; + discriptor.storageMode = MTLStorageModeManaged; + discriptor.cpuCacheMode = MTLCPUCacheModeWriteCombined; + discriptor.hazardTrackingMode = MTLHazardTrackingModeUntracked; + discriptor.usage = MTLTextureUsageShaderRead; + texture_arguments[data] = [device newTextureWithDescriptor:discriptor]; + [texture_arguments[data] replaceRegion:MTLRegionMake2D(0, 0, size[1], size[0]) + mipmapLevel:0 + withBytes:reinterpret_cast (data) + bytesPerRow:4*size[1]]; + + [encoder optimizeContentsForGPUAccess:texture_arguments[data]]; + } + textures.push_back(texture_arguments[data]); + } + [encoder endEncoding]; + [command_buffer commit]; + std::vector offsets(buffers.size(), 0); NSRange range = NSMakeRange(0, buffers.size()); + NSRange tex_range = NSMakeRange(0, textures.size()); NSUInteger threads_per_group = state.maxTotalThreadsPerThreadgroup; + NSUInteger thread_width = state.threadExecutionWidth; NSUInteger thread_groups = num_rays/threads_per_group + (num_rays%threads_per_group ? 1 : 0); if (jit::verbose) { std::cout << " Kernel name : " << kernel_name << std::endl; - std::cout << " Threads per group : " << threads_per_group << std::endl; - std::cout << " Number of groups : " << thread_groups << std::endl; - std::cout << " Total problem size : " << threads_per_group*thread_groups << std::endl; + std::cout << " Thread execution width : " << thread_width << std::endl; + std::cout << " Threads per group : " << threads_per_group << std::endl; + std::cout << " Number of groups : " << thread_groups << std::endl; + std::cout << " Total problem size : " << threads_per_group*thread_groups << std::endl; } - return [this, state, buffers, offsets, range, thread_groups, threads_per_group] () mutable { + return [this, state, buffers, offsets, range, tex_range, thread_groups, threads_per_group, textures] () mutable { command_buffer = [queue commandBuffer]; id encoder = [command_buffer computeCommandEncoderWithDispatchType:MTLDispatchTypeSerial]; @@ -157,6 +222,8 @@ namespace gpu { [encoder setBuffers:buffers.data() offsets:offsets.data() withRange:range]; + [encoder setTextures:textures.data() + withRange:tex_range]; [encoder dispatchThreadgroups:MTLSizeMake(thread_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(threads_per_group, 1, 1)]; @@ -178,13 +245,14 @@ namespace gpu { MTLComputePipelineDescriptor *compute = [MTLComputePipelineDescriptor new]; compute.threadGroupSizeIsMultipleOfThreadExecutionWidth = YES; compute.computeFunction = [library newFunctionWithName:@"max_reduction"]; + compute.maxTotalThreadsPerThreadgroup = 1024; + compute.buffers[0].mutability = MTLMutabilityImmutable; NSError *error; id max_state = [device newComputePipelineStateWithDescriptor:compute options:MTLPipelineOptionNone reflection:NULL error:&error]; - if (error) { NSLog(@"%@", error); } @@ -194,6 +262,16 @@ namespace gpu { id buffer = kernel_arguments[argument.get()]; + NSUInteger threads_per_group = max_state.maxTotalThreadsPerThreadgroup; + NSUInteger thread_width = max_state.threadExecutionWidth; + if (jit::verbose) { + std::cout << " Kernel name : max_reduction" << std::endl; + std::cout << " Thread execution width : " << thread_width << std::endl; + std::cout << " Threads per group : " << threads_per_group << std::endl; + std::cout << " Number of groups : " << 1 << std::endl; + std::cout << " Total problem size : " << threads_per_group*1 << std::endl; + } + return [this, run, buffer, result, max_state] () mutable { run(); command_buffer = [queue commandBuffer]; @@ -324,40 +402,77 @@ namespace gpu { /// @params[in] inputs Input variables of the kernel. /// @params[in] outputs Output nodes of the graph to compute. /// @params[in] size Size of the input buffer. +/// @params[in] is_constant Flags if the input is read only. /// @params[in,out] registers Map of used registers. +/// @params[in] usage List of register usage count. +/// @params[in] textures1d List of 1D kernel textures. +/// @params[in] textures2d List of 2D kernel textures. //------------------------------------------------------------------------------ void create_kernel_prefix(std::ostringstream &source_buffer, const std::string name, graph::input_nodes &inputs, graph::output_nodes &outputs, - const size_t size, - jit::register_map ®isters) { + const size_t size, + const std::vector &is_constant, + jit::register_map ®isters, + const jit::register_usage &usage, + jit::texture1d_list &textures1d, + jit::texture2d_list &textures2d) { source_buffer << std::endl; source_buffer << "kernel void " << name << "(" << std::endl; - + + bufferMutability[name] = std::vector (); + for (size_t i = 0, ie = inputs.size(); i < ie; i++) { - source_buffer << " device float *" + bufferMutability[name].push_back(is_constant[i] ? MTLMutabilityMutable : MTLMutabilityImmutable); + source_buffer << " " << (is_constant[i] ? "constant" : "device") + << " float *" << jit::to_string('v', inputs[i].get()) - << " [[buffer(" << i << ")]]," << std::endl; + << " [[buffer(" << i << ")]], // " + << inputs[i]->get_symbol() +#ifndef USE_INPUT_CACHE + << " used " << usage.at(inputs[i].get()) +#endif + << std::endl; } - for (size_t i = 0, ie = outputs.size(); i < ie; i++) { + bufferMutability[name].push_back(MTLMutabilityMutable); source_buffer << " device float *" << jit::to_string('o', outputs[i].get()) << " [[buffer(" << i + inputs.size() << ")]]," << std::endl; } - + size_t index = 0; + for (auto &[key, value] : textures1d) { + source_buffer << " const texture1d " + << jit::to_string('a', key) + << " [[texture(" << index++ << ")]]," + << std::endl; + } + for (auto &[key, value] : textures2d) { + source_buffer << " const texture2d " + << jit::to_string('a', key) + << " [[texture(" << index++ << ")]]," + << std::endl; + } + source_buffer << " uint index [[thread_position_in_grid]]) {" << std::endl; source_buffer << " if (index < " << size << ") {" << std::endl; for (auto &input : inputs) { - registers[input.get()] = jit::to_string('r', input.get()); - source_buffer << " const "; - jit::add_type (source_buffer); - source_buffer << " " << registers[input.get()] << " = " - << jit::to_string('v', input.get()) << "[index];" - << std::endl; +#ifdef USE_INPUT_CACHE + if (usage.at(input.get())) { + registers[input.get()] = jit::to_string('r', input.get()); + source_buffer << " const "; + jit::add_type (source_buffer); + source_buffer << " " << registers[input.get()] << " = " + << jit::to_string('v', input.get()) + << "[index]; // " << input->get_symbol() + << " used " << usage.at(input.get()) << std::endl; + } +#else + registers[input.get()] = jit::to_string('v', input.get()) + "[index]"; +#endif } } @@ -368,13 +483,17 @@ namespace gpu { /// @params[in] outputs Output nodes of the graph to compute. /// @params[in] setters Map outputs back to input values. /// @params[in,out] registers Map of used registers. +/// @params[in] usage List of register usage count. //------------------------------------------------------------------------------ void create_kernel_postfix(std::ostringstream &source_buffer, graph::output_nodes &outputs, graph::map_nodes &setters, - jit::register_map ®isters) { + jit::register_map ®isters, + const jit::register_usage &usage) { for (auto &[out, in] : setters) { - graph::shared_leaf a = out->compile(source_buffer, registers); + graph::shared_leaf a = out->compile(source_buffer, + registers, + usage); source_buffer << " " << jit::to_string('v', in.get()) << "[index] = "; if constexpr (SAFE_MATH) { @@ -385,7 +504,9 @@ namespace gpu { } for (auto &out : outputs) { - graph::shared_leaf a = out->compile(source_buffer, registers); + graph::shared_leaf a = out->compile(source_buffer, + registers, + usage); source_buffer << " " << jit::to_string('o', out.get()) << "[index] = "; if constexpr (SAFE_MATH) { diff --git a/graph_framework/node.hpp b/graph_framework/node.hpp index 360986591889d99fa56d1671e35ba34c9b5786c4..ef5c140da0c806ba1440a3950b63ea730df046fc 100644 --- a/graph_framework/node.hpp +++ b/graph_framework/node.hpp @@ -88,24 +88,40 @@ namespace graph { /// Some nodes require additions to the preamble however most don't so define a /// generic method that does nothing. /// -/// @params[in,out] stream String buffer stream. -/// @params[in,out] registers List of defined registers. -/// @params[in,out] visited List of visited nodes. +/// @params[in,out] stream String buffer stream. +/// @params[in,out] registers List of defined registers. +/// @params[in,out] visited List of visited nodes. +/// @params[in,out] usage List of register usage count. +/// @params[in,out] textures1d List of 1D textures. +/// @params[in,out] textures2d List of 2D textures. +/// @params[in,out] avail_const_mem Available constant memory. //------------------------------------------------------------------------------ virtual void compile_preamble(std::ostringstream &stream, jit::register_map ®isters, - jit::visiter_map &visited) {} + jit::visiter_map &visited, + jit::register_usage &usage, + jit::texture1d_list &textures1d, + jit::texture2d_list &textures2d, + int &avail_const_mem) { + if (usage.find(this) == usage.end()) { + usage[this] = 1; + } else { + ++usage[this]; + } + } //------------------------------------------------------------------------------ /// @brief Compile the node. /// /// @params[in,out] stream String buffer stream. /// @params[in,out] registers List of defined registers. +/// @params[in] usage List of register usage count. /// @returns The current node. //------------------------------------------------------------------------------ virtual std::shared_ptr> compile(std::ostringstream &stream, - jit::register_map ®isters) = 0; + jit::register_map ®isters, + const jit::register_usage &usage) = 0; //------------------------------------------------------------------------------ /// @brief Querey if the nodes match. @@ -171,11 +187,31 @@ namespace graph { jit::register_map ®isters) = 0; //------------------------------------------------------------------------------ -/// @brief Test if node acts like a constant. +/// @brief Test if node is a constant. +/// +/// @returns True if the node is like a constant. +//------------------------------------------------------------------------------ + virtual bool is_constant() const { + return false; + } + +//------------------------------------------------------------------------------ +/// @brief Test the constant node has a zero. +/// +/// @returns True the node has a zero constant value. +//------------------------------------------------------------------------------ + virtual bool has_constant_zero() const { + return false; + } + +//------------------------------------------------------------------------------ +/// @brief Test if the result is normal. /// -/// @returns True if the node acts like a constant. +/// @returns True if the node is normal. //------------------------------------------------------------------------------ - virtual bool is_constant_like() const = 0; + bool is_normal() { + return this->evaluate().is_normal(); + } //------------------------------------------------------------------------------ /// @brief Test if all the subnodes terminate in variables. @@ -187,7 +223,7 @@ namespace graph { //------------------------------------------------------------------------------ /// @brief Test if the node acts like a power of variable. /// -/// Most notes are not so default to false. +/// Most nodes are not so default to false. /// /// @returns True the node is power like and false otherwise. //------------------------------------------------------------------------------ @@ -258,7 +294,7 @@ namespace graph { /// Cache for the backend buffers. inline thread_local static std::map> backend_cache; - + /// Type def to retrieve the backend type. typedef T base; }; @@ -367,11 +403,13 @@ namespace graph { /// /// @params[in,out] stream String buffer stream. /// @params[in,out] registers List of defined registers. +/// @params[in] usage List of register usage count. /// @returns The current node. //------------------------------------------------------------------------------ virtual shared_leaf compile(std::ostringstream &stream, - jit::register_map ®isters) { + jit::register_map ®isters, + const jit::register_usage &usage) { if (registers.find(this) == registers.end()) { registers[this] = jit::to_string('r', this); stream << " const "; @@ -382,7 +420,8 @@ namespace graph { if constexpr (jit::is_complex ()) { jit::add_type (stream); } - stream << temp << ";" << std::endl; + stream << temp << "; // used " + << usage.at(this) << std::endl; } return this->shared_from_this(); @@ -451,14 +490,23 @@ namespace graph { } //------------------------------------------------------------------------------ -/// @brief Test if node acts like a constant. +/// @brief Test if node is a constant. /// -/// @returns True if the node acts like a constant. +/// @returns True if the is a constant. //------------------------------------------------------------------------------ - virtual bool is_constant_like() const { + virtual bool is_constant() const { return true; } +//------------------------------------------------------------------------------ +/// @brief Test the constant node has a zero. +/// +/// @returns True the node has a zero constant value. +//------------------------------------------------------------------------------ + virtual bool has_constant_zero() const { + return data.has_zero(); + } + //------------------------------------------------------------------------------ /// @brief Test if node acts like a variable. /// @@ -700,15 +748,30 @@ namespace graph { //------------------------------------------------------------------------------ /// @brief Compile preamble. /// -/// @params[in,out] stream String buffer stream. -/// @params[in,out] registers List of defined registers. +/// @params[in,out] stream String buffer stream. +/// @params[in,out] registers List of defined registers. +/// @params[in,out] visited List of visited nodes. +/// @params[in,out] usage List of register usage count. +/// @params[in,out] textures1d List of 1D textures. +/// @params[in,out] textures2d List of 2D textures. +/// @params[in,out] avail_const_mem Available constant memory. //------------------------------------------------------------------------------ virtual void compile_preamble(std::ostringstream &stream, jit::register_map ®isters, - jit::visiter_map &visited) { + jit::visiter_map &visited, + jit::register_usage &usage, + jit::texture1d_list &textures1d, + jit::texture2d_list &textures2d, + int &avail_const_mem) { if (visited.find(this) == visited.end()) { - this->arg->compile_preamble(stream, registers, visited); - visited[this] = 0; + this->arg->compile_preamble(stream, registers, + visited, usage, + textures1d, textures2d, + avail_const_mem); + visited.insert(this); + usage[this] = 1; + } else { + ++usage[this]; } } @@ -717,12 +780,14 @@ namespace graph { /// /// @params[in,out] stream String buffer stream. /// @params[in,out] registers List of defined registers. +/// @params[in] usage List of register usage count. /// @returns The current node. //------------------------------------------------------------------------------ virtual shared_leaf compile(std::ostringstream &stream, - jit::register_map ®isters) { - return this->arg->compile(stream, registers); + jit::register_map ®isters, + const jit::register_usage &usage) { + return this->arg->compile(stream, registers, usage); } //------------------------------------------------------------------------------ @@ -732,15 +797,6 @@ namespace graph { return this->arg; } -//------------------------------------------------------------------------------ -/// @brief Test if node acts like a constant. -/// -/// @returns True if the node acts like a constant. -//------------------------------------------------------------------------------ - virtual bool is_constant_like() const { - return this->arg->is_constant_like(); - } - //------------------------------------------------------------------------------ /// @brief Test if node acts like a variable. /// @@ -816,17 +872,34 @@ namespace graph { //------------------------------------------------------------------------------ /// @brief Compile preamble. /// -/// @params[in,out] stream String buffer stream. -/// @params[in,out] registers List of defined registers. -/// @params[in,out] visited List of visited nodes. +/// @params[in,out] stream String buffer stream. +/// @params[in,out] registers List of defined registers. +/// @params[in,out] visited List of visited nodes. +/// @params[in,out] usage List of register usage count. +/// @params[in,out] textures1d List of 1D textures. +/// @params[in,out] textures2d List of 2D textures. +/// @params[in,out] avail_const_mem Available constant memory. //------------------------------------------------------------------------------ virtual void compile_preamble(std::ostringstream &stream, jit::register_map ®isters, - jit::visiter_map &visited) { + jit::visiter_map &visited, + jit::register_usage &usage, + jit::texture1d_list &textures1d, + jit::texture2d_list &textures2d, + int &avail_const_mem) { if (visited.find(this) == visited.end()) { - this->left->compile_preamble(stream, registers, visited); - this->right->compile_preamble(stream, registers, visited); - visited[this] = 0; + this->left->compile_preamble(stream, registers, + visited, usage, + textures1d, textures2d, + avail_const_mem); + this->right->compile_preamble(stream, registers, + visited, usage, + textures1d, textures2d, + avail_const_mem); + visited.insert(this); + usage[this] = 1; + } else { + ++usage[this]; } } @@ -844,16 +917,6 @@ namespace graph { return this->right; } -//------------------------------------------------------------------------------ -/// @brief Test if node acts like a constant. -/// -/// @returns True if the node acts like a constant. -//------------------------------------------------------------------------------ - virtual bool is_constant_like() const { - return this->left->is_constant_like() && - this->right->is_constant_like(); - } - //------------------------------------------------------------------------------ /// @brief Test if node acts like a variable. /// @@ -919,18 +982,38 @@ namespace graph { //------------------------------------------------------------------------------ /// @brief Compile preamble. /// -/// @params[in,out] stream String buffer stream. -/// @params[in,out] registers List of defined registers. -/// @params[in,out] visited List of visited nodes. +/// @params[in,out] stream String buffer stream. +/// @params[in,out] registers List of defined registers. +/// @params[in,out] visited List of visited nodes. +/// @params[in,out] usage List of register usage count. +/// @params[in,out] textures1d List of 1D textures. +/// @params[in,out] textures2d List of 2D textures. +/// @params[in,out] avail_const_mem Available constant memory. //------------------------------------------------------------------------------ virtual void compile_preamble(std::ostringstream &stream, jit::register_map ®isters, - jit::visiter_map &visited) { + jit::visiter_map &visited, + jit::register_usage &usage, + jit::texture1d_list &textures1d, + jit::texture2d_list &textures2d, + int &avail_const_mem) { if (visited.find(this) == visited.end()) { - this->left->compile_preamble(stream, registers, visited); - this->middle->compile_preamble(stream, registers, visited); - this->right->compile_preamble(stream, registers, visited); - visited[this] = 0; + this->left->compile_preamble(stream, registers, + visited, usage, + textures1d, textures2d, + avail_const_mem); + this->middle->compile_preamble(stream, registers, + visited, usage, + textures1d, textures2d, + avail_const_mem); + this->right->compile_preamble(stream, registers, + visited, usage, + textures1d, textures2d, + avail_const_mem); + visited.insert(this); + usage[this] = 1; + } else { + ++usage[this]; } } @@ -941,17 +1024,6 @@ namespace graph { return this->middle; } -//------------------------------------------------------------------------------ -/// @brief Test if node acts like a constant. -/// -/// @returns True if the node acts like a constant. -//------------------------------------------------------------------------------ - virtual bool is_constant_like() const { - return this->left->is_constant_like() && - this->middle->is_constant_like() && - this->right->is_constant_like(); - } - //------------------------------------------------------------------------------ /// @brief Test if node acts like a variable. /// @@ -1013,7 +1085,9 @@ namespace graph { variable_node(const size_t s, const T d, const std::string &symbol) : leaf_node (variable_node::to_string(this), 1, false), - buffer(s, d), symbol(symbol) {} + buffer(s, d), symbol(symbol) { + assert(buffer.is_normal() && "NaN or Inf value."); + } //------------------------------------------------------------------------------ /// @brief Construct a variable node from a vector. @@ -1024,7 +1098,9 @@ namespace graph { variable_node(const std::vector &d, const std::string &symbol) : leaf_node (variable_node::to_string(this), 1, false), - buffer(d), symbol(symbol) {} + buffer(d), symbol(symbol) { + assert(buffer.is_normal() && "NaN or Inf value."); + } //------------------------------------------------------------------------------ /// @brief Construct a variable node from backend buffer. @@ -1035,7 +1111,9 @@ namespace graph { variable_node(const backend::buffer &d, const std::string &symbol) : leaf_node (variable_node::to_string(this), 1, false), - buffer(d), symbol(symbol) {} + buffer(d), symbol(symbol) { + assert(buffer.is_normal() && "NaN or Inf value."); + } //------------------------------------------------------------------------------ /// @brief Evaluate method. @@ -1072,11 +1150,13 @@ namespace graph { /// /// @params[in,out] stream String buffer stream. /// @params[in,out] registers List of defined registers. +/// @params[in] usage List of register usage count. /// @returns The current node. //------------------------------------------------------------------------------ virtual shared_leaf compile(std::ostringstream &stream, - jit::register_map ®isters) { + jit::register_map ®isters, + const jit::register_usage &usage) { return this->shared_from_this(); } @@ -1176,15 +1256,6 @@ namespace graph { T *data() { return buffer.data(); } - -//------------------------------------------------------------------------------ -/// @brief Test if node acts like a constant. -/// -/// @returns True if the node acts like a constant. -//------------------------------------------------------------------------------ - virtual bool is_constant_like() const { - return false; - } //------------------------------------------------------------------------------ /// @brief Test if node acts like a variable. @@ -1384,15 +1455,6 @@ namespace graph { std::cout << "\\right)"; } -//------------------------------------------------------------------------------ -/// @brief Test if node acts like a constant. -/// -/// @returns True if the node acts like a constant. -//------------------------------------------------------------------------------ - virtual bool is_constant_like() const { - return false; - } - //------------------------------------------------------------------------------ /// @brief Test if node acts like a variable. /// diff --git a/graph_framework/piecewise.hpp b/graph_framework/piecewise.hpp index f8a6069fa6ef8d494875aecf524bc7f18654444f..ac610b9c7c3608ee0c220656575a7427501fd63b 100644 --- a/graph_framework/piecewise.hpp +++ b/graph_framework/piecewise.hpp @@ -11,6 +11,30 @@ #include "node.hpp" namespace graph { +//------------------------------------------------------------------------------ +/// @brief Compile an index. +/// +/// @tparam T Base type of the calculation. +/// +/// @params[in,out] stream String buffer stream. +/// @params[in] register_name Reister for the argument. +/// @params[in] length Dimension length of argument. +//------------------------------------------------------------------------------ +template +void compile_index(std::ostringstream &stream, + const std::string ®ister_name, + const size_t length) { + stream << "min(max((unsigned int)"; + if constexpr (jit::is_complex ()) { + stream << "real("; + } + stream << register_name; + if constexpr (jit::is_complex ()) { + stream << ")"; + } + stream << ",0u)," << length - 1 << "u)"; +} + //****************************************************************************** // 1D Piecewise node. //****************************************************************************** @@ -155,38 +179,80 @@ namespace graph { //------------------------------------------------------------------------------ /// @brief Compile preamble. /// -/// @params[in,out] stream String buffer stream. -/// @params[in,out] registers List of defined registers. -/// @params[in,out] visited List of visited nodes. +/// @params[in,out] stream String buffer stream. +/// @params[in,out] registers List of defined registers. +/// @params[in,out] visited List of visited nodes. +/// @params[in,out] usage List of register usage count. +/// @params[in,out] textures1d List of 1D textures. +/// @params[in,out] textures2d List of 2D textures. +/// @params[in,out] avail_const_mem Available constant memory. //------------------------------------------------------------------------------ virtual void compile_preamble(std::ostringstream &stream, jit::register_map ®isters, - jit::visiter_map &visited) { + jit::visiter_map &visited, + jit::register_usage &usage, + jit::texture1d_list &textures1d, + jit::texture2d_list &textures2d, + int &avail_const_mem) { if (visited.find(this) == visited.end()) { + this->arg->compile_preamble(stream, registers, + visited, usage, + textures1d, textures2d, + avail_const_mem); if (registers.find(leaf_node::backend_cache[data_hash].data()) == registers.end()) { registers[leaf_node::backend_cache[data_hash].data()] = jit::to_string('a', leaf_node::backend_cache[data_hash].data()); + const size_t length = leaf_node::backend_cache[data_hash].size(); if constexpr (jit::use_metal ()) { - stream << "constant "; - } - stream << "const "; - jit::add_type (stream); - stream << " " << registers[leaf_node::backend_cache[data_hash].data()] << "[] = {"; - if constexpr (jit::is_complex ()) { + textures1d.try_emplace(leaf_node::backend_cache[data_hash].data(), + length); +#ifdef USE_CUDA_TEXTURES + } else if constexpr (jit::use_cuda()) { + textures1d.try_emplace(leaf_node::backend_cache[data_hash].data(), + length); +#endif + } else { + if constexpr (jit::use_cuda()) { + const int buffer_size = length*sizeof(T); + if (avail_const_mem - buffer_size > 0) { + avail_const_mem -= buffer_size; + stream << "__constant__ "; + } + } + stream << "const "; jit::add_type (stream); - } - stream << leaf_node::backend_cache[data_hash][0]; - for (size_t i = 1, ie = leaf_node::backend_cache[data_hash].size(); - i < ie; i++) { - stream << ", "; + stream << " " << registers[leaf_node::backend_cache[data_hash].data()] << "[] = {"; if constexpr (jit::is_complex ()) { jit::add_type (stream); } - stream << leaf_node::backend_cache[data_hash][i]; + stream << leaf_node::backend_cache[data_hash][0]; + for (size_t i = 1; i < length; i++) { + stream << ", "; + if constexpr (jit::is_complex ()) { + jit::add_type (stream); + } + stream << leaf_node::backend_cache[data_hash][i]; + } + stream << "};" << std::endl; + } + } else { +// When using textures, the register can be defined in a previous kernel. We +// need to add the textures again. + const size_t length = leaf_node::backend_cache[data_hash].size(); + if constexpr (jit::use_metal ()) { + textures1d.try_emplace(leaf_node::backend_cache[data_hash].data(), + length); +#ifdef USE_CUDA_TEXTURES + } else if constexpr (jit::use_cuda()) { + textures1d.try_emplace(leaf_node::backend_cache[data_hash].data(), + length); +#endif } - stream << "};" << std::endl; - visited[this] = 0; } + visited.insert(this); + usage[this] = 1; + } else { + ++usage[this]; } } @@ -207,29 +273,55 @@ namespace graph { /// /// @params[in,out] stream String buffer stream. /// @params[in,out] registers List of defined registers. +/// @params[in] usage List of register usage count. /// @returns The current node. //------------------------------------------------------------------------------ virtual shared_leaf compile(std::ostringstream &stream, - jit::register_map ®isters) { + jit::register_map ®isters, + const jit::register_usage &usage) { if (registers.find(this) == registers.end()) { - shared_leaf a = this->arg->compile(stream, registers); + shared_leaf a = this->arg->compile(stream, + registers, + usage); registers[this] = jit::to_string('r', this); stream << " const "; jit::add_type (stream); - stream << " " << registers[this] << " = " - << registers[leaf_node::backend_cache[data_hash].data()]; - stream << "[max(min((int)"; - if constexpr (jit::is_complex ()) { - stream << "real("; + stream << " " << registers[this] << " = "; +#ifdef USE_CUDA_TEXTURES + if constexpr (jit::use_cuda()) { + if constexpr (jit::is_float () && !jit::is_complex ()) { + stream << "tex1D ("; + } else if constexpr (jit::is_double () && !jit::is_complex ()) { + stream << "to_double(tex1D ("; + } else if constexpr (jit::is_float ()) { + stream << "to_cmp_float(tex1D ("; + } else { + stream << "to_cmp_double(tex1D ("; + } } - stream << registers[a.get()]; - if constexpr (jit::is_complex ()) { - stream << ")"; +#endif + stream << registers[leaf_node::backend_cache[data_hash].data()]; + const size_t length = leaf_node::backend_cache[data_hash].size(); + if constexpr (jit::use_metal ()) { + stream << ".read("; + compile_index (stream, registers[a.get()], length); + stream << ").r;"; +#ifdef USE_CUDA_TEXTURES + } else if constexpr (jit::use_cuda()) { + stream << ", "; + compile_index (stream, registers[a.get()], length); + if constexpr (jit::is_complex () || jit::is_double ()) { + stream << ")"; + } + stream << ");"; +#endif + } else { + stream << "["; + compile_index (stream, registers[a.get()], length); + stream << "];"; } - stream << ", " - << leaf_node::backend_cache[data_hash].size() - 1 << "), 0)];" - << std::endl; + stream << " // used " << usage.at(this) <shared_from_this(); @@ -286,14 +378,23 @@ namespace graph { } //------------------------------------------------------------------------------ -/// @brief Test if node acts like a constant. +/// @brief Test if node is a constant. /// -/// @returns True if the node acts like a constant. +/// @returns True if the node is a constant. //------------------------------------------------------------------------------ - virtual bool is_constant_like() const { + virtual bool is_constant() const { return true; } +//------------------------------------------------------------------------------ +/// @brief Test the constant node has a zero. +/// +/// @returns True the node has a zero constant value. +//------------------------------------------------------------------------------ + virtual bool has_constant_zero() const { + return leaf_node::backend_cache[data_hash].has_zero(); + } + //------------------------------------------------------------------------------ /// @brief Test if node acts like a variable. /// @@ -338,7 +439,18 @@ namespace graph { //------------------------------------------------------------------------------ bool is_arg_match(shared_leaf x) { auto temp = piecewise_1D_cast(x); - return temp.get() && this->arg->is_match(temp->get_arg()); + return temp.get() && + this->arg->is_match(temp->get_arg()) && + (temp->get_size() == this->get_size()); + } + +//------------------------------------------------------------------------------ +/// @brief Get the size of the buffer. +/// +/// @returns The size of the buffer. +//------------------------------------------------------------------------------ + size_t get_size() const { + return leaf_node::backend_cache[data_hash].size(); } }; @@ -506,7 +618,7 @@ namespace graph { branch_node (x, y, piecewise_2D_node::to_string(d, x, y)), data_hash(piecewise_2D_node::hash_data(d)), num_columns(n) { - assert(d.size()/n && + assert(d.size()%n == 0 && "Expected the data buffer to be a multiple of the number of columns."); } @@ -519,6 +631,16 @@ namespace graph { return num_columns; } +//------------------------------------------------------------------------------ +/// @brief Get the number of columns. +/// +/// @returns The number of columns in the constant. +//------------------------------------------------------------------------------ + size_t get_num_rows() const { + return leaf_node::backend_cache[data_hash].size() / + num_columns; + } + //------------------------------------------------------------------------------ /// @brief Evaluate the results of the piecewise constant. /// @@ -561,36 +683,84 @@ namespace graph { //------------------------------------------------------------------------------ /// @brief Compile preamble. /// -/// @params[in,out] stream String buffer stream. -/// @params[in,out] registers List of defined registers. -/// @params[in,out] visited List of visited nodes. +/// @params[in,out] stream String buffer stream. +/// @params[in,out] registers List of defined registers. +/// @params[in,out] visited List of visited nodes. +/// @params[in,out] usage List of register usage count. +/// @params[in,out] textures1d List of 1D textures. +/// @params[in,out] textures2d List of 2D textures. +/// @params[in,out] avail_const_mem Available constant memory. //------------------------------------------------------------------------------ virtual void compile_preamble(std::ostringstream &stream, jit::register_map ®isters, - jit::visiter_map &visited) { + jit::visiter_map &visited, + jit::register_usage &usage, + jit::texture1d_list &textures1d, + jit::texture2d_list &textures2d, + int &avail_const_mem) { if (visited.find(this) == visited.end()) { + this->left->compile_preamble(stream, registers, + visited, usage, + textures1d, textures2d, + avail_const_mem); + this->right->compile_preamble(stream, registers, + visited, usage, + textures1d, textures2d, + avail_const_mem); if (registers.find(leaf_node::backend_cache[data_hash].data()) == registers.end()) { registers[leaf_node::backend_cache[data_hash].data()] = jit::to_string('a', leaf_node::backend_cache[data_hash].data()); + const size_t length = leaf_node::backend_cache[data_hash].size(); if constexpr (jit::use_metal ()) { - stream << "constant "; - } - stream << "const "; - jit::add_type (stream); - stream << " " << registers[leaf_node::backend_cache[data_hash].data()] << "[] = {"; - if constexpr (jit::is_complex ()) { + textures2d.try_emplace(leaf_node::backend_cache[data_hash].data(), + std::array ({length/num_columns, num_columns})); +#ifdef USE_CUDA_TEXTURES + } else if constexpr (jit::use_cuda()) { + textures2d.try_emplace(leaf_node::backend_cache[data_hash].data(), + std::array ({length/num_columns, num_columns})); +#endif + } else { + if constexpr (jit::use_cuda()) { + const int buffer_size = length*sizeof(T); + if (avail_const_mem - buffer_size > 0) { + avail_const_mem -= buffer_size; + stream << "__constant__ "; + } + } + stream << "const "; jit::add_type (stream); - } - stream << leaf_node::backend_cache[data_hash][0]; - for (size_t i = 1, ie = leaf_node::backend_cache[data_hash].size(); i < ie; i++) { - stream << ", "; + stream << " " << registers[leaf_node::backend_cache[data_hash].data()] << "[] = {"; if constexpr (jit::is_complex ()) { jit::add_type (stream); } - stream << leaf_node::backend_cache[data_hash][i]; + stream << leaf_node::backend_cache[data_hash][0]; + for (size_t i = 1; i < length; i++) { + stream << ", "; + if constexpr (jit::is_complex ()) { + jit::add_type (stream); + } + stream << leaf_node::backend_cache[data_hash][i]; + } + stream << "};" << std::endl; + } + } else { +// When using textures, the register can be defined in a previous kernel. We +// need to add the textures again. + const size_t length = leaf_node::backend_cache[data_hash].size(); + if constexpr (jit::use_metal ()) { + textures2d.try_emplace(leaf_node::backend_cache[data_hash].data(), + std::array ({length/num_columns, num_columns})); +#ifdef USE_CUDA_TEXTURES + } else if constexpr (jit::use_cuda()) { + textures2d.try_emplace(leaf_node::backend_cache[data_hash].data(), + std::array ({length/num_columns, num_columns})); +#endif } - stream << "};" << std::endl; } + visited.insert(this); + usage[this] = 1; + } else { + ++usage[this]; } } @@ -624,38 +794,77 @@ namespace graph { /// /// @params[in,out] stream String buffer stream. /// @params[in,out] registers List of defined registers. +/// @params[in] usage List of register usage count. /// @returns The current node. //------------------------------------------------------------------------------ virtual shared_leaf compile(std::ostringstream &stream, - jit::register_map ®isters) { + jit::register_map ®isters, + const jit::register_usage &usage) { if (registers.find(this) == registers.end()) { - shared_leaf x = this->left->compile(stream, registers); - shared_leaf y = this->right->compile(stream, registers); + shared_leaf x = this->left->compile(stream, + registers, + usage); + shared_leaf y = this->right->compile(stream, + registers, + usage); registers[this] = jit::to_string('r', this); stream << " const "; jit::add_type (stream); - stream << " " << registers[this] << " = " - << registers[leaf_node::backend_cache[data_hash].data()]; - stream << "[max(min((int)"; - if constexpr (jit::is_complex ()) { - stream << "real("; - } - stream << registers[x.get()]; - if constexpr (jit::is_complex ()) { - stream << ")"; - } - stream << "*" << num_columns << " + (int)"; - if constexpr (jit::is_complex ()) { - stream << "real("; + stream << " " << registers[this] << " = "; +#ifdef USE_CUDA_TEXTURES + if constexpr (jit::use_cuda()) { + if constexpr (jit::is_float () && !jit::is_complex ()) { + stream << "tex2D ("; + } else if constexpr (jit::is_double () && !jit::is_complex ()) { + stream << "to_double(tex2D ("; + } else if constexpr (jit::is_float ()) { + stream << "to_cmp_float(tex2D ("; + } else { + stream << "to_cmp_double(tex2D ("; + } } - stream << registers[y.get()]; - if constexpr (jit::is_complex ()) { - stream << ")"; +#endif + stream << registers[leaf_node::backend_cache[data_hash].data()]; + const size_t length = leaf_node::backend_cache[data_hash].size(); + const size_t num_rows = length/num_columns; + if constexpr (jit::use_metal ()) { + stream << ".read(uint2("; + compile_index (stream, registers[y.get()], num_columns); + stream << ","; + compile_index (stream, registers[x.get()], num_rows); + stream << ")).r;"; +#ifdef USE_CUDA_TEXTURES + } else if constexpr (jit::use_cuda()) { + stream << ", "; + compile_index (stream, registers[y.get()], num_columns); + stream << ", "; + compile_index (stream, registers[x.get()], num_rows); + if constexpr (jit::is_complex () || jit::is_double ()) { + stream << ")"; + } + stream << ");"; +#endif + } else { + stream << "[min(max((int)"; + if constexpr (jit::is_complex ()) { + stream << "real("; + } + stream << registers[x.get()]; + if constexpr (jit::is_complex ()) { + stream << ")"; + } + stream << "*" << num_columns << " + (int)"; + if constexpr (jit::is_complex ()) { + stream << "real("; + } + stream << registers[y.get()]; + if constexpr (jit::is_complex ()) { + stream << ")"; + } + stream << ",0), " << length - 1 << ")];"; } - stream << ", " - << leaf_node::backend_cache[data_hash].size() - 1 << "), 0)];" - << std::endl; + stream << " // used " << usage.at(this) << std::endl; } return this->shared_from_this(); @@ -716,14 +925,23 @@ namespace graph { } //------------------------------------------------------------------------------ -/// @brief Test if node acts like a constant. +/// @brief Test if node is a constant. /// -/// @returns True if the node acts like a constant. +/// @returns True if the node is a constant. //------------------------------------------------------------------------------ - virtual bool is_constant_like() const { + virtual bool is_constant() const { return true; } +//------------------------------------------------------------------------------ +/// @brief Test the constant node has a zero. +/// +/// @returns True the node has a zero constant value. +//------------------------------------------------------------------------------ + virtual bool has_constant_zero() const { + return leaf_node::backend_cache[data_hash].has_zero(); + } + //------------------------------------------------------------------------------ /// @brief Test if node acts like a variable. /// @@ -763,15 +981,44 @@ namespace graph { //------------------------------------------------------------------------------ /// @brief Check if the args match. /// -/// @param[in] x Node to match. +/// @params[in] x Node to match. /// @returns True if the arguments match. //------------------------------------------------------------------------------ bool is_arg_match(shared_leaf x) { auto temp = piecewise_2D_cast(x); - return temp.get() && - this->left->is_match(temp->get_left()) && - this->right->is_match(temp->get_right()) && - (num_columns == this->get_num_columns()); + return temp.get() && + this->left->is_match(temp->get_left()) && + this->right->is_match(temp->get_right()) && + (temp->get_num_rows() == this->get_num_rows()) && + (temp->get_num_columns() == this->get_num_columns()); + } + +//------------------------------------------------------------------------------ +/// @brief Do the rows match. +/// +/// @params[in] x Node to match. +/// @returns True if the row arguments match. +//------------------------------------------------------------------------------ + bool is_row_match(shared_leaf x) { + auto temp = piecewise_1D_cast(x); + return temp.get() && + this->left->is_match(temp->get_arg()) && + (temp->get_size() == this->get_num_rows()); + } + +//------------------------------------------------------------------------------ +/// @brief Do the columns match. +/// +/// The number of rows is the column dimension. +/// +/// @params[in] x Node to match. +/// @returns True if the column arguments match. +//------------------------------------------------------------------------------ + bool is_col_match(shared_leaf x) { + auto temp = piecewise_1D_cast(x); + return temp.get() && + this->right->is_match(temp->get_arg()) && + (temp->get_size() == this->get_num_columns()); } }; diff --git a/graph_framework/register.hpp b/graph_framework/register.hpp index 8a3feb656a4018512b87bcccf73265ffd7ab080f..41395512525e393f8d0a77b88853ce68e93cc781 100644 --- a/graph_framework/register.hpp +++ b/graph_framework/register.hpp @@ -1,10 +1,7 @@ -// -// register.hpp -// graph_framework -// -// Created by Cianciosa, Mark on 12/8/22. -// Copyright © 2022 Cianciosa, Mark R. All rights reserved. -// +//------------------------------------------------------------------------------ +/// @file register.hpp +/// @brief Utilities for writting jit source code. +//------------------------------------------------------------------------------ #ifndef register_h #define register_h @@ -12,12 +9,14 @@ #include #include #include +#include #include #include #include #include #include #include +#include namespace jit { /// Complex scalar concept. @@ -242,8 +241,14 @@ namespace jit { /// Type alias for mapping node pointers to register names. typedef std::map register_map; +/// Type alias for counting register usage. + typedef std::map register_usage; /// Type alias for listing visited nodes. - typedef std::map visiter_map; + typedef std::set visiter_map; +/// Type alias for indexing 1D textures. + typedef std::map texture1d_list; +/// Type alias for indexing 2D textures. + typedef std::map> texture2d_list; //------------------------------------------------------------------------------ /// @brief Define a custom comparitor class. diff --git a/graph_framework/trigonometry.hpp b/graph_framework/trigonometry.hpp index eefd020ddc50e7f526677cf48f3567b90ce58adf..8a5672d922dc37822f98039c71c72911becca647 100644 --- a/graph_framework/trigonometry.hpp +++ b/graph_framework/trigonometry.hpp @@ -126,19 +126,23 @@ namespace graph { /// /// @params[in,out] stream String buffer stream. /// @params[in,out] registers List of defined registers. +/// @params[in] usage List of register usage count. /// @returns The current node. //------------------------------------------------------------------------------ virtual shared_leaf compile(std::ostringstream &stream, - jit::register_map ®isters) { + jit::register_map ®isters, + const jit::register_usage &usage) { if (registers.find(this) == registers.end()) { - shared_leaf a = this->arg->compile(stream, registers); + shared_leaf a = this->arg->compile(stream, + registers, + usage); registers[this] = jit::to_string('r', this); stream << " const "; jit::add_type (stream); stream << " " << registers[this] << " = sin(" - << registers[a.get()] << ");" - << std::endl; + << registers[a.get()] << "); // used " + << usage.at(this) << std::endl; } return this->shared_from_this(); @@ -364,20 +368,24 @@ namespace graph { /// /// @params[in,out] stream String buffer stream. /// @params[in,out] registers List of defined registers. +/// @params[in] usage List of register usage count. /// @returns The current node. //------------------------------------------------------------------------------ virtual shared_leaf compile(std::ostringstream &stream, - jit::register_map ®isters) { + jit::register_map ®isters, + const jit::register_usage &usage) { if (registers.find(this) == registers.end()) { - shared_leaf a = this->arg->compile(stream, registers); + shared_leaf a = this->arg->compile(stream, + registers, + usage); registers[this] = jit::to_string('r', this); stream << " const "; jit::add_type (stream); stream << " " << registers[this] << " = cos(" - << registers[a.get()] << ");" - << std::endl; + << registers[a.get()] << "); // used " + << usage.at(this) << std::endl; } return this->shared_from_this(); @@ -569,6 +577,61 @@ namespace graph { return constant (this->evaluate()); } + auto pl1 = piecewise_1D_cast(this->left); + auto pr1 = piecewise_1D_cast(this->right); + + if (pl1.get() && (r.get() || pl1->is_arg_match(this->right))) { + return piecewise_1D(this->evaluate(), pl1->get_arg()); + } else if (pr1.get() && (l.get() || pr1->is_arg_match(this->left))) { + return piecewise_1D(this->evaluate(), pr1->get_arg()); + } + + auto pl2 = piecewise_2D_cast(this->left); + auto pr2 = piecewise_2D_cast(this->right); + + if (pl2.get() && (r.get() || pl2->is_arg_match(this->right))) { + return piecewise_2D(this->evaluate(), + pl2->get_num_columns(), + pl2->get_left(), + pl2->get_right()); + } else if (pr2.get() && (l.get() || pr2->is_arg_match(this->left))) { + return piecewise_2D(this->evaluate(), + pr2->get_num_columns(), + pr2->get_left(), + pr2->get_right()); + } + +// Combine 2D and 1D piecewise constants if a row or column matches. + if (pr2.get() && pr2->is_row_match(this->left)) { + backend::buffer result = pl1->evaluate(); + result.atan_row(pr2->evaluate()); + return piecewise_2D(result, + pr2->get_num_columns(), + pr2->get_left(), + pr2->get_right()); + } else if (pr2.get() && pr2->is_col_match(this->left)) { + backend::buffer result = pl1->evaluate(); + result.atan_col(pr2->evaluate()); + return piecewise_2D(result, + pr2->get_num_columns(), + pr2->get_left(), + pr2->get_right()); + } else if (pl2.get() && pl2->is_row_match(this->right)) { + backend::buffer result = pl2->evaluate(); + result.atan_row(pr1->evaluate()); + return piecewise_2D(result, + pl2->get_num_columns(), + pl2->get_left(), + pl2->get_right()); + } else if (pl2.get() && pl2->is_col_match(this->right)) { + backend::buffer result = pl2->evaluate(); + result.atan_col(pr1->evaluate()); + return piecewise_2D(result, + pl2->get_num_columns(), + pl2->get_left(), + pl2->get_right()); + } + return this->shared_from_this(); } @@ -600,14 +663,20 @@ namespace graph { /// /// @params[in,out] stream String buffer stream. /// @params[in,out] registers List of defined registers. +/// @params[in] usage List of register usage count. /// @returns The current node. //------------------------------------------------------------------------------ virtual shared_leaf compile(std::ostringstream &stream, - jit::register_map ®isters) { + jit::register_map ®isters, + const jit::register_usage &usage) { if (registers.find(this) == registers.end()) { - shared_leaf l = this->left->compile(stream, registers); - shared_leaf r = this->right->compile(stream, registers); + shared_leaf l = this->left->compile(stream, + registers, + usage); + shared_leaf r = this->right->compile(stream, + registers, + usage); registers[this] = jit::to_string('r', this); stream << " const "; @@ -621,7 +690,7 @@ namespace graph { << registers[r.get()] << "," << registers[l.get()]; } - stream << ");" << std::endl; + stream << "); // used " << usage.at(this) << std::endl; } return this->shared_from_this(); diff --git a/graph_framework/vector.hpp b/graph_framework/vector.hpp index edafc7e9d8a41f0deeea9bb017ecff303891d811..893ab00852c3761960873e5279a4cfebde126580 100644 --- a/graph_framework/vector.hpp +++ b/graph_framework/vector.hpp @@ -1,9 +1,6 @@ //------------------------------------------------------------------------------ -/// vector.hpp -/// graph_framework -/// -/// Created by Cianciosa, Mark R. on 3/31/22. -/// Copyright © 2022 Cianciosa, Mark R. All rights reserved. +/// @file vector.hpp +/// @brief Defines vectors of graphs. //------------------------------------------------------------------------------ #ifndef vector_h diff --git a/graph_tests/arithmetic_test.cpp b/graph_tests/arithmetic_test.cpp index 124df716a5a776938bfd829a4943e9ab041960eb..b4e5e78722a0588292cb9f4da3104c859e405d8d 100644 --- a/graph_tests/arithmetic_test.cpp +++ b/graph_tests/arithmetic_test.cpp @@ -150,7 +150,7 @@ template void test_add() { // (c1*v1 + c2) + (c3*v1 + c4) -> c5*v1 + c6 auto var_e = graph::variable (1, ""); auto addfma1 = graph::fma(var_b, var_a, var_d) - + graph::fma(var_c, var_a, var_e); + + graph::fma(var_c, var_a, var_e); assert(graph::fma_cast(addfma1).get() && "Expected fused multiply add node."); // (v1*c1 + c2) + (v1*c3 + c4) -> c5*v1 + c6 @@ -250,20 +250,20 @@ template void test_add() { assert(muliply_divide_factor_cast4.get() && "Expected divide node."); // Test node properties. - assert(three->is_constant_like() && "Expected a constant."); + assert(three->is_constant() && "Expected a constant."); assert(!three->is_all_variables() && "Did not expect a variable."); assert(three->is_power_like() && "Expected a power like."); auto constant_add = three + graph::piecewise_1D (std::vector ({static_cast (1.0), static_cast (2.0)}), var_a); - assert(constant_add->is_constant_like() && "Expected a constant."); + assert(constant_add->is_constant() && "Expected a constant."); assert(!constant_add->is_all_variables() && "Did not expect a variable."); assert(constant_add->is_power_like() && "Expected a power like."); auto constant_var_add = three + var_a; - assert(!constant_var_add->is_constant_like() && "Did not expect a constant."); + assert(!constant_var_add->is_constant() && "Did not expect a constant."); assert(!constant_var_add->is_all_variables() && "Did not expect a variable."); assert(!constant_var_add->is_power_like() && "Did not expect a power like."); auto var_var_add = var_a + variable; - assert(!var_var_add->is_constant_like() && "Did not expect a constant."); + assert(!var_var_add->is_constant() && "Did not expect a constant."); assert(var_var_add->is_all_variables() && "Expected a variable."); assert(!var_var_add->is_power_like() && "Did not expect a power like."); } @@ -553,20 +553,20 @@ template void test_subtract() { "Expected a fused multiply add node on the left."); // Test node properties. - assert(zero->is_constant_like() && "Expected a constant."); + assert(zero->is_constant() && "Expected a constant."); assert(!zero->is_all_variables() && "Did not expect a variable."); assert(zero->is_power_like() && "Expected a power like."); auto constant_sub = one - graph::piecewise_1D (std::vector ({static_cast (1.0), static_cast (2.0)}), var_a); - assert(constant_sub->is_constant_like() && "Expected a constant."); + assert(constant_sub->is_constant() && "Expected a constant."); assert(!constant_sub->is_all_variables() && "Did not expect a variable."); assert(constant_sub->is_power_like() && "Expected a power like."); auto constant_var_sub = one - var_a; - assert(!constant_var_sub->is_constant_like() && "Did not expect a constant."); + assert(!constant_var_sub->is_constant() && "Did not expect a constant."); assert(!constant_var_sub->is_all_variables() && "Did not expect a variable."); assert(!constant_var_sub->is_power_like() && "Did not expect a power like."); auto var_var_sub = var_a - var_b; - assert(!var_var_sub->is_constant_like() && "Did not expect a constant."); + assert(!var_var_sub->is_constant() && "Did not expect a constant."); assert(var_var_sub->is_all_variables() && "Expected a variable."); assert(!var_var_sub->is_power_like() && "Did not expect a power like."); @@ -600,6 +600,11 @@ template void test_subtract() { auto factor4 = var_b - (var_b*var_a); assert(graph::multiply_cast(factor4).get() && "Expected a multiply node."); + +// -1*a - b -> -1*(a + b) + auto neg_vara_minus_varb = (graph::none ()*var_a) - var_b; + assert(graph::multiply_cast(neg_vara_minus_varb).get() && + "Expected a multiply node."); } //------------------------------------------------------------------------------ @@ -1029,20 +1034,20 @@ template void test_multiply() { "Expected a divide node."); // Test node properties. - assert(two_times_three->is_constant_like() && "Expected a constant."); + assert(two_times_three->is_constant() && "Expected a constant."); assert(!two_times_three->is_all_variables() && "Did not expect a variable."); assert(two_times_three->is_power_like() && "Expected a power like."); auto constant_mul = three*graph::piecewise_1D (std::vector ({static_cast (1.0), static_cast (2.0)}), variable); - assert(constant_mul->is_constant_like() && "Expected a constant."); + assert(constant_mul->is_constant() && "Expected a constant."); assert(!constant_mul->is_all_variables() && "Did not expect a variable."); assert(constant_mul->is_power_like() && "Expected a power like."); auto constant_var_mul = three*variable; - assert(!constant_var_mul->is_constant_like() && "Did not expect a constant."); + assert(!constant_var_mul->is_constant() && "Did not expect a constant."); assert(!constant_var_mul->is_all_variables() && "Did not expect a variable."); assert(!constant_var_mul->is_power_like() && "Did not expect a power like."); auto var_var_mul = variable*a; - assert(!var_var_mul->is_constant_like() && "Did not expect a constant."); + assert(!var_var_mul->is_constant() && "Did not expect a constant."); assert(var_var_mul->is_all_variables() && "Expected a variable."); assert(!var_var_mul->is_power_like() && "Did not expect a power like."); @@ -1781,8 +1786,8 @@ template void test_divide() { assert(fma_divide_cast2.get() && "Expected an fma node."); // fma(d,a,c*d)/d -> a + c auto fma_divide3 = graph::fma(a, - graph::variable (1, ""), - graph::variable (1, "")*a)/a; + graph::variable (1, ""), + graph::variable (1, "")*a)/a; auto fma_divide_cast3 = graph::add_cast(fma_divide3); assert(fma_divide_cast3.get() && "Expected an add node."); // fma(d,a,c*d)/d -> a + c @@ -1792,6 +1797,15 @@ template void test_divide() { auto fma_divide_cast4 = graph::add_cast(fma_divide4); assert(fma_divide_cast4.get() && "Expected an add node."); +// fma(a,b,a)/a -> 1 + b + auto fma_divide5 = graph::fma(a, graph::variable (1, ""), a)/a; + auto fma_divide5_cast = graph::add_cast(fma_divide5); + assert(fma_divide5_cast.get() && "Expected an add node."); +// fma(b,a,a)/a -> 1 + b + auto fma_divide6 = graph::fma(graph::variable (1, ""), a, a)/a; + auto fma_divide6_cast = graph::add_cast(fma_divide6); + assert(fma_divide6_cast.get() && "Expected an add node."); + // (a*b^c)/b^d -> a*b^(c - d) auto common_power = (variable*graph::pow(a, three))/graph::pow(a, two); assert(graph::multiply_cast(common_power).get() && @@ -1802,20 +1816,20 @@ template void test_divide() { "Expected a multiply node."); // Test node properties. - assert(two_divided_three->is_constant_like() && "Expected a constant."); + assert(two_divided_three->is_constant() && "Expected a constant."); assert(!two_divided_three->is_all_variables() && "Did not expect a variable."); assert(two_divided_three->is_power_like() && "Expected a power like."); auto constant_div = two_divided_three/graph::piecewise_1D (std::vector ({static_cast (1.0), static_cast (2.0)}), variable); - assert(constant_div->is_constant_like() && "Expected a constant."); + assert(constant_div->is_constant() && "Expected a constant."); assert(!constant_div->is_all_variables() && "Did not expect a variable."); assert(constant_div->is_power_like() && "Expected a power like."); auto constant_var_div = two_divided_three/variable; - assert(!constant_var_div->is_constant_like() && "Did not expect a constant."); + assert(!constant_var_div->is_constant() && "Did not expect a constant."); assert(!constant_var_div->is_all_variables() && "Did not expect a variable."); assert(!constant_var_div->is_power_like() && "Did not expect a power like."); auto var_var_div = variable/a; - assert(!var_var_div->is_constant_like() && "Did not expect a constant."); + assert(!var_var_div->is_constant() && "Did not expect a constant."); assert(var_var_div->is_all_variables() && "Expected a variable."); assert(!var_var_div->is_power_like() && "Did not expect a power like."); @@ -2030,6 +2044,28 @@ template void test_fma() { auto var_b = graph::variable (1, ""); auto var_c = graph::variable (1, ""); +// fma(1,a,b) = a + b + auto one_times_vara_plus_varb = graph::fma(one, var_a, var_b); + auto one_times_vara_plus_varb_cast = + graph::add_cast(one_times_vara_plus_varb); + assert(one_times_vara_plus_varb_cast.get() && "Expected an add node."); + +// fma(a,1,b) = a + b + auto vara_times_one_plus_varb = graph::fma(var_a, one, var_b); + auto vara_times_one_plus_varb_cast = + graph::add_cast(vara_times_one_plus_varb); + assert(vara_times_one_plus_varb_cast.get() && "Expected an add node."); + +// fma(b,a,a) = a*(1 + b) + auto common1 = graph::fma(var_a, var_b, var_a); + auto common1_cast = graph::multiply_cast(common1); + assert(common1_cast.get() && "Expected multiply node."); +// fma(b,a,a) = a*(1 + b) + auto common2 = graph::fma(var_b, var_a, var_a); + auto common2_cast = graph::multiply_cast(common2); + assert(common2_cast.get() && "Expected multiply node."); + assert(common1->is_match(common2) && "Expected same graph"); + auto reduce1 = graph::fma(var_a, var_b, var_a*var_c); auto reduce1_cast = graph::multiply_cast(reduce1); assert(reduce1_cast.get() && "Expected multiply node."); @@ -2057,6 +2093,264 @@ template void test_fma() { assert(graph::multiply_cast(graph::fma(two, var_a, one)).get() && "Expected multiply node."); +// fma(a, b, fma(c, b, d)) -> fma(b, a + c, d) + auto var_d = graph::variable (1, ""); + auto match1 = graph::fma(var_b, var_a + var_c, var_d); + auto nested_fma1 = graph::fma(var_a, var_b, + graph::fma(var_c, var_b, var_d)); + assert(nested_fma1->is_match(match1) && "Expected match."); +// fma(b, a, fma(c, b, d)) -> fma(b, a + c, d) + auto nested_fma2 = graph::fma(var_b, var_a, + graph::fma(var_c, var_b, var_d)); + assert(nested_fma2->is_match(match1) && "Expected match."); +// fma(a, b, fma(b, c, d)) -> fma(b, a + c, d) + auto nested_fma3 = graph::fma(var_a, var_b, + graph::fma(var_b, var_c, var_d)); + assert(nested_fma3->is_match(match1) && "Expected match."); +// fma(b, a, fma(b, c, d)) -> fma(b, a + c, d) + auto nested_fma4 = graph::fma(var_b, var_a, + graph::fma(var_b, var_c, var_d)); + assert(nested_fma4->is_match(match1) && "Expected match."); + +// fma(a, e*b, fma(c, b, d)) -> fma(b, fma(a, e, c), d) + auto var_e = graph::variable (1, ""); + auto match2 = graph::fma(var_b, graph::fma(var_a, var_e, var_c), var_d); + auto nested_fma5 = graph::fma(var_a, + var_e*var_b, + graph::fma(var_c, var_b, var_d)); + assert(nested_fma5->is_match(match2) && "Expected match."); +// fma(a, b*e, fma(c, b, d)) -> fma(b, fma(a, e, c), d) + auto nested_fma6 = graph::fma(var_a, + var_b*var_e, + graph::fma(var_c, var_b, var_d)); + assert(nested_fma6->is_match(match2) && "Expected match."); + // fma(a, e*b, fma(b, c, d)) -> fma(b, fma(a, e, c), d) + auto nested_fma7 = graph::fma(var_a, + var_e*var_b, + graph::fma(var_b, var_c, var_d)); + assert(nested_fma7->is_match(match2) && "Expected match."); +// fma(a, b*e, fma(c, b, d)) -> fma(b, fma(a, e, c), d) + auto nested_fma8 = graph::fma(var_a, + var_b*var_e, + graph::fma(var_b, var_c, var_d)); + assert(nested_fma8->is_match(match2) && "Expected match."); + +// fma(e*b, a, fma(c, b, d)) -> fma(b, fma(a, e, c), d) + auto nested_fma9 = graph::fma(var_e*var_b, + var_a, + graph::fma(var_c, var_b, var_d)); + assert(nested_fma9->is_match(match2) && "Expected match."); +// fma(b*e, a, fma(c, b, d)) -> fma(b, fma(a, e, c), d) + auto nested_fma10 = graph::fma(var_b*var_e, + var_a, + graph::fma(var_c, var_b, var_d)); + assert(nested_fma10->is_match(match2) && "Expected match."); +// fma(e*b, a, fma(b, c, d)) -> fma(b, fma(a, e, c), d) + auto nested_fma11 = graph::fma(var_e*var_b, + var_a, + graph::fma(var_b, var_c, var_d)); + assert(nested_fma11->is_match(match2) && "Expected match."); +// fma(e*d, a, fma(b, c, d)) -> fma(b, fma(a, e, c), d) + auto nested_fma12 = graph::fma(var_a, + var_b*var_e, + graph::fma(var_b, var_c, var_d)); + assert(nested_fma12->is_match(match2) && "Expected match."); + +// fma(a, b, fma(c, e*b, d)) -> fma(b, fma(c, e, a), d) + auto match3 = graph::fma(var_b, graph::fma(var_c, var_e, var_a), var_d); + auto nested_fma13 = graph::fma(var_a, + var_b, + graph::fma(var_c, var_e*var_b, var_d)); + assert(nested_fma13->is_match(match3) && "Expected match."); +// fma(b, a, fma(c, e*b, d)) -> fma(b, fma(c, e, a), d) + auto nested_fma14 = graph::fma(var_b, + var_a, + graph::fma(var_c, var_e*var_b, var_d)); + assert(nested_fma14->is_match(match3) && "Expected match."); +// fma(a, b, fma(c, b*e, d)) -> fma(b, fma(c, e, a), d) + auto nested_fma15 = graph::fma(var_a, + var_b, + graph::fma(var_c, var_b*var_e, var_d)); + assert(nested_fma15->is_match(match3) && "Expected match."); +// fma(b, a, fma(c, b*e, d)) -> fma(b, fma(c, e, a), d) + auto nested_fma16 = graph::fma(var_b, + var_a, + graph::fma(var_c, var_b*var_e, var_d)); + assert(nested_fma16->is_match(match3) && "Expected match."); +// fma(a, b, fma(e*b, c, d)) -> fma(b, fma(c, e, a), d) + auto nested_fma17 = graph::fma(var_a, + var_b, + graph::fma(var_e*var_b, var_c, var_d)); + assert(nested_fma17->is_match(match3) && "Expected match."); +// fma(b, a, fma(e*b, c, d)) -> fma(b, fma(c, e, a), d) + auto nested_fma18 = graph::fma(var_b, + var_a, + graph::fma(var_e*var_b, var_c, var_d)); + assert(nested_fma18->is_match(match3) && "Expected match."); +// fma(a, b, fma(b*e, c, d)) -> fma(b, fma(c, e, a), d) + auto nested_fma19 = graph::fma(var_a, + var_b, + graph::fma(var_b*var_e, var_c, var_d)); + assert(nested_fma19->is_match(match3) && "Expected match."); +// fma(b, a, fma(b*e, c, d)) -> fma(b, fma(c, e, a), d) + auto nested_fma20 = graph::fma(var_b, + var_a, + graph::fma(var_b*var_e, var_c, var_d)); + assert(nested_fma20->is_match(match3) && "Expected match."); + +// fma(a, f*b, fma(c, e*b, d)) -> fma(b, fma(a, f, c*e), d) + auto var_f = graph::variable (1, ""); + auto match4 = graph::fma(var_b, graph::fma(var_a, var_f, var_c*var_e), var_d); + auto nested_fma21 = graph::fma(var_a, + var_f*var_b, + graph::fma(var_c, var_e*var_b, var_d)); + assert(nested_fma21->is_match(match4) && "Expected match."); +// fma(a, b*f, fma(c, e*b, d)) -> fma(b, fma(a, f, c*e), d) + auto nested_fma22 = graph::fma(var_a, + var_b*var_f, + graph::fma(var_c, var_e*var_b, var_d)); + assert(nested_fma22->is_match(match4) && "Expected match."); +// fma(a, f*b, fma(c, b*e, d)) -> fma(b, fma(a, f, c*e), d) + auto nested_fma23 = graph::fma(var_a, + var_f*var_b, + graph::fma(var_c, var_b*var_e, var_d)); + assert(nested_fma23->is_match(match4) && "Expected match."); +// fma(a, b*f, fma(c, b*e, d)) -> fma(b, fma(a, f, c*e), d) + auto nested_fma24 = graph::fma(var_a, + var_b*var_f, + graph::fma(var_c, var_b*var_e, var_d)); + assert(nested_fma24->is_match(match4) && "Expected match."); +// fma(f*b, a, fma(c, e*b, d)) -> fma(b, fma(a, f, c*e), d) + auto nested_fma25 = graph::fma(var_f*var_b, + var_a, + graph::fma(var_c, var_e*var_b, var_d)); + assert(nested_fma25->is_match(match4) && "Expected match."); +// fma(b*f, a, fma(c, e*b, d)) -> fma(b, fma(a, f, c*e), d) + auto nested_fma26 = graph::fma(var_b*var_f, + var_a, + graph::fma(var_c, var_e*var_b, var_d)); + assert(nested_fma26->is_match(match4) && "Expected match."); +// fma(f*b, a, fma(c, b*e, d)) -> fma(b, fma(a, f, c*e), d) + auto nested_fma27 = graph::fma(var_f*var_b, + var_a, + graph::fma(var_c, var_b*var_e, var_d)); + assert(nested_fma27->is_match(match4) && "Expected match."); +// fma(b*f, a, fma(c, b*e, d)) -> fma(b, fma(a, f, c*e), d) + auto nested_fma28 = graph::fma(var_b*var_f, + var_a, + graph::fma(var_c, var_b*var_e, var_d)); + assert(nested_fma28->is_match(match4) && "Expected match."); +// fma(a, f*b, fma(e*b, c, d)) -> fma(b, fma(a, f, c*e), d) + auto nested_fma29 = graph::fma(var_a, + var_f*var_b, + graph::fma(var_e*var_b, var_c, var_d)); + assert(nested_fma29->is_match(match4) && "Expected match."); +// fma(a, b*f, fma(e*b, c, d)) -> fma(b, fma(a, f, c*e), d) + auto nested_fma30 = graph::fma(var_a, + var_b*var_f, + graph::fma(var_e*var_b, var_c, var_d)); + assert(nested_fma30->is_match(match4) && "Expected match."); +// fma(a, f*b, fma(b*e, c, d)) -> fma(b, fma(a, f, c*e), d) + auto nested_fma31= graph::fma(var_a, + var_f*var_b, + graph::fma(var_b*var_e, var_c, var_d)); + assert(nested_fma31->is_match(match4) && "Expected match."); +// fma(a, b*f, fma(b*e, c, d)) -> fma(b, fma(a, f, c*e), d) + auto nested_fma32 = graph::fma(var_a, + var_b*var_f, + graph::fma(var_b*var_e, var_c, var_d)); + assert(nested_fma32->is_match(match4) && "Expected match."); +// fma(f*b, a, fma(e*b, c, d)) -> fma(b, fma(a, f, c*e), d) + auto nested_fma33 = graph::fma(var_f*var_b, + var_a, + graph::fma(var_e*var_b, var_c, var_d)); + assert(nested_fma33->is_match(match4) && "Expected match."); +// fma(b*f, a, fma(e*b, c, d)) -> fma(b, fma(a, f, c*e), d) + auto nested_fma34 = graph::fma(var_b*var_f, + var_a, + graph::fma(var_e*var_b, var_c, var_d)); + assert(nested_fma34->is_match(match4) && "Expected match."); +// fma(f*b, a, fma(b*e, c, d)) -> fma(b, fma(a, f, c*e), d) + auto nested_fma35 = graph::fma(var_f*var_b, + var_a, + graph::fma(var_b*var_e, var_c, var_d)); + assert(nested_fma35->is_match(match4) && "Expected match."); +// fma(b*f, a, fma(b*e, c, d)) -> fma(b, fma(a, f, c*e), d) + auto nested_fma36 = graph::fma(var_b*var_f, + var_a, + graph::fma(var_b*var_e, var_c, var_d)); + assert(nested_fma36->is_match(match4) && "Expected match."); + +// fma(a^b,a^c,d) -> a^(b+c) +d + assert(graph::fma(graph::pow(var_a, var_b), + graph::pow(var_a, var_c), + var_d)->is_match(graph::pow(var_a, + var_b + var_c) + var_d) && + "Expected match"); + +// fma(a,x^b,fma(c,x^d,e)) -> fma(x^d,fma(x^(d-b),a,c),e) if b > d + auto matchv1 = graph::fma(graph::pow(var_b, two), + fma(var_b, var_a, var_c), + var_d); + auto matchv2 = graph::fma(graph::pow(var_b, two), + fma(var_b, var_c, var_a), + var_d); + auto nested_fmav1 = graph::fma(var_a, + graph::pow(var_b, three), + fma(var_c, + graph::pow(var_b, two), + var_d)); + assert(nested_fmav1->is_match(matchv1) && "Expected match"); +// fma(a,x^b,fma(c,x^d,e)) -> fma(x^b,fma(x^(d-b),c,a),e) if d > b + auto nested_fmav2 = graph::fma(var_a, + graph::pow(var_b, two), + fma(var_c, + graph::pow(var_b, three), + var_d)); + assert(nested_fmav2->is_match(matchv2) && "Expected match"); +// fma(x^b,a,fma(c,x^d,e)) -> fma(x^d,fma(x^(d-b),a,c),e) if b > d + auto nested_fmav3 = graph::fma(graph::pow(var_b, three), + var_a, + fma(var_c, + graph::pow(var_b, two), + var_d)); + assert(nested_fmav3->is_match(matchv1) && "Expected match"); +// fma(x^b,a,fma(c,x^d,e)) -> fma(x^b,fma(x^(d-b),c,a),e) if d > b + auto nested_fmav4 = graph::fma(graph::pow(var_b, two), + var_a, + fma(var_c, + graph::pow(var_b, three), + var_d)); + assert(nested_fmav4->is_match(matchv2) && "Expected match"); +// fma(a,x^b,fma(x^d,c,e)) -> fma(x^d,fma(x^(d-b),a,c),e) if b > d + auto nested_fmav5 = graph::fma(var_a, + graph::pow(var_b, three), + fma(graph::pow(var_b, two), + var_c, + var_d)); + assert(nested_fmav5->is_match(matchv1) && "Expected match"); +// fma(a,x^b,fma(x^d,c,e)) -> fma(x^b,fma(x^(d-b),c,a),e) if d > b + auto nested_fmav6 = graph::fma(var_a, + graph::pow(var_b, two), + fma(graph::pow(var_b, three), + var_c, + var_d)); + assert(nested_fmav6->is_match(matchv2) && "Expected match"); +// fma(x^b,a,fma(x^d,c,e)) -> fma(x^d,fma(x^(d-b),a,c),e) if b > d + auto nested_fmav7 = graph::fma(graph::pow(var_b, three), + var_a, + fma(graph::pow(var_b, two), + var_c, + var_d)); + assert(nested_fmav7->is_match(matchv1) && "Expected match"); +// fma(x^b,a,fma(x^d,c,e)) -> fma(x^b,fma(x^(d-b),c,a),e) if d > b + auto nested_fmav8 = graph::fma(graph::pow(var_b, two), + var_a, + fma(graph::pow(var_b, three), + var_c, + var_d)); + assert(nested_fmav8->is_match(matchv2) && "Expected match"); + // fma(a, b, a*b) -> 2*a*b // fma(b, a, a*b) -> 2*a*b // fma(a, b, b*a) -> 2*a*b @@ -2109,8 +2403,6 @@ template void test_fma() { "Expected constant node."); // fma(a,b/c,fma(d,e/c,g)) -> (a*b + d*e)/c + g - auto var_d = graph::variable (1, ""); - auto var_e = graph::variable (1, ""); auto chained_fma3 = fma(var_a, var_b/var_c, fma(var_d, var_e/var_c, var)); assert(add_cast(chained_fma3).get() && "expected add node."); // fma(a,b/c,fma(e/c,f,g)) -> (a*b + e*f)/c + g @@ -2153,7 +2445,7 @@ template void test_fma() { "Expetced a divide node."); // Test node properties. - assert(one_two_three->is_constant_like() && "Expected a constant."); + assert(one_two_three->is_constant() && "Expected a constant."); assert(!one_two_three->is_all_variables() && "Did not expect a variable."); assert(one_two_three->is_power_like() && "Expected a power like."); auto constant_fma = graph::fma(one_two_three, @@ -2164,11 +2456,11 @@ template void test_fma() { assert(!constant_fma->is_all_variables() && "Did not expect a variable."); assert(constant_fma->is_power_like() && "Expected a power like."); auto constant_var_fma = graph::fma(var_a, var_b, one); - assert(!constant_var_fma->is_constant_like() && "Did not expect a constant."); + assert(!constant_var_fma->is_constant() && "Did not expect a constant."); assert(!constant_var_fma->is_all_variables() && "Did not expect a variable."); assert(!constant_var_fma->is_power_like() && "Did not expect a power like."); auto var_var_fma = graph::fma(var_a, var_b, var_c); - assert(!var_var_fma->is_constant_like() && "Did not expect a constant."); + assert(!var_var_fma->is_constant() && "Did not expect a constant."); assert(var_var_fma->is_all_variables() && "Expected a variable."); assert(!var_var_fma->is_power_like() && "Did not expect a power like."); @@ -2254,7 +2546,6 @@ template void test_fma() { // fma(a/c, b, d*((f/c)*e)) -> fma(a, b, f*e*d)/c // fma(a, b/c, d*(e*(f/c))) -> fma(a, b, f*e*d)/c // fma(a/c, b, d*(e*(f/c))) -> fma(a, b, f*e*d)/c - auto var_f = graph::variable (1, ""); auto exp_a = (one + var_a); auto exp_b = (one + var_b); auto exp_c = (one + var_c); @@ -2442,6 +2733,23 @@ template void test_fma() { assert(fmaexp21_cast.get() && "Expected an add node."); assert(graph::divide_cast(fmaexp21_cast->get_left()).get() && "Expected a dive node on the left."); + +// fma(p2,p1,a) -> fma(p1,p2,a) + auto p1 = graph::piecewise_1D (std::vector ({static_cast (1.0), + static_cast (2.0)}), + var_a); + auto p2 = graph::piecewise_2D (std::vector ({static_cast (1.0), + static_cast (2.0), + static_cast (3.0), + static_cast (4.0)}), + 2, var_b, var_c); + auto fma_promote = graph::fma(p2, p1, var_a); + auto fma_promote_cast = graph::fma_cast(fma_promote); + assert(fma_promote_cast.get() && "Expected a fma node."); + assert(graph::piecewise_1D_cast(fma_promote_cast->get_left()).get() && + "Expected a piecewise 1d node on the left."); + assert(graph::piecewise_2D_cast(fma_promote_cast->get_middle()).get() && + "Expected a piecewise 2d node in the middle."); } //------------------------------------------------------------------------------ diff --git a/graph_tests/backend_test.cpp b/graph_tests/backend_test.cpp index 79224c2a0451f41f64b5ceb6ee20663fe7833975..b4e303d827b32c28251dd8890b011929c5257e3f 100644 --- a/graph_tests/backend_test.cpp +++ b/graph_tests/backend_test.cpp @@ -543,6 +543,30 @@ template void test_backend() { static_cast (2.0) })); assert(!base_vec.is_negative() && "Expected false."); + + backend::buffer has_zero_vec(std::vector ({ + static_cast (3.0), + static_cast (0.0) + })); + assert(has_zero_vec.has_zero() && "Expected zero."); + backend::buffer has_zero_vec2(std::vector ({ + static_cast (3.0), + static_cast (1.0) + })); + assert(!has_zero_vec2.has_zero() && "Expected zero."); + assert(has_zero_vec2.is_normal() && "Expected normal."); + + backend::buffer inf_vec(std::vector ({ + static_cast (3.0), + static_cast (INFINITY) + })); + assert(!inf_vec.is_normal() && "Expected a inf."); + + backend::buffer nan_vec(std::vector ({ + static_cast (3.0), + static_cast (NAN) + })); + assert(!nan_vec.is_normal() && "Expected a NaN."); } //------------------------------------------------------------------------------ diff --git a/graph_tests/math_test.cpp b/graph_tests/math_test.cpp index 0c5440bb381256d41240059ec0f504fb43ea9bf8..6bca69f6bf8a619e1d0d2ef3a270bfbc5245228d 100644 --- a/graph_tests/math_test.cpp +++ b/graph_tests/math_test.cpp @@ -110,10 +110,10 @@ void test_sqrt() { // Test node properties. auto sqrt_const = graph::sqrt(graph::piecewise_1D (std::vector ({static_cast (1.0), static_cast (2.0)}), var)); - assert(sqrt_const->is_constant_like() && "Expected a constant."); + assert(sqrt_const->is_constant() && "Expected a constant."); assert(!sqrt_const->is_all_variables() && "Did not expect a variable."); assert(sqrt_const->is_power_like() && "Expected a power like."); - assert(!sqrt_var->is_constant_like() && "Did not expect a constant."); + assert(!sqrt_var->is_constant() && "Did not expect a constant."); assert(sqrt_var->is_all_variables() && "Expected a variable."); assert(sqrt_var->is_power_like() && "Expected a power like."); } @@ -152,7 +152,7 @@ void test_exp() { assert(dexp_var->evaluate().at(0) == std::exp(static_cast (3.0))); // Test node properties. - assert(!exp_var->is_constant_like() && "Did not expect a constant."); + assert(!exp_var->is_constant() && "Did not expect a constant."); assert(exp_var->is_all_variables() && "Expected a variable."); assert(!exp_var->is_power_like() && "Did not expect a power like."); } @@ -239,7 +239,7 @@ void test_pow() { assert(sqrd_neg->evaluate().at(0) == static_cast (non_int_neg*non_int_neg) && "Expected x*x"); - auto three = graph::two (); + auto three = graph::constant (static_cast (3)); auto pow_pow1 = graph::pow(graph::pow(ten, three), two); auto pow_pow2 = graph::pow(ten, three*two); assert(pow_pow1->is_match(pow_pow2) && @@ -273,16 +273,16 @@ void test_pow() { auto var_a = graph::variable (1, ""); auto pow_const = graph::pow(three, graph::piecewise_1D (std::vector ({static_cast (1.0), static_cast (2.0)}), var_a)); - assert(pow_const->is_constant_like() && "Expected a constant."); + assert(pow_const->is_constant() && "Expected a constant."); assert(!pow_const->is_all_variables() && "Did not expect a variable."); assert(pow_const->is_power_like() && "Expected a power like."); auto pow_var = graph::pow(var_a, three); - assert(!pow_var->is_constant_like() && "Did not expect a constant."); + assert(!pow_var->is_constant() && "Did not expect a constant."); assert(pow_var->is_all_variables() && "Expected a variable."); assert(pow_var->is_power_like() && "Expected a power like."); auto var_b = graph::variable (1, ""); auto pow_var_var = graph::pow(var_a, var_b); - assert(!pow_var->is_constant_like() && "Did not expect a constant."); + assert(!pow_var->is_constant() && "Did not expect a constant."); assert(pow_var->is_all_variables() && "Expected a variable."); assert(pow_var->is_power_like() && "Expected a power like."); @@ -320,6 +320,10 @@ void test_pow() { auto powexp_float_cast = graph::pow_cast(powexp_float); assert(powexp_float_cast.get() && "Expected power cast."); + +// c1^c2 + assert(graph::constant_cast(graph::pow(two, three)).get() && + "Expected a constant node."); } //------------------------------------------------------------------------------ @@ -340,7 +344,7 @@ void test_log() { auto dlogy = logy->df(y); assert(graph::divide_cast(dlogy) && "Expected divide node."); - assert(!logy->is_constant_like() && "Did not expect a constant."); + assert(!logy->is_constant() && "Did not expect a constant."); assert(logy->is_all_variables() && "Expected a variable."); assert(!logy->is_power_like() && "Did not expect a power like."); } @@ -367,7 +371,7 @@ void test_erfi() { "Expected a constant node."); // Test node properties. - assert(!erfi->is_constant_like() && "Did not expect a constant."); + assert(!erfi->is_constant() && "Did not expect a constant."); assert(erfi->is_all_variables() && "Expected a variable."); assert(!erfi->is_power_like() && "Did not expect a power like."); } diff --git a/graph_tests/node_test.cpp b/graph_tests/node_test.cpp index 0eb9a687c5062ae0d61d76d89b73ad90e96fffb2..a0bdc78b613f73c963d4f67a5735446ccf4c4176 100644 --- a/graph_tests/node_test.cpp +++ b/graph_tests/node_test.cpp @@ -60,7 +60,7 @@ void test_constant() { assert(c1->is_match(c2) && "Expected match."); // Test node properties. - assert(c1->is_constant_like() && "Expected a constant."); + assert(c1->is_constant() && "Expected a constant."); assert(!c1->is_all_variables() && "Did not expect a variable."); assert(c1->is_power_like() && "Expected a power like."); } @@ -124,7 +124,7 @@ void test_variable() { assert(!v1->is_match(v2) && "Expected no match."); // Test node properties. - assert(!v1->is_constant_like() && "Did not expect a constant."); + assert(!v1->is_constant() && "Did not expect a constant."); assert(v1->is_all_variables() && "Expected a variable."); assert(v1->is_power_like() && "Expected a power like."); } @@ -157,7 +157,7 @@ void test_pseudo_variable() { "Expected constant node."); // Test node properties. - assert(!c->is_constant_like() && "Did not expect a constant."); + assert(!c->is_constant() && "Did not expect a constant."); assert(c->is_all_variables() && "Expected a variable."); assert(c->is_power_like() && "Expected a power like."); } diff --git a/graph_tests/piecewise_test.cpp b/graph_tests/piecewise_test.cpp index fa8b4dad39442b7140a74bc8022a9374f82c9893..222286d89bfa86e2b6d29cf25c8f56e4afc2dce5 100644 --- a/graph_tests/piecewise_test.cpp +++ b/graph_tests/piecewise_test.cpp @@ -8,6 +8,7 @@ #undef NDEBUG #endif +#include #include #include "../graph_framework/arithmetic.hpp" @@ -27,7 +28,7 @@ /// @params[in] tolarance Test tolarance. //------------------------------------------------------------------------------ template void check(const T test, - const T tolarance) { + const T tolarance) { if constexpr (jit::is_complex ()) { assert(std::real(test) <= std::real(tolarance) && "Real GPU and CPU values differ."); @@ -84,6 +85,10 @@ template void piecewise_1D() { auto p2 = graph::piecewise_1D (std::vector ({static_cast (2.0), static_cast (4.0), static_cast (6.0)}), b); + auto p3 = graph::piecewise_1D (std::vector ({static_cast (2.0), + static_cast (4.0), + static_cast (6.0)}), a); + auto zero = graph::zero (); assert(graph::constant_cast(p1*zero).get() && @@ -95,6 +100,8 @@ template void piecewise_1D() { "Expected a piecewise_1D node."); assert(graph::multiply_cast(p1*p2).get() && "Expected a multiply node."); + assert(graph::piecewise_1D_cast(p1*p3).get() && + "Expected a piecewise_1D node."); assert(graph::piecewise_1D_cast(p1 + zero).get() && "Expected a piecewise_1D node."); @@ -102,6 +109,8 @@ template void piecewise_1D() { "Expected a piecewise_1D node."); assert(graph::add_cast(p1 + p2).get() && "Expected an add node."); + assert(graph::piecewise_1D_cast(p1 + p3).get() && + "Expected a piecewise_1D node."); assert(graph::piecewise_1D_cast(p1 - zero).get() && "Expected a piecewise_1D node."); @@ -109,20 +118,31 @@ template void piecewise_1D() { "Expected a piecewise_1D node."); assert(graph::subtract_cast(p1 - p2).get() && "Expected a subtract node."); + assert(graph::piecewise_1D_cast(p1 - p3).get() && + "Expected a piecewise_1D node."); assert(graph::constant_cast(zero/p1).get() && "Expected a constant node."); assert(graph::piecewise_1D_cast(p1/two).get() && "Expected a piecewise_1D node."); - assert(graph::divide_cast(p1/p2).get() && - "Expected a divide node."); + assert(graph::multiply_cast(p1/p2).get() && + "Expected a multiply node."); + assert(graph::constant_cast(p1/p3).get() && + "Expected a constant node."); assert(graph::piecewise_1D_cast(graph::fma(p1, two, zero)).get() && "Expected a piecewise_1D node."); assert(graph::add_cast(graph::fma(p1, two, p2)).get() && "Expected an add node."); - assert(graph::fma_cast(graph::fma(p1, p2, two)).get() && - "Expected a fma node."); + auto temp = graph::fma(p1, p2, two); + assert(graph::multiply_cast(graph::fma(p1, p2, two)).get() && + "Expected a multiply node."); + assert(graph::add_cast(graph::fma(p1, p3, p2)).get() && + "Expected an add node."); + assert(graph::piecewise_1D_cast(graph::fma(p1, p3, two)).get() && + "Expected a piecewise_1D node."); + assert(graph::piecewise_1D_cast(graph::fma(p1, p3, p1)).get() && + "Expected a piecewise_1D node."); assert(graph::piecewise_1D_cast(graph::sqrt(p1)).get() && "Expected a piecewise_1D node."); @@ -137,6 +157,8 @@ template void piecewise_1D() { "Expected a piecewise_1D node."); assert(graph::pow_cast(graph::pow(p1, p2)).get() && "Expected a pow constant."); + assert(graph::piecewise_1D_cast(graph::pow(p1, p3)).get() && + "Expected a piecewise_1D node."); assert(graph::piecewise_1D_cast(graph::sin(p1)).get() && "Expected a piecewise_1D node."); @@ -147,10 +169,12 @@ template void piecewise_1D() { assert(graph::piecewise_1D_cast(graph::tan(p1)).get() && "Expected a piecewise_1D node."); - assert(graph::atan_cast(graph::atan(p1, two)).get() && - "Expected an atan node."); + assert(graph::piecewise_1D_cast(graph::atan(p1, two)).get() && + "Expected a piecewise_1D node."); assert(graph::atan_cast(graph::atan(p1, p2)).get() && - "Expected a atan constant."); + "Expected an atan node."); + assert(graph::constant_cast(graph::atan(p1, p3)).get() && + "Expected a constant node."); a->set(static_cast (1.5)); compile ({graph::variable_cast(a)}, @@ -166,7 +190,42 @@ template void piecewise_1D() { compile ({graph::variable_cast(a)}, {p1}, {}, static_cast (3.0), 0.0); - + + a->set(static_cast (1.5)); + compile ({graph::variable_cast(a)}, + {p1 + p3}, {}, + static_cast (6.0), 0.0); + compile ({graph::variable_cast(a)}, + {p1 - p3}, {}, + static_cast (-2.0), 0.0); + compile ({graph::variable_cast(a)}, + {p1*p3}, {}, + static_cast (8.0), 0.0); + compile ({graph::variable_cast(a)}, + {p1/p3}, {}, + static_cast (0.5), 0.0); + compile ({graph::variable_cast(a), + graph::variable_cast(b)}, + {graph::fma(p1, p3, p2)}, {}, + static_cast (10.0), 0.0); + compile ({graph::variable_cast(a)}, + {graph::pow(p1, p3)}, {}, + static_cast (std::pow(static_cast (2.0), + static_cast (4.0))), 0.0); + if constexpr (jit::is_complex ()) { + compile ({graph::variable_cast(a)}, + {graph::atan(p1, p3)}, {}, + static_cast (std::atan(static_cast (4.0) / + static_cast (2.0))), + 0.0); + } else { + compile ({graph::variable_cast(a)}, + {graph::atan(p1, p3)}, {}, + static_cast (std::atan2(static_cast (4.0), + static_cast (2.0))), + 0.0); + } + auto pc = graph::piecewise_1D (std::vector ({static_cast (10.0), static_cast (10.0), static_cast (10.0)}), a); @@ -184,16 +243,24 @@ template void piecewise_2D() { auto ay = graph::variable (1, ""); auto bx = graph::variable (1, ""); auto by = graph::variable (1, ""); - auto p1 = graph::piecewise_2D (std::vector ({static_cast (1.0), - static_cast (2.0), - static_cast (3.0), - static_cast (4.0)}), - 2, ax, ay); - auto p2 = graph::piecewise_2D (std::vector ({static_cast (2.0), - static_cast (4.0), - static_cast (6.0), - static_cast (10.0)}), - 2, bx, by); + auto p1 = graph::piecewise_2D (std::vector ({ + static_cast (1.0), static_cast (2.0), + static_cast (3.0), static_cast (4.0) + }), 2, ax, ay); + auto p2 = graph::piecewise_2D (std::vector ({ + static_cast (2.0), static_cast (4.0), + static_cast (6.0), static_cast (10.0) + }), 2, bx, by); + auto p3 = graph::piecewise_2D (std::vector ({ + static_cast (2.0), static_cast (4.0), + static_cast (6.0), static_cast (10.0) + }), 2, ax, ay); + auto p4 = graph::piecewise_1D (std::vector ({ + static_cast (2.0), static_cast (4.0) + }), ax); + auto p5 = graph::piecewise_1D (std::vector ({ + static_cast (2.0), static_cast (4.0) + }), ay); auto zero = graph::zero (); @@ -206,6 +273,12 @@ template void piecewise_2D() { "Expected a piecewise_2D node."); assert(graph::multiply_cast(p1*p2).get() && "Expected a multiply node."); + assert(graph::piecewise_2D_cast(p1*p3).get() && + "Expected a piecewise_2D node."); + assert(graph::piecewise_2D_cast(p1*p4).get() && + "Expected a piecewise_2D node."); + assert(graph::piecewise_2D_cast(p1*p5).get() && + "Expected a piecewise_2D node."); assert(graph::piecewise_2D_cast(p1 + zero).get() && "Expected a piecewise_2D node."); @@ -213,6 +286,12 @@ template void piecewise_2D() { "Expected a piecewise_2D node."); assert(graph::add_cast(p1 + p2).get() && "Expected an add node."); + assert(graph::piecewise_2D_cast(p1 + p3).get() && + "Expected a piecewise_2D node."); + assert(graph::piecewise_2D_cast(p1 + p4).get() && + "Expected a piecewise_2D node."); + assert(graph::piecewise_2D_cast(p1 + p5).get() && + "Expected a piecewise_2D node."); assert(graph::piecewise_2D_cast(p1 - zero).get() && "Expected a piecewise_2D node."); @@ -220,20 +299,42 @@ template void piecewise_2D() { "Expected a piecewise_2D node."); assert(graph::subtract_cast(p1 - p2).get() && "Expected a subtract node."); + assert(graph::piecewise_2D_cast(p1 - p3).get() && + "Expected a piecewise_2D node."); + assert(graph::piecewise_2D_cast(p1 - p4).get() && + "Expected a piecewise_2D node."); + assert(graph::piecewise_2D_cast(p1 - p5).get() && + "Expected a piecewise_2D node."); assert(graph::constant_cast(zero/p1).get() && "Expected a constant node."); assert(graph::piecewise_2D_cast(p1/two).get() && "Expected a piecewise_2D node."); - assert(graph::divide_cast(p1/p2).get() && - "Expected a divide node."); + assert(graph::multiply_cast(p1/p2).get() && + "Expected a multiply node."); + assert(graph::piecewise_2D_cast(p1/p3).get() && + "Expected a piecewise_2D node."); + assert(graph::piecewise_2D_cast(p1/p4).get() && + "Expected a piecewise_2D node."); + assert(graph::piecewise_2D_cast(p1/p5).get() && + "Expected a piecewise_2D node."); assert(graph::piecewise_2D_cast(graph::fma(p1, two, zero)).get() && "Expected a piecewise_2D node."); assert(graph::add_cast(graph::fma(p1, two, p2)).get() && "Expected an add node."); - assert(graph::fma_cast(graph::fma(p1, p2, two)).get() && - "Expected a fma node."); + assert(graph::multiply_cast(graph::fma(p1, p2, two)).get() && + "Expected a multiply node."); + assert(graph::add_cast(graph::fma(p1, p3, p2)).get() && + "Expected an add node."); + assert(graph::piecewise_2D_cast(graph::fma(p1, p3, two)).get() && + "Expected a piecewise_2D node."); + assert(graph::piecewise_2D_cast(graph::fma(p1, p3, p1)).get() && + "Expected a piecewise_2D node."); + assert(graph::add_cast(graph::fma(p1, p4, p2)).get() && + "Expected an add node."); + assert(graph::add_cast(graph::fma(p1, p5, p2)).get() && + "Expected an add node."); assert(graph::piecewise_2D_cast(graph::sqrt(p1)).get() && "Expected a piecewise_2D node."); @@ -247,7 +348,13 @@ template void piecewise_2D() { assert(graph::piecewise_2D_cast(graph::pow(p1, two)).get() && "Expected a piecewise_2D node."); assert(graph::pow_cast(graph::pow(p1, p2)).get() && - "Expected a pow constant."); + "Expected a pow node."); + assert(graph::piecewise_2D_cast(graph::pow(p1, p3)).get() && + "Expected a pow node."); + assert(graph::piecewise_2D_cast(graph::pow(p1, p4)).get() && + "Expected a piecewise_2D node."); + assert(graph::piecewise_2D_cast(graph::pow(p1, p5)).get() && + "Expected a piecewise_2D node."); assert(graph::piecewise_2D_cast(graph::sin(p1)).get() && "Expected a piecewise_2D node."); @@ -258,10 +365,16 @@ template void piecewise_2D() { assert(graph::piecewise_2D_cast(graph::tan(p1)).get() && "Expected a piecewise_2D node."); - assert(graph::atan_cast(graph::atan(p1, two)).get() && - "Expected an atan node."); + assert(graph::piecewise_2D_cast(graph::atan(p1, two)).get() && + "Expected a piecewise_2d node."); assert(graph::atan_cast(graph::atan(p1, p2)).get() && - "Expected a atan constant."); + "Expected an atan node."); + assert(graph::piecewise_2D_cast(graph::atan(p1, p3)).get() && + "Expected a piecewise_2d node."); + assert(graph::piecewise_2D_cast(graph::atan(p1, p4)).get() && + "Expected a piecewise_2d node."); + assert(graph::piecewise_2D_cast(graph::atan(p1, p5)).get() && + "Expected a piecewise_2d node."); ax->set(static_cast (1.5)); ay->set(static_cast (1.5)); @@ -290,7 +403,149 @@ template void piecewise_2D() { graph::variable_cast(ay)}, {p1}, {}, static_cast (3.0), 0.0); - + + ax->set(static_cast (0.5)); + ay->set(static_cast (1.5)); + compile ({graph::variable_cast(ax), + graph::variable_cast(ay)}, + {p1 + p3}, {}, + static_cast (6.0), 0.0); + compile ({graph::variable_cast(ax), + graph::variable_cast(ay)}, + {p1 - p3}, {}, + static_cast (-2.0), 0.0); + compile ({graph::variable_cast(ax), + graph::variable_cast(ay)}, + {p1*p3}, {}, + static_cast (8.0), 0.0); + compile ({graph::variable_cast(ax), + graph::variable_cast(ay)}, + {p1/p3}, {}, + static_cast (0.5), 0.0); + bx->set(static_cast (1.5)); + by->set(static_cast (0.5)); + compile ({graph::variable_cast(ax), + graph::variable_cast(ay), + graph::variable_cast(bx), + graph::variable_cast(by)}, + {graph::fma(p1, p3, p2)}, {}, + static_cast (14.0), 0.0); + compile ({graph::variable_cast(ax), + graph::variable_cast(ay)}, + {graph::pow(p1, p3)}, {}, + static_cast (std::pow(static_cast (2.0), + static_cast (4.0))), 0.0); + if constexpr (jit::is_complex ()) { + compile ({graph::variable_cast(ax), + graph::variable_cast(ay)}, + {graph::atan(p1, p3)}, {}, + static_cast (std::atan(static_cast (4.0) / + static_cast (2.0))), + 0.0); + } else { + compile ({graph::variable_cast(ax), + graph::variable_cast(ay)}, + {graph::atan(p1, p3)}, {}, + static_cast (std::atan2(static_cast (4.0), + static_cast (2.0))), + 0.0); + } + +// Test row combines. + compile ({graph::variable_cast(ax), + graph::variable_cast(ay)}, + {p1}, {}, + static_cast (2.0), 0.0); + compile ({graph::variable_cast(ax)}, + {p4}, {}, + static_cast (2.0), 0.0); + compile ({graph::variable_cast(ax), + graph::variable_cast(ay)}, + {p1 + p4}, {}, + static_cast (4.0), 0.0); + compile ({graph::variable_cast(ax), + graph::variable_cast(ay)}, + {p1 - p4}, {}, + static_cast (0.0), 0.0); + compile ({graph::variable_cast(ax), + graph::variable_cast(ay)}, + {p1*p4}, {}, + static_cast (4.0), 0.0); + compile ({graph::variable_cast(ax), + graph::variable_cast(ay)}, + {p1/p4}, {}, + static_cast (1.0), 0.0); + compile ({graph::variable_cast(ax), + graph::variable_cast(ay), + graph::variable_cast(bx), + graph::variable_cast(by)}, + {graph::fma(p1, p4, p2)}, {}, + static_cast (10.0), 0.0); + compile ({graph::variable_cast(ax), + graph::variable_cast(ay)}, + {graph::pow(p1, p4)}, {}, + static_cast (std::pow(static_cast (2.0), + static_cast (2.0))), 0.0); + if constexpr (jit::is_complex ()) { + compile ({graph::variable_cast(ax), + graph::variable_cast(ay)}, + {graph::atan(p1, p4)}, {}, + static_cast (std::atan(static_cast (2.0) / + static_cast (2.0))), + 0.0); + } else { + compile ({graph::variable_cast(ax), + graph::variable_cast(ay)}, + {graph::atan(p1, p4)}, {}, + static_cast (std::atan2(static_cast (2.0), + static_cast (2.0))), + 0.0); + } + +// Test column combines. + compile ({graph::variable_cast(ax), + graph::variable_cast(ay)}, + {p1 + p5}, {}, + static_cast (6.0), 0.0); + compile ({graph::variable_cast(ax), + graph::variable_cast(ay)}, + {p1 - p5}, {}, + static_cast (-2.0), 0.0); + compile ({graph::variable_cast(ax), + graph::variable_cast(ay)}, + {p1*p5}, {}, + static_cast (8.0), 0.0); + compile ({graph::variable_cast(ax), + graph::variable_cast(ay)}, + {p1/p5}, {}, + static_cast (0.5), 0.0); + compile ({graph::variable_cast(ax), + graph::variable_cast(ay), + graph::variable_cast(bx), + graph::variable_cast(by)}, + {graph::fma(p1, p5, p2)}, {}, + static_cast (14.0), 0.0); + compile ({graph::variable_cast(ax), + graph::variable_cast(ay)}, + {graph::pow(p1, p5)}, {}, + static_cast (std::pow(static_cast (2.0), + static_cast (4.0))), 0.0); + if constexpr (jit::is_complex ()) { + compile ({graph::variable_cast(ax), + graph::variable_cast(ay)}, + {graph::atan(p1, p5)}, {}, + static_cast (std::atan(static_cast (4.0) / + static_cast (2.0))), + 0.0); + } else { + compile ({graph::variable_cast(ax), + graph::variable_cast(ay)}, + {graph::atan(p1, p5)}, {}, + static_cast (std::atan2(static_cast (4.0), + static_cast (2.0))), + 0.0); + } + auto pc = graph::piecewise_2D (std::vector ({static_cast (10.0), static_cast (10.0), static_cast (10.0), @@ -298,6 +553,51 @@ template void piecewise_2D() { 2, ax, bx); assert(graph::constant_cast(pc).get() && "Expected a constant."); + + auto prc = graph::piecewise_1D (std::vector ({ + static_cast (1.0), + static_cast (2.0), + static_cast (3.0) + }), ax); + auto pcc = graph::piecewise_1D (std::vector ({ + static_cast (1.0), + static_cast (2.0), + static_cast (3.0) + }), ay); + auto p2Dc = graph::piecewise_2D (std::vector ({ + static_cast (1.0), static_cast (2.0), + static_cast (3.0), static_cast (4.0), + static_cast (5.0), static_cast (6.0) + }), 2, ax, ay); + + auto row_test = prc + p2Dc; + auto row_test_cast = graph::piecewise_2D_cast(row_test); + assert(row_test_cast.get() && "Expected a 2D piecewise node.."); + + auto col_test = pcc + p2Dc; + auto col_test_cast = graph::add_cast(col_test); + assert(col_test_cast.get() && "Expected an add node."); + + ax->set(static_cast (2.5)); + ay->set(static_cast (1.5)); + compile ({graph::variable_cast(ax)}, + {prc}, {}, + static_cast (3.0), 0.0); + compile ({graph::variable_cast(ay)}, + {pcc}, {}, + static_cast (2.0), 0.0); + compile ({graph::variable_cast(ax), + graph::variable_cast(ay)}, + {p2Dc}, {}, + static_cast (6.0), 0.0); + compile ({graph::variable_cast(ax), + graph::variable_cast(ay)}, + {row_test}, {}, + static_cast (9.0), 0.0); + compile ({graph::variable_cast(ax), + graph::variable_cast(ay)}, + {col_test}, {}, + static_cast (8.0), 0.0); } //------------------------------------------------------------------------------