From 22ab0b6f07e9c37d74b973e4180952469f41c631 Mon Sep 17 00:00:00 2001 From: cianciosa Date: Tue, 11 Feb 2025 18:07:07 -0500 Subject: [PATCH] Combine and reduce more constants in nested fma nodes. This reduces complexity of splines. --- graph_framework.xcodeproj/project.pbxproj | 4 + graph_framework/arithmetic.hpp | 146 +++++++++++++++++++++- graph_framework/equilibrium.hpp | 2 +- graph_tests/arithmetic_test.cpp | 72 +++++++++++ graph_tests/efit_test.cpp | 6 +- graph_tests/piecewise_test.cpp | 28 +++-- 6 files changed, 239 insertions(+), 19 deletions(-) diff --git a/graph_framework.xcodeproj/project.pbxproj b/graph_framework.xcodeproj/project.pbxproj index 794b6ae..84ebb0b 100644 --- a/graph_framework.xcodeproj/project.pbxproj +++ b/graph_framework.xcodeproj/project.pbxproj @@ -1363,6 +1363,7 @@ "-lLLVMAArch64CodeGen", "-lLLVMCGData", "-lLLVMSandboxIR", + "-lLLVMObjectYAML", "-lLLVMFrontendAtomic", "-lclangFrontend", "-lclangBasic", @@ -1468,6 +1469,7 @@ "-lLLVMAArch64CodeGen", "-lLLVMCGData", "-lLLVMSandboxIR", + "-lLLVMObjectYAML", "-lLLVMFrontendAtomic", "-lclangFrontend", "-lclangBasic", @@ -1848,6 +1850,7 @@ "-lLLVMCodeGenData", "-lLLVMCGData", "-lLLVMSandboxIR", + "-lLLVMObjectYAML", "-lLLVMFrontendAtomic", "-lclangFrontend", "-lclangBasic", @@ -1947,6 +1950,7 @@ "-lLLVMCodeGenData", "-lLLVMCGData", "-lLLVMSandboxIR", + "-lLLVMObjectYAML", "-lLLVMFrontendAtomic", "-lclangFrontend", "-lclangBasic", diff --git a/graph_framework/arithmetic.hpp b/graph_framework/arithmetic.hpp index efcbb9a..e6efea1 100644 --- a/graph_framework/arithmetic.hpp +++ b/graph_framework/arithmetic.hpp @@ -1903,6 +1903,33 @@ namespace graph { if (is_variable_promotable(rm->get_right(), this->left)) { return (this->left*rm->get_left())*rm->get_right(); } + + auto rmlfma = fma_cast(rm->get_left()); + if (rmlfma.get()) { + if (is_constant_combineable(this->left, + rmlfma->get_left()) && + is_constant_combineable(this->left, + rmlfma->get_right())) { + return fma(this->left*rmlfma->get_left(), + rmlfma->get_middle(), + this->left*rmlfma->get_right())*rm->get_right(); + } + + auto rmlfmalfma = fma_cast(rmlfma->get_left()); + if (rmlfmalfma.get()) { + if (is_constant_combineable(this->left, + rmlfmalfma->get_left()) && + is_constant_combineable(this->left, + rmlfmalfma->get_right()) && + is_constant_combineable(this->left, rmlfma->get_right())) { + return fma(fma(this->left*rmlfmalfma->get_left(), + rmlfmalfma->get_middle(), + this->left*rmlfmalfma->get_right()), + rmlfma->get_middle(), + this->left*rmlfma->get_right())*rm->get_right(); + } + } + } } // v1*(c*v2) -> c*(v1*v2) @@ -2299,6 +2326,79 @@ namespace graph { } } +// c3*fma(c1,a,c2) -> fma(c4,a,c5) + auto rfma = fma_cast(this->right); + if (rfma.get()) { + if (is_constant_combineable(this->left, rfma->get_left()) && + is_constant_combineable(this->left, rfma->get_right())) { + return fma(this->left*rfma->get_left(), + rfma->get_middle(), + this->left*rfma->get_right()); + } + + auto rfmalfma = fma_cast(rfma->get_left()); + if (rfmalfma.get()) { + if (is_constant_combineable(this->left, rfmalfma->get_left()) && + is_constant_combineable(this->left, rfmalfma->get_right()) && + is_constant_combineable(this->left, rfma->get_right())) { + return fma(fma(this->left*rfmalfma->get_left(), + rfmalfma->get_middle(), + this->left*rfmalfma->get_right()), + rfma->get_middle(), + this->left*rfma->get_right()); + } + + auto rfmalfmalfma = fma_cast(rfmalfma->get_left()); + if (rfmalfmalfma.get()) { + if (is_constant_combineable(this->left, rfmalfmalfma->get_left()) && + is_constant_combineable(this->left, rfmalfmalfma->get_right()) && + is_constant_combineable(this->left, rfmalfma->get_right()) && + is_constant_combineable(this->left, rfma->get_right())) { + return fma(fma(fma(this->left*rfmalfmalfma->get_left(), + rfmalfmalfma->get_middle(), + this->left*rfmalfmalfma->get_right()), + rfmalfma->get_middle(), + this->left*rfmalfma->get_right()), + rfma->get_middle(), + this->left*rfma->get_right()); + } + } + } + } + +// fma(c1,x,c2)*(c3 + x) -> fma(fma(c1,x,c4),x,c5) + auto lfma = fma_cast(this->left); + auto ra = add_cast(this->right); + if (lfma.get() && ra.get()) { + if (ra->get_right()->is_match(lfma->get_middle()) && + is_constant_combineable(ra->get_left(), lfma->get_left()) && + is_constant_combineable(ra->get_left(), lfma->get_right())) { + return fma(fma(lfma->get_left(), + ra->get_right(), + ra->get_left()*lfma->get_left() + lfma->get_right()), + ra->get_right(), + lfma->get_right()*ra->get_left()); + } + +// fma(fma(c1,x,c2),x,c3)*(c4 + x) -> fma(fma(fma(c1,x,c5),x,c6),x,c7) + auto lfmalfma = fma_cast(lfma->get_left()); + if (ra->get_right()->is_match(lfma->get_middle()) && + ra->get_right()->is_match(lfmalfma->get_middle()) && + is_constant_combineable(ra->get_left(), lfma->get_right()) && + is_constant_combineable(ra->get_left(), lfmalfma->get_right()) && + is_constant_combineable(ra->get_left(), lfmalfma->get_left())) { + return fma(fma(fma(lfmalfma->get_left(), + ra->get_right(), + ra->get_left()*lfmalfma->get_left() + + lfmalfma->get_right()), + ra->get_right(), + ra->get_left()*lfmalfma->get_right() + + lfma->get_right()), + ra->get_right(), + ra->get_left()*lfma->get_right()); + } + } + // Cases like // (c/exp(a))*(exp(b)/d) -> (c/d)*(exp(b)/exp(a)) // (c/exp(a))*(d/exp(b)) -> (c*e)/(exp(b)*exp(a)) @@ -3673,13 +3773,53 @@ namespace graph { } } -// fma(c1,c2 - a,c3) -> c4 - c5*a +// fma(c1,c2 - a,c3) -> fma(-c1,a,c1*c2 + c3) +// fma(c1,a - c2,c3) -> fma(c1,a,c3 - c1*c2) auto ms = subtract_cast(this->middle); if (ms.get()) { if (is_constant_combineable(this->left, ms->get_left()) && is_constant_combineable(this->left, this->right)) { - return fma(this->left, ms->get_left(), this->right) - - this->left*ms->get_right(); + return fma(-this->left, ms->get_right(), + this->left*ms->get_left() + this->right); + } else if (is_constant_combineable(this->left, ms->get_right()) && + is_constant_combineable(this->left, this->right)) { + return fma(this->left, ms->get_left(), + this->right - this->left*ms->get_right()); + } + + auto lfma = fma_cast(this->left); + if (lfma.get()) { + if (is_constant_combineable(ms->get_right(), lfma->get_left()) && + is_constant_combineable(ms->get_right(), lfma->get_right()) && + is_constant_combineable(this->right, lfma->get_right()) && + lfma->get_middle()->is_match(ms->get_left())) { + return fma(fma(lfma->get_left(), + ms->get_left(), + lfma->get_right() - lfma->get_left()*ms->get_right()), + ms->get_left(), + this->right - lfma->get_right()*ms->get_right()); + } + + auto lfmalfma = fma_cast(lfma->get_left()); + if (lfmalfma.get()) { + if (lfma->get_middle()->is_match(ms->get_left()) && + lfmalfma->get_middle()->is_match(ms->get_left()) && + is_constant_combineable(ms->get_right(), lfmalfma->get_left()) && + is_constant_combineable(ms->get_right(), lfmalfma->get_right()) && + is_constant_combineable(ms->get_right(), lfma->get_right()) && + is_constant_combineable(ms->get_right(), this->right)) { + return fma(fma(fma(lfmalfma->get_left(), + ms->get_left(), + lfmalfma->get_right() - + lfmalfma->get_left()*ms->get_right()), + ms->get_left(), + lfma->get_right() - + lfmalfma->get_right()*ms->get_right()), + ms->get_left(), + this->right - + lfma->get_right()*ms->get_right()); + } + } } } diff --git a/graph_framework/equilibrium.hpp b/graph_framework/equilibrium.hpp index fe51749..4fd4621 100644 --- a/graph_framework/equilibrium.hpp +++ b/graph_framework/equilibrium.hpp @@ -1143,7 +1143,7 @@ namespace equilibrium { auto b1_temp = graph::piecewise_1D(fpol_c1, r_norm); auto b2_temp = graph::piecewise_1D(fpol_c2, r_norm); auto b3_temp = graph::piecewise_1D(fpol_c3, r_norm); - + auto bp = (((b3_temp*r_norm + b2_temp) * r_norm + b1_temp)*r_norm + b0_temp)/r; diff --git a/graph_tests/arithmetic_test.cpp b/graph_tests/arithmetic_test.cpp index 2565987..feb0ff2 100644 --- a/graph_tests/arithmetic_test.cpp +++ b/graph_tests/arithmetic_test.cpp @@ -1925,6 +1925,48 @@ template void test_multiply() { assert(gather_power5_cast->get_right()->is_match(graph::pow(var_a*var_b, var_c)) && "Expected (a*b)^c."); + +// c3*fma(c1,a,c2) -> fma(c4,a,c5) + auto constant_reduction = 0.25*fma(2.0, v1, 3.0); + assert(constant_reduction->is_match(fma(2.0*0.25, v1, 3.0*0.25)) && + "Expected (0.5*a + 0.75)"); +// c3*(fma(c1,a,c2)*b) -> fma(c4,a,c5)*b + auto constant_reduction2 = 0.25*(fma(2.0, v1, 3.0)*v2); + assert(constant_reduction2->is_match(fma(2.0*0.25, v1, 3.0*0.25)*v2) && + "Expected (0.5*a + 0.75)*b"); +// c1*(fma(c2,a,c3)*b + c4) -> fma(c5,a,c6)*b + c7 + auto constant_reduction3 = 0.25*(fma(fma(2.0, v1, 3.0),v2,2.0)); + assert(constant_reduction3->is_match(fma(fma(2.0*0.25, v1, 3.0*0.25),v2,0.5)) && + "Expected (0.5*a + 0.75)*b + 0.5"); +// c1*((fma(c2,a,c3)*b + c4)*c) -> (fma(c5,a,c6)*b + c7)*c + auto constant_reduction4 = 0.25*(fma(fma(2.0, v1, 3.0),v2,2.0)*v1); + assert(constant_reduction4->is_match(fma(fma(2.0*0.25, v1, 3.0*0.25),v2,0.5)*v1) && + "Expected ((0.5*a + 0.75)*b + 0.5)*c"); + +// fma(c1,x,c2)*(c3 + x) -> fma(fma(c1,x,c4),x,c5) + auto expand = graph::fma(0.2, v1, 3.0)*(4.0 + v1); + assert(expand->is_match(graph::fma(graph::fma(0.2, v1, 3.8), v1, 12.0))); +// fma(fma(c1,x,c2),x,c3)*(c4 + x) -> fma(fma(fma(c1,x,c5),x,c6),x,c7) + auto expand2 = graph::fma(fma(0.2,v1,2.3),v1,3.0)*(4.0 + v1); + assert(expand2->is_match(graph::fma(graph::fma(graph::fma(0.2, + v1, + 0.2*4.0 + 2.3), + v1, + 12.2), + v1, + 12.0)) && + "Exptected (((0.2*x + 3.1)*x + 12.2)*x + 12"); + +// c1*fma(fma(fma(c2,x,c3),x,c4),x,c5) -> fma(fma(fma(c6,x,c7),x,c8),x,c9) + auto consume = 10.0*(graph::fma(graph::fma(graph::fma(0.2,v1,2.3),v1,3.0),v1,0.1)); + assert(consume->is_match(graph::fma(graph::fma(graph::fma(2.0, + v1, + 23.0), + v1, + 30.0), + v1, + 1.0)) && + "Expected fma(fma(fma(2,x,23),x,30,x,1))"); } //------------------------------------------------------------------------------ @@ -3705,6 +3747,36 @@ template void test_fma() { "Expected b."); assert(factorize6_cast->get_left()->is_match(graph::fma(var_a,var_c,var_d)) && "Expected a*c + d."); + +// fma(c1,a - c2,c3) -> fma(c1,a,c4) + auto consume = graph::fma(2.0,var_a - 3.0,20.0); + assert(consume->is_match(graph::fma(2.0,var_a,14.0)) && + "Expected fma(2,x,14)"); +// fma(c1,c2 - a,c3) -> fma(-c1,a,c4) + auto consume2 = graph::fma(2.0,3.0 - var_a,20.0); + assert(consume2->is_match(graph::fma(-2.0,var_a,26.0)) && + "Expected fma(-2,x,26)"); + +// fma(fma(c1,a,c2),a - c3,c4) -> fma(fma(c1,x,c5),x,c6) + auto gather = graph::fma(graph::fma(2.0,var_a,20.0),var_a - 2.0,30.0); + assert(gather->is_match(graph::fma(graph::fma(2.0,var_a,16.0),var_a,-10.0)) && + "Expected fma(fma(2,x,16),x,-10)"); +// fma(fma(fma(c1,a,c2),a,c3),a - c4,c5) -> fma(fma(c1,x,c6),x,c6),x,c8) + auto gather2 = graph::fma(graph::fma(graph::fma(2.0, + var_a, + 20.0), + var_a, + 30.0), + var_a - 2.0, + 50.0); + assert(gather2->is_match(graph::fma(graph::fma(graph::fma(2.0, + var_a, + 16.0), + var_a, + -10.0), + var_a, + -10.0)) && + "Expected fma(fma(fma(2,x,16),x,-10),x,-10)"); } //------------------------------------------------------------------------------ diff --git a/graph_tests/efit_test.cpp b/graph_tests/efit_test.cpp index db8956a..c565e9b 100644 --- a/graph_tests/efit_test.cpp +++ b/graph_tests/efit_test.cpp @@ -157,11 +157,11 @@ void run_test() { "Expected a match in bx."); check_error(work.check_value(i, bvec->get_y()), gold.by_grid[i], 1.0E-20, "Expected a match in by."); - check_error(work.check_value(i, bvec->get_z()), gold.bz_grid[i], 4.0E-12, + check_error(work.check_value(i, bvec->get_z()), gold.bz_grid[i], 5.0E-12, "Expected a match in bz."); - check_error(work.check_value(i, ne), gold.ne_grid[i], 8.0E-13, + check_error(work.check_value(i, ne), gold.ne_grid[i], 2.1E-12, "Expected a match in ne."); - check_error(work.check_value(i, te), gold.te_grid[i], 8.0E-13, + check_error(work.check_value(i, te), gold.te_grid[i], 2.1E-12, "Expected a match in te."); } } diff --git a/graph_tests/piecewise_test.cpp b/graph_tests/piecewise_test.cpp index e6927b4..b2e7b6c 100644 --- a/graph_tests/piecewise_test.cpp +++ b/graph_tests/piecewise_test.cpp @@ -241,12 +241,14 @@ template void piecewise_1D() { "Expected p1 + p3 on the right."); // fma(p1,c1 - a,p2) -> p3 - p1*a auto fma_combine2 = fma(p1,1.0 - a,p3); - auto fma_combine2_cast = graph::subtract_cast(fma_combine2); - assert(fma_combine2_cast.get() && "Expected an subtract node."); - assert(fma_combine2_cast->get_right()->is_match(p1*a) && - "Expected p1*a on the right."); - assert(fma_combine2_cast->get_left()->is_match(p1 + p3) && - "Expected p1 + p3 on the left."); + auto fma_combine2_cast = graph::fma_cast(fma_combine2); + assert(fma_combine2_cast.get() && "Expected a fma node."); + assert(fma_combine2_cast->get_left()->is_match(-p1) && + "Expected -p1 on the left."); + assert(fma_combine2_cast->get_middle()->is_match(a) && + "Expected a in the middle."); + assert(fma_combine2_cast->get_right()->is_match(p1 + p3) && + "Expected p1 + p3 on the right."); // p1*(c1 + a) - p2 -> fma(p1,a,p3) auto fma_combine3 = p1*(1.0 + a) - p3; auto fma_combine3_cast = graph::fma_cast(fma_combine3); @@ -651,12 +653,14 @@ template void piecewise_2D() { "Expected p1 + p3 on the right."); // fma(p1,c1 - a,p2) -> p3 - p1*a auto fma_combine2 = fma(p1,1.0 - ax,p3); - auto fma_combine2_cast = graph::subtract_cast(fma_combine2); - assert(fma_combine2_cast.get() && "Expected an subtract node."); - assert(fma_combine2_cast->get_right()->is_match(p1*ax) && - "Expected p1*a on the right."); - assert(fma_combine2_cast->get_left()->is_match(p1 + p3) && - "Expected p1 + p3 on the left."); + auto fma_combine2_cast = graph::fma_cast(fma_combine2); + assert(fma_combine2_cast.get() && "Expected a fma node."); + assert(fma_combine2_cast->get_left()->is_match(-p1) && + "Expected -p1 on the right."); + assert(fma_combine2_cast->get_middle()->is_match(ax) && + "Expected a in the middle."); + assert(fma_combine2_cast->get_right()->is_match(p1 + p3) && + "Expected p1 + p3 on the right."); // p1*(c1 + a) - p2 -> fma(p1,a,p3) auto fma_combine3 = p1*(1.0 + ax) - p3; auto fma_combine3_cast = graph::fma_cast(fma_combine3); -- GitLab