Loading graph_framework.xcodeproj/project.pbxproj +177 −1 Original line number Diff line number Diff line Loading @@ -1429,6 +1429,94 @@ "$(inherited)", ); MACOSX_DEPLOYMENT_TARGET = 13.3; OTHER_LDFLAGS = ( "-lnetcdf", "-ld_classic", "-L/Users/m4c/Projects/graph_framework/build/_deps/llvm-build/lib", "-lz", "-lLLVMCoverage", "-lLLVMSupport", "-lLLVMDebugInfoCodeView", "-lLLVMRemarks", "-lLLVMJITLink", "-lLLVMLinker", "-lLLVMTextAPI", "-lLLVMRuntimeDyld", "-lLLVMOrcShared", "-lLLVMOrcDebugging", "-lLLVMOrcTargetProcess", "-lLLVMOrcJIT", "-lLLVMHipStdPar", "-lLLVMAggressiveInstCombine", "-lLLVMVectorize", "-lLLVMAsmParser", "-lLLVMOption", "-lLLVMLTO", "-lLLVMObject", "-lLLVMWindowsDriver", "-lLLVMDemangle", "-lLLVMIRReader", "-lLLVMIRPrinter", "-lLLVMInstCombine", "-lLLVMBinaryFormat", "-lLLVMCoroutines", "-lLLVMBitstreamReader", "-lLLVMBitReader", "-lLLVMBitWriter", "-lLLVMDebugInfoDWARF", "-lLLVMInstrumentation", "-lLLVMCFGuard", "-lLLVMObjCARCOpts", "-lLLVMipo", "-lLLVMGlobalISel", "-lLLVMExecutionEngine", "-lLLVMFrontendDriver", "-lLLVMFrontendHLSL", "-lLLVMFrontendOpenMP", "-lLLVMFrontendOffloading", "-lLLVMSelectionDAG", "-lLLVMProfileData", "-lLLVMAnalysis", "-lLLVMScalarOpts", "-lLLVMCodeGenTypes", "-lLLVMCodeGenData", "-lLLVMCodeGen", "-lLLVMTargetParser", "-lLLVMScalarOpts", "-lLLVMTarget", "-lLLVMTransformUtils", "-lLLVMPasses", "-lLLVMSupport", "-lLLVMMCParser", "-lLLVMMC", "-lLLVMCore", "-lLLVMAsmPrinter", "-lLLVMAArch64Utils", "-lLLVMAArch64Info", "-lLLVMAArch64Desc", "-lLLVMAArch64AsmParser", "-lLLVMAArch64CodeGen", "-lLLVMCodeGenData", "-lLLVMCGData", "-lLLVMSandboxIR", "-lLLVMFrontendAtomic", "-lclangFrontend", "-lclangBasic", "-lclangEdit", "-lclangLex", "-lclangDriver", "-lclangSerialization", "-lclangAST", "-lclangSema", "-lclangAnalysis", "-lclangASTMatchers", "-lclangSupport", "-lclangParse", "-lclangAPINotes", "-lclangCodeGen", "-rpath", /usr/local/lib, ); PRODUCT_NAME = "$(TARGET_NAME)"; }; name = Debug; Loading @@ -1440,6 +1528,94 @@ "CODE_SIGN_IDENTITY[sdk=macosx*]" = "-"; CODE_SIGN_STYLE = Automatic; MACOSX_DEPLOYMENT_TARGET = 13.3; OTHER_LDFLAGS = ( "-lnetcdf", "-ld_classic", "-L/Users/m4c/Projects/graph_framework/build/_deps/llvm-build/lib", "-lz", "-lLLVMCoverage", "-lLLVMSupport", "-lLLVMDebugInfoCodeView", "-lLLVMRemarks", "-lLLVMJITLink", "-lLLVMLinker", "-lLLVMTextAPI", "-lLLVMRuntimeDyld", "-lLLVMOrcShared", "-lLLVMOrcDebugging", "-lLLVMOrcTargetProcess", "-lLLVMOrcJIT", "-lLLVMHipStdPar", "-lLLVMAggressiveInstCombine", "-lLLVMVectorize", "-lLLVMAsmParser", "-lLLVMOption", "-lLLVMLTO", "-lLLVMObject", "-lLLVMWindowsDriver", "-lLLVMDemangle", "-lLLVMIRReader", "-lLLVMIRPrinter", "-lLLVMInstCombine", "-lLLVMBinaryFormat", "-lLLVMCoroutines", "-lLLVMBitstreamReader", "-lLLVMBitReader", "-lLLVMBitWriter", "-lLLVMDebugInfoDWARF", "-lLLVMInstrumentation", "-lLLVMCFGuard", "-lLLVMObjCARCOpts", "-lLLVMipo", "-lLLVMGlobalISel", "-lLLVMExecutionEngine", "-lLLVMFrontendDriver", "-lLLVMFrontendHLSL", "-lLLVMFrontendOpenMP", "-lLLVMFrontendOffloading", "-lLLVMSelectionDAG", "-lLLVMProfileData", "-lLLVMAnalysis", "-lLLVMScalarOpts", "-lLLVMCodeGenTypes", "-lLLVMCodeGenData", "-lLLVMCodeGen", "-lLLVMTargetParser", "-lLLVMScalarOpts", "-lLLVMTarget", "-lLLVMTransformUtils", "-lLLVMPasses", "-lLLVMSupport", "-lLLVMMCParser", "-lLLVMMC", "-lLLVMCore", "-lLLVMAsmPrinter", "-lLLVMAArch64Utils", "-lLLVMAArch64Info", "-lLLVMAArch64Desc", "-lLLVMAArch64AsmParser", "-lLLVMAArch64CodeGen", "-lLLVMCodeGenData", "-lLLVMCGData", "-lLLVMSandboxIR", "-lLLVMFrontendAtomic", "-lclangFrontend", "-lclangBasic", "-lclangEdit", "-lclangLex", "-lclangDriver", "-lclangSerialization", "-lclangAST", "-lclangSema", "-lclangAnalysis", "-lclangASTMatchers", "-lclangSupport", "-lclangParse", "-lclangAPINotes", "-lclangCodeGen", "-rpath", /usr/local/lib, ); PRODUCT_NAME = "$(TARGET_NAME)"; }; name = Release; Loading Loading @@ -1526,7 +1702,7 @@ "EFIT_FILE=\\\"/Users/m4c/Projects/graph_framework/graph_tests/efit.nc\\\"", "VMEC_FILE=\\\"/Users/m4c/Projects/graph_framework/graph_tests/vmec.nc\\\"", USE_METAL, "\"CXX_ARGS=\\\"-I/Users/m4c/Projects/graph_framework/graph_framework -I/usr/local/include -I/Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk/usr/include/c++/v1 -I/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/lib/clang/15.0.0/include -I/Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk/usr/include -I/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/include -I/Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk/System/Library/Frameworks -fgnuc-version=4.2.1 -std=gnu++2a\\\"\"", "\"CXX_ARGS=\\\"-I/Users/m4c/Projects/graph_framework/graph_framework -I/usr/local/include -I/Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk/usr/include/c++/v1 -I/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/lib/clang/16/include -I/Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk/usr/include -I/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/include -I/Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk/System/Library/Frameworks -fgnuc-version=4.2.1 -std=gnu++2a\\\"\"", STATIC, "DEBUG=1", "$(inherited)", Loading graph_framework/arithmetic.hpp +47 −9 Original line number Diff line number Diff line Loading @@ -925,6 +925,32 @@ namespace graph { auto lm = multiply_cast(this->left); auto rm = multiply_cast(this->right); // c1*(c2 + a) - c3 -> fma(c1,a,c4) if (lm.get()) { auto lmra = add_cast(lm->get_right()); if (lmra.get()) { if (is_constant_combineable(lm->get_left(), lmra->get_left()) && is_constant_combineable(lm->get_left(), this->right)) { return fma(lm->get_left(), lmra->get_right(), lm->get_left()*lmra->get_left() - this->right); } } auto lmrs = subtract_cast(lm->get_right()); if (lmrs.get()) { if (is_constant_combineable(lm->get_left(), lmrs->get_left()) && is_constant_combineable(lm->get_left(), this->right)) { return lm->get_left()*lmrs->get_left() - this->right - lm->get_left()*lmrs->get_right(); } } } // Assume constants are on the left. // v1 - -c*v2 -> v1 + c*v2 if (rm.get() && Loading Loading @@ -3491,6 +3517,27 @@ namespace graph { return this->middle*(1.0 + this->left); } // fma(c1,c2 + a,c3) -> fma(c4,a,c5) auto ma = add_cast(this->middle); if (ma.get()) { if (is_constant_combineable(this->left, ma->get_left()) && is_constant_combineable(this->left, this->right)) { return fma(this->left, ma->get_right(), fma(this->left, ma->get_left(), this->right)); } } // fma(c1,c2 - a,c3) -> c4 - c5*a auto ms = subtract_cast(this->middle); if (ms.get()) { if (is_constant_combineable(this->left, ms->get_left()) && is_constant_combineable(this->left, this->right)) { return fma(this->left, ms->get_left(), this->right) - this->left*ms->get_right(); } } // Common factor reduction. If the left and right are both multiply nodes check // for a common factor. So you can change a*b + (a*c) -> a*(b + c). auto lm = multiply_cast(this->left); Loading Loading @@ -4149,15 +4196,6 @@ namespace graph { } } // Promote constants out to the left. if (is_constant_combineable(this->left, this->right) && !this->left->has_constant_zero()) { auto temp = this->right/this->left; if (temp->is_normal()) { return this->left*(this->middle + temp); } } // Change negative exponents to divide so that can be factored out. // fma(a,b^-c,d) = a/b^c + d // fma(b^-c,a,d) = a/b^c + d Loading graph_tests/arithmetic_test.cpp +2 −4 Original line number Diff line number Diff line Loading @@ -538,7 +538,8 @@ template<jit::float_scalar T> void test_subtract() { // (c1*v1 + c2) - (c3*v1 + c4) -> c5*(v1 - c6) auto subfma = graph::fma(3.0, var_a, 2.0) - graph::fma(2.0, var_a, 3.0); assert(graph::multiply_cast(subfma).get() && "Expected a multiply node."); // -1 + a assert(graph::add_cast(subfma).get() && "Expected an add node."); // Test cases like // (c1 + c2/x) - c3/x -> c1 + c4/x Loading Loading @@ -2665,9 +2666,6 @@ template<jit::float_scalar T> void test_fma() { assert(reduce4_cast->get_right()->is_match(var_b) && "Expected common var_b"); assert(graph::multiply_cast(graph::fma(two, var_a, one)).get() && "Expected multiply node."); // fma(a, b, fma(c, b, d)) -> fma(b, a + c, d) auto var_d = graph::variable<T> (1, ""); auto match1 = graph::fma(var_b, var_a + var_c, var_d); Loading graph_tests/piecewise_test.cpp +74 −5 Original line number Diff line number Diff line Loading @@ -129,9 +129,6 @@ template<jit::float_scalar T> void piecewise_1D() { "Expected a piecewise_1D node."); assert(graph::add_cast(graph::fma(p1, 2.0, p2)).get() && "Expected an add node."); auto temp = graph::fma(p1, p2, 2.0); assert(graph::multiply_cast(graph::fma(p1, p2, 2.0)).get() && "Expected a multiply node."); assert(graph::add_cast(graph::fma(p1, p3, p2)).get() && "Expected an add node."); assert(graph::piecewise_1D_cast(graph::fma(p1, p3, 2.0)).get() && Loading Loading @@ -231,6 +228,43 @@ template<jit::float_scalar T> void piecewise_1D() { static_cast<T> (10.0)}), a); assert(graph::constant_cast(pc).get() && "Expected a constant."); // fma(p1,c1 + a,p2) -> fma(p1,a,p3) auto fma_combine = fma(p1,1.0 + a,p3); auto fma_combine_cast = graph::fma_cast(fma_combine); assert(fma_combine_cast.get() && "Expected an fma node."); assert(fma_combine_cast->get_middle()->is_match(a) && "Expected a in the middle."); assert(fma_combine_cast->get_left()->is_match(p1) && "Expected p1 on the left."); assert(fma_combine_cast->get_right()->is_match(p1 + p3) && "Expected p1 + p3 on the right."); // fma(p1,c1 - a,p2) -> p3 - p1*a auto fma_combine2 = fma(p1,1.0 - a,p3); auto fma_combine2_cast = graph::subtract_cast(fma_combine2); assert(fma_combine2_cast.get() && "Expected an subtract node."); assert(fma_combine2_cast->get_right()->is_match(p1*a) && "Expected p1*a on the right."); assert(fma_combine2_cast->get_left()->is_match(p1 + p3) && "Expected p1 + p3 on the left."); // p1*(c1 + a) - p2 -> fma(p1,a,p3) auto fma_combine3 = p1*(1.0 + a) - p3; auto fma_combine3_cast = graph::fma_cast(fma_combine3); assert(fma_combine3_cast.get() && "Expected a fma node."); assert(fma_combine3_cast->get_middle()->is_match(a) && "Expected a in the middle."); assert(fma_combine3_cast->get_left()->is_match(p1) && "Expected p1 on the left."); assert(fma_combine3_cast->get_right()->is_match(p1 - p3) && "Expected p1 - p3 on the right."); // p1*(c1 - a) - p2 -> p3 - p1*a auto fma_combine4 = p1*(1.0 - a) - p3; auto fma_combine4_cast = graph::subtract_cast(fma_combine4); assert(fma_combine4_cast.get() && "Expected an subtract node."); assert(fma_combine4_cast->get_right()->is_match(p1*a) && "Expected p1*a on the right."); assert(fma_combine4_cast->get_left()->is_match(p1 - p3) && "Expected p1 - p3 on the left."); } //------------------------------------------------------------------------------ Loading Loading @@ -319,8 +353,6 @@ template<jit::float_scalar T> void piecewise_2D() { "Expected a piecewise_2D node."); assert(graph::add_cast(graph::fma(p1, 2.0, p2)).get() && "Expected an add node."); assert(graph::multiply_cast(graph::fma(p1, p2, 2.0)).get() && "Expected a multiply node."); assert(graph::add_cast(graph::fma(p1, p3, p2)).get() && "Expected an add node."); assert(graph::piecewise_2D_cast(graph::fma(p1, p3, 2.0)).get() && Loading Loading @@ -606,6 +638,43 @@ template<jit::float_scalar T> void piecewise_2D() { graph::variable_cast(ay)}, {col_test}, {}, static_cast<T> (8.0), 0.0); // fma(p1,c1 + a,p2) -> fma(p1,a,p3) auto fma_combine = fma(p1,1.0 + ax,p3); auto fma_combine_cast = graph::fma_cast(fma_combine); assert(fma_combine_cast.get() && "Expected an fma node."); assert(fma_combine_cast->get_middle()->is_match(ax) && "Expected a in the middle."); assert(fma_combine_cast->get_left()->is_match(p1) && "Expected p1 on the left."); assert(fma_combine_cast->get_right()->is_match(p1 + p3) && "Expected p1 + p3 on the right."); // fma(p1,c1 - a,p2) -> p3 - p1*a auto fma_combine2 = fma(p1,1.0 - ax,p3); auto fma_combine2_cast = graph::subtract_cast(fma_combine2); assert(fma_combine2_cast.get() && "Expected an subtract node."); assert(fma_combine2_cast->get_right()->is_match(p1*ax) && "Expected p1*a on the right."); assert(fma_combine2_cast->get_left()->is_match(p1 + p3) && "Expected p1 + p3 on the left."); // p1*(c1 + a) - p2 -> fma(p1,a,p3) auto fma_combine3 = p1*(1.0 + ax) - p3; auto fma_combine3_cast = graph::fma_cast(fma_combine3); assert(fma_combine3_cast.get() && "Expected a fma node."); assert(fma_combine3_cast->get_middle()->is_match(ax) && "Expected a in the middle."); assert(fma_combine3_cast->get_left()->is_match(p1) && "Expected p1 on the left."); assert(fma_combine3_cast->get_right()->is_match(p1 - p3) && "Expected p1 - p3 on the right."); // p1*(c1 - a) - p2 -> p3 - p1*a auto fma_combine4 = p1*(1.0 - ax) - p3; auto fma_combine4_cast = graph::subtract_cast(fma_combine4); assert(fma_combine4_cast.get() && "Expected an subtract node."); assert(fma_combine4_cast->get_right()->is_match(p1*ax) && "Expected p1*a on the right."); assert(fma_combine4_cast->get_left()->is_match(p1 - p3) && "Expected p1 - p3 on the left."); } //------------------------------------------------------------------------------ Loading Loading
graph_framework.xcodeproj/project.pbxproj +177 −1 Original line number Diff line number Diff line Loading @@ -1429,6 +1429,94 @@ "$(inherited)", ); MACOSX_DEPLOYMENT_TARGET = 13.3; OTHER_LDFLAGS = ( "-lnetcdf", "-ld_classic", "-L/Users/m4c/Projects/graph_framework/build/_deps/llvm-build/lib", "-lz", "-lLLVMCoverage", "-lLLVMSupport", "-lLLVMDebugInfoCodeView", "-lLLVMRemarks", "-lLLVMJITLink", "-lLLVMLinker", "-lLLVMTextAPI", "-lLLVMRuntimeDyld", "-lLLVMOrcShared", "-lLLVMOrcDebugging", "-lLLVMOrcTargetProcess", "-lLLVMOrcJIT", "-lLLVMHipStdPar", "-lLLVMAggressiveInstCombine", "-lLLVMVectorize", "-lLLVMAsmParser", "-lLLVMOption", "-lLLVMLTO", "-lLLVMObject", "-lLLVMWindowsDriver", "-lLLVMDemangle", "-lLLVMIRReader", "-lLLVMIRPrinter", "-lLLVMInstCombine", "-lLLVMBinaryFormat", "-lLLVMCoroutines", "-lLLVMBitstreamReader", "-lLLVMBitReader", "-lLLVMBitWriter", "-lLLVMDebugInfoDWARF", "-lLLVMInstrumentation", "-lLLVMCFGuard", "-lLLVMObjCARCOpts", "-lLLVMipo", "-lLLVMGlobalISel", "-lLLVMExecutionEngine", "-lLLVMFrontendDriver", "-lLLVMFrontendHLSL", "-lLLVMFrontendOpenMP", "-lLLVMFrontendOffloading", "-lLLVMSelectionDAG", "-lLLVMProfileData", "-lLLVMAnalysis", "-lLLVMScalarOpts", "-lLLVMCodeGenTypes", "-lLLVMCodeGenData", "-lLLVMCodeGen", "-lLLVMTargetParser", "-lLLVMScalarOpts", "-lLLVMTarget", "-lLLVMTransformUtils", "-lLLVMPasses", "-lLLVMSupport", "-lLLVMMCParser", "-lLLVMMC", "-lLLVMCore", "-lLLVMAsmPrinter", "-lLLVMAArch64Utils", "-lLLVMAArch64Info", "-lLLVMAArch64Desc", "-lLLVMAArch64AsmParser", "-lLLVMAArch64CodeGen", "-lLLVMCodeGenData", "-lLLVMCGData", "-lLLVMSandboxIR", "-lLLVMFrontendAtomic", "-lclangFrontend", "-lclangBasic", "-lclangEdit", "-lclangLex", "-lclangDriver", "-lclangSerialization", "-lclangAST", "-lclangSema", "-lclangAnalysis", "-lclangASTMatchers", "-lclangSupport", "-lclangParse", "-lclangAPINotes", "-lclangCodeGen", "-rpath", /usr/local/lib, ); PRODUCT_NAME = "$(TARGET_NAME)"; }; name = Debug; Loading @@ -1440,6 +1528,94 @@ "CODE_SIGN_IDENTITY[sdk=macosx*]" = "-"; CODE_SIGN_STYLE = Automatic; MACOSX_DEPLOYMENT_TARGET = 13.3; OTHER_LDFLAGS = ( "-lnetcdf", "-ld_classic", "-L/Users/m4c/Projects/graph_framework/build/_deps/llvm-build/lib", "-lz", "-lLLVMCoverage", "-lLLVMSupport", "-lLLVMDebugInfoCodeView", "-lLLVMRemarks", "-lLLVMJITLink", "-lLLVMLinker", "-lLLVMTextAPI", "-lLLVMRuntimeDyld", "-lLLVMOrcShared", "-lLLVMOrcDebugging", "-lLLVMOrcTargetProcess", "-lLLVMOrcJIT", "-lLLVMHipStdPar", "-lLLVMAggressiveInstCombine", "-lLLVMVectorize", "-lLLVMAsmParser", "-lLLVMOption", "-lLLVMLTO", "-lLLVMObject", "-lLLVMWindowsDriver", "-lLLVMDemangle", "-lLLVMIRReader", "-lLLVMIRPrinter", "-lLLVMInstCombine", "-lLLVMBinaryFormat", "-lLLVMCoroutines", "-lLLVMBitstreamReader", "-lLLVMBitReader", "-lLLVMBitWriter", "-lLLVMDebugInfoDWARF", "-lLLVMInstrumentation", "-lLLVMCFGuard", "-lLLVMObjCARCOpts", "-lLLVMipo", "-lLLVMGlobalISel", "-lLLVMExecutionEngine", "-lLLVMFrontendDriver", "-lLLVMFrontendHLSL", "-lLLVMFrontendOpenMP", "-lLLVMFrontendOffloading", "-lLLVMSelectionDAG", "-lLLVMProfileData", "-lLLVMAnalysis", "-lLLVMScalarOpts", "-lLLVMCodeGenTypes", "-lLLVMCodeGenData", "-lLLVMCodeGen", "-lLLVMTargetParser", "-lLLVMScalarOpts", "-lLLVMTarget", "-lLLVMTransformUtils", "-lLLVMPasses", "-lLLVMSupport", "-lLLVMMCParser", "-lLLVMMC", "-lLLVMCore", "-lLLVMAsmPrinter", "-lLLVMAArch64Utils", "-lLLVMAArch64Info", "-lLLVMAArch64Desc", "-lLLVMAArch64AsmParser", "-lLLVMAArch64CodeGen", "-lLLVMCodeGenData", "-lLLVMCGData", "-lLLVMSandboxIR", "-lLLVMFrontendAtomic", "-lclangFrontend", "-lclangBasic", "-lclangEdit", "-lclangLex", "-lclangDriver", "-lclangSerialization", "-lclangAST", "-lclangSema", "-lclangAnalysis", "-lclangASTMatchers", "-lclangSupport", "-lclangParse", "-lclangAPINotes", "-lclangCodeGen", "-rpath", /usr/local/lib, ); PRODUCT_NAME = "$(TARGET_NAME)"; }; name = Release; Loading Loading @@ -1526,7 +1702,7 @@ "EFIT_FILE=\\\"/Users/m4c/Projects/graph_framework/graph_tests/efit.nc\\\"", "VMEC_FILE=\\\"/Users/m4c/Projects/graph_framework/graph_tests/vmec.nc\\\"", USE_METAL, "\"CXX_ARGS=\\\"-I/Users/m4c/Projects/graph_framework/graph_framework -I/usr/local/include -I/Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk/usr/include/c++/v1 -I/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/lib/clang/15.0.0/include -I/Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk/usr/include -I/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/include -I/Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk/System/Library/Frameworks -fgnuc-version=4.2.1 -std=gnu++2a\\\"\"", "\"CXX_ARGS=\\\"-I/Users/m4c/Projects/graph_framework/graph_framework -I/usr/local/include -I/Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk/usr/include/c++/v1 -I/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/lib/clang/16/include -I/Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk/usr/include -I/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/include -I/Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk/System/Library/Frameworks -fgnuc-version=4.2.1 -std=gnu++2a\\\"\"", STATIC, "DEBUG=1", "$(inherited)", Loading
graph_framework/arithmetic.hpp +47 −9 Original line number Diff line number Diff line Loading @@ -925,6 +925,32 @@ namespace graph { auto lm = multiply_cast(this->left); auto rm = multiply_cast(this->right); // c1*(c2 + a) - c3 -> fma(c1,a,c4) if (lm.get()) { auto lmra = add_cast(lm->get_right()); if (lmra.get()) { if (is_constant_combineable(lm->get_left(), lmra->get_left()) && is_constant_combineable(lm->get_left(), this->right)) { return fma(lm->get_left(), lmra->get_right(), lm->get_left()*lmra->get_left() - this->right); } } auto lmrs = subtract_cast(lm->get_right()); if (lmrs.get()) { if (is_constant_combineable(lm->get_left(), lmrs->get_left()) && is_constant_combineable(lm->get_left(), this->right)) { return lm->get_left()*lmrs->get_left() - this->right - lm->get_left()*lmrs->get_right(); } } } // Assume constants are on the left. // v1 - -c*v2 -> v1 + c*v2 if (rm.get() && Loading Loading @@ -3491,6 +3517,27 @@ namespace graph { return this->middle*(1.0 + this->left); } // fma(c1,c2 + a,c3) -> fma(c4,a,c5) auto ma = add_cast(this->middle); if (ma.get()) { if (is_constant_combineable(this->left, ma->get_left()) && is_constant_combineable(this->left, this->right)) { return fma(this->left, ma->get_right(), fma(this->left, ma->get_left(), this->right)); } } // fma(c1,c2 - a,c3) -> c4 - c5*a auto ms = subtract_cast(this->middle); if (ms.get()) { if (is_constant_combineable(this->left, ms->get_left()) && is_constant_combineable(this->left, this->right)) { return fma(this->left, ms->get_left(), this->right) - this->left*ms->get_right(); } } // Common factor reduction. If the left and right are both multiply nodes check // for a common factor. So you can change a*b + (a*c) -> a*(b + c). auto lm = multiply_cast(this->left); Loading Loading @@ -4149,15 +4196,6 @@ namespace graph { } } // Promote constants out to the left. if (is_constant_combineable(this->left, this->right) && !this->left->has_constant_zero()) { auto temp = this->right/this->left; if (temp->is_normal()) { return this->left*(this->middle + temp); } } // Change negative exponents to divide so that can be factored out. // fma(a,b^-c,d) = a/b^c + d // fma(b^-c,a,d) = a/b^c + d Loading
graph_tests/arithmetic_test.cpp +2 −4 Original line number Diff line number Diff line Loading @@ -538,7 +538,8 @@ template<jit::float_scalar T> void test_subtract() { // (c1*v1 + c2) - (c3*v1 + c4) -> c5*(v1 - c6) auto subfma = graph::fma(3.0, var_a, 2.0) - graph::fma(2.0, var_a, 3.0); assert(graph::multiply_cast(subfma).get() && "Expected a multiply node."); // -1 + a assert(graph::add_cast(subfma).get() && "Expected an add node."); // Test cases like // (c1 + c2/x) - c3/x -> c1 + c4/x Loading Loading @@ -2665,9 +2666,6 @@ template<jit::float_scalar T> void test_fma() { assert(reduce4_cast->get_right()->is_match(var_b) && "Expected common var_b"); assert(graph::multiply_cast(graph::fma(two, var_a, one)).get() && "Expected multiply node."); // fma(a, b, fma(c, b, d)) -> fma(b, a + c, d) auto var_d = graph::variable<T> (1, ""); auto match1 = graph::fma(var_b, var_a + var_c, var_d); Loading
graph_tests/piecewise_test.cpp +74 −5 Original line number Diff line number Diff line Loading @@ -129,9 +129,6 @@ template<jit::float_scalar T> void piecewise_1D() { "Expected a piecewise_1D node."); assert(graph::add_cast(graph::fma(p1, 2.0, p2)).get() && "Expected an add node."); auto temp = graph::fma(p1, p2, 2.0); assert(graph::multiply_cast(graph::fma(p1, p2, 2.0)).get() && "Expected a multiply node."); assert(graph::add_cast(graph::fma(p1, p3, p2)).get() && "Expected an add node."); assert(graph::piecewise_1D_cast(graph::fma(p1, p3, 2.0)).get() && Loading Loading @@ -231,6 +228,43 @@ template<jit::float_scalar T> void piecewise_1D() { static_cast<T> (10.0)}), a); assert(graph::constant_cast(pc).get() && "Expected a constant."); // fma(p1,c1 + a,p2) -> fma(p1,a,p3) auto fma_combine = fma(p1,1.0 + a,p3); auto fma_combine_cast = graph::fma_cast(fma_combine); assert(fma_combine_cast.get() && "Expected an fma node."); assert(fma_combine_cast->get_middle()->is_match(a) && "Expected a in the middle."); assert(fma_combine_cast->get_left()->is_match(p1) && "Expected p1 on the left."); assert(fma_combine_cast->get_right()->is_match(p1 + p3) && "Expected p1 + p3 on the right."); // fma(p1,c1 - a,p2) -> p3 - p1*a auto fma_combine2 = fma(p1,1.0 - a,p3); auto fma_combine2_cast = graph::subtract_cast(fma_combine2); assert(fma_combine2_cast.get() && "Expected an subtract node."); assert(fma_combine2_cast->get_right()->is_match(p1*a) && "Expected p1*a on the right."); assert(fma_combine2_cast->get_left()->is_match(p1 + p3) && "Expected p1 + p3 on the left."); // p1*(c1 + a) - p2 -> fma(p1,a,p3) auto fma_combine3 = p1*(1.0 + a) - p3; auto fma_combine3_cast = graph::fma_cast(fma_combine3); assert(fma_combine3_cast.get() && "Expected a fma node."); assert(fma_combine3_cast->get_middle()->is_match(a) && "Expected a in the middle."); assert(fma_combine3_cast->get_left()->is_match(p1) && "Expected p1 on the left."); assert(fma_combine3_cast->get_right()->is_match(p1 - p3) && "Expected p1 - p3 on the right."); // p1*(c1 - a) - p2 -> p3 - p1*a auto fma_combine4 = p1*(1.0 - a) - p3; auto fma_combine4_cast = graph::subtract_cast(fma_combine4); assert(fma_combine4_cast.get() && "Expected an subtract node."); assert(fma_combine4_cast->get_right()->is_match(p1*a) && "Expected p1*a on the right."); assert(fma_combine4_cast->get_left()->is_match(p1 - p3) && "Expected p1 - p3 on the left."); } //------------------------------------------------------------------------------ Loading Loading @@ -319,8 +353,6 @@ template<jit::float_scalar T> void piecewise_2D() { "Expected a piecewise_2D node."); assert(graph::add_cast(graph::fma(p1, 2.0, p2)).get() && "Expected an add node."); assert(graph::multiply_cast(graph::fma(p1, p2, 2.0)).get() && "Expected a multiply node."); assert(graph::add_cast(graph::fma(p1, p3, p2)).get() && "Expected an add node."); assert(graph::piecewise_2D_cast(graph::fma(p1, p3, 2.0)).get() && Loading Loading @@ -606,6 +638,43 @@ template<jit::float_scalar T> void piecewise_2D() { graph::variable_cast(ay)}, {col_test}, {}, static_cast<T> (8.0), 0.0); // fma(p1,c1 + a,p2) -> fma(p1,a,p3) auto fma_combine = fma(p1,1.0 + ax,p3); auto fma_combine_cast = graph::fma_cast(fma_combine); assert(fma_combine_cast.get() && "Expected an fma node."); assert(fma_combine_cast->get_middle()->is_match(ax) && "Expected a in the middle."); assert(fma_combine_cast->get_left()->is_match(p1) && "Expected p1 on the left."); assert(fma_combine_cast->get_right()->is_match(p1 + p3) && "Expected p1 + p3 on the right."); // fma(p1,c1 - a,p2) -> p3 - p1*a auto fma_combine2 = fma(p1,1.0 - ax,p3); auto fma_combine2_cast = graph::subtract_cast(fma_combine2); assert(fma_combine2_cast.get() && "Expected an subtract node."); assert(fma_combine2_cast->get_right()->is_match(p1*ax) && "Expected p1*a on the right."); assert(fma_combine2_cast->get_left()->is_match(p1 + p3) && "Expected p1 + p3 on the left."); // p1*(c1 + a) - p2 -> fma(p1,a,p3) auto fma_combine3 = p1*(1.0 + ax) - p3; auto fma_combine3_cast = graph::fma_cast(fma_combine3); assert(fma_combine3_cast.get() && "Expected a fma node."); assert(fma_combine3_cast->get_middle()->is_match(ax) && "Expected a in the middle."); assert(fma_combine3_cast->get_left()->is_match(p1) && "Expected p1 on the left."); assert(fma_combine3_cast->get_right()->is_match(p1 - p3) && "Expected p1 - p3 on the right."); // p1*(c1 - a) - p2 -> p3 - p1*a auto fma_combine4 = p1*(1.0 - ax) - p3; auto fma_combine4_cast = graph::subtract_cast(fma_combine4); assert(fma_combine4_cast.get() && "Expected an subtract node."); assert(fma_combine4_cast->get_right()->is_match(p1*ax) && "Expected p1*a on the right."); assert(fma_combine4_cast->get_left()->is_match(p1 - p3) && "Expected p1 - p3 on the left."); } //------------------------------------------------------------------------------ Loading