Loading graph_framework.xcodeproj/project.pbxproj +4 −0 Original line number Diff line number Diff line Loading @@ -1363,6 +1363,7 @@ "-lLLVMAArch64CodeGen", "-lLLVMCGData", "-lLLVMSandboxIR", "-lLLVMObjectYAML", "-lLLVMFrontendAtomic", "-lclangFrontend", "-lclangBasic", Loading Loading @@ -1468,6 +1469,7 @@ "-lLLVMAArch64CodeGen", "-lLLVMCGData", "-lLLVMSandboxIR", "-lLLVMObjectYAML", "-lLLVMFrontendAtomic", "-lclangFrontend", "-lclangBasic", Loading Loading @@ -1848,6 +1850,7 @@ "-lLLVMCodeGenData", "-lLLVMCGData", "-lLLVMSandboxIR", "-lLLVMObjectYAML", "-lLLVMFrontendAtomic", "-lclangFrontend", "-lclangBasic", Loading Loading @@ -1947,6 +1950,7 @@ "-lLLVMCodeGenData", "-lLLVMCGData", "-lLLVMSandboxIR", "-lLLVMObjectYAML", "-lLLVMFrontendAtomic", "-lclangFrontend", "-lclangBasic", Loading graph_framework/arithmetic.hpp +143 −3 Original line number Diff line number Diff line Loading @@ -1903,6 +1903,33 @@ namespace graph { if (is_variable_promotable(rm->get_right(), this->left)) { return (this->left*rm->get_left())*rm->get_right(); } auto rmlfma = fma_cast(rm->get_left()); if (rmlfma.get()) { if (is_constant_combineable(this->left, rmlfma->get_left()) && is_constant_combineable(this->left, rmlfma->get_right())) { return fma(this->left*rmlfma->get_left(), rmlfma->get_middle(), this->left*rmlfma->get_right())*rm->get_right(); } auto rmlfmalfma = fma_cast(rmlfma->get_left()); if (rmlfmalfma.get()) { if (is_constant_combineable(this->left, rmlfmalfma->get_left()) && is_constant_combineable(this->left, rmlfmalfma->get_right()) && is_constant_combineable(this->left, rmlfma->get_right())) { return fma(fma(this->left*rmlfmalfma->get_left(), rmlfmalfma->get_middle(), this->left*rmlfmalfma->get_right()), rmlfma->get_middle(), this->left*rmlfma->get_right())*rm->get_right(); } } } } // v1*(c*v2) -> c*(v1*v2) Loading Loading @@ -2299,6 +2326,79 @@ namespace graph { } } // c3*fma(c1,a,c2) -> fma(c4,a,c5) auto rfma = fma_cast(this->right); if (rfma.get()) { if (is_constant_combineable(this->left, rfma->get_left()) && is_constant_combineable(this->left, rfma->get_right())) { return fma(this->left*rfma->get_left(), rfma->get_middle(), this->left*rfma->get_right()); } auto rfmalfma = fma_cast(rfma->get_left()); if (rfmalfma.get()) { if (is_constant_combineable(this->left, rfmalfma->get_left()) && is_constant_combineable(this->left, rfmalfma->get_right()) && is_constant_combineable(this->left, rfma->get_right())) { return fma(fma(this->left*rfmalfma->get_left(), rfmalfma->get_middle(), this->left*rfmalfma->get_right()), rfma->get_middle(), this->left*rfma->get_right()); } auto rfmalfmalfma = fma_cast(rfmalfma->get_left()); if (rfmalfmalfma.get()) { if (is_constant_combineable(this->left, rfmalfmalfma->get_left()) && is_constant_combineable(this->left, rfmalfmalfma->get_right()) && is_constant_combineable(this->left, rfmalfma->get_right()) && is_constant_combineable(this->left, rfma->get_right())) { return fma(fma(fma(this->left*rfmalfmalfma->get_left(), rfmalfmalfma->get_middle(), this->left*rfmalfmalfma->get_right()), rfmalfma->get_middle(), this->left*rfmalfma->get_right()), rfma->get_middle(), this->left*rfma->get_right()); } } } } // fma(c1,x,c2)*(c3 + x) -> fma(fma(c1,x,c4),x,c5) auto lfma = fma_cast(this->left); auto ra = add_cast(this->right); if (lfma.get() && ra.get()) { if (ra->get_right()->is_match(lfma->get_middle()) && is_constant_combineable(ra->get_left(), lfma->get_left()) && is_constant_combineable(ra->get_left(), lfma->get_right())) { return fma(fma(lfma->get_left(), ra->get_right(), ra->get_left()*lfma->get_left() + lfma->get_right()), ra->get_right(), lfma->get_right()*ra->get_left()); } // fma(fma(c1,x,c2),x,c3)*(c4 + x) -> fma(fma(fma(c1,x,c5),x,c6),x,c7) auto lfmalfma = fma_cast(lfma->get_left()); if (ra->get_right()->is_match(lfma->get_middle()) && ra->get_right()->is_match(lfmalfma->get_middle()) && is_constant_combineable(ra->get_left(), lfma->get_right()) && is_constant_combineable(ra->get_left(), lfmalfma->get_right()) && is_constant_combineable(ra->get_left(), lfmalfma->get_left())) { return fma(fma(fma(lfmalfma->get_left(), ra->get_right(), ra->get_left()*lfmalfma->get_left() + lfmalfma->get_right()), ra->get_right(), ra->get_left()*lfmalfma->get_right() + lfma->get_right()), ra->get_right(), ra->get_left()*lfma->get_right()); } } // Cases like // (c/exp(a))*(exp(b)/d) -> (c/d)*(exp(b)/exp(a)) // (c/exp(a))*(d/exp(b)) -> (c*e)/(exp(b)*exp(a)) Loading Loading @@ -3673,13 +3773,53 @@ namespace graph { } } // fma(c1,c2 - a,c3) -> c4 - c5*a // fma(c1,c2 - a,c3) -> fma(-c1,a,c1*c2 + c3) // fma(c1,a - c2,c3) -> fma(c1,a,c3 - c1*c2) auto ms = subtract_cast(this->middle); if (ms.get()) { if (is_constant_combineable(this->left, ms->get_left()) && is_constant_combineable(this->left, this->right)) { return fma(this->left, ms->get_left(), this->right) - this->left*ms->get_right(); return fma(-this->left, ms->get_right(), this->left*ms->get_left() + this->right); } else if (is_constant_combineable(this->left, ms->get_right()) && is_constant_combineable(this->left, this->right)) { return fma(this->left, ms->get_left(), this->right - this->left*ms->get_right()); } auto lfma = fma_cast(this->left); if (lfma.get()) { if (is_constant_combineable(ms->get_right(), lfma->get_left()) && is_constant_combineable(ms->get_right(), lfma->get_right()) && is_constant_combineable(this->right, lfma->get_right()) && lfma->get_middle()->is_match(ms->get_left())) { return fma(fma(lfma->get_left(), ms->get_left(), lfma->get_right() - lfma->get_left()*ms->get_right()), ms->get_left(), this->right - lfma->get_right()*ms->get_right()); } auto lfmalfma = fma_cast(lfma->get_left()); if (lfmalfma.get()) { if (lfma->get_middle()->is_match(ms->get_left()) && lfmalfma->get_middle()->is_match(ms->get_left()) && is_constant_combineable(ms->get_right(), lfmalfma->get_left()) && is_constant_combineable(ms->get_right(), lfmalfma->get_right()) && is_constant_combineable(ms->get_right(), lfma->get_right()) && is_constant_combineable(ms->get_right(), this->right)) { return fma(fma(fma(lfmalfma->get_left(), ms->get_left(), lfmalfma->get_right() - lfmalfma->get_left()*ms->get_right()), ms->get_left(), lfma->get_right() - lfmalfma->get_right()*ms->get_right()), ms->get_left(), this->right - lfma->get_right()*ms->get_right()); } } } } Loading graph_tests/arithmetic_test.cpp +72 −0 Original line number Diff line number Diff line Loading @@ -1925,6 +1925,48 @@ template<jit::float_scalar T> void test_multiply() { assert(gather_power5_cast->get_right()->is_match(graph::pow(var_a*var_b, var_c)) && "Expected (a*b)^c."); // c3*fma(c1,a,c2) -> fma(c4,a,c5) auto constant_reduction = 0.25*fma(2.0, v1, 3.0); assert(constant_reduction->is_match(fma(2.0*0.25, v1, 3.0*0.25)) && "Expected (0.5*a + 0.75)"); // c3*(fma(c1,a,c2)*b) -> fma(c4,a,c5)*b auto constant_reduction2 = 0.25*(fma(2.0, v1, 3.0)*v2); assert(constant_reduction2->is_match(fma(2.0*0.25, v1, 3.0*0.25)*v2) && "Expected (0.5*a + 0.75)*b"); // c1*(fma(c2,a,c3)*b + c4) -> fma(c5,a,c6)*b + c7 auto constant_reduction3 = 0.25*(fma(fma(2.0, v1, 3.0),v2,2.0)); assert(constant_reduction3->is_match(fma(fma(2.0*0.25, v1, 3.0*0.25),v2,0.5)) && "Expected (0.5*a + 0.75)*b + 0.5"); // c1*((fma(c2,a,c3)*b + c4)*c) -> (fma(c5,a,c6)*b + c7)*c auto constant_reduction4 = 0.25*(fma(fma(2.0, v1, 3.0),v2,2.0)*v1); assert(constant_reduction4->is_match(fma(fma(2.0*0.25, v1, 3.0*0.25),v2,0.5)*v1) && "Expected ((0.5*a + 0.75)*b + 0.5)*c"); // fma(c1,x,c2)*(c3 + x) -> fma(fma(c1,x,c4),x,c5) auto expand = graph::fma(0.2, v1, 3.0)*(4.0 + v1); assert(expand->is_match(graph::fma(graph::fma(0.2, v1, 3.8), v1, 12.0))); // fma(fma(c1,x,c2),x,c3)*(c4 + x) -> fma(fma(fma(c1,x,c5),x,c6),x,c7) auto expand2 = graph::fma(fma(0.2,v1,2.3),v1,3.0)*(4.0 + v1); assert(expand2->is_match(graph::fma(graph::fma(graph::fma(0.2, v1, 0.2*4.0 + 2.3), v1, 12.2), v1, 12.0)) && "Exptected (((0.2*x + 3.1)*x + 12.2)*x + 12"); // c1*fma(fma(fma(c2,x,c3),x,c4),x,c5) -> fma(fma(fma(c6,x,c7),x,c8),x,c9) auto consume = 10.0*(graph::fma(graph::fma(graph::fma(0.2,v1,2.3),v1,3.0),v1,0.1)); assert(consume->is_match(graph::fma(graph::fma(graph::fma(2.0, v1, 23.0), v1, 30.0), v1, 1.0)) && "Expected fma(fma(fma(2,x,23),x,30,x,1))"); } //------------------------------------------------------------------------------ Loading Loading @@ -3705,6 +3747,36 @@ template<jit::float_scalar T> void test_fma() { "Expected b."); assert(factorize6_cast->get_left()->is_match(graph::fma(var_a,var_c,var_d)) && "Expected a*c + d."); // fma(c1,a - c2,c3) -> fma(c1,a,c4) auto consume = graph::fma(2.0,var_a - 3.0,20.0); assert(consume->is_match(graph::fma(2.0,var_a,14.0)) && "Expected fma(2,x,14)"); // fma(c1,c2 - a,c3) -> fma(-c1,a,c4) auto consume2 = graph::fma(2.0,3.0 - var_a,20.0); assert(consume2->is_match(graph::fma(-2.0,var_a,26.0)) && "Expected fma(-2,x,26)"); // fma(fma(c1,a,c2),a - c3,c4) -> fma(fma(c1,x,c5),x,c6) auto gather = graph::fma(graph::fma(2.0,var_a,20.0),var_a - 2.0,30.0); assert(gather->is_match(graph::fma(graph::fma(2.0,var_a,16.0),var_a,-10.0)) && "Expected fma(fma(2,x,16),x,-10)"); // fma(fma(fma(c1,a,c2),a,c3),a - c4,c5) -> fma(fma(c1,x,c6),x,c6),x,c8) auto gather2 = graph::fma(graph::fma(graph::fma(2.0, var_a, 20.0), var_a, 30.0), var_a - 2.0, 50.0); assert(gather2->is_match(graph::fma(graph::fma(graph::fma(2.0, var_a, 16.0), var_a, -10.0), var_a, -10.0)) && "Expected fma(fma(fma(2,x,16),x,-10),x,-10)"); } //------------------------------------------------------------------------------ Loading graph_tests/efit_test.cpp +3 −3 Original line number Diff line number Diff line Loading @@ -157,11 +157,11 @@ void run_test() { "Expected a match in bx."); check_error(work.check_value(i, bvec->get_y()), gold.by_grid[i], 1.0E-20, "Expected a match in by."); check_error(work.check_value(i, bvec->get_z()), gold.bz_grid[i], 4.0E-12, check_error(work.check_value(i, bvec->get_z()), gold.bz_grid[i], 5.0E-12, "Expected a match in bz."); check_error(work.check_value(i, ne), gold.ne_grid[i], 8.0E-13, check_error(work.check_value(i, ne), gold.ne_grid[i], 2.1E-12, "Expected a match in ne."); check_error(work.check_value(i, te), gold.te_grid[i], 8.0E-13, check_error(work.check_value(i, te), gold.te_grid[i], 2.1E-12, "Expected a match in te."); } } Loading graph_tests/piecewise_test.cpp +16 −12 Original line number Diff line number Diff line Loading @@ -241,12 +241,14 @@ template<jit::float_scalar T> void piecewise_1D() { "Expected p1 + p3 on the right."); // fma(p1,c1 - a,p2) -> p3 - p1*a auto fma_combine2 = fma(p1,1.0 - a,p3); auto fma_combine2_cast = graph::subtract_cast(fma_combine2); assert(fma_combine2_cast.get() && "Expected an subtract node."); assert(fma_combine2_cast->get_right()->is_match(p1*a) && "Expected p1*a on the right."); assert(fma_combine2_cast->get_left()->is_match(p1 + p3) && "Expected p1 + p3 on the left."); auto fma_combine2_cast = graph::fma_cast(fma_combine2); assert(fma_combine2_cast.get() && "Expected a fma node."); assert(fma_combine2_cast->get_left()->is_match(-p1) && "Expected -p1 on the left."); assert(fma_combine2_cast->get_middle()->is_match(a) && "Expected a in the middle."); assert(fma_combine2_cast->get_right()->is_match(p1 + p3) && "Expected p1 + p3 on the right."); // p1*(c1 + a) - p2 -> fma(p1,a,p3) auto fma_combine3 = p1*(1.0 + a) - p3; auto fma_combine3_cast = graph::fma_cast(fma_combine3); Loading Loading @@ -651,12 +653,14 @@ template<jit::float_scalar T> void piecewise_2D() { "Expected p1 + p3 on the right."); // fma(p1,c1 - a,p2) -> p3 - p1*a auto fma_combine2 = fma(p1,1.0 - ax,p3); auto fma_combine2_cast = graph::subtract_cast(fma_combine2); assert(fma_combine2_cast.get() && "Expected an subtract node."); assert(fma_combine2_cast->get_right()->is_match(p1*ax) && "Expected p1*a on the right."); assert(fma_combine2_cast->get_left()->is_match(p1 + p3) && "Expected p1 + p3 on the left."); auto fma_combine2_cast = graph::fma_cast(fma_combine2); assert(fma_combine2_cast.get() && "Expected a fma node."); assert(fma_combine2_cast->get_left()->is_match(-p1) && "Expected -p1 on the right."); assert(fma_combine2_cast->get_middle()->is_match(ax) && "Expected a in the middle."); assert(fma_combine2_cast->get_right()->is_match(p1 + p3) && "Expected p1 + p3 on the right."); // p1*(c1 + a) - p2 -> fma(p1,a,p3) auto fma_combine3 = p1*(1.0 + ax) - p3; auto fma_combine3_cast = graph::fma_cast(fma_combine3); Loading graph_framework/equilibrium.hpp +1 −1 File changed.Contains only whitespace changes. Show changes Loading
graph_framework.xcodeproj/project.pbxproj +4 −0 Original line number Diff line number Diff line Loading @@ -1363,6 +1363,7 @@ "-lLLVMAArch64CodeGen", "-lLLVMCGData", "-lLLVMSandboxIR", "-lLLVMObjectYAML", "-lLLVMFrontendAtomic", "-lclangFrontend", "-lclangBasic", Loading Loading @@ -1468,6 +1469,7 @@ "-lLLVMAArch64CodeGen", "-lLLVMCGData", "-lLLVMSandboxIR", "-lLLVMObjectYAML", "-lLLVMFrontendAtomic", "-lclangFrontend", "-lclangBasic", Loading Loading @@ -1848,6 +1850,7 @@ "-lLLVMCodeGenData", "-lLLVMCGData", "-lLLVMSandboxIR", "-lLLVMObjectYAML", "-lLLVMFrontendAtomic", "-lclangFrontend", "-lclangBasic", Loading Loading @@ -1947,6 +1950,7 @@ "-lLLVMCodeGenData", "-lLLVMCGData", "-lLLVMSandboxIR", "-lLLVMObjectYAML", "-lLLVMFrontendAtomic", "-lclangFrontend", "-lclangBasic", Loading
graph_framework/arithmetic.hpp +143 −3 Original line number Diff line number Diff line Loading @@ -1903,6 +1903,33 @@ namespace graph { if (is_variable_promotable(rm->get_right(), this->left)) { return (this->left*rm->get_left())*rm->get_right(); } auto rmlfma = fma_cast(rm->get_left()); if (rmlfma.get()) { if (is_constant_combineable(this->left, rmlfma->get_left()) && is_constant_combineable(this->left, rmlfma->get_right())) { return fma(this->left*rmlfma->get_left(), rmlfma->get_middle(), this->left*rmlfma->get_right())*rm->get_right(); } auto rmlfmalfma = fma_cast(rmlfma->get_left()); if (rmlfmalfma.get()) { if (is_constant_combineable(this->left, rmlfmalfma->get_left()) && is_constant_combineable(this->left, rmlfmalfma->get_right()) && is_constant_combineable(this->left, rmlfma->get_right())) { return fma(fma(this->left*rmlfmalfma->get_left(), rmlfmalfma->get_middle(), this->left*rmlfmalfma->get_right()), rmlfma->get_middle(), this->left*rmlfma->get_right())*rm->get_right(); } } } } // v1*(c*v2) -> c*(v1*v2) Loading Loading @@ -2299,6 +2326,79 @@ namespace graph { } } // c3*fma(c1,a,c2) -> fma(c4,a,c5) auto rfma = fma_cast(this->right); if (rfma.get()) { if (is_constant_combineable(this->left, rfma->get_left()) && is_constant_combineable(this->left, rfma->get_right())) { return fma(this->left*rfma->get_left(), rfma->get_middle(), this->left*rfma->get_right()); } auto rfmalfma = fma_cast(rfma->get_left()); if (rfmalfma.get()) { if (is_constant_combineable(this->left, rfmalfma->get_left()) && is_constant_combineable(this->left, rfmalfma->get_right()) && is_constant_combineable(this->left, rfma->get_right())) { return fma(fma(this->left*rfmalfma->get_left(), rfmalfma->get_middle(), this->left*rfmalfma->get_right()), rfma->get_middle(), this->left*rfma->get_right()); } auto rfmalfmalfma = fma_cast(rfmalfma->get_left()); if (rfmalfmalfma.get()) { if (is_constant_combineable(this->left, rfmalfmalfma->get_left()) && is_constant_combineable(this->left, rfmalfmalfma->get_right()) && is_constant_combineable(this->left, rfmalfma->get_right()) && is_constant_combineable(this->left, rfma->get_right())) { return fma(fma(fma(this->left*rfmalfmalfma->get_left(), rfmalfmalfma->get_middle(), this->left*rfmalfmalfma->get_right()), rfmalfma->get_middle(), this->left*rfmalfma->get_right()), rfma->get_middle(), this->left*rfma->get_right()); } } } } // fma(c1,x,c2)*(c3 + x) -> fma(fma(c1,x,c4),x,c5) auto lfma = fma_cast(this->left); auto ra = add_cast(this->right); if (lfma.get() && ra.get()) { if (ra->get_right()->is_match(lfma->get_middle()) && is_constant_combineable(ra->get_left(), lfma->get_left()) && is_constant_combineable(ra->get_left(), lfma->get_right())) { return fma(fma(lfma->get_left(), ra->get_right(), ra->get_left()*lfma->get_left() + lfma->get_right()), ra->get_right(), lfma->get_right()*ra->get_left()); } // fma(fma(c1,x,c2),x,c3)*(c4 + x) -> fma(fma(fma(c1,x,c5),x,c6),x,c7) auto lfmalfma = fma_cast(lfma->get_left()); if (ra->get_right()->is_match(lfma->get_middle()) && ra->get_right()->is_match(lfmalfma->get_middle()) && is_constant_combineable(ra->get_left(), lfma->get_right()) && is_constant_combineable(ra->get_left(), lfmalfma->get_right()) && is_constant_combineable(ra->get_left(), lfmalfma->get_left())) { return fma(fma(fma(lfmalfma->get_left(), ra->get_right(), ra->get_left()*lfmalfma->get_left() + lfmalfma->get_right()), ra->get_right(), ra->get_left()*lfmalfma->get_right() + lfma->get_right()), ra->get_right(), ra->get_left()*lfma->get_right()); } } // Cases like // (c/exp(a))*(exp(b)/d) -> (c/d)*(exp(b)/exp(a)) // (c/exp(a))*(d/exp(b)) -> (c*e)/(exp(b)*exp(a)) Loading Loading @@ -3673,13 +3773,53 @@ namespace graph { } } // fma(c1,c2 - a,c3) -> c4 - c5*a // fma(c1,c2 - a,c3) -> fma(-c1,a,c1*c2 + c3) // fma(c1,a - c2,c3) -> fma(c1,a,c3 - c1*c2) auto ms = subtract_cast(this->middle); if (ms.get()) { if (is_constant_combineable(this->left, ms->get_left()) && is_constant_combineable(this->left, this->right)) { return fma(this->left, ms->get_left(), this->right) - this->left*ms->get_right(); return fma(-this->left, ms->get_right(), this->left*ms->get_left() + this->right); } else if (is_constant_combineable(this->left, ms->get_right()) && is_constant_combineable(this->left, this->right)) { return fma(this->left, ms->get_left(), this->right - this->left*ms->get_right()); } auto lfma = fma_cast(this->left); if (lfma.get()) { if (is_constant_combineable(ms->get_right(), lfma->get_left()) && is_constant_combineable(ms->get_right(), lfma->get_right()) && is_constant_combineable(this->right, lfma->get_right()) && lfma->get_middle()->is_match(ms->get_left())) { return fma(fma(lfma->get_left(), ms->get_left(), lfma->get_right() - lfma->get_left()*ms->get_right()), ms->get_left(), this->right - lfma->get_right()*ms->get_right()); } auto lfmalfma = fma_cast(lfma->get_left()); if (lfmalfma.get()) { if (lfma->get_middle()->is_match(ms->get_left()) && lfmalfma->get_middle()->is_match(ms->get_left()) && is_constant_combineable(ms->get_right(), lfmalfma->get_left()) && is_constant_combineable(ms->get_right(), lfmalfma->get_right()) && is_constant_combineable(ms->get_right(), lfma->get_right()) && is_constant_combineable(ms->get_right(), this->right)) { return fma(fma(fma(lfmalfma->get_left(), ms->get_left(), lfmalfma->get_right() - lfmalfma->get_left()*ms->get_right()), ms->get_left(), lfma->get_right() - lfmalfma->get_right()*ms->get_right()), ms->get_left(), this->right - lfma->get_right()*ms->get_right()); } } } } Loading
graph_tests/arithmetic_test.cpp +72 −0 Original line number Diff line number Diff line Loading @@ -1925,6 +1925,48 @@ template<jit::float_scalar T> void test_multiply() { assert(gather_power5_cast->get_right()->is_match(graph::pow(var_a*var_b, var_c)) && "Expected (a*b)^c."); // c3*fma(c1,a,c2) -> fma(c4,a,c5) auto constant_reduction = 0.25*fma(2.0, v1, 3.0); assert(constant_reduction->is_match(fma(2.0*0.25, v1, 3.0*0.25)) && "Expected (0.5*a + 0.75)"); // c3*(fma(c1,a,c2)*b) -> fma(c4,a,c5)*b auto constant_reduction2 = 0.25*(fma(2.0, v1, 3.0)*v2); assert(constant_reduction2->is_match(fma(2.0*0.25, v1, 3.0*0.25)*v2) && "Expected (0.5*a + 0.75)*b"); // c1*(fma(c2,a,c3)*b + c4) -> fma(c5,a,c6)*b + c7 auto constant_reduction3 = 0.25*(fma(fma(2.0, v1, 3.0),v2,2.0)); assert(constant_reduction3->is_match(fma(fma(2.0*0.25, v1, 3.0*0.25),v2,0.5)) && "Expected (0.5*a + 0.75)*b + 0.5"); // c1*((fma(c2,a,c3)*b + c4)*c) -> (fma(c5,a,c6)*b + c7)*c auto constant_reduction4 = 0.25*(fma(fma(2.0, v1, 3.0),v2,2.0)*v1); assert(constant_reduction4->is_match(fma(fma(2.0*0.25, v1, 3.0*0.25),v2,0.5)*v1) && "Expected ((0.5*a + 0.75)*b + 0.5)*c"); // fma(c1,x,c2)*(c3 + x) -> fma(fma(c1,x,c4),x,c5) auto expand = graph::fma(0.2, v1, 3.0)*(4.0 + v1); assert(expand->is_match(graph::fma(graph::fma(0.2, v1, 3.8), v1, 12.0))); // fma(fma(c1,x,c2),x,c3)*(c4 + x) -> fma(fma(fma(c1,x,c5),x,c6),x,c7) auto expand2 = graph::fma(fma(0.2,v1,2.3),v1,3.0)*(4.0 + v1); assert(expand2->is_match(graph::fma(graph::fma(graph::fma(0.2, v1, 0.2*4.0 + 2.3), v1, 12.2), v1, 12.0)) && "Exptected (((0.2*x + 3.1)*x + 12.2)*x + 12"); // c1*fma(fma(fma(c2,x,c3),x,c4),x,c5) -> fma(fma(fma(c6,x,c7),x,c8),x,c9) auto consume = 10.0*(graph::fma(graph::fma(graph::fma(0.2,v1,2.3),v1,3.0),v1,0.1)); assert(consume->is_match(graph::fma(graph::fma(graph::fma(2.0, v1, 23.0), v1, 30.0), v1, 1.0)) && "Expected fma(fma(fma(2,x,23),x,30,x,1))"); } //------------------------------------------------------------------------------ Loading Loading @@ -3705,6 +3747,36 @@ template<jit::float_scalar T> void test_fma() { "Expected b."); assert(factorize6_cast->get_left()->is_match(graph::fma(var_a,var_c,var_d)) && "Expected a*c + d."); // fma(c1,a - c2,c3) -> fma(c1,a,c4) auto consume = graph::fma(2.0,var_a - 3.0,20.0); assert(consume->is_match(graph::fma(2.0,var_a,14.0)) && "Expected fma(2,x,14)"); // fma(c1,c2 - a,c3) -> fma(-c1,a,c4) auto consume2 = graph::fma(2.0,3.0 - var_a,20.0); assert(consume2->is_match(graph::fma(-2.0,var_a,26.0)) && "Expected fma(-2,x,26)"); // fma(fma(c1,a,c2),a - c3,c4) -> fma(fma(c1,x,c5),x,c6) auto gather = graph::fma(graph::fma(2.0,var_a,20.0),var_a - 2.0,30.0); assert(gather->is_match(graph::fma(graph::fma(2.0,var_a,16.0),var_a,-10.0)) && "Expected fma(fma(2,x,16),x,-10)"); // fma(fma(fma(c1,a,c2),a,c3),a - c4,c5) -> fma(fma(c1,x,c6),x,c6),x,c8) auto gather2 = graph::fma(graph::fma(graph::fma(2.0, var_a, 20.0), var_a, 30.0), var_a - 2.0, 50.0); assert(gather2->is_match(graph::fma(graph::fma(graph::fma(2.0, var_a, 16.0), var_a, -10.0), var_a, -10.0)) && "Expected fma(fma(fma(2,x,16),x,-10),x,-10)"); } //------------------------------------------------------------------------------ Loading
graph_tests/efit_test.cpp +3 −3 Original line number Diff line number Diff line Loading @@ -157,11 +157,11 @@ void run_test() { "Expected a match in bx."); check_error(work.check_value(i, bvec->get_y()), gold.by_grid[i], 1.0E-20, "Expected a match in by."); check_error(work.check_value(i, bvec->get_z()), gold.bz_grid[i], 4.0E-12, check_error(work.check_value(i, bvec->get_z()), gold.bz_grid[i], 5.0E-12, "Expected a match in bz."); check_error(work.check_value(i, ne), gold.ne_grid[i], 8.0E-13, check_error(work.check_value(i, ne), gold.ne_grid[i], 2.1E-12, "Expected a match in ne."); check_error(work.check_value(i, te), gold.te_grid[i], 8.0E-13, check_error(work.check_value(i, te), gold.te_grid[i], 2.1E-12, "Expected a match in te."); } } Loading
graph_tests/piecewise_test.cpp +16 −12 Original line number Diff line number Diff line Loading @@ -241,12 +241,14 @@ template<jit::float_scalar T> void piecewise_1D() { "Expected p1 + p3 on the right."); // fma(p1,c1 - a,p2) -> p3 - p1*a auto fma_combine2 = fma(p1,1.0 - a,p3); auto fma_combine2_cast = graph::subtract_cast(fma_combine2); assert(fma_combine2_cast.get() && "Expected an subtract node."); assert(fma_combine2_cast->get_right()->is_match(p1*a) && "Expected p1*a on the right."); assert(fma_combine2_cast->get_left()->is_match(p1 + p3) && "Expected p1 + p3 on the left."); auto fma_combine2_cast = graph::fma_cast(fma_combine2); assert(fma_combine2_cast.get() && "Expected a fma node."); assert(fma_combine2_cast->get_left()->is_match(-p1) && "Expected -p1 on the left."); assert(fma_combine2_cast->get_middle()->is_match(a) && "Expected a in the middle."); assert(fma_combine2_cast->get_right()->is_match(p1 + p3) && "Expected p1 + p3 on the right."); // p1*(c1 + a) - p2 -> fma(p1,a,p3) auto fma_combine3 = p1*(1.0 + a) - p3; auto fma_combine3_cast = graph::fma_cast(fma_combine3); Loading Loading @@ -651,12 +653,14 @@ template<jit::float_scalar T> void piecewise_2D() { "Expected p1 + p3 on the right."); // fma(p1,c1 - a,p2) -> p3 - p1*a auto fma_combine2 = fma(p1,1.0 - ax,p3); auto fma_combine2_cast = graph::subtract_cast(fma_combine2); assert(fma_combine2_cast.get() && "Expected an subtract node."); assert(fma_combine2_cast->get_right()->is_match(p1*ax) && "Expected p1*a on the right."); assert(fma_combine2_cast->get_left()->is_match(p1 + p3) && "Expected p1 + p3 on the left."); auto fma_combine2_cast = graph::fma_cast(fma_combine2); assert(fma_combine2_cast.get() && "Expected a fma node."); assert(fma_combine2_cast->get_left()->is_match(-p1) && "Expected -p1 on the right."); assert(fma_combine2_cast->get_middle()->is_match(ax) && "Expected a in the middle."); assert(fma_combine2_cast->get_right()->is_match(p1 + p3) && "Expected p1 + p3 on the right."); // p1*(c1 + a) - p2 -> fma(p1,a,p3) auto fma_combine3 = p1*(1.0 + ax) - p3; auto fma_combine3_cast = graph::fma_cast(fma_combine3); Loading