Loading graph_framework/arithmetic.hpp +11 −3 Original line number Diff line number Diff line Loading @@ -335,9 +335,13 @@ namespace graph { // c1*a/b + c2*a/d = c3*(a/b + c4*a/d) // a*b/c + d*b/e -> (a/c + d/e)*b // Make sure we prevent combining constants when we just need to factor out a // common term. // c1*a/b + c2*a/d -> (c1/b + c2/d)*a if (ldlm.get() && rdlm.get()) { if (is_constant_combineable(ldlm->get_left(), rdlm->get_left())) { rdlm->get_left()) && !ldlm->get_right()->is_match(rdlm->get_right())) { return (ldlm->get_right()/ld->get_right() + rdlm->get_left()/ldlm->get_left() * rdlm->get_right()/rd->get_right())*ldlm->get_left(); Loading Loading @@ -1134,11 +1138,15 @@ namespace graph { } } // c1*a/b - c2*a/d = c3*(a/b - c4*a/d) // c1*a/b - c2*e/d = c3*(a/b - c4*e/d) // a*b/c - d*b/e -> (a/c - d/e)*b // Make sure we prevent combining constants when we just need to factor out a // common term. // c1*a/b - c2*a/d -> (c1/b - c2/d)*a if (ldlm.get() && rdlm.get()) { if (is_constant_combineable(ldlm->get_left(), rdlm->get_left())) { rdlm->get_left()) && !ldlm->get_right()->is_match(rdlm->get_right())) { return (ldlm->get_right()/ld->get_right() - rdlm->get_left()/ldlm->get_left() * rdlm->get_right()/rd->get_right())*ldlm->get_left(); Loading graph_framework/math.hpp +16 −0 Original line number Diff line number Diff line Loading @@ -959,6 +959,22 @@ namespace graph { return pow(lm->get_left(), this->right) * pow(lm->get_right(), this->right); } // ((Sqrt(a)*b)*c)^d -> a^(d/2)*(b*c)^d // ((b*Sqrt(a))*c)^d -> a^(d/2)*(b*c)^d auto lmlm = multiply_cast(lm->get_left()); if (lmlm.get()) { if (lmlm->get_left()->is_constant() || lmlm->get_right()->is_constant() || sqrt_cast(lmlm->get_left()).get() || sqrt_cast(lmlm->get_right()).get() || pow_cast(lmlm->get_left()).get() || pow_cast(lmlm->get_right()).get()) { return pow(lmlm->get_left(), this->right) * pow(lmlm->get_right(), this->right) * pow(lm->get_right(), this->right); } } } auto ld = divide_cast(this->left); Loading graph_playground/CMakeLists.txt +4 −0 Original line number Diff line number Diff line add_tool_target (xplayground) if (${USE_PCH}) target_precompile_headers (xplayground REUSE_FROM xrays) endif () graph_tests/arithmetic_test.cpp +16 −0 Original line number Diff line number Diff line Loading @@ -384,6 +384,14 @@ template<jit::float_scalar T> void test_add() { "Expected var_a"); assert(common_var4_cast->get_left()->is_match(var_c/var_b + 1.0/var_d) && "Expected c/b + 1/d"); auto common_var5 = 2.0*var_a/var_b + 3.0*var_a/var_c; auto common_var5_cast = graph::multiply_cast(common_var5); assert(common_var5_cast.get() && "Expected a multiply node."); assert(common_var5_cast->get_right()->is_match(var_a) && "Expected var_a"); assert(common_var5_cast->get_left()->is_match(2.0/var_b + 3.0/var_c) && "Expected 2/b + 3/c"); } //------------------------------------------------------------------------------ Loading Loading @@ -855,6 +863,14 @@ template<jit::float_scalar T> void test_subtract() { "Expected var_a"); assert(common_var4_cast->get_left()->is_match(var_c/var_b - 1.0/var_d) && "Expected c/b - 1/d"); auto common_var5 = 2.0*var_c/var_a - 3.0*var_c/var_b; auto common_var5_cast = graph::multiply_cast(common_var5); assert(common_var5_cast.get() && "Expected a multiply node."); assert(common_var5_cast->get_right()->is_match(var_c) && "Expected var_a"); assert(common_var5_cast->get_left()->is_match(2.0/var_a - 3.0/var_b) && "Expected 2/a - 3/b"); } //------------------------------------------------------------------------------ Loading graph_tests/math_test.cpp +8 −1 Original line number Diff line number Diff line Loading @@ -350,7 +350,7 @@ void test_pow() { "Expected (a*b/c)^2."); // (a^b*c/Sqrt(a)d)^2 -> (a^(2*b - 1)*c^2)/d^2 auto var_d = graph::variable<T> (1, "d"); auto var_d = graph::variable<T> (1, ""); auto pow_combine = graph::pow((graph::pow(var_a, var_b)*var_c) / (graph::sqrt(var_a)*var_d), 2.0); auto pow_combine_cast = graph::divide_cast(pow_combine); Loading Loading @@ -402,6 +402,13 @@ void test_pow() { "Expected (b/c)^2."); assert(pow_combine6_cast->get_right()->is_match(graph::pow(expr_a, 4.0)) && "Expected (b/c)^2."); // (Sqrt(a)*b*c)^d -> a^(d/2)*(b*c)^d auto sqrtpow = graph::pow(var_c*var_d*graph::sqrt(var_a), var_b); assert(sqrtpow.get()->is_match(graph::pow(var_a, var_b/2.0) * graph::pow(var_c, var_b) * graph::pow(var_d, var_b)) && "Expected a^(d/2)*b^2*c^d."); } //------------------------------------------------------------------------------ Loading Loading
graph_framework/arithmetic.hpp +11 −3 Original line number Diff line number Diff line Loading @@ -335,9 +335,13 @@ namespace graph { // c1*a/b + c2*a/d = c3*(a/b + c4*a/d) // a*b/c + d*b/e -> (a/c + d/e)*b // Make sure we prevent combining constants when we just need to factor out a // common term. // c1*a/b + c2*a/d -> (c1/b + c2/d)*a if (ldlm.get() && rdlm.get()) { if (is_constant_combineable(ldlm->get_left(), rdlm->get_left())) { rdlm->get_left()) && !ldlm->get_right()->is_match(rdlm->get_right())) { return (ldlm->get_right()/ld->get_right() + rdlm->get_left()/ldlm->get_left() * rdlm->get_right()/rd->get_right())*ldlm->get_left(); Loading Loading @@ -1134,11 +1138,15 @@ namespace graph { } } // c1*a/b - c2*a/d = c3*(a/b - c4*a/d) // c1*a/b - c2*e/d = c3*(a/b - c4*e/d) // a*b/c - d*b/e -> (a/c - d/e)*b // Make sure we prevent combining constants when we just need to factor out a // common term. // c1*a/b - c2*a/d -> (c1/b - c2/d)*a if (ldlm.get() && rdlm.get()) { if (is_constant_combineable(ldlm->get_left(), rdlm->get_left())) { rdlm->get_left()) && !ldlm->get_right()->is_match(rdlm->get_right())) { return (ldlm->get_right()/ld->get_right() - rdlm->get_left()/ldlm->get_left() * rdlm->get_right()/rd->get_right())*ldlm->get_left(); Loading
graph_framework/math.hpp +16 −0 Original line number Diff line number Diff line Loading @@ -959,6 +959,22 @@ namespace graph { return pow(lm->get_left(), this->right) * pow(lm->get_right(), this->right); } // ((Sqrt(a)*b)*c)^d -> a^(d/2)*(b*c)^d // ((b*Sqrt(a))*c)^d -> a^(d/2)*(b*c)^d auto lmlm = multiply_cast(lm->get_left()); if (lmlm.get()) { if (lmlm->get_left()->is_constant() || lmlm->get_right()->is_constant() || sqrt_cast(lmlm->get_left()).get() || sqrt_cast(lmlm->get_right()).get() || pow_cast(lmlm->get_left()).get() || pow_cast(lmlm->get_right()).get()) { return pow(lmlm->get_left(), this->right) * pow(lmlm->get_right(), this->right) * pow(lm->get_right(), this->right); } } } auto ld = divide_cast(this->left); Loading
graph_playground/CMakeLists.txt +4 −0 Original line number Diff line number Diff line add_tool_target (xplayground) if (${USE_PCH}) target_precompile_headers (xplayground REUSE_FROM xrays) endif ()
graph_tests/arithmetic_test.cpp +16 −0 Original line number Diff line number Diff line Loading @@ -384,6 +384,14 @@ template<jit::float_scalar T> void test_add() { "Expected var_a"); assert(common_var4_cast->get_left()->is_match(var_c/var_b + 1.0/var_d) && "Expected c/b + 1/d"); auto common_var5 = 2.0*var_a/var_b + 3.0*var_a/var_c; auto common_var5_cast = graph::multiply_cast(common_var5); assert(common_var5_cast.get() && "Expected a multiply node."); assert(common_var5_cast->get_right()->is_match(var_a) && "Expected var_a"); assert(common_var5_cast->get_left()->is_match(2.0/var_b + 3.0/var_c) && "Expected 2/b + 3/c"); } //------------------------------------------------------------------------------ Loading Loading @@ -855,6 +863,14 @@ template<jit::float_scalar T> void test_subtract() { "Expected var_a"); assert(common_var4_cast->get_left()->is_match(var_c/var_b - 1.0/var_d) && "Expected c/b - 1/d"); auto common_var5 = 2.0*var_c/var_a - 3.0*var_c/var_b; auto common_var5_cast = graph::multiply_cast(common_var5); assert(common_var5_cast.get() && "Expected a multiply node."); assert(common_var5_cast->get_right()->is_match(var_c) && "Expected var_a"); assert(common_var5_cast->get_left()->is_match(2.0/var_a - 3.0/var_b) && "Expected 2/a - 3/b"); } //------------------------------------------------------------------------------ Loading
graph_tests/math_test.cpp +8 −1 Original line number Diff line number Diff line Loading @@ -350,7 +350,7 @@ void test_pow() { "Expected (a*b/c)^2."); // (a^b*c/Sqrt(a)d)^2 -> (a^(2*b - 1)*c^2)/d^2 auto var_d = graph::variable<T> (1, "d"); auto var_d = graph::variable<T> (1, ""); auto pow_combine = graph::pow((graph::pow(var_a, var_b)*var_c) / (graph::sqrt(var_a)*var_d), 2.0); auto pow_combine_cast = graph::divide_cast(pow_combine); Loading Loading @@ -402,6 +402,13 @@ void test_pow() { "Expected (b/c)^2."); assert(pow_combine6_cast->get_right()->is_match(graph::pow(expr_a, 4.0)) && "Expected (b/c)^2."); // (Sqrt(a)*b*c)^d -> a^(d/2)*(b*c)^d auto sqrtpow = graph::pow(var_c*var_d*graph::sqrt(var_a), var_b); assert(sqrtpow.get()->is_match(graph::pow(var_a, var_b/2.0) * graph::pow(var_c, var_b) * graph::pow(var_d, var_b)) && "Expected a^(d/2)*b^2*c^d."); } //------------------------------------------------------------------------------ Loading