Loading graph_framework/arithmetic.hpp +54 −8 Original line number Diff line number Diff line Loading @@ -4341,8 +4341,10 @@ namespace graph { } } else if (this->middle->is_all_variables()) { auto rdm = this->right/this->middle; if (rdm->get_complexity() < this->middle->get_complexity() + this->right->get_complexity()) { auto rdmc = constant_cast(rdm->get_power_exponent()); if ((rdm->get_complexity() < this->middle->get_complexity() + this->right->get_complexity()) && !(rdmc.get() && rdmc->evaluate().is_negative())) { return (this->left + rdm)*this->middle; } } Loading @@ -4365,6 +4367,19 @@ namespace graph { return this->left/pow(mp->get_left(), -mp->get_right()) + this->right; } // fma(2,a^2,a) -> a*fma(2,a,1) // Note this case is handled eailer. fma(2,a,a^2) -> a*fma(2,1,a) if (is_variable_combineable(this->middle, this->right)) { auto temp = this->right/this->middle; auto temp_exponent = constant_cast(temp->get_power_exponent()); if (temp_exponent.get() && temp_exponent->evaluate().is_negative()) { return this->right*fma(this->left, this->middle/this->right, 1.0); } } } // a^b*c^b + d -> (a*c)^b + d Loading @@ -4382,12 +4397,43 @@ namespace graph { if (mplm.get()) { if (is_variable_combineable(mplm->get_left(), rm->get_left())) { return pow(mplm->get_left(), mp->get_right()) * auto temp = pow(mplm->get_left(), mp->get_right()); return temp*fma(this->left, this->middle/temp, this->right/temp); } else if (is_variable_combineable(mplm->get_right(), rm->get_left())) { auto temp = pow(mplm->get_right(), mp->get_right()); return temp*fma(this->left, this->middle/temp, this->right/temp); } } } // fma(2,(a*b)^2,fma(a^2,b,c)) -> fma(a^2,fma(2,b^2,b),c) if (rfma.get() && mp.get()) { auto mplm = multiply_cast(mp->get_left()); if (mplm.get()) { if (is_variable_combineable(mplm->get_left(), rfma->get_left())) { auto temp = pow(mplm->get_left(), mp->get_right()); return fma(temp, fma(this->left, pow(mplm->get_right(), mp->get_right()), this->right/mplm->get_left()); this->middle/temp, rfma->get_middle()), rfma->get_right()); } else if (is_variable_combineable(mplm->get_right(), rfma->get_left())) { auto temp = pow(mplm->get_right(), mp->get_right()); return fma(temp, fma(this->left, this->middle/temp, rfma->get_middle()), rfma->get_right()); } } } Loading graph_framework/node.hpp +13 −0 Original line number Diff line number Diff line Loading @@ -660,6 +660,19 @@ namespace graph { return constant<T, SAFE_MATH> (static_cast<T> (1.0)); } //------------------------------------------------------------------------------ /// @brief Create a one constant. /// /// @tparam T Base type of the calculation. /// @tparam SAFE_MATH Use safe math operations. /// /// @returns A one constant. //------------------------------------------------------------------------------ template<jit::float_scalar T, bool SAFE_MATH> constexpr shared_leaf<T, SAFE_MATH> none() { return constant<T, SAFE_MATH> (static_cast<T> (-1.0)); } /// Convinece type for imaginary constant. template<jit::complex_scalar T> constexpr T i = T(0.0, 1.0); Loading graph_tests/arithmetic_test.cpp +54 −5 Original line number Diff line number Diff line Loading @@ -2773,7 +2773,7 @@ template<jit::float_scalar T> void test_fma() { // Test reduction. auto var_a = graph::variable<T> (1, "a"); auto var_b = graph::variable<T> (1, "b"); auto var_c = graph::variable<T> (1, ""); auto var_c = graph::variable<T> (1, "c"); // fma(1,a,b) = a + b auto one_times_vara_plus_varb = graph::fma(one, var_a, var_b); Loading Loading @@ -3573,13 +3573,62 @@ template<jit::float_scalar T> void test_fma() { var_b) + var_d) && "Expected a power node."); // fma(2,(ab)^2,a^2b) -> a^2*fma(2, b^2, b) // fma(2,(ab)^2,a^2b) -> a^2*b*fma(2, b, 1) auto commom_power = graph::fma(2.0, graph::pow(var_a*var_b, 2.0), graph::pow(var_a, 2.0)*var_b); commom_power->to_latex(); std::cout << std::endl << std::endl; auto commom_power_cast = graph::multiply_cast(commom_power); assert(commom_power_cast.get() && "Expected a multiply node."); // fma(2,(a*b)^2,fma()) assert(commom_power_cast->get_right()->is_match(var_b) && "Expced b"); assert(commom_power_cast->get_left()->is_match(graph::pow(var_a, 2.0) * graph::fma(2.0, var_b, 1.0)) && "Expced a^2*fma(2, b, 1)"); // fma(2,(ba)^2,a^2b) -> a^2*fma(2, b^2, b) auto commom_power2 = graph::fma(2.0, graph::pow(var_b*var_a, 2.0), graph::pow(var_a, 2.0)*var_b); auto commom_power2_cast = graph::multiply_cast(commom_power2); assert(commom_power2_cast.get() && "Expected a multiply node."); assert(commom_power2_cast->get_right()->is_match(var_b) && "Expced b"); assert(commom_power2_cast->get_left()->is_match(graph::pow(var_a, 2.0) * graph::fma(2.0, var_b, 1.0)) && "Expced a^2*fma(2, b, 1)"); // fma(2,(a*b)^2,fma(a^2,b,c)) -> fma(a^2,fma(2,b^2,b),c) auto commom_power3 = graph::fma(2.0, graph::pow(var_a*var_b, 2.0), graph::fma(graph::pow(var_a, 2.0), var_b, var_c)); auto commom_power3_cast = graph::fma_cast(commom_power3); assert(commom_power3_cast.get() && "Expected a fma node."); assert(commom_power3_cast->get_left()->is_match(graph::pow(var_a, 2.0)) && "Expected a^2"); // fma(2,(a*b)^2,fma(a^2,b,c)) -> fma(a^2,fma(2,b^2,b),c) auto commom_power4 = graph::fma(2.0, graph::pow(var_b*var_a, 2.0), graph::fma(graph::pow(var_a, 2.0), var_b, var_c)); auto commom_power4_cast = graph::fma_cast(commom_power4); assert(commom_power4_cast.get() && "Expected a fma node."); assert(commom_power4_cast->get_left()->is_match(graph::pow(var_a, 2.0)) && "Expected a^2"); // fma(2,a^2,a) -> a*fma(2,a,1) auto common_power5 = graph::fma(2.0,var_a*var_a,var_a); auto commom_power5_cast = graph::multiply_cast(common_power5); assert(commom_power5_cast.get() && "Expected a multiply node."); assert(commom_power5_cast->get_left()->is_match(graph::fma(2.0,var_a,1.0)) && "Expected fma(2,a,1)."); assert(commom_power5_cast->get_right()->is_match(var_a) && "Expected a."); // fma(2,a,a^2) -> a*(2 + a) auto temp = var_a*var_a; auto common_power6 = graph::fma(2.0,var_a,temp); auto commom_power6_cast = graph::multiply_cast(common_power6); assert(commom_power6_cast.get() && "Expected a multiply node."); assert(commom_power6_cast->get_left()->is_match(2.0 + var_a) && "Expected (2 + a)."); assert(commom_power6_cast->get_right()->is_match(var_a) && "Expected a."); } //------------------------------------------------------------------------------ Loading graph_tests/efit_test.cpp +2 −0 Original line number Diff line number Diff line Loading @@ -139,6 +139,8 @@ void run_test() { auto bvec = eq->get_magnetic_field(x, y, z); auto ne = eq->get_electron_density(x, y, z); ne->to_latex(); std::cout << std::endl << std::endl; auto te = eq->get_electron_temperature(x, y, z); workflow::manager<T> work(0); Loading Loading
graph_framework/arithmetic.hpp +54 −8 Original line number Diff line number Diff line Loading @@ -4341,8 +4341,10 @@ namespace graph { } } else if (this->middle->is_all_variables()) { auto rdm = this->right/this->middle; if (rdm->get_complexity() < this->middle->get_complexity() + this->right->get_complexity()) { auto rdmc = constant_cast(rdm->get_power_exponent()); if ((rdm->get_complexity() < this->middle->get_complexity() + this->right->get_complexity()) && !(rdmc.get() && rdmc->evaluate().is_negative())) { return (this->left + rdm)*this->middle; } } Loading @@ -4365,6 +4367,19 @@ namespace graph { return this->left/pow(mp->get_left(), -mp->get_right()) + this->right; } // fma(2,a^2,a) -> a*fma(2,a,1) // Note this case is handled eailer. fma(2,a,a^2) -> a*fma(2,1,a) if (is_variable_combineable(this->middle, this->right)) { auto temp = this->right/this->middle; auto temp_exponent = constant_cast(temp->get_power_exponent()); if (temp_exponent.get() && temp_exponent->evaluate().is_negative()) { return this->right*fma(this->left, this->middle/this->right, 1.0); } } } // a^b*c^b + d -> (a*c)^b + d Loading @@ -4382,12 +4397,43 @@ namespace graph { if (mplm.get()) { if (is_variable_combineable(mplm->get_left(), rm->get_left())) { return pow(mplm->get_left(), mp->get_right()) * auto temp = pow(mplm->get_left(), mp->get_right()); return temp*fma(this->left, this->middle/temp, this->right/temp); } else if (is_variable_combineable(mplm->get_right(), rm->get_left())) { auto temp = pow(mplm->get_right(), mp->get_right()); return temp*fma(this->left, this->middle/temp, this->right/temp); } } } // fma(2,(a*b)^2,fma(a^2,b,c)) -> fma(a^2,fma(2,b^2,b),c) if (rfma.get() && mp.get()) { auto mplm = multiply_cast(mp->get_left()); if (mplm.get()) { if (is_variable_combineable(mplm->get_left(), rfma->get_left())) { auto temp = pow(mplm->get_left(), mp->get_right()); return fma(temp, fma(this->left, pow(mplm->get_right(), mp->get_right()), this->right/mplm->get_left()); this->middle/temp, rfma->get_middle()), rfma->get_right()); } else if (is_variable_combineable(mplm->get_right(), rfma->get_left())) { auto temp = pow(mplm->get_right(), mp->get_right()); return fma(temp, fma(this->left, this->middle/temp, rfma->get_middle()), rfma->get_right()); } } } Loading
graph_framework/node.hpp +13 −0 Original line number Diff line number Diff line Loading @@ -660,6 +660,19 @@ namespace graph { return constant<T, SAFE_MATH> (static_cast<T> (1.0)); } //------------------------------------------------------------------------------ /// @brief Create a one constant. /// /// @tparam T Base type of the calculation. /// @tparam SAFE_MATH Use safe math operations. /// /// @returns A one constant. //------------------------------------------------------------------------------ template<jit::float_scalar T, bool SAFE_MATH> constexpr shared_leaf<T, SAFE_MATH> none() { return constant<T, SAFE_MATH> (static_cast<T> (-1.0)); } /// Convinece type for imaginary constant. template<jit::complex_scalar T> constexpr T i = T(0.0, 1.0); Loading
graph_tests/arithmetic_test.cpp +54 −5 Original line number Diff line number Diff line Loading @@ -2773,7 +2773,7 @@ template<jit::float_scalar T> void test_fma() { // Test reduction. auto var_a = graph::variable<T> (1, "a"); auto var_b = graph::variable<T> (1, "b"); auto var_c = graph::variable<T> (1, ""); auto var_c = graph::variable<T> (1, "c"); // fma(1,a,b) = a + b auto one_times_vara_plus_varb = graph::fma(one, var_a, var_b); Loading Loading @@ -3573,13 +3573,62 @@ template<jit::float_scalar T> void test_fma() { var_b) + var_d) && "Expected a power node."); // fma(2,(ab)^2,a^2b) -> a^2*fma(2, b^2, b) // fma(2,(ab)^2,a^2b) -> a^2*b*fma(2, b, 1) auto commom_power = graph::fma(2.0, graph::pow(var_a*var_b, 2.0), graph::pow(var_a, 2.0)*var_b); commom_power->to_latex(); std::cout << std::endl << std::endl; auto commom_power_cast = graph::multiply_cast(commom_power); assert(commom_power_cast.get() && "Expected a multiply node."); // fma(2,(a*b)^2,fma()) assert(commom_power_cast->get_right()->is_match(var_b) && "Expced b"); assert(commom_power_cast->get_left()->is_match(graph::pow(var_a, 2.0) * graph::fma(2.0, var_b, 1.0)) && "Expced a^2*fma(2, b, 1)"); // fma(2,(ba)^2,a^2b) -> a^2*fma(2, b^2, b) auto commom_power2 = graph::fma(2.0, graph::pow(var_b*var_a, 2.0), graph::pow(var_a, 2.0)*var_b); auto commom_power2_cast = graph::multiply_cast(commom_power2); assert(commom_power2_cast.get() && "Expected a multiply node."); assert(commom_power2_cast->get_right()->is_match(var_b) && "Expced b"); assert(commom_power2_cast->get_left()->is_match(graph::pow(var_a, 2.0) * graph::fma(2.0, var_b, 1.0)) && "Expced a^2*fma(2, b, 1)"); // fma(2,(a*b)^2,fma(a^2,b,c)) -> fma(a^2,fma(2,b^2,b),c) auto commom_power3 = graph::fma(2.0, graph::pow(var_a*var_b, 2.0), graph::fma(graph::pow(var_a, 2.0), var_b, var_c)); auto commom_power3_cast = graph::fma_cast(commom_power3); assert(commom_power3_cast.get() && "Expected a fma node."); assert(commom_power3_cast->get_left()->is_match(graph::pow(var_a, 2.0)) && "Expected a^2"); // fma(2,(a*b)^2,fma(a^2,b,c)) -> fma(a^2,fma(2,b^2,b),c) auto commom_power4 = graph::fma(2.0, graph::pow(var_b*var_a, 2.0), graph::fma(graph::pow(var_a, 2.0), var_b, var_c)); auto commom_power4_cast = graph::fma_cast(commom_power4); assert(commom_power4_cast.get() && "Expected a fma node."); assert(commom_power4_cast->get_left()->is_match(graph::pow(var_a, 2.0)) && "Expected a^2"); // fma(2,a^2,a) -> a*fma(2,a,1) auto common_power5 = graph::fma(2.0,var_a*var_a,var_a); auto commom_power5_cast = graph::multiply_cast(common_power5); assert(commom_power5_cast.get() && "Expected a multiply node."); assert(commom_power5_cast->get_left()->is_match(graph::fma(2.0,var_a,1.0)) && "Expected fma(2,a,1)."); assert(commom_power5_cast->get_right()->is_match(var_a) && "Expected a."); // fma(2,a,a^2) -> a*(2 + a) auto temp = var_a*var_a; auto common_power6 = graph::fma(2.0,var_a,temp); auto commom_power6_cast = graph::multiply_cast(common_power6); assert(commom_power6_cast.get() && "Expected a multiply node."); assert(commom_power6_cast->get_left()->is_match(2.0 + var_a) && "Expected (2 + a)."); assert(commom_power6_cast->get_right()->is_match(var_a) && "Expected a."); } //------------------------------------------------------------------------------ Loading
graph_tests/efit_test.cpp +2 −0 Original line number Diff line number Diff line Loading @@ -139,6 +139,8 @@ void run_test() { auto bvec = eq->get_magnetic_field(x, y, z); auto ne = eq->get_electron_density(x, y, z); ne->to_latex(); std::cout << std::endl << std::endl; auto te = eq->get_electron_temperature(x, y, z); workflow::manager<T> work(0); Loading