Loading graph_framework.xcodeproj/project.pbxproj +174 −0 Original line number Diff line number Diff line Loading @@ -1294,6 +1294,93 @@ ); LOCALIZATION_PREFERS_STRING_CATALOGS = YES; MACOSX_DEPLOYMENT_TARGET = 14.5; OTHER_LDFLAGS = ( "-lnetcdf", "-ld_classic", "-L/Users/m4c/Projects/graph_framework/build/_deps/llvm-build/lib", "-lz", "-lLLVMCoverage", "-lLLVMSupport", "-lLLVMDebugInfoCodeView", "-lLLVMRemarks", "-lLLVMJITLink", "-lLLVMLinker", "-lLLVMTextAPI", "-lLLVMRuntimeDyld", "-lLLVMOrcShared", "-lLLVMOrcDebugging", "-lLLVMOrcTargetProcess", "-lLLVMOrcJIT", "-lLLVMHipStdPar", "-lLLVMAggressiveInstCombine", "-lLLVMVectorize", "-lLLVMAsmParser", "-lLLVMOption", "-lLLVMLTO", "-lLLVMObject", "-lLLVMWindowsDriver", "-lLLVMDemangle", "-lLLVMIRReader", "-lLLVMIRPrinter", "-lLLVMInstCombine", "-lLLVMBinaryFormat", "-lLLVMCoroutines", "-lLLVMBitstreamReader", "-lLLVMBitReader", "-lLLVMBitWriter", "-lLLVMDebugInfoDWARF", "-lLLVMInstrumentation", "-lLLVMCFGuard", "-lLLVMObjCARCOpts", "-lLLVMipo", "-lLLVMGlobalISel", "-lLLVMExecutionEngine", "-lLLVMFrontendDriver", "-lLLVMFrontendHLSL", "-lLLVMFrontendOpenMP", "-lLLVMFrontendOffloading", "-lLLVMSelectionDAG", "-lLLVMProfileData", "-lLLVMAnalysis", "-lLLVMScalarOpts", "-lLLVMCodeGenTypes", "-lLLVMCodeGenData", "-lLLVMCodeGen", "-lLLVMTargetParser", "-lLLVMScalarOpts", "-lLLVMTarget", "-lLLVMTransformUtils", "-lLLVMPasses", "-lLLVMSupport", "-lLLVMMCParser", "-lLLVMMC", "-lLLVMCore", "-lLLVMAsmPrinter", "-lLLVMAArch64Utils", "-lLLVMAArch64Info", "-lLLVMAArch64Desc", "-lLLVMAArch64AsmParser", "-lLLVMAArch64CodeGen", "-lLLVMCGData", "-lLLVMSandboxIR", "-lLLVMFrontendAtomic", "-lclangFrontend", "-lclangBasic", "-lclangEdit", "-lclangLex", "-lclangDriver", "-lclangSerialization", "-lclangAST", "-lclangSema", "-lclangAnalysis", "-lclangASTMatchers", "-lclangSupport", "-lclangParse", "-lclangAPINotes", "-lclangCodeGen", "-rpath", /usr/local/lib, ); PRODUCT_NAME = "$(TARGET_NAME)"; }; name = Debug; Loading @@ -1312,6 +1399,93 @@ ); LOCALIZATION_PREFERS_STRING_CATALOGS = YES; MACOSX_DEPLOYMENT_TARGET = 14.5; OTHER_LDFLAGS = ( "-lnetcdf", "-ld_classic", "-L/Users/m4c/Projects/graph_framework/build/_deps/llvm-build/lib", "-lz", "-lLLVMCoverage", "-lLLVMSupport", "-lLLVMDebugInfoCodeView", "-lLLVMRemarks", "-lLLVMJITLink", "-lLLVMLinker", "-lLLVMTextAPI", "-lLLVMRuntimeDyld", "-lLLVMOrcShared", "-lLLVMOrcDebugging", "-lLLVMOrcTargetProcess", "-lLLVMOrcJIT", "-lLLVMHipStdPar", "-lLLVMAggressiveInstCombine", "-lLLVMVectorize", "-lLLVMAsmParser", "-lLLVMOption", "-lLLVMLTO", "-lLLVMObject", "-lLLVMWindowsDriver", "-lLLVMDemangle", "-lLLVMIRReader", "-lLLVMIRPrinter", "-lLLVMInstCombine", "-lLLVMBinaryFormat", "-lLLVMCoroutines", "-lLLVMBitstreamReader", "-lLLVMBitReader", "-lLLVMBitWriter", "-lLLVMDebugInfoDWARF", "-lLLVMInstrumentation", "-lLLVMCFGuard", "-lLLVMObjCARCOpts", "-lLLVMipo", "-lLLVMGlobalISel", "-lLLVMExecutionEngine", "-lLLVMFrontendDriver", "-lLLVMFrontendHLSL", "-lLLVMFrontendOpenMP", "-lLLVMFrontendOffloading", "-lLLVMSelectionDAG", "-lLLVMProfileData", "-lLLVMAnalysis", "-lLLVMScalarOpts", "-lLLVMCodeGenTypes", "-lLLVMCodeGenData", "-lLLVMCodeGen", "-lLLVMTargetParser", "-lLLVMScalarOpts", "-lLLVMTarget", "-lLLVMTransformUtils", "-lLLVMPasses", "-lLLVMSupport", "-lLLVMMCParser", "-lLLVMMC", "-lLLVMCore", "-lLLVMAsmPrinter", "-lLLVMAArch64Utils", "-lLLVMAArch64Info", "-lLLVMAArch64Desc", "-lLLVMAArch64AsmParser", "-lLLVMAArch64CodeGen", "-lLLVMCGData", "-lLLVMSandboxIR", "-lLLVMFrontendAtomic", "-lclangFrontend", "-lclangBasic", "-lclangEdit", "-lclangLex", "-lclangDriver", "-lclangSerialization", "-lclangAST", "-lclangSema", "-lclangAnalysis", "-lclangASTMatchers", "-lclangSupport", "-lclangParse", "-lclangAPINotes", "-lclangCodeGen", "-rpath", /usr/local/lib, ); PRODUCT_NAME = "$(TARGET_NAME)"; }; name = Release; Loading graph_framework/arithmetic.hpp +72 −0 Original line number Diff line number Diff line Loading @@ -2024,6 +2024,45 @@ namespace graph { return this->right/pow(lp->get_left(), -lp->get_right()); } } // a^b*c^b -> (a*c)^b if (lp.get() && rp.get()) { if (lp->get_right()->is_match(rp->get_right())) { return pow(lp->get_left()*rp->get_left(), lp->get_right()); } } // (a*b^c)*d^c -> a*(b*d)^c // (a^c*b)*d^c -> b*(a*d)^c // a^c*(b*d^c) -> b*(a*d)^c // a^c*(b^c*d) -> d*(a*b)^c if (lm.get() && rp.get()) { auto lmlp = pow_cast(lm->get_left()); auto lmrp = pow_cast(lm->get_right()); if (lmrp.get()) { if (lmrp->get_right()->is_match(rp->get_right())) { return lm->get_left()*pow(lmrp->get_left()*rp->get_left(), rp->get_right()); } } else if (lmlp.get()) { if (lmlp->get_right()->is_match(rp->get_right())) { return lm->get_right()*pow(lmlp->get_left()*rp->get_left(), rp->get_right()); } } } else if (rm.get() && lp.get()) { auto rmlp = pow_cast(rm->get_left()); auto rmrp = pow_cast(rm->get_right()); if (rmrp.get()) { if (rmrp->get_right()->is_match(lp->get_right())) { return rm->get_left()*pow(lp->get_left()*rmrp->get_left(), lp->get_right()); } } else if (rmlp.get()) { if (rmlp->get_right()->is_match(lp->get_right())) { return rm->get_right()*pow(lp->get_left()*rmlp->get_left(), lp->get_right()); } } } // (b*a)^c*a^d -> b^c*a^(c + d) // (a*b)^c*a^d -> b^c*a^(c + d) Loading Loading @@ -2931,6 +2970,14 @@ namespace graph { } } } // a^b/c^b -> (a/c)^b if (lp.get() && rp.get()) { if (lp->get_right()->is_match(rp->get_right())) { return pow(lp->get_left()/rp->get_left(), lp->get_right()); } } // (a*b)^c/((a^d)*e) = a^(c - d)*b^c/e // (b*a)^c/((a^d)*e) = a^(c - d)*b^c/e // (a*b)^c/(e*(a^d)) = a^(c - d)*b^c/e Loading Loading @@ -4320,6 +4367,31 @@ namespace graph { } } // a^b*c^b + d -> (a*c)^b + d if (lp.get() && mp.get()) { if (lp->get_right()->is_match(mp->get_right())) { return pow(lp->get_left()*mp->get_left(), lp->get_right()) + this->right; } } // fma(2,(ab)^2,a^2b) -> a^2*fma(2, b^2, b) if (rm.get() && mp.get()) { auto mplm = multiply_cast(mp->get_left()); if (mplm.get()) { if (is_variable_combineable(mplm->get_left(), rm->get_left())) { return pow(mplm->get_left(), mp->get_right()) * fma(this->left, pow(mplm->get_right(), mp->get_right()), this->right/mplm->get_left()); } } } // fma(a,b/c,b/d) -> b*(a/c + 1/d) // fma(a,c/b,d/b) -> (a*c + d)/b if (md.get() && rd.get()) { Loading graph_framework/math.hpp +64 −2 Original line number Diff line number Diff line Loading @@ -957,8 +957,9 @@ namespace graph { } // Handle cases where (c*x)^a, (x*c)^a, (a*sqrt(b))^c and (a*b^c)^2. // These reductions only make sense if the power is constant. auto lm = multiply_cast(this->left); if (lm.get()) { if (lm.get() && rc.get()) { if (lm->get_left()->is_constant() || lm->get_right()->is_constant() || sqrt_cast(lm->get_left()).get() || Loading Loading @@ -986,8 +987,9 @@ namespace graph { } } // These reductions only make sense if the power is constant. auto ld = divide_cast(this->left); if (ld.get()) { if (ld.get() && rc.get()) { // For even exponents e. // (-a/b)^e -> (a/b)^e auto ldlm = multiply_cast(ld->get_left()); Loading @@ -1010,6 +1012,36 @@ namespace graph { pow(ldlm->get_right(), this->right)/ pow(ld->get_right(), this->right); } auto ldlmlm = multiply_cast(ldlm->get_left()); if (ldlmlm.get()) { if (ldlmlm->get_left()->is_constant() || ldlmlm->get_right()->is_constant() || sqrt_cast(ldlmlm->get_left()).get() || sqrt_cast(ldlmlm->get_right()).get() || pow_cast(ldlmlm->get_left()).get() || pow_cast(ldlmlm->get_right()).get()) { return (pow(ldlmlm->get_left(), this->right) * pow(ldlmlm->get_right(), this->right) * pow(ldlm->get_right(), this->right)) / pow(ld->get_right(), this->right); } } auto ldlmrm = multiply_cast(ldlm->get_right()); if (ldlmrm.get()) { if (ldlmrm->get_left()->is_constant() || ldlmrm->get_right()->is_constant() || sqrt_cast(ldlmrm->get_left()).get() || sqrt_cast(ldlmrm->get_right()).get() || pow_cast(ldlmrm->get_left()).get() || pow_cast(ldlmrm->get_right()).get()) { return (pow(ldlmrm->get_left(), this->right) * pow(ldlmrm->get_right(), this->right) * pow(ldlm->get_left(), this->right)) / pow(ld->get_right(), this->right); } } } // Handle cases where (c/x)^a, (x/c)^a, (a/sqrt(b))^c and (a/b^c)^2. Loading @@ -1036,6 +1068,36 @@ namespace graph { (pow(ldrm->get_left(), this->right) * pow(ldrm->get_right(), this->right)); } auto ldrmlm = multiply_cast(ldrm->get_left()); if (ldrmlm.get()) { if (ldrmlm->get_left()->is_constant() || ldrmlm->get_right()->is_constant() || sqrt_cast(ldrmlm->get_left()).get() || sqrt_cast(ldrmlm->get_right()).get() || pow_cast(ldrmlm->get_left()).get() || pow_cast(ldrmlm->get_right()).get()) { return pow(ld->get_left(), this->right) / (pow(ldrmlm->get_left(), this->right) * pow(ldrmlm->get_right(), this->right) * pow(ldrm->get_right(), this->right)); } } auto ldrmrm = multiply_cast(ldrm->get_right()); if (ldrmrm.get()) { if (ldrmrm->get_left()->is_constant() || ldrmrm->get_right()->is_constant() || sqrt_cast(ldrmrm->get_left()).get() || sqrt_cast(ldrmrm->get_right()).get() || pow_cast(ldrmrm->get_left()).get() || pow_cast(ldrmrm->get_right()).get()) { return pow(ld->get_left(), this->right) / (pow(ldrmrm->get_left(), this->right) * pow(ldrmrm->get_right(), this->right) * pow(ldrm->get_left(), this->right)); } } } if (is_variable_combineable(ld->get_left(), Loading graph_tests/arithmetic_test.cpp +83 −12 Original line number Diff line number Diff line Loading @@ -537,8 +537,8 @@ template<jit::float_scalar T> void test_subtract() { "Expected to reduce to a constant minus one."); // Test common factors. auto var_a = graph::variable<T> (1, ""); auto var_b = graph::variable<T> (1, ""); auto var_a = graph::variable<T> (1, "a"); auto var_b = graph::variable<T> (1, "b"); auto var_c = graph::variable<T> (1, ""); auto common_a = var_a*var_b - var_a*var_c; assert(graph::add_cast(common_a).get() == nullptr && Loading Loading @@ -1850,34 +1850,81 @@ template<jit::float_scalar T> void test_multiply() { auto common_pow2 = graph::pow(var_b*var_a, 2.0)*graph::pow(var_a, 2.0); auto common_pow2_cast = graph::multiply_cast(common_pow2); assert(common_pow2_cast.get() && "Expected a multiply node."); assert(common_pow2_cast->get_left()->is_match(graph::pow(var_b, 2.0)) && assert(common_pow2_cast->get_right()->is_match(graph::pow(var_b, 2.0)) && "Expected b^2."); assert(common_pow2_cast->get_right()->is_match(graph::pow(var_a, 4.0)) && assert(common_pow2_cast->get_left()->is_match(graph::pow(var_a, 4.0)) && "Expected a^4."); // (a*b)^2*a^2 -> b^2*a^4 auto common_pow3 = graph::pow(var_a*var_b, 2.0)*graph::pow(var_a, 2.0); auto common_pow3_cast = graph::multiply_cast(common_pow3); assert(common_pow3_cast.get() && "Expected a multiply node."); assert(common_pow3_cast->get_left()->is_match(graph::pow(var_b, 2.0)) && assert(common_pow3_cast->get_right()->is_match(graph::pow(var_b, 2.0)) && "Expected b^2."); assert(common_pow3_cast->get_right()->is_match(graph::pow(var_a, 4.0)) && assert(common_pow3_cast->get_left()->is_match(graph::pow(var_a, 4.0)) && "Expected a^4."); // a^2*(b*a)^2 -> b^2*a^4 auto common_pow4 = graph::pow(var_a, 2.0)*graph::pow(var_b*var_a, 2.0); auto common_pow4_cast = graph::multiply_cast(common_pow4); assert(common_pow4_cast.get() && "Expected a multiply node."); assert(common_pow4_cast->get_left()->is_match(graph::pow(var_b, 2.0)) && assert(common_pow4_cast->get_right()->is_match(graph::pow(var_b, 2.0)) && "Expected b^2."); assert(common_pow4_cast->get_right()->is_match(graph::pow(var_a, 4.0)) && assert(common_pow4_cast->get_left()->is_match(graph::pow(var_a, 4.0)) && "Expected a^4."); // a^2*(b*a)^2 -> b^2*a^4 auto common_pow5 = graph::pow(var_a, 2.0)*graph::pow(var_a*var_b, 2.0); auto common_pow5_cast = graph::multiply_cast(common_pow5); assert(common_pow5_cast.get() && "Expected a multiply node."); assert(common_pow5_cast->get_left()->is_match(graph::pow(var_b, 2.0)) && assert(common_pow5_cast->get_right()->is_match(graph::pow(var_b, 2.0)) && "Expected b^2."); assert(common_pow5_cast->get_right()->is_match(graph::pow(var_a, 4.0)) && assert(common_pow5_cast->get_left()->is_match(graph::pow(var_a, 4.0)) && "Expected a^4."); // a^b*c^b -> (a*c)^b auto gather_power = graph::pow(var_a, var_b)*graph::pow(var_c, var_b); auto gather_power_cast = graph::pow_cast(gather_power); assert(gather_power_cast.get() && "Expected a power node."); assert(gather_power_cast->get_left()->is_match(var_a*var_c) && "Expected a*c."); assert(gather_power_cast->get_right()->is_match(var_b) && "Expected b."); // (a*b^c)*d^c -> a*(b*d)^c auto var_d = (3.0 + graph::variable<T> (1, "")); auto gather_power2 = (var_a*graph::pow(var_b, var_c))*graph::pow(var_d, var_c); auto gather_power2_cast = graph::multiply_cast(gather_power2); assert(gather_power2_cast.get() && "Expected a mutliply node."); assert(gather_power2_cast->get_left()->is_match(var_a) && "Expected a."); assert(gather_power2_cast->get_right()->is_match(graph::pow(var_b*var_d, var_c)) && "Expected (b*d)^c."); // (a^c*b)*d^c -> b*(a*d)^c auto gather_power3 = (graph::pow(var_a, var_c)*var_b)*graph::pow(var_d, var_c); auto gather_power3_cast = graph::multiply_cast(gather_power3); assert(gather_power3_cast.get() && "Expected a mutliply node."); assert(gather_power3_cast->get_left()->is_match(var_b) && "Expected b."); assert(gather_power3_cast->get_right()->is_match(graph::pow(var_a*var_d, var_c)) && "Expected (a*d)^c."); // a^c*(b*d^c) -> b*(a*d)^c auto gather_power4 = graph::pow(var_a, var_c)*(var_b*graph::pow(var_d, var_c)); auto gather_power4_cast = graph::multiply_cast(gather_power4); assert(gather_power4_cast.get() && "Expected a mutliply node."); assert(gather_power4_cast->get_left()->is_match(var_b) && "Expected b."); assert(gather_power4_cast->get_right()->is_match(graph::pow(var_a*var_d, var_c)) && "Expected (a*d)^c."); // a^c*(b^c*d) -> d*(a*b)^c auto gather_power5 = graph::pow(var_a, var_c)*(graph::pow(var_b, var_c)*var_d); auto gather_power5_cast = graph::multiply_cast(gather_power5); assert(gather_power5_cast.get() && "Expected a mutliply node."); assert(gather_power5_cast->get_left()->is_match(var_d) && "Expected d."); assert(gather_power5_cast->get_right()->is_match(graph::pow(var_a*var_b, var_c)) && "Expected (a*b)^c."); } //------------------------------------------------------------------------------ Loading Loading @@ -2600,6 +2647,14 @@ template<jit::float_scalar T> void test_divide() { // a/(b/c) -> a*c/b assert((a/(b/c))->is_match(a*c/b) && "Expected a*b/c"); // a^b/c^b -> (a/c)^b auto gather_power = graph::pow(a, b)/graph::pow(c, b); auto gather_power_cast = graph::pow_cast(gather_power); assert(gather_power_cast.get() && "Expected a power node."); assert(gather_power_cast->get_left()->is_match(a/c) && "Expected a*c."); assert(gather_power_cast->get_right()->is_match(b) && "Expected b."); // (a*b*c)^2/a^2 -> (b*c)^2 // (a*b*c)^2/(a^2*d) -> (b*c)^2/d // (e*(a*b*c)^2)/(a^2*d) -> e*(b*c)^2/d Loading Loading @@ -2716,8 +2771,8 @@ template<jit::float_scalar T> void test_fma() { "Expected a value of one."); // Test reduction. auto var_a = graph::variable<T> (1, ""); auto var_b = graph::variable<T> (1, ""); auto var_a = graph::variable<T> (1, "a"); auto var_b = graph::variable<T> (1, "b"); auto var_c = graph::variable<T> (1, ""); // fma(1,a,b) = a + b Loading Loading @@ -3509,6 +3564,22 @@ template<jit::float_scalar T> void test_fma() { var_d*var_b, var_e)) && "Expected fma(a,d*b,e) as numerator."); // a^b*c^b + d -> (a*c)^b + d auto gather_power = graph::fma(graph::pow(var_a, var_b), graph::pow(var_c, var_b), var_d); assert(gather_power->is_match(graph::pow(var_a*var_c, var_b) + var_d) && "Expected a power node."); // fma(2,(ab)^2,a^2b) -> a^2*fma(2, b^2, b) auto commom_power = graph::fma(2.0, graph::pow(var_a*var_b, 2.0), graph::pow(var_a, 2.0)*var_b); commom_power->to_latex(); std::cout << std::endl << std::endl; auto commom_power_cast = graph::multiply_cast(commom_power); assert(commom_power_cast.get() && "Expected a multiply node."); // fma(2,(a*b)^2,fma()) } //------------------------------------------------------------------------------ Loading graph_tests/math_test.cpp +124 −45 File changed.Preview size limit exceeded, changes collapsed. Show changes Loading
graph_framework.xcodeproj/project.pbxproj +174 −0 Original line number Diff line number Diff line Loading @@ -1294,6 +1294,93 @@ ); LOCALIZATION_PREFERS_STRING_CATALOGS = YES; MACOSX_DEPLOYMENT_TARGET = 14.5; OTHER_LDFLAGS = ( "-lnetcdf", "-ld_classic", "-L/Users/m4c/Projects/graph_framework/build/_deps/llvm-build/lib", "-lz", "-lLLVMCoverage", "-lLLVMSupport", "-lLLVMDebugInfoCodeView", "-lLLVMRemarks", "-lLLVMJITLink", "-lLLVMLinker", "-lLLVMTextAPI", "-lLLVMRuntimeDyld", "-lLLVMOrcShared", "-lLLVMOrcDebugging", "-lLLVMOrcTargetProcess", "-lLLVMOrcJIT", "-lLLVMHipStdPar", "-lLLVMAggressiveInstCombine", "-lLLVMVectorize", "-lLLVMAsmParser", "-lLLVMOption", "-lLLVMLTO", "-lLLVMObject", "-lLLVMWindowsDriver", "-lLLVMDemangle", "-lLLVMIRReader", "-lLLVMIRPrinter", "-lLLVMInstCombine", "-lLLVMBinaryFormat", "-lLLVMCoroutines", "-lLLVMBitstreamReader", "-lLLVMBitReader", "-lLLVMBitWriter", "-lLLVMDebugInfoDWARF", "-lLLVMInstrumentation", "-lLLVMCFGuard", "-lLLVMObjCARCOpts", "-lLLVMipo", "-lLLVMGlobalISel", "-lLLVMExecutionEngine", "-lLLVMFrontendDriver", "-lLLVMFrontendHLSL", "-lLLVMFrontendOpenMP", "-lLLVMFrontendOffloading", "-lLLVMSelectionDAG", "-lLLVMProfileData", "-lLLVMAnalysis", "-lLLVMScalarOpts", "-lLLVMCodeGenTypes", "-lLLVMCodeGenData", "-lLLVMCodeGen", "-lLLVMTargetParser", "-lLLVMScalarOpts", "-lLLVMTarget", "-lLLVMTransformUtils", "-lLLVMPasses", "-lLLVMSupport", "-lLLVMMCParser", "-lLLVMMC", "-lLLVMCore", "-lLLVMAsmPrinter", "-lLLVMAArch64Utils", "-lLLVMAArch64Info", "-lLLVMAArch64Desc", "-lLLVMAArch64AsmParser", "-lLLVMAArch64CodeGen", "-lLLVMCGData", "-lLLVMSandboxIR", "-lLLVMFrontendAtomic", "-lclangFrontend", "-lclangBasic", "-lclangEdit", "-lclangLex", "-lclangDriver", "-lclangSerialization", "-lclangAST", "-lclangSema", "-lclangAnalysis", "-lclangASTMatchers", "-lclangSupport", "-lclangParse", "-lclangAPINotes", "-lclangCodeGen", "-rpath", /usr/local/lib, ); PRODUCT_NAME = "$(TARGET_NAME)"; }; name = Debug; Loading @@ -1312,6 +1399,93 @@ ); LOCALIZATION_PREFERS_STRING_CATALOGS = YES; MACOSX_DEPLOYMENT_TARGET = 14.5; OTHER_LDFLAGS = ( "-lnetcdf", "-ld_classic", "-L/Users/m4c/Projects/graph_framework/build/_deps/llvm-build/lib", "-lz", "-lLLVMCoverage", "-lLLVMSupport", "-lLLVMDebugInfoCodeView", "-lLLVMRemarks", "-lLLVMJITLink", "-lLLVMLinker", "-lLLVMTextAPI", "-lLLVMRuntimeDyld", "-lLLVMOrcShared", "-lLLVMOrcDebugging", "-lLLVMOrcTargetProcess", "-lLLVMOrcJIT", "-lLLVMHipStdPar", "-lLLVMAggressiveInstCombine", "-lLLVMVectorize", "-lLLVMAsmParser", "-lLLVMOption", "-lLLVMLTO", "-lLLVMObject", "-lLLVMWindowsDriver", "-lLLVMDemangle", "-lLLVMIRReader", "-lLLVMIRPrinter", "-lLLVMInstCombine", "-lLLVMBinaryFormat", "-lLLVMCoroutines", "-lLLVMBitstreamReader", "-lLLVMBitReader", "-lLLVMBitWriter", "-lLLVMDebugInfoDWARF", "-lLLVMInstrumentation", "-lLLVMCFGuard", "-lLLVMObjCARCOpts", "-lLLVMipo", "-lLLVMGlobalISel", "-lLLVMExecutionEngine", "-lLLVMFrontendDriver", "-lLLVMFrontendHLSL", "-lLLVMFrontendOpenMP", "-lLLVMFrontendOffloading", "-lLLVMSelectionDAG", "-lLLVMProfileData", "-lLLVMAnalysis", "-lLLVMScalarOpts", "-lLLVMCodeGenTypes", "-lLLVMCodeGenData", "-lLLVMCodeGen", "-lLLVMTargetParser", "-lLLVMScalarOpts", "-lLLVMTarget", "-lLLVMTransformUtils", "-lLLVMPasses", "-lLLVMSupport", "-lLLVMMCParser", "-lLLVMMC", "-lLLVMCore", "-lLLVMAsmPrinter", "-lLLVMAArch64Utils", "-lLLVMAArch64Info", "-lLLVMAArch64Desc", "-lLLVMAArch64AsmParser", "-lLLVMAArch64CodeGen", "-lLLVMCGData", "-lLLVMSandboxIR", "-lLLVMFrontendAtomic", "-lclangFrontend", "-lclangBasic", "-lclangEdit", "-lclangLex", "-lclangDriver", "-lclangSerialization", "-lclangAST", "-lclangSema", "-lclangAnalysis", "-lclangASTMatchers", "-lclangSupport", "-lclangParse", "-lclangAPINotes", "-lclangCodeGen", "-rpath", /usr/local/lib, ); PRODUCT_NAME = "$(TARGET_NAME)"; }; name = Release; Loading
graph_framework/arithmetic.hpp +72 −0 Original line number Diff line number Diff line Loading @@ -2024,6 +2024,45 @@ namespace graph { return this->right/pow(lp->get_left(), -lp->get_right()); } } // a^b*c^b -> (a*c)^b if (lp.get() && rp.get()) { if (lp->get_right()->is_match(rp->get_right())) { return pow(lp->get_left()*rp->get_left(), lp->get_right()); } } // (a*b^c)*d^c -> a*(b*d)^c // (a^c*b)*d^c -> b*(a*d)^c // a^c*(b*d^c) -> b*(a*d)^c // a^c*(b^c*d) -> d*(a*b)^c if (lm.get() && rp.get()) { auto lmlp = pow_cast(lm->get_left()); auto lmrp = pow_cast(lm->get_right()); if (lmrp.get()) { if (lmrp->get_right()->is_match(rp->get_right())) { return lm->get_left()*pow(lmrp->get_left()*rp->get_left(), rp->get_right()); } } else if (lmlp.get()) { if (lmlp->get_right()->is_match(rp->get_right())) { return lm->get_right()*pow(lmlp->get_left()*rp->get_left(), rp->get_right()); } } } else if (rm.get() && lp.get()) { auto rmlp = pow_cast(rm->get_left()); auto rmrp = pow_cast(rm->get_right()); if (rmrp.get()) { if (rmrp->get_right()->is_match(lp->get_right())) { return rm->get_left()*pow(lp->get_left()*rmrp->get_left(), lp->get_right()); } } else if (rmlp.get()) { if (rmlp->get_right()->is_match(lp->get_right())) { return rm->get_right()*pow(lp->get_left()*rmlp->get_left(), lp->get_right()); } } } // (b*a)^c*a^d -> b^c*a^(c + d) // (a*b)^c*a^d -> b^c*a^(c + d) Loading Loading @@ -2931,6 +2970,14 @@ namespace graph { } } } // a^b/c^b -> (a/c)^b if (lp.get() && rp.get()) { if (lp->get_right()->is_match(rp->get_right())) { return pow(lp->get_left()/rp->get_left(), lp->get_right()); } } // (a*b)^c/((a^d)*e) = a^(c - d)*b^c/e // (b*a)^c/((a^d)*e) = a^(c - d)*b^c/e // (a*b)^c/(e*(a^d)) = a^(c - d)*b^c/e Loading Loading @@ -4320,6 +4367,31 @@ namespace graph { } } // a^b*c^b + d -> (a*c)^b + d if (lp.get() && mp.get()) { if (lp->get_right()->is_match(mp->get_right())) { return pow(lp->get_left()*mp->get_left(), lp->get_right()) + this->right; } } // fma(2,(ab)^2,a^2b) -> a^2*fma(2, b^2, b) if (rm.get() && mp.get()) { auto mplm = multiply_cast(mp->get_left()); if (mplm.get()) { if (is_variable_combineable(mplm->get_left(), rm->get_left())) { return pow(mplm->get_left(), mp->get_right()) * fma(this->left, pow(mplm->get_right(), mp->get_right()), this->right/mplm->get_left()); } } } // fma(a,b/c,b/d) -> b*(a/c + 1/d) // fma(a,c/b,d/b) -> (a*c + d)/b if (md.get() && rd.get()) { Loading
graph_framework/math.hpp +64 −2 Original line number Diff line number Diff line Loading @@ -957,8 +957,9 @@ namespace graph { } // Handle cases where (c*x)^a, (x*c)^a, (a*sqrt(b))^c and (a*b^c)^2. // These reductions only make sense if the power is constant. auto lm = multiply_cast(this->left); if (lm.get()) { if (lm.get() && rc.get()) { if (lm->get_left()->is_constant() || lm->get_right()->is_constant() || sqrt_cast(lm->get_left()).get() || Loading Loading @@ -986,8 +987,9 @@ namespace graph { } } // These reductions only make sense if the power is constant. auto ld = divide_cast(this->left); if (ld.get()) { if (ld.get() && rc.get()) { // For even exponents e. // (-a/b)^e -> (a/b)^e auto ldlm = multiply_cast(ld->get_left()); Loading @@ -1010,6 +1012,36 @@ namespace graph { pow(ldlm->get_right(), this->right)/ pow(ld->get_right(), this->right); } auto ldlmlm = multiply_cast(ldlm->get_left()); if (ldlmlm.get()) { if (ldlmlm->get_left()->is_constant() || ldlmlm->get_right()->is_constant() || sqrt_cast(ldlmlm->get_left()).get() || sqrt_cast(ldlmlm->get_right()).get() || pow_cast(ldlmlm->get_left()).get() || pow_cast(ldlmlm->get_right()).get()) { return (pow(ldlmlm->get_left(), this->right) * pow(ldlmlm->get_right(), this->right) * pow(ldlm->get_right(), this->right)) / pow(ld->get_right(), this->right); } } auto ldlmrm = multiply_cast(ldlm->get_right()); if (ldlmrm.get()) { if (ldlmrm->get_left()->is_constant() || ldlmrm->get_right()->is_constant() || sqrt_cast(ldlmrm->get_left()).get() || sqrt_cast(ldlmrm->get_right()).get() || pow_cast(ldlmrm->get_left()).get() || pow_cast(ldlmrm->get_right()).get()) { return (pow(ldlmrm->get_left(), this->right) * pow(ldlmrm->get_right(), this->right) * pow(ldlm->get_left(), this->right)) / pow(ld->get_right(), this->right); } } } // Handle cases where (c/x)^a, (x/c)^a, (a/sqrt(b))^c and (a/b^c)^2. Loading @@ -1036,6 +1068,36 @@ namespace graph { (pow(ldrm->get_left(), this->right) * pow(ldrm->get_right(), this->right)); } auto ldrmlm = multiply_cast(ldrm->get_left()); if (ldrmlm.get()) { if (ldrmlm->get_left()->is_constant() || ldrmlm->get_right()->is_constant() || sqrt_cast(ldrmlm->get_left()).get() || sqrt_cast(ldrmlm->get_right()).get() || pow_cast(ldrmlm->get_left()).get() || pow_cast(ldrmlm->get_right()).get()) { return pow(ld->get_left(), this->right) / (pow(ldrmlm->get_left(), this->right) * pow(ldrmlm->get_right(), this->right) * pow(ldrm->get_right(), this->right)); } } auto ldrmrm = multiply_cast(ldrm->get_right()); if (ldrmrm.get()) { if (ldrmrm->get_left()->is_constant() || ldrmrm->get_right()->is_constant() || sqrt_cast(ldrmrm->get_left()).get() || sqrt_cast(ldrmrm->get_right()).get() || pow_cast(ldrmrm->get_left()).get() || pow_cast(ldrmrm->get_right()).get()) { return pow(ld->get_left(), this->right) / (pow(ldrmrm->get_left(), this->right) * pow(ldrmrm->get_right(), this->right) * pow(ldrm->get_left(), this->right)); } } } if (is_variable_combineable(ld->get_left(), Loading
graph_tests/arithmetic_test.cpp +83 −12 Original line number Diff line number Diff line Loading @@ -537,8 +537,8 @@ template<jit::float_scalar T> void test_subtract() { "Expected to reduce to a constant minus one."); // Test common factors. auto var_a = graph::variable<T> (1, ""); auto var_b = graph::variable<T> (1, ""); auto var_a = graph::variable<T> (1, "a"); auto var_b = graph::variable<T> (1, "b"); auto var_c = graph::variable<T> (1, ""); auto common_a = var_a*var_b - var_a*var_c; assert(graph::add_cast(common_a).get() == nullptr && Loading Loading @@ -1850,34 +1850,81 @@ template<jit::float_scalar T> void test_multiply() { auto common_pow2 = graph::pow(var_b*var_a, 2.0)*graph::pow(var_a, 2.0); auto common_pow2_cast = graph::multiply_cast(common_pow2); assert(common_pow2_cast.get() && "Expected a multiply node."); assert(common_pow2_cast->get_left()->is_match(graph::pow(var_b, 2.0)) && assert(common_pow2_cast->get_right()->is_match(graph::pow(var_b, 2.0)) && "Expected b^2."); assert(common_pow2_cast->get_right()->is_match(graph::pow(var_a, 4.0)) && assert(common_pow2_cast->get_left()->is_match(graph::pow(var_a, 4.0)) && "Expected a^4."); // (a*b)^2*a^2 -> b^2*a^4 auto common_pow3 = graph::pow(var_a*var_b, 2.0)*graph::pow(var_a, 2.0); auto common_pow3_cast = graph::multiply_cast(common_pow3); assert(common_pow3_cast.get() && "Expected a multiply node."); assert(common_pow3_cast->get_left()->is_match(graph::pow(var_b, 2.0)) && assert(common_pow3_cast->get_right()->is_match(graph::pow(var_b, 2.0)) && "Expected b^2."); assert(common_pow3_cast->get_right()->is_match(graph::pow(var_a, 4.0)) && assert(common_pow3_cast->get_left()->is_match(graph::pow(var_a, 4.0)) && "Expected a^4."); // a^2*(b*a)^2 -> b^2*a^4 auto common_pow4 = graph::pow(var_a, 2.0)*graph::pow(var_b*var_a, 2.0); auto common_pow4_cast = graph::multiply_cast(common_pow4); assert(common_pow4_cast.get() && "Expected a multiply node."); assert(common_pow4_cast->get_left()->is_match(graph::pow(var_b, 2.0)) && assert(common_pow4_cast->get_right()->is_match(graph::pow(var_b, 2.0)) && "Expected b^2."); assert(common_pow4_cast->get_right()->is_match(graph::pow(var_a, 4.0)) && assert(common_pow4_cast->get_left()->is_match(graph::pow(var_a, 4.0)) && "Expected a^4."); // a^2*(b*a)^2 -> b^2*a^4 auto common_pow5 = graph::pow(var_a, 2.0)*graph::pow(var_a*var_b, 2.0); auto common_pow5_cast = graph::multiply_cast(common_pow5); assert(common_pow5_cast.get() && "Expected a multiply node."); assert(common_pow5_cast->get_left()->is_match(graph::pow(var_b, 2.0)) && assert(common_pow5_cast->get_right()->is_match(graph::pow(var_b, 2.0)) && "Expected b^2."); assert(common_pow5_cast->get_right()->is_match(graph::pow(var_a, 4.0)) && assert(common_pow5_cast->get_left()->is_match(graph::pow(var_a, 4.0)) && "Expected a^4."); // a^b*c^b -> (a*c)^b auto gather_power = graph::pow(var_a, var_b)*graph::pow(var_c, var_b); auto gather_power_cast = graph::pow_cast(gather_power); assert(gather_power_cast.get() && "Expected a power node."); assert(gather_power_cast->get_left()->is_match(var_a*var_c) && "Expected a*c."); assert(gather_power_cast->get_right()->is_match(var_b) && "Expected b."); // (a*b^c)*d^c -> a*(b*d)^c auto var_d = (3.0 + graph::variable<T> (1, "")); auto gather_power2 = (var_a*graph::pow(var_b, var_c))*graph::pow(var_d, var_c); auto gather_power2_cast = graph::multiply_cast(gather_power2); assert(gather_power2_cast.get() && "Expected a mutliply node."); assert(gather_power2_cast->get_left()->is_match(var_a) && "Expected a."); assert(gather_power2_cast->get_right()->is_match(graph::pow(var_b*var_d, var_c)) && "Expected (b*d)^c."); // (a^c*b)*d^c -> b*(a*d)^c auto gather_power3 = (graph::pow(var_a, var_c)*var_b)*graph::pow(var_d, var_c); auto gather_power3_cast = graph::multiply_cast(gather_power3); assert(gather_power3_cast.get() && "Expected a mutliply node."); assert(gather_power3_cast->get_left()->is_match(var_b) && "Expected b."); assert(gather_power3_cast->get_right()->is_match(graph::pow(var_a*var_d, var_c)) && "Expected (a*d)^c."); // a^c*(b*d^c) -> b*(a*d)^c auto gather_power4 = graph::pow(var_a, var_c)*(var_b*graph::pow(var_d, var_c)); auto gather_power4_cast = graph::multiply_cast(gather_power4); assert(gather_power4_cast.get() && "Expected a mutliply node."); assert(gather_power4_cast->get_left()->is_match(var_b) && "Expected b."); assert(gather_power4_cast->get_right()->is_match(graph::pow(var_a*var_d, var_c)) && "Expected (a*d)^c."); // a^c*(b^c*d) -> d*(a*b)^c auto gather_power5 = graph::pow(var_a, var_c)*(graph::pow(var_b, var_c)*var_d); auto gather_power5_cast = graph::multiply_cast(gather_power5); assert(gather_power5_cast.get() && "Expected a mutliply node."); assert(gather_power5_cast->get_left()->is_match(var_d) && "Expected d."); assert(gather_power5_cast->get_right()->is_match(graph::pow(var_a*var_b, var_c)) && "Expected (a*b)^c."); } //------------------------------------------------------------------------------ Loading Loading @@ -2600,6 +2647,14 @@ template<jit::float_scalar T> void test_divide() { // a/(b/c) -> a*c/b assert((a/(b/c))->is_match(a*c/b) && "Expected a*b/c"); // a^b/c^b -> (a/c)^b auto gather_power = graph::pow(a, b)/graph::pow(c, b); auto gather_power_cast = graph::pow_cast(gather_power); assert(gather_power_cast.get() && "Expected a power node."); assert(gather_power_cast->get_left()->is_match(a/c) && "Expected a*c."); assert(gather_power_cast->get_right()->is_match(b) && "Expected b."); // (a*b*c)^2/a^2 -> (b*c)^2 // (a*b*c)^2/(a^2*d) -> (b*c)^2/d // (e*(a*b*c)^2)/(a^2*d) -> e*(b*c)^2/d Loading Loading @@ -2716,8 +2771,8 @@ template<jit::float_scalar T> void test_fma() { "Expected a value of one."); // Test reduction. auto var_a = graph::variable<T> (1, ""); auto var_b = graph::variable<T> (1, ""); auto var_a = graph::variable<T> (1, "a"); auto var_b = graph::variable<T> (1, "b"); auto var_c = graph::variable<T> (1, ""); // fma(1,a,b) = a + b Loading Loading @@ -3509,6 +3564,22 @@ template<jit::float_scalar T> void test_fma() { var_d*var_b, var_e)) && "Expected fma(a,d*b,e) as numerator."); // a^b*c^b + d -> (a*c)^b + d auto gather_power = graph::fma(graph::pow(var_a, var_b), graph::pow(var_c, var_b), var_d); assert(gather_power->is_match(graph::pow(var_a*var_c, var_b) + var_d) && "Expected a power node."); // fma(2,(ab)^2,a^2b) -> a^2*fma(2, b^2, b) auto commom_power = graph::fma(2.0, graph::pow(var_a*var_b, 2.0), graph::pow(var_a, 2.0)*var_b); commom_power->to_latex(); std::cout << std::endl << std::endl; auto commom_power_cast = graph::multiply_cast(commom_power); assert(commom_power_cast.get() && "Expected a multiply node."); // fma(2,(a*b)^2,fma()) } //------------------------------------------------------------------------------ Loading
graph_tests/math_test.cpp +124 −45 File changed.Preview size limit exceeded, changes collapsed. Show changes