From 1e3d6113c14a3bb293b6704bdde8e049b35f6299 Mon Sep 17 00:00:00 2001 From: cianciosa Date: Thu, 30 Jan 2025 17:50:31 -0500 Subject: [PATCH 1/4] Commit changes to power reductions. This is WIP becuase there are several regessions. --- graph_framework.xcodeproj/project.pbxproj | 174 ++++++++++++++++++++++ graph_framework/arithmetic.hpp | 72 +++++++++ graph_framework/math.hpp | 66 +++++++- graph_tests/arithmetic_test.cpp | 95 ++++++++++-- graph_tests/math_test.cpp | 169 +++++++++++++++------ 5 files changed, 517 insertions(+), 59 deletions(-) diff --git a/graph_framework.xcodeproj/project.pbxproj b/graph_framework.xcodeproj/project.pbxproj index 6a8b85a..794b6ae 100644 --- a/graph_framework.xcodeproj/project.pbxproj +++ b/graph_framework.xcodeproj/project.pbxproj @@ -1294,6 +1294,93 @@ ); LOCALIZATION_PREFERS_STRING_CATALOGS = YES; MACOSX_DEPLOYMENT_TARGET = 14.5; + OTHER_LDFLAGS = ( + "-lnetcdf", + "-ld_classic", + "-L/Users/m4c/Projects/graph_framework/build/_deps/llvm-build/lib", + "-lz", + "-lLLVMCoverage", + "-lLLVMSupport", + "-lLLVMDebugInfoCodeView", + "-lLLVMRemarks", + "-lLLVMJITLink", + "-lLLVMLinker", + "-lLLVMTextAPI", + "-lLLVMRuntimeDyld", + "-lLLVMOrcShared", + "-lLLVMOrcDebugging", + "-lLLVMOrcTargetProcess", + "-lLLVMOrcJIT", + "-lLLVMHipStdPar", + "-lLLVMAggressiveInstCombine", + "-lLLVMVectorize", + "-lLLVMAsmParser", + "-lLLVMOption", + "-lLLVMLTO", + "-lLLVMObject", + "-lLLVMWindowsDriver", + "-lLLVMDemangle", + "-lLLVMIRReader", + "-lLLVMIRPrinter", + "-lLLVMInstCombine", + "-lLLVMBinaryFormat", + "-lLLVMCoroutines", + "-lLLVMBitstreamReader", + "-lLLVMBitReader", + "-lLLVMBitWriter", + "-lLLVMDebugInfoDWARF", + "-lLLVMInstrumentation", + "-lLLVMCFGuard", + "-lLLVMObjCARCOpts", + "-lLLVMipo", + "-lLLVMGlobalISel", + "-lLLVMExecutionEngine", + "-lLLVMFrontendDriver", + "-lLLVMFrontendHLSL", + "-lLLVMFrontendOpenMP", + "-lLLVMFrontendOffloading", + "-lLLVMSelectionDAG", + "-lLLVMProfileData", + "-lLLVMAnalysis", + "-lLLVMScalarOpts", + "-lLLVMCodeGenTypes", + "-lLLVMCodeGenData", + "-lLLVMCodeGen", + "-lLLVMTargetParser", + "-lLLVMScalarOpts", + "-lLLVMTarget", + "-lLLVMTransformUtils", + "-lLLVMPasses", + "-lLLVMSupport", + "-lLLVMMCParser", + "-lLLVMMC", + "-lLLVMCore", + "-lLLVMAsmPrinter", + "-lLLVMAArch64Utils", + "-lLLVMAArch64Info", + "-lLLVMAArch64Desc", + "-lLLVMAArch64AsmParser", + "-lLLVMAArch64CodeGen", + "-lLLVMCGData", + "-lLLVMSandboxIR", + "-lLLVMFrontendAtomic", + "-lclangFrontend", + "-lclangBasic", + "-lclangEdit", + "-lclangLex", + "-lclangDriver", + "-lclangSerialization", + "-lclangAST", + "-lclangSema", + "-lclangAnalysis", + "-lclangASTMatchers", + "-lclangSupport", + "-lclangParse", + "-lclangAPINotes", + "-lclangCodeGen", + "-rpath", + /usr/local/lib, + ); PRODUCT_NAME = "$(TARGET_NAME)"; }; name = Debug; @@ -1312,6 +1399,93 @@ ); LOCALIZATION_PREFERS_STRING_CATALOGS = YES; MACOSX_DEPLOYMENT_TARGET = 14.5; + OTHER_LDFLAGS = ( + "-lnetcdf", + "-ld_classic", + "-L/Users/m4c/Projects/graph_framework/build/_deps/llvm-build/lib", + "-lz", + "-lLLVMCoverage", + "-lLLVMSupport", + "-lLLVMDebugInfoCodeView", + "-lLLVMRemarks", + "-lLLVMJITLink", + "-lLLVMLinker", + "-lLLVMTextAPI", + "-lLLVMRuntimeDyld", + "-lLLVMOrcShared", + "-lLLVMOrcDebugging", + "-lLLVMOrcTargetProcess", + "-lLLVMOrcJIT", + "-lLLVMHipStdPar", + "-lLLVMAggressiveInstCombine", + "-lLLVMVectorize", + "-lLLVMAsmParser", + "-lLLVMOption", + "-lLLVMLTO", + "-lLLVMObject", + "-lLLVMWindowsDriver", + "-lLLVMDemangle", + "-lLLVMIRReader", + "-lLLVMIRPrinter", + "-lLLVMInstCombine", + "-lLLVMBinaryFormat", + "-lLLVMCoroutines", + "-lLLVMBitstreamReader", + "-lLLVMBitReader", + "-lLLVMBitWriter", + "-lLLVMDebugInfoDWARF", + "-lLLVMInstrumentation", + "-lLLVMCFGuard", + "-lLLVMObjCARCOpts", + "-lLLVMipo", + "-lLLVMGlobalISel", + "-lLLVMExecutionEngine", + "-lLLVMFrontendDriver", + "-lLLVMFrontendHLSL", + "-lLLVMFrontendOpenMP", + "-lLLVMFrontendOffloading", + "-lLLVMSelectionDAG", + "-lLLVMProfileData", + "-lLLVMAnalysis", + "-lLLVMScalarOpts", + "-lLLVMCodeGenTypes", + "-lLLVMCodeGenData", + "-lLLVMCodeGen", + "-lLLVMTargetParser", + "-lLLVMScalarOpts", + "-lLLVMTarget", + "-lLLVMTransformUtils", + "-lLLVMPasses", + "-lLLVMSupport", + "-lLLVMMCParser", + "-lLLVMMC", + "-lLLVMCore", + "-lLLVMAsmPrinter", + "-lLLVMAArch64Utils", + "-lLLVMAArch64Info", + "-lLLVMAArch64Desc", + "-lLLVMAArch64AsmParser", + "-lLLVMAArch64CodeGen", + "-lLLVMCGData", + "-lLLVMSandboxIR", + "-lLLVMFrontendAtomic", + "-lclangFrontend", + "-lclangBasic", + "-lclangEdit", + "-lclangLex", + "-lclangDriver", + "-lclangSerialization", + "-lclangAST", + "-lclangSema", + "-lclangAnalysis", + "-lclangASTMatchers", + "-lclangSupport", + "-lclangParse", + "-lclangAPINotes", + "-lclangCodeGen", + "-rpath", + /usr/local/lib, + ); PRODUCT_NAME = "$(TARGET_NAME)"; }; name = Release; diff --git a/graph_framework/arithmetic.hpp b/graph_framework/arithmetic.hpp index 6be028d..bfd42cb 100644 --- a/graph_framework/arithmetic.hpp +++ b/graph_framework/arithmetic.hpp @@ -2024,6 +2024,45 @@ namespace graph { return this->right/pow(lp->get_left(), -lp->get_right()); } } +// a^b*c^b -> (a*c)^b + if (lp.get() && rp.get()) { + if (lp->get_right()->is_match(rp->get_right())) { + return pow(lp->get_left()*rp->get_left(), lp->get_right()); + } + } +// (a*b^c)*d^c -> a*(b*d)^c +// (a^c*b)*d^c -> b*(a*d)^c +// a^c*(b*d^c) -> b*(a*d)^c +// a^c*(b^c*d) -> d*(a*b)^c + if (lm.get() && rp.get()) { + auto lmlp = pow_cast(lm->get_left()); + auto lmrp = pow_cast(lm->get_right()); + if (lmrp.get()) { + if (lmrp->get_right()->is_match(rp->get_right())) { + return lm->get_left()*pow(lmrp->get_left()*rp->get_left(), + rp->get_right()); + } + } else if (lmlp.get()) { + if (lmlp->get_right()->is_match(rp->get_right())) { + return lm->get_right()*pow(lmlp->get_left()*rp->get_left(), + rp->get_right()); + } + } + } else if (rm.get() && lp.get()) { + auto rmlp = pow_cast(rm->get_left()); + auto rmrp = pow_cast(rm->get_right()); + if (rmrp.get()) { + if (rmrp->get_right()->is_match(lp->get_right())) { + return rm->get_left()*pow(lp->get_left()*rmrp->get_left(), + lp->get_right()); + } + } else if (rmlp.get()) { + if (rmlp->get_right()->is_match(lp->get_right())) { + return rm->get_right()*pow(lp->get_left()*rmlp->get_left(), + lp->get_right()); + } + } + } // (b*a)^c*a^d -> b^c*a^(c + d) // (a*b)^c*a^d -> b^c*a^(c + d) @@ -2931,6 +2970,14 @@ namespace graph { } } } + +// a^b/c^b -> (a/c)^b + if (lp.get() && rp.get()) { + if (lp->get_right()->is_match(rp->get_right())) { + return pow(lp->get_left()/rp->get_left(), lp->get_right()); + } + } + // (a*b)^c/((a^d)*e) = a^(c - d)*b^c/e // (b*a)^c/((a^d)*e) = a^(c - d)*b^c/e // (a*b)^c/(e*(a^d)) = a^(c - d)*b^c/e @@ -4320,6 +4367,31 @@ namespace graph { } } +// a^b*c^b + d -> (a*c)^b + d + if (lp.get() && mp.get()) { + if (lp->get_right()->is_match(mp->get_right())) { + return pow(lp->get_left()*mp->get_left(), + lp->get_right()) + + this->right; + } + } + +// fma(2,(ab)^2,a^2b) -> a^2*fma(2, b^2, b) + if (rm.get() && mp.get()) { + auto mplm = multiply_cast(mp->get_left()); + if (mplm.get()) { + if (is_variable_combineable(mplm->get_left(), + rm->get_left())) { + return pow(mplm->get_left(), + mp->get_right()) * + fma(this->left, + pow(mplm->get_right(), + mp->get_right()), + this->right/mplm->get_left()); + } + } + } + // fma(a,b/c,b/d) -> b*(a/c + 1/d) // fma(a,c/b,d/b) -> (a*c + d)/b if (md.get() && rd.get()) { diff --git a/graph_framework/math.hpp b/graph_framework/math.hpp index 7e58fa1..bda00d4 100644 --- a/graph_framework/math.hpp +++ b/graph_framework/math.hpp @@ -957,8 +957,9 @@ namespace graph { } // Handle cases where (c*x)^a, (x*c)^a, (a*sqrt(b))^c and (a*b^c)^2. +// These reductions only make sense if the power is constant. auto lm = multiply_cast(this->left); - if (lm.get()) { + if (lm.get() && rc.get()) { if (lm->get_left()->is_constant() || lm->get_right()->is_constant() || sqrt_cast(lm->get_left()).get() || @@ -986,8 +987,9 @@ namespace graph { } } +// These reductions only make sense if the power is constant. auto ld = divide_cast(this->left); - if (ld.get()) { + if (ld.get() && rc.get()) { // For even exponents e. // (-a/b)^e -> (a/b)^e auto ldlm = multiply_cast(ld->get_left()); @@ -1010,6 +1012,36 @@ namespace graph { pow(ldlm->get_right(), this->right)/ pow(ld->get_right(), this->right); } + + auto ldlmlm = multiply_cast(ldlm->get_left()); + if (ldlmlm.get()) { + if (ldlmlm->get_left()->is_constant() || + ldlmlm->get_right()->is_constant() || + sqrt_cast(ldlmlm->get_left()).get() || + sqrt_cast(ldlmlm->get_right()).get() || + pow_cast(ldlmlm->get_left()).get() || + pow_cast(ldlmlm->get_right()).get()) { + return (pow(ldlmlm->get_left(), this->right) * + pow(ldlmlm->get_right(), this->right) * + pow(ldlm->get_right(), this->right)) / + pow(ld->get_right(), this->right); + } + } + + auto ldlmrm = multiply_cast(ldlm->get_right()); + if (ldlmrm.get()) { + if (ldlmrm->get_left()->is_constant() || + ldlmrm->get_right()->is_constant() || + sqrt_cast(ldlmrm->get_left()).get() || + sqrt_cast(ldlmrm->get_right()).get() || + pow_cast(ldlmrm->get_left()).get() || + pow_cast(ldlmrm->get_right()).get()) { + return (pow(ldlmrm->get_left(), this->right) * + pow(ldlmrm->get_right(), this->right) * + pow(ldlm->get_left(), this->right)) / + pow(ld->get_right(), this->right); + } + } } // Handle cases where (c/x)^a, (x/c)^a, (a/sqrt(b))^c and (a/b^c)^2. @@ -1036,6 +1068,36 @@ namespace graph { (pow(ldrm->get_left(), this->right) * pow(ldrm->get_right(), this->right)); } + + auto ldrmlm = multiply_cast(ldrm->get_left()); + if (ldrmlm.get()) { + if (ldrmlm->get_left()->is_constant() || + ldrmlm->get_right()->is_constant() || + sqrt_cast(ldrmlm->get_left()).get() || + sqrt_cast(ldrmlm->get_right()).get() || + pow_cast(ldrmlm->get_left()).get() || + pow_cast(ldrmlm->get_right()).get()) { + return pow(ld->get_left(), this->right) / + (pow(ldrmlm->get_left(), this->right) * + pow(ldrmlm->get_right(), this->right) * + pow(ldrm->get_right(), this->right)); + } + } + + auto ldrmrm = multiply_cast(ldrm->get_right()); + if (ldrmrm.get()) { + if (ldrmrm->get_left()->is_constant() || + ldrmrm->get_right()->is_constant() || + sqrt_cast(ldrmrm->get_left()).get() || + sqrt_cast(ldrmrm->get_right()).get() || + pow_cast(ldrmrm->get_left()).get() || + pow_cast(ldrmrm->get_right()).get()) { + return pow(ld->get_left(), this->right) / + (pow(ldrmrm->get_left(), this->right) * + pow(ldrmrm->get_right(), this->right) * + pow(ldrm->get_left(), this->right)); + } + } } if (is_variable_combineable(ld->get_left(), diff --git a/graph_tests/arithmetic_test.cpp b/graph_tests/arithmetic_test.cpp index 5577346..d161514 100644 --- a/graph_tests/arithmetic_test.cpp +++ b/graph_tests/arithmetic_test.cpp @@ -537,8 +537,8 @@ template void test_subtract() { "Expected to reduce to a constant minus one."); // Test common factors. - auto var_a = graph::variable (1, ""); - auto var_b = graph::variable (1, ""); + auto var_a = graph::variable (1, "a"); + auto var_b = graph::variable (1, "b"); auto var_c = graph::variable (1, ""); auto common_a = var_a*var_b - var_a*var_c; assert(graph::add_cast(common_a).get() == nullptr && @@ -1850,34 +1850,81 @@ template void test_multiply() { auto common_pow2 = graph::pow(var_b*var_a, 2.0)*graph::pow(var_a, 2.0); auto common_pow2_cast = graph::multiply_cast(common_pow2); assert(common_pow2_cast.get() && "Expected a multiply node."); - assert(common_pow2_cast->get_left()->is_match(graph::pow(var_b, 2.0)) && + assert(common_pow2_cast->get_right()->is_match(graph::pow(var_b, 2.0)) && "Expected b^2."); - assert(common_pow2_cast->get_right()->is_match(graph::pow(var_a, 4.0)) && + assert(common_pow2_cast->get_left()->is_match(graph::pow(var_a, 4.0)) && "Expected a^4."); // (a*b)^2*a^2 -> b^2*a^4 auto common_pow3 = graph::pow(var_a*var_b, 2.0)*graph::pow(var_a, 2.0); auto common_pow3_cast = graph::multiply_cast(common_pow3); assert(common_pow3_cast.get() && "Expected a multiply node."); - assert(common_pow3_cast->get_left()->is_match(graph::pow(var_b, 2.0)) && + assert(common_pow3_cast->get_right()->is_match(graph::pow(var_b, 2.0)) && "Expected b^2."); - assert(common_pow3_cast->get_right()->is_match(graph::pow(var_a, 4.0)) && + assert(common_pow3_cast->get_left()->is_match(graph::pow(var_a, 4.0)) && "Expected a^4."); // a^2*(b*a)^2 -> b^2*a^4 auto common_pow4 = graph::pow(var_a, 2.0)*graph::pow(var_b*var_a, 2.0); auto common_pow4_cast = graph::multiply_cast(common_pow4); assert(common_pow4_cast.get() && "Expected a multiply node."); - assert(common_pow4_cast->get_left()->is_match(graph::pow(var_b, 2.0)) && + assert(common_pow4_cast->get_right()->is_match(graph::pow(var_b, 2.0)) && "Expected b^2."); - assert(common_pow4_cast->get_right()->is_match(graph::pow(var_a, 4.0)) && + assert(common_pow4_cast->get_left()->is_match(graph::pow(var_a, 4.0)) && "Expected a^4."); // a^2*(b*a)^2 -> b^2*a^4 auto common_pow5 = graph::pow(var_a, 2.0)*graph::pow(var_a*var_b, 2.0); auto common_pow5_cast = graph::multiply_cast(common_pow5); assert(common_pow5_cast.get() && "Expected a multiply node."); - assert(common_pow5_cast->get_left()->is_match(graph::pow(var_b, 2.0)) && + assert(common_pow5_cast->get_right()->is_match(graph::pow(var_b, 2.0)) && "Expected b^2."); - assert(common_pow5_cast->get_right()->is_match(graph::pow(var_a, 4.0)) && + assert(common_pow5_cast->get_left()->is_match(graph::pow(var_a, 4.0)) && "Expected a^4."); + +// a^b*c^b -> (a*c)^b + auto gather_power = graph::pow(var_a, var_b)*graph::pow(var_c, var_b); + auto gather_power_cast = graph::pow_cast(gather_power); + assert(gather_power_cast.get() && "Expected a power node."); + assert(gather_power_cast->get_left()->is_match(var_a*var_c) && + "Expected a*c."); + assert(gather_power_cast->get_right()->is_match(var_b) && + "Expected b."); + +// (a*b^c)*d^c -> a*(b*d)^c + auto var_d = (3.0 + graph::variable (1, "")); + auto gather_power2 = (var_a*graph::pow(var_b, var_c))*graph::pow(var_d, var_c); + auto gather_power2_cast = graph::multiply_cast(gather_power2); + assert(gather_power2_cast.get() && "Expected a mutliply node."); + assert(gather_power2_cast->get_left()->is_match(var_a) && + "Expected a."); + assert(gather_power2_cast->get_right()->is_match(graph::pow(var_b*var_d, + var_c)) && + "Expected (b*d)^c."); +// (a^c*b)*d^c -> b*(a*d)^c + auto gather_power3 = (graph::pow(var_a, var_c)*var_b)*graph::pow(var_d, var_c); + auto gather_power3_cast = graph::multiply_cast(gather_power3); + assert(gather_power3_cast.get() && "Expected a mutliply node."); + assert(gather_power3_cast->get_left()->is_match(var_b) && + "Expected b."); + assert(gather_power3_cast->get_right()->is_match(graph::pow(var_a*var_d, + var_c)) && + "Expected (a*d)^c."); +// a^c*(b*d^c) -> b*(a*d)^c + auto gather_power4 = graph::pow(var_a, var_c)*(var_b*graph::pow(var_d, var_c)); + auto gather_power4_cast = graph::multiply_cast(gather_power4); + assert(gather_power4_cast.get() && "Expected a mutliply node."); + assert(gather_power4_cast->get_left()->is_match(var_b) && + "Expected b."); + assert(gather_power4_cast->get_right()->is_match(graph::pow(var_a*var_d, + var_c)) && + "Expected (a*d)^c."); +// a^c*(b^c*d) -> d*(a*b)^c + auto gather_power5 = graph::pow(var_a, var_c)*(graph::pow(var_b, var_c)*var_d); + auto gather_power5_cast = graph::multiply_cast(gather_power5); + assert(gather_power5_cast.get() && "Expected a mutliply node."); + assert(gather_power5_cast->get_left()->is_match(var_d) && + "Expected d."); + assert(gather_power5_cast->get_right()->is_match(graph::pow(var_a*var_b, + var_c)) && + "Expected (a*b)^c."); } //------------------------------------------------------------------------------ @@ -2600,6 +2647,14 @@ template void test_divide() { // a/(b/c) -> a*c/b assert((a/(b/c))->is_match(a*c/b) && "Expected a*b/c"); +// a^b/c^b -> (a/c)^b + auto gather_power = graph::pow(a, b)/graph::pow(c, b); + auto gather_power_cast = graph::pow_cast(gather_power); + assert(gather_power_cast.get() && "Expected a power node."); + assert(gather_power_cast->get_left()->is_match(a/c) && + "Expected a*c."); + assert(gather_power_cast->get_right()->is_match(b) && + "Expected b."); // (a*b*c)^2/a^2 -> (b*c)^2 // (a*b*c)^2/(a^2*d) -> (b*c)^2/d // (e*(a*b*c)^2)/(a^2*d) -> e*(b*c)^2/d @@ -2716,8 +2771,8 @@ template void test_fma() { "Expected a value of one."); // Test reduction. - auto var_a = graph::variable (1, ""); - auto var_b = graph::variable (1, ""); + auto var_a = graph::variable (1, "a"); + auto var_b = graph::variable (1, "b"); auto var_c = graph::variable (1, ""); // fma(1,a,b) = a + b @@ -3509,6 +3564,22 @@ template void test_fma() { var_d*var_b, var_e)) && "Expected fma(a,d*b,e) as numerator."); + +// a^b*c^b + d -> (a*c)^b + d + auto gather_power = graph::fma(graph::pow(var_a, var_b), + graph::pow(var_c, var_b), + var_d); + assert(gather_power->is_match(graph::pow(var_a*var_c, + var_b) + var_d) && + "Expected a power node."); + +// fma(2,(ab)^2,a^2b) -> a^2*fma(2, b^2, b) + auto commom_power = graph::fma(2.0, graph::pow(var_a*var_b, 2.0), graph::pow(var_a, 2.0)*var_b); + commom_power->to_latex(); + std::cout << std::endl << std::endl; + auto commom_power_cast = graph::multiply_cast(commom_power); + assert(commom_power_cast.get() && "Expected a multiply node."); +// fma(2,(a*b)^2,fma()) } //------------------------------------------------------------------------------ diff --git a/graph_tests/math_test.cpp b/graph_tests/math_test.cpp index 47849cc..972ff51 100644 --- a/graph_tests/math_test.cpp +++ b/graph_tests/math_test.cpp @@ -67,26 +67,20 @@ void test_sqrt() { auto x2_sqrt = graph::sqrt(x_var*x_var); assert(x2_sqrt.get() != x_var.get() && "Expected not to reduce to x_var."); -// Reduction Sqrt(x*y*x*y) = x*y +// Reduction Sqrt(x*y*x*y) = sqrt((x*y)^2) auto y_var = graph::variable (1, "y"); auto x2y2_sqrt = graph::sqrt(x_var*y_var*x_var*y_var); - auto x2y2_sqrt_cast = graph::multiply_cast(x2y2_sqrt); - assert(x2y2_sqrt_cast.get() && "Expected multiply node"); - assert((x2y2_sqrt_cast->get_left().get() != x_var.get() || - x2y2_sqrt_cast->get_left().get() != y_var.get()) && - "Expected x_var or y_var."); - assert((x2y2_sqrt_cast->get_right().get() != x_var.get() || - x2y2_sqrt_cast->get_right().get() != y_var.get()) && - "Expected x_var or y_var."); + auto x2y2_sqrt_cast = graph::sqrt_cast(x2y2_sqrt); + assert(x2y2_sqrt_cast.get() && "Expected sqrt node"); + assert(x2y2_sqrt_cast->get_arg()->is_match(graph::pow(y_var*x_var, 2.0)) && + "Expected (x_var*y_var)^2."); // Reduction Sqrt(x*x/y*y); auto sq_reduce = graph::sqrt((x_var*x_var)/(y_var*y_var)); - auto sq_reduce_cast = graph::divide_cast(sq_reduce); - assert(sq_reduce_cast.get() && "Expected divide node."); - assert(sq_reduce_cast->get_left().get() != x_var.get() && - "Expected x_var."); - assert(sq_reduce_cast->get_right().get() != y_var.get() && - "Expected y_var."); + auto sq_reduce_cast = graph::sqrt_cast(sq_reduce); + assert(sq_reduce_cast.get() && "Expected sqrt node."); + assert(sq_reduce_cast->get_arg()->is_match(graph::pow(x_var/y_var, 2.0)) && + "Expected (x_var/y_var)^2."); // Reduction Sqrt(c*x/b*y) = d*Sqrt(x/y) auto cxby_sqrt = graph::sqrt(x1/x2); @@ -195,32 +189,32 @@ void test_pow() { assert(constant_cast.get() && "Expected constant node on the left."); assert(constant_cast->is(0.5) && "Expected a value of 0.5"); -// (c*Sqrt(b))^a -> c^a*b^a/2 - assert(graph::multiply_cast(graph::pow(2.0*graph::sqrt(ten), ten)).get() && +// (c1*Sqrt(b))^c2 -> c3*b^c2/2 + assert(graph::multiply_cast(graph::pow(2.0*graph::sqrt(ten), 10.0)).get() && "Expected multiply node."); // (Sqrt(b)*c)^a -> c^a*b^a/2 - assert(graph::multiply_cast(graph::pow(graph::sqrt(ten)*2.0, ten)).get() && + assert(graph::multiply_cast(graph::pow(graph::sqrt(ten)*2.0, 10.0)).get() && "Expected multiply node."); // (c*b^d)^a -> c^a*b^(a*d) - assert(graph::multiply_cast(graph::pow(2.0*graph::pow(ten, 2.0), ten)).get() && + assert(graph::multiply_cast(graph::pow(2.0*graph::pow(ten, 2.0), 10.0)).get() && "Expected multiply node."); // ((b^d)*c)^a -> b^(a*d)*c^a - assert(graph::multiply_cast(graph::pow(graph::pow(ten, 2.0)*2.0, ten)).get() && + assert(graph::multiply_cast(graph::pow(graph::pow(ten, 2.0)*2.0, 10.0)).get() && "Expected multiply node."); // (c/Sqrt(b))^a -> c^a/b^a/2 - assert(graph::divide_cast(graph::pow(2.0/graph::sqrt(ten), ten)).get() && + assert(graph::divide_cast(graph::pow(2.0/graph::sqrt(ten), 10.0)).get() && "Expected divide node."); // (Sqrt(b)/c)^a -> (b^a/2)/c^a -> c2*b^a - assert(graph::multiply_cast(graph::pow(graph::sqrt(ten)/2.0, ten)).get() && + assert(graph::multiply_cast(graph::pow(graph::sqrt(ten)/2.0, 10.0)).get() && "Expected multiply node."); // (c/(b^d))^a -> c^a/(b^(a*d)) - assert(graph::divide_cast(graph::pow(2.0/graph::pow(ten, 2.0), ten)).get() && + assert(graph::divide_cast(graph::pow(2.0/graph::pow(ten, 2.0), 10.0)).get() && "Expected divide node."); // ((b^d)/c))^a -> (b^(a*d))/c^a -> c2*b^a - assert(graph::multiply_cast(graph::pow(graph::pow(ten, 2.0)/2.0, ten)).get() && + assert(graph::multiply_cast(graph::pow(graph::pow(ten, 2.0)/2.0, 10.0)).get() && "Expected multiply node."); // a^1/2 -> sqrt(a); @@ -377,38 +371,41 @@ void test_pow() { auto pow_combine3 = graph::pow(expr_b/(expr_c*graph::sqrt(expr_a*expr_a)*expr_a), 2.0); auto pow_combine3_cast = graph::divide_cast(pow_combine3); assert(pow_combine3_cast.get() && "Expected a divide node."); - assert(pow_combine3_cast->get_left()->is_match(graph::pow(expr_b/expr_c, 2.0)) && - "Expected (b/c)^2."); - assert(pow_combine3_cast->get_right()->is_match(graph::pow(expr_a, 4.0)) && - "Expected (b/c)^2."); + assert(pow_combine3_cast->get_left()->is_match(graph::pow(expr_b, 2.0)) && + "Expected b^2."); + assert(pow_combine3_cast->get_right()->is_match(graph::pow(expr_a, 4.0) * + graph::pow(expr_c, 2.0)) && + "Expected a^4*c^2."); auto pow_combine4 = graph::pow(expr_b/(graph::sqrt(expr_a*expr_a)*expr_c*expr_a), 2.0); auto pow_combine4_cast = graph::divide_cast(pow_combine4); assert(pow_combine4_cast.get() && "Expected a divide node."); - assert(pow_combine4_cast->get_left()->is_match(graph::pow(expr_b/expr_c, 2.0)) && - "Expected (b/c)^2."); - assert(pow_combine4_cast->get_right()->is_match(graph::pow(expr_a, 4.0)) && - "Expected (b/c)^2."); + assert(pow_combine4_cast->get_left()->is_match(graph::pow(expr_b, 2.0)) && + "Expected b^2."); + assert(pow_combine4_cast->get_right()->is_match(graph::pow(expr_a, 4.0) * + graph::pow(expr_c, 2.0)) && + "Expected a^4*c^2."); auto pow_combine5 = graph::pow(expr_b/(expr_a*graph::sqrt(expr_a*expr_a)*expr_c), 2.0); auto pow_combine5_cast = graph::divide_cast(pow_combine5); assert(pow_combine5_cast.get() && "Expected a divide node."); - assert(pow_combine5_cast->get_left()->is_match(graph::pow(expr_b/expr_c, 2.0)) && - "Expected (b/c)^2."); - assert(pow_combine5_cast->get_right()->is_match(graph::pow(expr_a, 4.0)) && - "Expected (b/c)^2."); + assert(pow_combine5_cast->get_left()->is_match(graph::pow(expr_b, 2.0)) && + "Expected b^2."); + assert(pow_combine5_cast->get_right()->is_match(graph::pow(expr_a, 4.0) * + graph::pow(expr_c, 2.0)) && + "Expected a^4*c^2."); auto pow_combine6 = graph::pow(expr_b/(graph::sqrt(expr_a*expr_a)*expr_a*expr_c), 2.0); auto pow_combine6_cast = graph::divide_cast(pow_combine6); assert(pow_combine6_cast.get() && "Expected a divide node."); - assert(pow_combine6_cast->get_left()->is_match(graph::pow(expr_b/expr_c, 2.0)) && - "Expected (b/c)^2."); - assert(pow_combine6_cast->get_right()->is_match(graph::pow(expr_a, 4.0)) && - "Expected (b/c)^2."); + assert(pow_combine6_cast->get_left()->is_match(graph::pow(expr_b, 2.0)) && + "Expected b^2."); + assert(pow_combine6_cast->get_right()->is_match(graph::pow(expr_a, 4.0) * + graph::pow(expr_c, 2.0)) && + "Expected a^4*c^2."); // (Sqrt(a)*b*c)^d -> a^(d/2)*(b*c)^d - auto sqrtpow = graph::pow(var_c*var_d*graph::sqrt(var_a), var_b); - assert(sqrtpow.get()->is_match(graph::pow(var_a, var_b/2.0) * - graph::pow(var_c, var_b) * - graph::pow(var_d, var_b)) && - "Expected a^(d/2)*b^2*c^d."); + auto sqrtpow = graph::pow(var_c*var_d*graph::sqrt(var_a), 10.0); + assert(sqrtpow.get()->is_match(graph::pow(var_a, 5.0) * + graph::pow(var_c*var_d, 10.0)) && + "Expected a^(d/2)*(c*b)^d."); auto factorconst = graph::pow(-0.5*var_a/var_b, 2.0); auto factorconst_cast = graph::multiply_cast(factorconst); @@ -417,6 +414,88 @@ void test_pow() { "Expected 0.25 on the left."); assert(factorconst_cast->get_right()->is_match(graph::pow(var_a/var_b, 2.0)) && "Expected (a/b)^2 on the right."); + +// pow(a/((sqrt(b)*c)*d), 2) -> pow(a,2)/(b*pow(c,2)*pow(d,2)) + auto divid_sqrt = graph::pow(var_a/((graph::sqrt(var_b)*var_c)*var_d), 2.0); + auto divid_sqrt_cast = graph::divide_cast(divid_sqrt); + assert(divid_sqrt_cast.get() && "Expected a divide node."); + assert(divid_sqrt_cast->get_left()->is_match(graph::pow(var_a, 2.0)) && + "Expected a^2."); + assert(divid_sqrt_cast->get_right()->is_match(graph::pow(var_c, 2.0) * + graph::pow(var_d, 2.0) * + var_b) && + "Expected c^2*d^2*b."); +// pow(a/((c*sqrt(b))*d), 2) -> pow(a,2)/(b*pow(c,2)*pow(d,2)) + auto divid_sqrt2 = graph::pow(var_a/((var_c*graph::sqrt(var_b))*var_d), 2.0); + auto divid_sqrt2_cast = graph::divide_cast(divid_sqrt2); + assert(divid_sqrt2_cast.get() && "Expected a divide node."); + assert(divid_sqrt2_cast->get_left()->is_match(graph::pow(var_a, 2.0)) && + "Expected a^2."); + assert(divid_sqrt2_cast->get_right()->is_match(graph::pow(var_c, 2.0) * + graph::pow(var_d, 2.0) * + var_b) && + "Expected c^2*d^2*b."); +// pow(((sqrt(b)*c)*d)/a, 2) -> (b*pow(c,2)*pow(d,2))/pow(a,2) + auto divid_sqrt3 = graph::pow(((graph::sqrt(var_b)*var_c)*var_d)/var_a, 2.0); + auto divid_sqrt3_cast = graph::divide_cast(divid_sqrt3); + assert(divid_sqrt3_cast.get() && "Expected a divide node."); + assert(divid_sqrt3_cast->get_right()->is_match(graph::pow(var_a, 2.0)) && + "Expected a^2."); + assert(divid_sqrt3_cast->get_left()->is_match(graph::pow(var_c, 2.0) * + graph::pow(var_d, 2.0) * + var_b) && + "Expected c^2*d^2*b."); +// pow(((sqrt(b)*c)*d)/a, 2) -> (b*pow(c,2)*pow(d,2))/pow(a,2) + auto divid_sqrt4 = graph::pow(((var_c*graph::sqrt(var_b))*var_d)/var_a, 2.0); + auto divid_sqrt4_cast = graph::divide_cast(divid_sqrt4); + assert(divid_sqrt4_cast.get() && "Expected a divide node."); + assert(divid_sqrt4_cast->get_right()->is_match(graph::pow(var_a, 2.0)) && + "Expected a^2."); + assert(divid_sqrt4_cast->get_left()->is_match(graph::pow(var_c, 2.0) * + graph::pow(var_d, 2.0) * + var_b) && + "Expected c^2*d^2*b."); + +// pow(a/(c*(sqrt(b)*d)), 2) -> pow(a,2)/(b*pow(c,2)*pow(d,2)) + auto divid_sqrt5 = graph::pow(var_a/(expr_c*(graph::sqrt(expr_b)*expr_a)), 2.0); + auto divid_sqrt5_cast = graph::divide_cast(divid_sqrt5); + assert(divid_sqrt5_cast.get() && "Expected a divide node."); + assert(divid_sqrt5_cast->get_left()->is_match(graph::pow(var_a, 2.0)) && + "Expected a^2."); + assert(divid_sqrt5_cast->get_right()->is_match(expr_b* + graph::pow(expr_a, 2.0) * + graph::pow(expr_c, 2.0)) && + "Expected b*c^2*d^2."); +// pow(a/(c(d**sqrt(b))), 2) -> pow(a,2)/(b*pow(c,2)*pow(d,2)) + auto divid_sqrt6 = graph::pow(var_a/(expr_c*(expr_a*graph::sqrt(expr_b))), 2.0); + auto divid_sqrt6_cast = graph::divide_cast(divid_sqrt6); + assert(divid_sqrt6_cast.get() && "Expected a divide node."); + assert(divid_sqrt6_cast->get_left()->is_match(graph::pow(var_a, 2.0)) && + "Expected a^2."); + assert(divid_sqrt6_cast->get_right()->is_match(expr_b* + graph::pow(expr_a, 2.0) * + graph::pow(expr_c, 2.0)) && + "Expected b*c^2*d^2."); +// pow(((c*sqrt(b))*d), 2)/a -> pow(a,2)/(b*pow(c,2)*pow(d,2)) + auto divid_sqrt7 = graph::pow((expr_c*(graph::sqrt(expr_b)*expr_a)/var_a), 2.0); + auto divid_sqrt7_cast = graph::divide_cast(divid_sqrt7); + assert(divid_sqrt7_cast.get() && "Expected a divide node."); + assert(divid_sqrt7_cast->get_right()->is_match(graph::pow(var_a, 2.0)) && + "Expected a^2."); + assert(divid_sqrt7_cast->get_left()->is_match(expr_b* + graph::pow(expr_a, 2.0) * + graph::pow(expr_c, 2.0)) && + "Expected b*c^2*d^2."); +// pow(((c*sqrt(b))*d), 2)/a -> pow(a,2)/(b*pow(c,2)*pow(d,2)) + auto divid_sqrt8 = graph::pow((expr_c*(expr_a*graph::sqrt(expr_b))/var_a), 2.0); + auto divid_sqrt8_cast = graph::divide_cast(divid_sqrt8); + assert(divid_sqrt8_cast.get() && "Expected a divide node."); + assert(divid_sqrt8_cast->get_right()->is_match(graph::pow(var_a, 2.0)) && + "Expected a^2."); + assert(divid_sqrt8_cast->get_left()->is_match(expr_b* + graph::pow(expr_a, 2.0) * + graph::pow(expr_c, 2.0)) && + "Expected b*c^2*d^2."); } //------------------------------------------------------------------------------ -- GitLab From 8574c2eead122a382df73a26577bee3efbccb41d Mon Sep 17 00:00:00 2001 From: cianciosa Date: Fri, 31 Jan 2025 16:55:02 -0500 Subject: [PATCH 2/4] Save work in progress. --- graph_framework/arithmetic.hpp | 62 ++++++++++++++++++++++++++++----- graph_framework/node.hpp | 13 +++++++ graph_tests/arithmetic_test.cpp | 59 ++++++++++++++++++++++++++++--- graph_tests/efit_test.cpp | 2 ++ 4 files changed, 123 insertions(+), 13 deletions(-) diff --git a/graph_framework/arithmetic.hpp b/graph_framework/arithmetic.hpp index bfd42cb..7c1002e 100644 --- a/graph_framework/arithmetic.hpp +++ b/graph_framework/arithmetic.hpp @@ -4341,8 +4341,10 @@ namespace graph { } } else if (this->middle->is_all_variables()) { auto rdm = this->right/this->middle; - if (rdm->get_complexity() < this->middle->get_complexity() + - this->right->get_complexity()) { + auto rdmc = constant_cast(rdm->get_power_exponent()); + if ((rdm->get_complexity() < this->middle->get_complexity() + + this->right->get_complexity()) && + !(rdmc.get() && rdmc->evaluate().is_negative())) { return (this->left + rdm)*this->middle; } } @@ -4365,6 +4367,19 @@ namespace graph { return this->left/pow(mp->get_left(), -mp->get_right()) + this->right; } + +// fma(2,a^2,a) -> a*fma(2,a,1) +// Note this case is handled eailer. fma(2,a,a^2) -> a*fma(2,1,a) + if (is_variable_combineable(this->middle, + this->right)) { + auto temp = this->right/this->middle; + auto temp_exponent = constant_cast(temp->get_power_exponent()); + if (temp_exponent.get() && temp_exponent->evaluate().is_negative()) { + return this->right*fma(this->left, + this->middle/this->right, + 1.0); + } + } } // a^b*c^b + d -> (a*c)^b + d @@ -4382,12 +4397,43 @@ namespace graph { if (mplm.get()) { if (is_variable_combineable(mplm->get_left(), rm->get_left())) { - return pow(mplm->get_left(), - mp->get_right()) * - fma(this->left, - pow(mplm->get_right(), - mp->get_right()), - this->right/mplm->get_left()); + auto temp = pow(mplm->get_left(), + mp->get_right()); + return temp*fma(this->left, + this->middle/temp, + this->right/temp); + } else if (is_variable_combineable(mplm->get_right(), + rm->get_left())) { + auto temp = pow(mplm->get_right(), + mp->get_right()); + return temp*fma(this->left, + this->middle/temp, + this->right/temp); + } + } + } +// fma(2,(a*b)^2,fma(a^2,b,c)) -> fma(a^2,fma(2,b^2,b),c) + if (rfma.get() && mp.get()) { + auto mplm = multiply_cast(mp->get_left()); + if (mplm.get()) { + if (is_variable_combineable(mplm->get_left(), + rfma->get_left())) { + auto temp = pow(mplm->get_left(), + mp->get_right()); + return fma(temp, + fma(this->left, + this->middle/temp, + rfma->get_middle()), + rfma->get_right()); + } else if (is_variable_combineable(mplm->get_right(), + rfma->get_left())) { + auto temp = pow(mplm->get_right(), + mp->get_right()); + return fma(temp, + fma(this->left, + this->middle/temp, + rfma->get_middle()), + rfma->get_right()); } } } diff --git a/graph_framework/node.hpp b/graph_framework/node.hpp index 948f529..87b9d85 100644 --- a/graph_framework/node.hpp +++ b/graph_framework/node.hpp @@ -660,6 +660,19 @@ namespace graph { return constant (static_cast (1.0)); } +//------------------------------------------------------------------------------ +/// @brief Create a one constant. +/// +/// @tparam T Base type of the calculation. +/// @tparam SAFE_MATH Use safe math operations. +/// +/// @returns A one constant. +//------------------------------------------------------------------------------ + template + constexpr shared_leaf none() { + return constant (static_cast (-1.0)); + } + /// Convinece type for imaginary constant. template constexpr T i = T(0.0, 1.0); diff --git a/graph_tests/arithmetic_test.cpp b/graph_tests/arithmetic_test.cpp index d161514..0e74e2d 100644 --- a/graph_tests/arithmetic_test.cpp +++ b/graph_tests/arithmetic_test.cpp @@ -2773,7 +2773,7 @@ template void test_fma() { // Test reduction. auto var_a = graph::variable (1, "a"); auto var_b = graph::variable (1, "b"); - auto var_c = graph::variable (1, ""); + auto var_c = graph::variable (1, "c"); // fma(1,a,b) = a + b auto one_times_vara_plus_varb = graph::fma(one, var_a, var_b); @@ -3573,13 +3573,62 @@ template void test_fma() { var_b) + var_d) && "Expected a power node."); -// fma(2,(ab)^2,a^2b) -> a^2*fma(2, b^2, b) +// fma(2,(ab)^2,a^2b) -> a^2*b*fma(2, b, 1) auto commom_power = graph::fma(2.0, graph::pow(var_a*var_b, 2.0), graph::pow(var_a, 2.0)*var_b); - commom_power->to_latex(); - std::cout << std::endl << std::endl; auto commom_power_cast = graph::multiply_cast(commom_power); assert(commom_power_cast.get() && "Expected a multiply node."); -// fma(2,(a*b)^2,fma()) + assert(commom_power_cast->get_right()->is_match(var_b) && + "Expced b"); + assert(commom_power_cast->get_left()->is_match(graph::pow(var_a, 2.0) * + graph::fma(2.0, var_b, 1.0)) && + "Expced a^2*fma(2, b, 1)"); +// fma(2,(ba)^2,a^2b) -> a^2*fma(2, b^2, b) + auto commom_power2 = graph::fma(2.0, graph::pow(var_b*var_a, 2.0), graph::pow(var_a, 2.0)*var_b); + auto commom_power2_cast = graph::multiply_cast(commom_power2); + assert(commom_power2_cast.get() && "Expected a multiply node."); + assert(commom_power2_cast->get_right()->is_match(var_b) && + "Expced b"); + assert(commom_power2_cast->get_left()->is_match(graph::pow(var_a, 2.0) * + graph::fma(2.0, var_b, 1.0)) && + "Expced a^2*fma(2, b, 1)"); +// fma(2,(a*b)^2,fma(a^2,b,c)) -> fma(a^2,fma(2,b^2,b),c) + auto commom_power3 = graph::fma(2.0, + graph::pow(var_a*var_b, 2.0), + graph::fma(graph::pow(var_a, 2.0), + var_b, + var_c)); + auto commom_power3_cast = graph::fma_cast(commom_power3); + assert(commom_power3_cast.get() && "Expected a fma node."); + assert(commom_power3_cast->get_left()->is_match(graph::pow(var_a, 2.0)) && + "Expected a^2"); +// fma(2,(a*b)^2,fma(a^2,b,c)) -> fma(a^2,fma(2,b^2,b),c) + auto commom_power4 = graph::fma(2.0, + graph::pow(var_b*var_a, 2.0), + graph::fma(graph::pow(var_a, 2.0), + var_b, + var_c)); + auto commom_power4_cast = graph::fma_cast(commom_power4); + assert(commom_power4_cast.get() && "Expected a fma node."); + assert(commom_power4_cast->get_left()->is_match(graph::pow(var_a, 2.0)) && + "Expected a^2"); + +// fma(2,a^2,a) -> a*fma(2,a,1) + auto common_power5 = graph::fma(2.0,var_a*var_a,var_a); + auto commom_power5_cast = graph::multiply_cast(common_power5); + assert(commom_power5_cast.get() && "Expected a multiply node."); + assert(commom_power5_cast->get_left()->is_match(graph::fma(2.0,var_a,1.0)) && + "Expected fma(2,a,1)."); + assert(commom_power5_cast->get_right()->is_match(var_a) && + "Expected a."); +// fma(2,a,a^2) -> a*(2 + a) + auto temp = var_a*var_a; + auto common_power6 = graph::fma(2.0,var_a,temp); + auto commom_power6_cast = graph::multiply_cast(common_power6); + assert(commom_power6_cast.get() && "Expected a multiply node."); + assert(commom_power6_cast->get_left()->is_match(2.0 + var_a) && + "Expected (2 + a)."); + assert(commom_power6_cast->get_right()->is_match(var_a) && + "Expected a."); } //------------------------------------------------------------------------------ diff --git a/graph_tests/efit_test.cpp b/graph_tests/efit_test.cpp index c8d749c..6b84399 100644 --- a/graph_tests/efit_test.cpp +++ b/graph_tests/efit_test.cpp @@ -139,6 +139,8 @@ void run_test() { auto bvec = eq->get_magnetic_field(x, y, z); auto ne = eq->get_electron_density(x, y, z); + ne->to_latex(); + std::cout << std::endl << std::endl; auto te = eq->get_electron_temperature(x, y, z); workflow::manager work(0); -- GitLab From 3972c494f5d8c241d2cf426821558cbbbb2f6318 Mon Sep 17 00:00:00 2001 From: cianciosa Date: Mon, 3 Feb 2025 18:01:43 -0500 Subject: [PATCH 3/4] Add reductions of sqrts and powers. Refactor equilibrium to splimify expressions. --- graph_framework/arithmetic.hpp | 52 +++++++++++++++++++ graph_framework/equilibrium.hpp | 24 ++++----- graph_korc/xkorc.cpp | 5 +- graph_tests/arithmetic_test.cpp | 92 ++++++++++++++++++++++++++++++--- graph_tests/efit_test.cpp | 10 ++-- 5 files changed, 155 insertions(+), 28 deletions(-) diff --git a/graph_framework/arithmetic.hpp b/graph_framework/arithmetic.hpp index 7c1002e..efcbb9a 100644 --- a/graph_framework/arithmetic.hpp +++ b/graph_framework/arithmetic.hpp @@ -3753,6 +3753,30 @@ namespace graph { } } +// fma(a,b*c,b*d) -> b*fma(a,c,d) +// fma(a,c*b,b*d) -> b*fma(a,c,d) +// fma(a,b*c,d*b) -> b*fma(a,c,d) +// fma(a,c*b,d*b) -> b*fma(a,c,d) + if (mm.get()) { + if (mm->get_left()->is_match(rm->get_left())) { + return mm->get_left()*fma(this->left, + mm->get_right(), + rm->get_right()); + } else if (mm->get_left()->is_match(rm->get_right())) { + return mm->get_left()*fma(this->left, + mm->get_right(), + rm->get_left()); + } else if (mm->get_right()->is_match(rm->get_left())) { + return mm->get_right()*fma(this->left, + mm->get_left(), + rm->get_right()); + } else if (mm->get_right()->is_match(rm->get_right())) { + return mm->get_right()*fma(this->left, + mm->get_left(), + rm->get_left()); + } + } + // Convert fma(a*b,c,d*e) -> fma(d,e,a*b*c) // Convert fma(a,b*c,d*e) -> fma(d,e,a*b*c) if ((lm.get() || mm.get()) && @@ -3850,6 +3874,19 @@ namespace graph { } } +// fma(a,b*c,b) -> b*fma(a,c,1) + if (mm.get()) { + if (mm->get_left()->is_match(this->right)) { + return mm->get_left()*fma(this->left, + mm->get_right(), + 1.0); + } else if (mm->get_right()->is_match(this->right)) { + return mm->get_right()*fma(this->left, + mm->get_left(), + 1.0); + } + } + // fma(c1,a,c2/b) -> c1*(a + c3/b) // fma(a,c1,c2/b) -> c1*(a + c3/b) auto rd = divide_cast(this->right); @@ -4435,6 +4472,21 @@ namespace graph { rfma->get_middle()), rfma->get_right()); } + +// fma(2,(a*b)^2,fma(3,a^2*b,c)) -> a^2*fma(2,b^2,fma(3,b,c)) + auto rfmamm = multiply_cast(rfma->get_middle()); + if (rfmamm.get()) { + if (is_variable_combineable(mplm->get_left(), + rfmamm->get_left())) { + auto temp = pow(mplm->get_left(), + mp->get_right()); + return temp*fma(this->left, + this->middle/temp, + fma(rfma->get_left(), + rfma->get_middle()/temp, + rfma->get_right())); + } + } } } diff --git a/graph_framework/equilibrium.hpp b/graph_framework/equilibrium.hpp index e97c5c5..779edda 100644 --- a/graph_framework/equilibrium.hpp +++ b/graph_framework/equilibrium.hpp @@ -1073,18 +1073,18 @@ namespace equilibrium { + c01_temp*z_norm + c02_temp*(z_norm*z_norm) + c03_temp*(z_norm*z_norm*z_norm) - + c10_temp*r_norm - + c11_temp*r_norm*z_norm - + c12_temp*r_norm*(z_norm*z_norm) - + c13_temp*r_norm*(z_norm*z_norm*z_norm) - + c20_temp*(r_norm*r_norm) - + c21_temp*(r_norm*r_norm)*z_norm - + c22_temp*(r_norm*r_norm)*(z_norm*z_norm) - + c23_temp*(r_norm*r_norm)*(z_norm*z_norm*z_norm) - + c30_temp*(r_norm*r_norm*r_norm) - + c31_temp*(r_norm*r_norm*r_norm)*z_norm - + c32_temp*(r_norm*r_norm*r_norm)*(z_norm*z_norm) - + c33_temp*(r_norm*r_norm*r_norm)*(z_norm*z_norm*z_norm); + + r_norm*(c10_temp + + c11_temp*z_norm + + c12_temp*(z_norm*z_norm) + + c13_temp*(z_norm*z_norm*z_norm)) + + (r_norm*r_norm)*(c20_temp + + c21_temp*z_norm + + c22_temp*(z_norm*z_norm) + + c23_temp*(z_norm*z_norm*z_norm)) + + (r_norm*r_norm*r_norm)*(c30_temp + + c31_temp*z_norm + + c32_temp*(z_norm*z_norm) + + c33_temp*(z_norm*z_norm*z_norm)); } //------------------------------------------------------------------------------ diff --git a/graph_korc/xkorc.cpp b/graph_korc/xkorc.cpp index 9c366e3..3c721ab 100644 --- a/graph_korc/xkorc.cpp +++ b/graph_korc/xkorc.cpp @@ -143,11 +143,11 @@ void run_korc() { const timeing::measure_diagnostic t_run("Run Time"); work.pre_run(); for (size_t i = 0; i < 1000000; i++) { - sync.join(); +/* sync.join(); work.wait(); sync = std::thread([&file, &dataset] () -> void { dataset.write(file); - }); + });*/ work.run(); } @@ -181,6 +181,7 @@ int main(int argc, const char * argv[]) { (void)argv; run_korc (); +// run_korc (); END_GPU } diff --git a/graph_tests/arithmetic_test.cpp b/graph_tests/arithmetic_test.cpp index 0e74e2d..2565987 100644 --- a/graph_tests/arithmetic_test.cpp +++ b/graph_tests/arithmetic_test.cpp @@ -2822,7 +2822,7 @@ template void test_fma() { "Expected common var_b"); // fma(a, b, fma(c, b, d)) -> fma(b, a + c, d) - auto var_d = graph::variable (1, ""); + auto var_d = graph::variable (1, "d"); auto match1 = graph::fma(var_b, var_a + var_c, var_d); auto nested_fma1 = graph::fma(var_a, var_b, graph::fma(var_c, var_b, var_d)); @@ -3578,11 +3578,11 @@ template void test_fma() { auto commom_power_cast = graph::multiply_cast(commom_power); assert(commom_power_cast.get() && "Expected a multiply node."); assert(commom_power_cast->get_right()->is_match(var_b) && - "Expced b"); + "Expeced b"); assert(commom_power_cast->get_left()->is_match(graph::pow(var_a, 2.0) * graph::fma(2.0, var_b, 1.0)) && - "Expced a^2*fma(2, b, 1)"); -// fma(2,(ba)^2,a^2b) -> a^2*fma(2, b^2, b) + "Expeced a^2*fma(2, b, 1)"); +// fma(2,(ba)^2,a^2b) -> a^2*b*fma(2, b, 1) auto commom_power2 = graph::fma(2.0, graph::pow(var_b*var_a, 2.0), graph::pow(var_a, 2.0)*var_b); auto commom_power2_cast = graph::multiply_cast(commom_power2); assert(commom_power2_cast.get() && "Expected a multiply node."); @@ -3591,7 +3591,7 @@ template void test_fma() { assert(commom_power2_cast->get_left()->is_match(graph::pow(var_a, 2.0) * graph::fma(2.0, var_b, 1.0)) && "Expced a^2*fma(2, b, 1)"); -// fma(2,(a*b)^2,fma(a^2,b,c)) -> fma(a^2,fma(2,b^2,b),c) +// fma(2,(a*b)^2,fma(a^2,b,c)) -> fma(a^2*b,fma(2,b,1),c) auto commom_power3 = graph::fma(2.0, graph::pow(var_a*var_b, 2.0), graph::fma(graph::pow(var_a, 2.0), @@ -3601,7 +3601,12 @@ template void test_fma() { assert(commom_power3_cast.get() && "Expected a fma node."); assert(commom_power3_cast->get_left()->is_match(graph::pow(var_a, 2.0)) && "Expected a^2"); -// fma(2,(a*b)^2,fma(a^2,b,c)) -> fma(a^2,fma(2,b^2,b),c) + assert(commom_power3_cast->get_middle()->is_match(var_b*graph::fma(2.0, + var_b, + 1.0)) && + "Expected b*fma(2,b,1)"); + assert(commom_power3_cast->get_right()->is_match(var_c) && "Expected c"); +// fma(2,(a*b)^2,fma(a^2,b,c)) -> fma(a^2*b,fma(2,b,1),c) auto commom_power4 = graph::fma(2.0, graph::pow(var_b*var_a, 2.0), graph::fma(graph::pow(var_a, 2.0), @@ -3611,6 +3616,11 @@ template void test_fma() { assert(commom_power4_cast.get() && "Expected a fma node."); assert(commom_power4_cast->get_left()->is_match(graph::pow(var_a, 2.0)) && "Expected a^2"); + assert(commom_power4_cast->get_middle()->is_match(var_b*graph::fma(2.0, + var_b, + 1.0)) && + "Expected b*fma(2,b,1)"); + assert(commom_power4_cast->get_right()->is_match(var_c) && "Expected c"); // fma(2,a^2,a) -> a*fma(2,a,1) auto common_power5 = graph::fma(2.0,var_a*var_a,var_a); @@ -3621,14 +3631,80 @@ template void test_fma() { assert(commom_power5_cast->get_right()->is_match(var_a) && "Expected a."); // fma(2,a,a^2) -> a*(2 + a) - auto temp = var_a*var_a; - auto common_power6 = graph::fma(2.0,var_a,temp); + auto common_power6 = graph::fma(2.0,var_a,var_a*var_a); auto commom_power6_cast = graph::multiply_cast(common_power6); assert(commom_power6_cast.get() && "Expected a multiply node."); assert(commom_power6_cast->get_left()->is_match(2.0 + var_a) && "Expected (2 + a)."); assert(commom_power6_cast->get_right()->is_match(var_a) && "Expected a."); + +// fma(2,(a*b)^2,fma(3,a^2*b,c)) -> fma(a^2*b,fma(2,b,3),c) + auto common_power7 = graph::fma(2.0, + graph::pow(var_a*var_b, + 2.0), + graph::fma(3.0, + var_a*var_a*var_b, + var_c)); + auto common_power7_cast = graph::multiply_cast(common_power7); + assert(common_power7_cast.get() && "Expected a multiply node."); + assert(common_power7_cast->get_left()->is_match(graph::fma(var_b, + graph::fma(2.0, + var_b, + 3.0), + var_c)) && + "Expected fma(b,fma(2,b,3),c)"); + assert(common_power7_cast->get_right()->is_match(var_a*var_a) && + "Expected a^2"); + +// fma(a,b*c,b) -> b*fma(a,c,1) + auto factorize = graph::fma(var_a,var_b*var_c,var_b); + auto factorize_cast = multiply_cast(factorize); + assert(factorize_cast.get() && "Expected a multiply node."); + assert(factorize_cast->get_right()->is_match(var_b) && + "Expected b."); + assert(factorize_cast->get_left()->is_match(graph::fma(var_a,var_c,1.0)) && + "Expected a*c + 1."); +// fma(a,c*b,b) -> b*fma(a,c,1) + auto factorize2 = graph::fma(var_a,var_c*var_b,var_b); + auto factorize2_cast = multiply_cast(factorize2); + assert(factorize2_cast.get() && "Expected a multiply node."); + assert(factorize2_cast->get_right()->is_match(var_b) && + "Expected b."); + assert(factorize2_cast->get_left()->is_match(graph::fma(var_a,var_c,1.0)) && + "Expected a*c + 1."); +// fma(a,b*c,b*d) -> b*fma(a,c,d) + auto factorize3 = graph::fma(var_a,var_b*var_c,var_b*var_d); + auto factorize3_cast = multiply_cast(factorize3); + assert(factorize3_cast.get() && "Expected a multiply node."); + assert(factorize3_cast->get_right()->is_match(var_b) && + "Expected b."); + assert(factorize3_cast->get_left()->is_match(graph::fma(var_a,var_c,var_d)) && + "Expected a*c + d."); +// fma(a,c*b,b*d) -> b*fma(a,c,d) + auto factorize4 = graph::fma(var_a,var_c*var_b,var_b*var_d); + auto factorize4_cast = multiply_cast(factorize4); + assert(factorize4_cast.get() && "Expected a multiply node."); + assert(factorize4_cast->get_right()->is_match(var_b) && + "Expected b."); + assert(factorize4_cast->get_left()->is_match(graph::fma(var_a,var_c,var_d)) && + "Expected a*c + d."); +// fma(a,b*c,d*b) -> b*fma(a,c,d) + auto factorize5 = graph::fma(var_a,var_b*var_c,var_d*var_b); + auto factorize5_cast = multiply_cast(factorize5); + assert(factorize5_cast.get() && "Expected a multiply node."); + assert(factorize5_cast->get_right()->is_match(var_b) && + "Expected b."); + assert(factorize5_cast->get_left()->is_match(graph::fma(var_a,var_c,var_d)) && + "Expected a*c + d."); +// fma(a,c*b,d*b) -> b*fma(a,c,d) + auto factorize6 = graph::fma(var_a,var_c*var_b,var_d*var_b); + auto factorize6_cast = multiply_cast(factorize6); + assert(factorize6_cast.get() && "Expected a multiply node."); + assert(factorize6_cast->get_right()->is_match(var_b) && + "Expected b."); + assert(factorize6_cast->get_left()->is_match(graph::fma(var_a,var_c,var_d)) && + "Expected a*c + d."); } //------------------------------------------------------------------------------ diff --git a/graph_tests/efit_test.cpp b/graph_tests/efit_test.cpp index 6b84399..4b16b5e 100644 --- a/graph_tests/efit_test.cpp +++ b/graph_tests/efit_test.cpp @@ -139,8 +139,6 @@ void run_test() { auto bvec = eq->get_magnetic_field(x, y, z); auto ne = eq->get_electron_density(x, y, z); - ne->to_latex(); - std::cout << std::endl << std::endl; auto te = eq->get_electron_temperature(x, y, z); workflow::manager work(0); @@ -155,15 +153,15 @@ void run_test() { work.run(); for (size_t i = 0, ie = gold.r_grid.size()*gold.z_grid.size(); i < ie; i++) { - check_error(work.check_value(i, bvec->get_x()), gold.bx_grid[i], 4.0E-11, + check_error(work.check_value(i, bvec->get_x()), gold.bx_grid[i], 9.0E-12, "Expected a match in bx."); check_error(work.check_value(i, bvec->get_y()), gold.by_grid[i], 1.0E-20, "Expected a match in by."); - check_error(work.check_value(i, bvec->get_z()), gold.bz_grid[i], 3.0E-12, + check_error(work.check_value(i, bvec->get_z()), gold.bz_grid[i], 4.0E-12, "Expected a match in bz."); - check_error(work.check_value(i, ne), gold.ne_grid[i], 5.0E-13, + check_error(work.check_value(i, ne), gold.ne_grid[i], 8.0E-13, "Expected a match in ne."); - check_error(work.check_value(i, te), gold.te_grid[i], 5.0E-13, + check_error(work.check_value(i, te), gold.te_grid[i], 8.0E-13, "Expected a match in te."); } } -- GitLab From 196c46cc664f787197d006a7a13a421a19b3d08e Mon Sep 17 00:00:00 2001 From: cianciosa Date: Tue, 4 Feb 2025 11:04:05 -0500 Subject: [PATCH 4/4] Refactor splines to avoid powers. --- graph_framework/equilibrium.hpp | 72 +++++++++++---------------------- graph_tests/efit_test.cpp | 2 +- 2 files changed, 25 insertions(+), 49 deletions(-) diff --git a/graph_framework/equilibrium.hpp b/graph_framework/equilibrium.hpp index 779edda..fe51749 100644 --- a/graph_framework/equilibrium.hpp +++ b/graph_framework/equilibrium.hpp @@ -1069,22 +1069,12 @@ namespace equilibrium { auto c32_temp = graph::piecewise_2D(c32, num_cols, r_norm, z_norm); auto c33_temp = graph::piecewise_2D(c33, num_cols, r_norm, z_norm); - return c00_temp - + c01_temp*z_norm - + c02_temp*(z_norm*z_norm) - + c03_temp*(z_norm*z_norm*z_norm) - + r_norm*(c10_temp + - c11_temp*z_norm + - c12_temp*(z_norm*z_norm) + - c13_temp*(z_norm*z_norm*z_norm)) - + (r_norm*r_norm)*(c20_temp + - c21_temp*z_norm + - c22_temp*(z_norm*z_norm) + - c23_temp*(z_norm*z_norm*z_norm)) - + (r_norm*r_norm*r_norm)*(c30_temp + - c31_temp*z_norm + - c32_temp*(z_norm*z_norm) + - c33_temp*(z_norm*z_norm*z_norm)); + auto c0 = ((c03_temp*z_norm + c02_temp)*z_norm + c01_temp)*z_norm + c00_temp; + auto c1 = ((c13_temp*z_norm + c12_temp)*z_norm + c11_temp)*z_norm + c10_temp; + auto c2 = ((c23_temp*z_norm + c22_temp)*z_norm + c21_temp)*z_norm + c20_temp; + auto c3 = ((c33_temp*z_norm + c32_temp)*z_norm + c31_temp)*z_norm + c30_temp; + + return ((c3*r_norm + c2)*r_norm + c1)*r_norm + c0; } //------------------------------------------------------------------------------ @@ -1118,30 +1108,27 @@ namespace equilibrium { auto n2_temp = graph::piecewise_1D(ne_c2, psi_norm_cache); auto n3_temp = graph::piecewise_1D(ne_c3, psi_norm_cache); - ne_cache = ne_scale*(n0_temp + - n1_temp*psi_norm_cache + - n2_temp*psi_norm_cache*psi_norm_cache + - n3_temp*psi_norm_cache*psi_norm_cache*psi_norm_cache); + ne_cache = ne_scale + * (((n3_temp*psi_norm_cache + n2_temp) * + psi_norm_cache + n1_temp)*psi_norm_cache + n0_temp); auto t0_temp = graph::piecewise_1D(te_c0, psi_norm_cache); auto t1_temp = graph::piecewise_1D(te_c1, psi_norm_cache); auto t2_temp = graph::piecewise_1D(te_c2, psi_norm_cache); auto t3_temp = graph::piecewise_1D(te_c3, psi_norm_cache); - te_cache = te_scale*(t0_temp + - t1_temp*psi_norm_cache + - t2_temp*psi_norm_cache*psi_norm_cache + - t3_temp*psi_norm_cache*psi_norm_cache*psi_norm_cache); + te_cache = te_scale + * (((t3_temp*psi_norm_cache + t2_temp) * + psi_norm_cache + t1_temp)*psi_norm_cache + t0_temp); auto p0_temp = graph::piecewise_1D(pres_c0, psi_norm_cache); auto p1_temp = graph::piecewise_1D(pres_c1, psi_norm_cache); auto p2_temp = graph::piecewise_1D(pres_c2, psi_norm_cache); auto p3_temp = graph::piecewise_1D(pres_c3, psi_norm_cache); - auto pressure = pres_scale*(p0_temp + - p1_temp*psi_norm_cache + - p2_temp*psi_norm_cache*psi_norm_cache + - p3_temp*psi_norm_cache*psi_norm_cache*psi_norm_cache); + auto pressure = pres_scale + * (((p3_temp*psi_norm_cache + p2_temp) * + psi_norm_cache + p1_temp)*psi_norm_cache + p0_temp); auto q = graph::constant (static_cast (1.60218E-19)); @@ -1157,10 +1144,8 @@ namespace equilibrium { auto b2_temp = graph::piecewise_1D(fpol_c2, r_norm); auto b3_temp = graph::piecewise_1D(fpol_c3, r_norm); - auto bp = (b0_temp + - b1_temp*r_norm + - b2_temp*r_norm*r_norm + - b3_temp*r_norm*r_norm*r_norm)/r; + auto bp = (((b3_temp*r_norm + b2_temp) * + r_norm + b1_temp)*r_norm + b0_temp)/r; auto bz = -psi->df(r)/r; @@ -1835,10 +1820,7 @@ namespace equilibrium { auto c2_temp = graph::piecewise_1D(chi_c2, s_norm); auto c3_temp = graph::piecewise_1D(chi_c3, s_norm); - return c0_temp + - c1_temp*s_norm + - c2_temp*s_norm*s_norm + - c3_temp*s_norm*s_norm*s_norm; + return ((c3_temp*s_norm + c2_temp)*s_norm + c1_temp)*s_norm + c0_temp; } //------------------------------------------------------------------------------ @@ -1895,18 +1877,12 @@ namespace equilibrium { auto lmns_c2_temp = graph::piecewise_1D(lmns_c2[i], s_norm_h); auto lmns_c3_temp = graph::piecewise_1D(lmns_c3[i], s_norm_h); - auto rmnc = rmnc_c0_temp - + rmnc_c1_temp*s_norm_f - + rmnc_c2_temp*s_norm_f*s_norm_f - + rmnc_c3_temp*s_norm_f*s_norm_f*s_norm_f; - auto zmns = zmns_c0_temp - + zmns_c1_temp*s_norm_f - + zmns_c2_temp*s_norm_f*s_norm_f - + zmns_c3_temp*s_norm_f*s_norm_f*s_norm_f; - auto lmns = lmns_c0_temp - + lmns_c1_temp*s_norm_h - + lmns_c2_temp*s_norm_h*s_norm_h - + lmns_c3_temp*s_norm_h*s_norm_h*s_norm_h; + auto rmnc = ((rmnc_c3_temp*s_norm_f + rmnc_c2_temp)*s_norm_f + + rmnc_c1_temp)*s_norm_f + rmnc_c0_temp; + auto zmns = ((zmns_c3_temp*s_norm_f + zmns_c2_temp)*s_norm_f + + zmns_c1_temp)*s_norm_f + zmns_c0_temp; + auto lmns = ((lmns_c3_temp*s_norm_h + lmns_c2_temp)*s_norm_h + + lmns_c1_temp)*s_norm_h + lmns_c0_temp; auto m = graph::constant (xm[i]); auto n = graph::constant (xn[i]); diff --git a/graph_tests/efit_test.cpp b/graph_tests/efit_test.cpp index 4b16b5e..db8956a 100644 --- a/graph_tests/efit_test.cpp +++ b/graph_tests/efit_test.cpp @@ -153,7 +153,7 @@ void run_test() { work.run(); for (size_t i = 0, ie = gold.r_grid.size()*gold.z_grid.size(); i < ie; i++) { - check_error(work.check_value(i, bvec->get_x()), gold.bx_grid[i], 9.0E-12, + check_error(work.check_value(i, bvec->get_x()), gold.bx_grid[i], 10.0E-11, "Expected a match in bx."); check_error(work.check_value(i, bvec->get_y()), gold.by_grid[i], 1.0E-20, "Expected a match in by."); -- GitLab