Loading graph_framework/arithmetic.hpp +52 −0 Original line number Diff line number Diff line Loading @@ -3753,6 +3753,30 @@ namespace graph { } } // fma(a,b*c,b*d) -> b*fma(a,c,d) // fma(a,c*b,b*d) -> b*fma(a,c,d) // fma(a,b*c,d*b) -> b*fma(a,c,d) // fma(a,c*b,d*b) -> b*fma(a,c,d) if (mm.get()) { if (mm->get_left()->is_match(rm->get_left())) { return mm->get_left()*fma(this->left, mm->get_right(), rm->get_right()); } else if (mm->get_left()->is_match(rm->get_right())) { return mm->get_left()*fma(this->left, mm->get_right(), rm->get_left()); } else if (mm->get_right()->is_match(rm->get_left())) { return mm->get_right()*fma(this->left, mm->get_left(), rm->get_right()); } else if (mm->get_right()->is_match(rm->get_right())) { return mm->get_right()*fma(this->left, mm->get_left(), rm->get_left()); } } // Convert fma(a*b,c,d*e) -> fma(d,e,a*b*c) // Convert fma(a,b*c,d*e) -> fma(d,e,a*b*c) if ((lm.get() || mm.get()) && Loading Loading @@ -3850,6 +3874,19 @@ namespace graph { } } // fma(a,b*c,b) -> b*fma(a,c,1) if (mm.get()) { if (mm->get_left()->is_match(this->right)) { return mm->get_left()*fma(this->left, mm->get_right(), 1.0); } else if (mm->get_right()->is_match(this->right)) { return mm->get_right()*fma(this->left, mm->get_left(), 1.0); } } // fma(c1,a,c2/b) -> c1*(a + c3/b) // fma(a,c1,c2/b) -> c1*(a + c3/b) auto rd = divide_cast(this->right); Loading Loading @@ -4435,6 +4472,21 @@ namespace graph { rfma->get_middle()), rfma->get_right()); } // fma(2,(a*b)^2,fma(3,a^2*b,c)) -> a^2*fma(2,b^2,fma(3,b,c)) auto rfmamm = multiply_cast(rfma->get_middle()); if (rfmamm.get()) { if (is_variable_combineable(mplm->get_left(), rfmamm->get_left())) { auto temp = pow(mplm->get_left(), mp->get_right()); return temp*fma(this->left, this->middle/temp, fma(rfma->get_left(), rfma->get_middle()/temp, rfma->get_right())); } } } } Loading graph_framework/equilibrium.hpp +12 −12 Original line number Diff line number Diff line Loading @@ -1073,18 +1073,18 @@ namespace equilibrium { + c01_temp*z_norm + c02_temp*(z_norm*z_norm) + c03_temp*(z_norm*z_norm*z_norm) + c10_temp*r_norm + c11_temp*r_norm*z_norm + c12_temp*r_norm*(z_norm*z_norm) + c13_temp*r_norm*(z_norm*z_norm*z_norm) + c20_temp*(r_norm*r_norm) + c21_temp*(r_norm*r_norm)*z_norm + c22_temp*(r_norm*r_norm)*(z_norm*z_norm) + c23_temp*(r_norm*r_norm)*(z_norm*z_norm*z_norm) + c30_temp*(r_norm*r_norm*r_norm) + c31_temp*(r_norm*r_norm*r_norm)*z_norm + c32_temp*(r_norm*r_norm*r_norm)*(z_norm*z_norm) + c33_temp*(r_norm*r_norm*r_norm)*(z_norm*z_norm*z_norm); + r_norm*(c10_temp + c11_temp*z_norm + c12_temp*(z_norm*z_norm) + c13_temp*(z_norm*z_norm*z_norm)) + (r_norm*r_norm)*(c20_temp + c21_temp*z_norm + c22_temp*(z_norm*z_norm) + c23_temp*(z_norm*z_norm*z_norm)) + (r_norm*r_norm*r_norm)*(c30_temp + c31_temp*z_norm + c32_temp*(z_norm*z_norm) + c33_temp*(z_norm*z_norm*z_norm)); } //------------------------------------------------------------------------------ Loading graph_korc/xkorc.cpp +3 −2 Original line number Diff line number Diff line Loading @@ -143,11 +143,11 @@ void run_korc() { const timeing::measure_diagnostic t_run("Run Time"); work.pre_run(); for (size_t i = 0; i < 1000000; i++) { sync.join(); /* sync.join(); work.wait(); sync = std::thread([&file, &dataset] () -> void { dataset.write(file); }); });*/ work.run(); } Loading Loading @@ -181,6 +181,7 @@ int main(int argc, const char * argv[]) { (void)argv; run_korc<double> (); // run_korc<float> (); END_GPU } graph_tests/arithmetic_test.cpp +84 −8 Original line number Diff line number Diff line Loading @@ -2822,7 +2822,7 @@ template<jit::float_scalar T> void test_fma() { "Expected common var_b"); // fma(a, b, fma(c, b, d)) -> fma(b, a + c, d) auto var_d = graph::variable<T> (1, ""); auto var_d = graph::variable<T> (1, "d"); auto match1 = graph::fma(var_b, var_a + var_c, var_d); auto nested_fma1 = graph::fma(var_a, var_b, graph::fma(var_c, var_b, var_d)); Loading Loading @@ -3578,11 +3578,11 @@ template<jit::float_scalar T> void test_fma() { auto commom_power_cast = graph::multiply_cast(commom_power); assert(commom_power_cast.get() && "Expected a multiply node."); assert(commom_power_cast->get_right()->is_match(var_b) && "Expced b"); "Expeced b"); assert(commom_power_cast->get_left()->is_match(graph::pow(var_a, 2.0) * graph::fma(2.0, var_b, 1.0)) && "Expced a^2*fma(2, b, 1)"); // fma(2,(ba)^2,a^2b) -> a^2*fma(2, b^2, b) "Expeced a^2*fma(2, b, 1)"); // fma(2,(ba)^2,a^2b) -> a^2*b*fma(2, b, 1) auto commom_power2 = graph::fma(2.0, graph::pow(var_b*var_a, 2.0), graph::pow(var_a, 2.0)*var_b); auto commom_power2_cast = graph::multiply_cast(commom_power2); assert(commom_power2_cast.get() && "Expected a multiply node."); Loading @@ -3591,7 +3591,7 @@ template<jit::float_scalar T> void test_fma() { assert(commom_power2_cast->get_left()->is_match(graph::pow(var_a, 2.0) * graph::fma(2.0, var_b, 1.0)) && "Expced a^2*fma(2, b, 1)"); // fma(2,(a*b)^2,fma(a^2,b,c)) -> fma(a^2,fma(2,b^2,b),c) // fma(2,(a*b)^2,fma(a^2,b,c)) -> fma(a^2*b,fma(2,b,1),c) auto commom_power3 = graph::fma(2.0, graph::pow(var_a*var_b, 2.0), graph::fma(graph::pow(var_a, 2.0), Loading @@ -3601,7 +3601,12 @@ template<jit::float_scalar T> void test_fma() { assert(commom_power3_cast.get() && "Expected a fma node."); assert(commom_power3_cast->get_left()->is_match(graph::pow(var_a, 2.0)) && "Expected a^2"); // fma(2,(a*b)^2,fma(a^2,b,c)) -> fma(a^2,fma(2,b^2,b),c) assert(commom_power3_cast->get_middle()->is_match(var_b*graph::fma(2.0, var_b, 1.0)) && "Expected b*fma(2,b,1)"); assert(commom_power3_cast->get_right()->is_match(var_c) && "Expected c"); // fma(2,(a*b)^2,fma(a^2,b,c)) -> fma(a^2*b,fma(2,b,1),c) auto commom_power4 = graph::fma(2.0, graph::pow(var_b*var_a, 2.0), graph::fma(graph::pow(var_a, 2.0), Loading @@ -3611,6 +3616,11 @@ template<jit::float_scalar T> void test_fma() { assert(commom_power4_cast.get() && "Expected a fma node."); assert(commom_power4_cast->get_left()->is_match(graph::pow(var_a, 2.0)) && "Expected a^2"); assert(commom_power4_cast->get_middle()->is_match(var_b*graph::fma(2.0, var_b, 1.0)) && "Expected b*fma(2,b,1)"); assert(commom_power4_cast->get_right()->is_match(var_c) && "Expected c"); // fma(2,a^2,a) -> a*fma(2,a,1) auto common_power5 = graph::fma(2.0,var_a*var_a,var_a); Loading @@ -3621,14 +3631,80 @@ template<jit::float_scalar T> void test_fma() { assert(commom_power5_cast->get_right()->is_match(var_a) && "Expected a."); // fma(2,a,a^2) -> a*(2 + a) auto temp = var_a*var_a; auto common_power6 = graph::fma(2.0,var_a,temp); auto common_power6 = graph::fma(2.0,var_a,var_a*var_a); auto commom_power6_cast = graph::multiply_cast(common_power6); assert(commom_power6_cast.get() && "Expected a multiply node."); assert(commom_power6_cast->get_left()->is_match(2.0 + var_a) && "Expected (2 + a)."); assert(commom_power6_cast->get_right()->is_match(var_a) && "Expected a."); // fma(2,(a*b)^2,fma(3,a^2*b,c)) -> fma(a^2*b,fma(2,b,3),c) auto common_power7 = graph::fma(2.0, graph::pow(var_a*var_b, 2.0), graph::fma(3.0, var_a*var_a*var_b, var_c)); auto common_power7_cast = graph::multiply_cast(common_power7); assert(common_power7_cast.get() && "Expected a multiply node."); assert(common_power7_cast->get_left()->is_match(graph::fma(var_b, graph::fma(2.0, var_b, 3.0), var_c)) && "Expected fma(b,fma(2,b,3),c)"); assert(common_power7_cast->get_right()->is_match(var_a*var_a) && "Expected a^2"); // fma(a,b*c,b) -> b*fma(a,c,1) auto factorize = graph::fma(var_a,var_b*var_c,var_b); auto factorize_cast = multiply_cast(factorize); assert(factorize_cast.get() && "Expected a multiply node."); assert(factorize_cast->get_right()->is_match(var_b) && "Expected b."); assert(factorize_cast->get_left()->is_match(graph::fma(var_a,var_c,1.0)) && "Expected a*c + 1."); // fma(a,c*b,b) -> b*fma(a,c,1) auto factorize2 = graph::fma(var_a,var_c*var_b,var_b); auto factorize2_cast = multiply_cast(factorize2); assert(factorize2_cast.get() && "Expected a multiply node."); assert(factorize2_cast->get_right()->is_match(var_b) && "Expected b."); assert(factorize2_cast->get_left()->is_match(graph::fma(var_a,var_c,1.0)) && "Expected a*c + 1."); // fma(a,b*c,b*d) -> b*fma(a,c,d) auto factorize3 = graph::fma(var_a,var_b*var_c,var_b*var_d); auto factorize3_cast = multiply_cast(factorize3); assert(factorize3_cast.get() && "Expected a multiply node."); assert(factorize3_cast->get_right()->is_match(var_b) && "Expected b."); assert(factorize3_cast->get_left()->is_match(graph::fma(var_a,var_c,var_d)) && "Expected a*c + d."); // fma(a,c*b,b*d) -> b*fma(a,c,d) auto factorize4 = graph::fma(var_a,var_c*var_b,var_b*var_d); auto factorize4_cast = multiply_cast(factorize4); assert(factorize4_cast.get() && "Expected a multiply node."); assert(factorize4_cast->get_right()->is_match(var_b) && "Expected b."); assert(factorize4_cast->get_left()->is_match(graph::fma(var_a,var_c,var_d)) && "Expected a*c + d."); // fma(a,b*c,d*b) -> b*fma(a,c,d) auto factorize5 = graph::fma(var_a,var_b*var_c,var_d*var_b); auto factorize5_cast = multiply_cast(factorize5); assert(factorize5_cast.get() && "Expected a multiply node."); assert(factorize5_cast->get_right()->is_match(var_b) && "Expected b."); assert(factorize5_cast->get_left()->is_match(graph::fma(var_a,var_c,var_d)) && "Expected a*c + d."); // fma(a,c*b,d*b) -> b*fma(a,c,d) auto factorize6 = graph::fma(var_a,var_c*var_b,var_d*var_b); auto factorize6_cast = multiply_cast(factorize6); assert(factorize6_cast.get() && "Expected a multiply node."); assert(factorize6_cast->get_right()->is_match(var_b) && "Expected b."); assert(factorize6_cast->get_left()->is_match(graph::fma(var_a,var_c,var_d)) && "Expected a*c + d."); } //------------------------------------------------------------------------------ Loading graph_tests/efit_test.cpp +4 −6 Original line number Diff line number Diff line Loading @@ -139,8 +139,6 @@ void run_test() { auto bvec = eq->get_magnetic_field(x, y, z); auto ne = eq->get_electron_density(x, y, z); ne->to_latex(); std::cout << std::endl << std::endl; auto te = eq->get_electron_temperature(x, y, z); workflow::manager<T> work(0); Loading @@ -155,15 +153,15 @@ void run_test() { work.run(); for (size_t i = 0, ie = gold.r_grid.size()*gold.z_grid.size(); i < ie; i++) { check_error(work.check_value(i, bvec->get_x()), gold.bx_grid[i], 4.0E-11, check_error(work.check_value(i, bvec->get_x()), gold.bx_grid[i], 9.0E-12, "Expected a match in bx."); check_error(work.check_value(i, bvec->get_y()), gold.by_grid[i], 1.0E-20, "Expected a match in by."); check_error(work.check_value(i, bvec->get_z()), gold.bz_grid[i], 3.0E-12, check_error(work.check_value(i, bvec->get_z()), gold.bz_grid[i], 4.0E-12, "Expected a match in bz."); check_error(work.check_value(i, ne), gold.ne_grid[i], 5.0E-13, check_error(work.check_value(i, ne), gold.ne_grid[i], 8.0E-13, "Expected a match in ne."); check_error(work.check_value(i, te), gold.te_grid[i], 5.0E-13, check_error(work.check_value(i, te), gold.te_grid[i], 8.0E-13, "Expected a match in te."); } } Loading Loading
graph_framework/arithmetic.hpp +52 −0 Original line number Diff line number Diff line Loading @@ -3753,6 +3753,30 @@ namespace graph { } } // fma(a,b*c,b*d) -> b*fma(a,c,d) // fma(a,c*b,b*d) -> b*fma(a,c,d) // fma(a,b*c,d*b) -> b*fma(a,c,d) // fma(a,c*b,d*b) -> b*fma(a,c,d) if (mm.get()) { if (mm->get_left()->is_match(rm->get_left())) { return mm->get_left()*fma(this->left, mm->get_right(), rm->get_right()); } else if (mm->get_left()->is_match(rm->get_right())) { return mm->get_left()*fma(this->left, mm->get_right(), rm->get_left()); } else if (mm->get_right()->is_match(rm->get_left())) { return mm->get_right()*fma(this->left, mm->get_left(), rm->get_right()); } else if (mm->get_right()->is_match(rm->get_right())) { return mm->get_right()*fma(this->left, mm->get_left(), rm->get_left()); } } // Convert fma(a*b,c,d*e) -> fma(d,e,a*b*c) // Convert fma(a,b*c,d*e) -> fma(d,e,a*b*c) if ((lm.get() || mm.get()) && Loading Loading @@ -3850,6 +3874,19 @@ namespace graph { } } // fma(a,b*c,b) -> b*fma(a,c,1) if (mm.get()) { if (mm->get_left()->is_match(this->right)) { return mm->get_left()*fma(this->left, mm->get_right(), 1.0); } else if (mm->get_right()->is_match(this->right)) { return mm->get_right()*fma(this->left, mm->get_left(), 1.0); } } // fma(c1,a,c2/b) -> c1*(a + c3/b) // fma(a,c1,c2/b) -> c1*(a + c3/b) auto rd = divide_cast(this->right); Loading Loading @@ -4435,6 +4472,21 @@ namespace graph { rfma->get_middle()), rfma->get_right()); } // fma(2,(a*b)^2,fma(3,a^2*b,c)) -> a^2*fma(2,b^2,fma(3,b,c)) auto rfmamm = multiply_cast(rfma->get_middle()); if (rfmamm.get()) { if (is_variable_combineable(mplm->get_left(), rfmamm->get_left())) { auto temp = pow(mplm->get_left(), mp->get_right()); return temp*fma(this->left, this->middle/temp, fma(rfma->get_left(), rfma->get_middle()/temp, rfma->get_right())); } } } } Loading
graph_framework/equilibrium.hpp +12 −12 Original line number Diff line number Diff line Loading @@ -1073,18 +1073,18 @@ namespace equilibrium { + c01_temp*z_norm + c02_temp*(z_norm*z_norm) + c03_temp*(z_norm*z_norm*z_norm) + c10_temp*r_norm + c11_temp*r_norm*z_norm + c12_temp*r_norm*(z_norm*z_norm) + c13_temp*r_norm*(z_norm*z_norm*z_norm) + c20_temp*(r_norm*r_norm) + c21_temp*(r_norm*r_norm)*z_norm + c22_temp*(r_norm*r_norm)*(z_norm*z_norm) + c23_temp*(r_norm*r_norm)*(z_norm*z_norm*z_norm) + c30_temp*(r_norm*r_norm*r_norm) + c31_temp*(r_norm*r_norm*r_norm)*z_norm + c32_temp*(r_norm*r_norm*r_norm)*(z_norm*z_norm) + c33_temp*(r_norm*r_norm*r_norm)*(z_norm*z_norm*z_norm); + r_norm*(c10_temp + c11_temp*z_norm + c12_temp*(z_norm*z_norm) + c13_temp*(z_norm*z_norm*z_norm)) + (r_norm*r_norm)*(c20_temp + c21_temp*z_norm + c22_temp*(z_norm*z_norm) + c23_temp*(z_norm*z_norm*z_norm)) + (r_norm*r_norm*r_norm)*(c30_temp + c31_temp*z_norm + c32_temp*(z_norm*z_norm) + c33_temp*(z_norm*z_norm*z_norm)); } //------------------------------------------------------------------------------ Loading
graph_korc/xkorc.cpp +3 −2 Original line number Diff line number Diff line Loading @@ -143,11 +143,11 @@ void run_korc() { const timeing::measure_diagnostic t_run("Run Time"); work.pre_run(); for (size_t i = 0; i < 1000000; i++) { sync.join(); /* sync.join(); work.wait(); sync = std::thread([&file, &dataset] () -> void { dataset.write(file); }); });*/ work.run(); } Loading Loading @@ -181,6 +181,7 @@ int main(int argc, const char * argv[]) { (void)argv; run_korc<double> (); // run_korc<float> (); END_GPU }
graph_tests/arithmetic_test.cpp +84 −8 Original line number Diff line number Diff line Loading @@ -2822,7 +2822,7 @@ template<jit::float_scalar T> void test_fma() { "Expected common var_b"); // fma(a, b, fma(c, b, d)) -> fma(b, a + c, d) auto var_d = graph::variable<T> (1, ""); auto var_d = graph::variable<T> (1, "d"); auto match1 = graph::fma(var_b, var_a + var_c, var_d); auto nested_fma1 = graph::fma(var_a, var_b, graph::fma(var_c, var_b, var_d)); Loading Loading @@ -3578,11 +3578,11 @@ template<jit::float_scalar T> void test_fma() { auto commom_power_cast = graph::multiply_cast(commom_power); assert(commom_power_cast.get() && "Expected a multiply node."); assert(commom_power_cast->get_right()->is_match(var_b) && "Expced b"); "Expeced b"); assert(commom_power_cast->get_left()->is_match(graph::pow(var_a, 2.0) * graph::fma(2.0, var_b, 1.0)) && "Expced a^2*fma(2, b, 1)"); // fma(2,(ba)^2,a^2b) -> a^2*fma(2, b^2, b) "Expeced a^2*fma(2, b, 1)"); // fma(2,(ba)^2,a^2b) -> a^2*b*fma(2, b, 1) auto commom_power2 = graph::fma(2.0, graph::pow(var_b*var_a, 2.0), graph::pow(var_a, 2.0)*var_b); auto commom_power2_cast = graph::multiply_cast(commom_power2); assert(commom_power2_cast.get() && "Expected a multiply node."); Loading @@ -3591,7 +3591,7 @@ template<jit::float_scalar T> void test_fma() { assert(commom_power2_cast->get_left()->is_match(graph::pow(var_a, 2.0) * graph::fma(2.0, var_b, 1.0)) && "Expced a^2*fma(2, b, 1)"); // fma(2,(a*b)^2,fma(a^2,b,c)) -> fma(a^2,fma(2,b^2,b),c) // fma(2,(a*b)^2,fma(a^2,b,c)) -> fma(a^2*b,fma(2,b,1),c) auto commom_power3 = graph::fma(2.0, graph::pow(var_a*var_b, 2.0), graph::fma(graph::pow(var_a, 2.0), Loading @@ -3601,7 +3601,12 @@ template<jit::float_scalar T> void test_fma() { assert(commom_power3_cast.get() && "Expected a fma node."); assert(commom_power3_cast->get_left()->is_match(graph::pow(var_a, 2.0)) && "Expected a^2"); // fma(2,(a*b)^2,fma(a^2,b,c)) -> fma(a^2,fma(2,b^2,b),c) assert(commom_power3_cast->get_middle()->is_match(var_b*graph::fma(2.0, var_b, 1.0)) && "Expected b*fma(2,b,1)"); assert(commom_power3_cast->get_right()->is_match(var_c) && "Expected c"); // fma(2,(a*b)^2,fma(a^2,b,c)) -> fma(a^2*b,fma(2,b,1),c) auto commom_power4 = graph::fma(2.0, graph::pow(var_b*var_a, 2.0), graph::fma(graph::pow(var_a, 2.0), Loading @@ -3611,6 +3616,11 @@ template<jit::float_scalar T> void test_fma() { assert(commom_power4_cast.get() && "Expected a fma node."); assert(commom_power4_cast->get_left()->is_match(graph::pow(var_a, 2.0)) && "Expected a^2"); assert(commom_power4_cast->get_middle()->is_match(var_b*graph::fma(2.0, var_b, 1.0)) && "Expected b*fma(2,b,1)"); assert(commom_power4_cast->get_right()->is_match(var_c) && "Expected c"); // fma(2,a^2,a) -> a*fma(2,a,1) auto common_power5 = graph::fma(2.0,var_a*var_a,var_a); Loading @@ -3621,14 +3631,80 @@ template<jit::float_scalar T> void test_fma() { assert(commom_power5_cast->get_right()->is_match(var_a) && "Expected a."); // fma(2,a,a^2) -> a*(2 + a) auto temp = var_a*var_a; auto common_power6 = graph::fma(2.0,var_a,temp); auto common_power6 = graph::fma(2.0,var_a,var_a*var_a); auto commom_power6_cast = graph::multiply_cast(common_power6); assert(commom_power6_cast.get() && "Expected a multiply node."); assert(commom_power6_cast->get_left()->is_match(2.0 + var_a) && "Expected (2 + a)."); assert(commom_power6_cast->get_right()->is_match(var_a) && "Expected a."); // fma(2,(a*b)^2,fma(3,a^2*b,c)) -> fma(a^2*b,fma(2,b,3),c) auto common_power7 = graph::fma(2.0, graph::pow(var_a*var_b, 2.0), graph::fma(3.0, var_a*var_a*var_b, var_c)); auto common_power7_cast = graph::multiply_cast(common_power7); assert(common_power7_cast.get() && "Expected a multiply node."); assert(common_power7_cast->get_left()->is_match(graph::fma(var_b, graph::fma(2.0, var_b, 3.0), var_c)) && "Expected fma(b,fma(2,b,3),c)"); assert(common_power7_cast->get_right()->is_match(var_a*var_a) && "Expected a^2"); // fma(a,b*c,b) -> b*fma(a,c,1) auto factorize = graph::fma(var_a,var_b*var_c,var_b); auto factorize_cast = multiply_cast(factorize); assert(factorize_cast.get() && "Expected a multiply node."); assert(factorize_cast->get_right()->is_match(var_b) && "Expected b."); assert(factorize_cast->get_left()->is_match(graph::fma(var_a,var_c,1.0)) && "Expected a*c + 1."); // fma(a,c*b,b) -> b*fma(a,c,1) auto factorize2 = graph::fma(var_a,var_c*var_b,var_b); auto factorize2_cast = multiply_cast(factorize2); assert(factorize2_cast.get() && "Expected a multiply node."); assert(factorize2_cast->get_right()->is_match(var_b) && "Expected b."); assert(factorize2_cast->get_left()->is_match(graph::fma(var_a,var_c,1.0)) && "Expected a*c + 1."); // fma(a,b*c,b*d) -> b*fma(a,c,d) auto factorize3 = graph::fma(var_a,var_b*var_c,var_b*var_d); auto factorize3_cast = multiply_cast(factorize3); assert(factorize3_cast.get() && "Expected a multiply node."); assert(factorize3_cast->get_right()->is_match(var_b) && "Expected b."); assert(factorize3_cast->get_left()->is_match(graph::fma(var_a,var_c,var_d)) && "Expected a*c + d."); // fma(a,c*b,b*d) -> b*fma(a,c,d) auto factorize4 = graph::fma(var_a,var_c*var_b,var_b*var_d); auto factorize4_cast = multiply_cast(factorize4); assert(factorize4_cast.get() && "Expected a multiply node."); assert(factorize4_cast->get_right()->is_match(var_b) && "Expected b."); assert(factorize4_cast->get_left()->is_match(graph::fma(var_a,var_c,var_d)) && "Expected a*c + d."); // fma(a,b*c,d*b) -> b*fma(a,c,d) auto factorize5 = graph::fma(var_a,var_b*var_c,var_d*var_b); auto factorize5_cast = multiply_cast(factorize5); assert(factorize5_cast.get() && "Expected a multiply node."); assert(factorize5_cast->get_right()->is_match(var_b) && "Expected b."); assert(factorize5_cast->get_left()->is_match(graph::fma(var_a,var_c,var_d)) && "Expected a*c + d."); // fma(a,c*b,d*b) -> b*fma(a,c,d) auto factorize6 = graph::fma(var_a,var_c*var_b,var_d*var_b); auto factorize6_cast = multiply_cast(factorize6); assert(factorize6_cast.get() && "Expected a multiply node."); assert(factorize6_cast->get_right()->is_match(var_b) && "Expected b."); assert(factorize6_cast->get_left()->is_match(graph::fma(var_a,var_c,var_d)) && "Expected a*c + d."); } //------------------------------------------------------------------------------ Loading
graph_tests/efit_test.cpp +4 −6 Original line number Diff line number Diff line Loading @@ -139,8 +139,6 @@ void run_test() { auto bvec = eq->get_magnetic_field(x, y, z); auto ne = eq->get_electron_density(x, y, z); ne->to_latex(); std::cout << std::endl << std::endl; auto te = eq->get_electron_temperature(x, y, z); workflow::manager<T> work(0); Loading @@ -155,15 +153,15 @@ void run_test() { work.run(); for (size_t i = 0, ie = gold.r_grid.size()*gold.z_grid.size(); i < ie; i++) { check_error(work.check_value(i, bvec->get_x()), gold.bx_grid[i], 4.0E-11, check_error(work.check_value(i, bvec->get_x()), gold.bx_grid[i], 9.0E-12, "Expected a match in bx."); check_error(work.check_value(i, bvec->get_y()), gold.by_grid[i], 1.0E-20, "Expected a match in by."); check_error(work.check_value(i, bvec->get_z()), gold.bz_grid[i], 3.0E-12, check_error(work.check_value(i, bvec->get_z()), gold.bz_grid[i], 4.0E-12, "Expected a match in bz."); check_error(work.check_value(i, ne), gold.ne_grid[i], 5.0E-13, check_error(work.check_value(i, ne), gold.ne_grid[i], 8.0E-13, "Expected a match in ne."); check_error(work.check_value(i, te), gold.te_grid[i], 5.0E-13, check_error(work.check_value(i, te), gold.te_grid[i], 8.0E-13, "Expected a match in te."); } } Loading