Loading graph_framework/arithmetic.hpp +14 −0 Original line number Diff line number Diff line Loading @@ -520,6 +520,20 @@ namespace graph { pow(prm->get_left(), pl->get_right())); } } // cos(x)^2 + sin(x)^2 -> 1 // sin(x)^2 + cos(x)^2 -> 1 auto plrc = constant_cast(pl->get_right()); if (plrc.get() && plrc->is(static_cast<T> (2.0))) { auto pls = sin_cast(pl->get_left()); auto prc = cos_cast(pr->get_left()); auto plc = cos_cast(pl->get_left()); auto prs = sin_cast(pr->get_left()); if ((pls.get() && prc.get() && pls->get_arg()->is_match(prc->get_arg())) || (plc.get() && prs.get() && plc->get_arg()->is_match(prs->get_arg()))) { return one<T, SAFE_MATH> (); } } } // (a/y)^e + b/y^e -> (a^2 + b)/(y^e) Loading graph_tests/arithmetic_test.cpp +13 −0 Original line number Diff line number Diff line Loading @@ -433,6 +433,19 @@ template<jit::float_scalar T> void test_add() { assert(common_power_factor4_cast->get_left()->is_match(var_b*var_b + var_c*var_c) && "Expected b^2 + c^2 on the left."); // cos(x)^2 + sin(x)^2 -> 1 auto trig = graph::cos(var_a)*graph::cos(var_a) + graph::sin(var_a)*graph::sin(var_a); auto trig_cast = graph::constant_cast(trig); assert(trig_cast.get() && "Expected a constant node."); assert(trig_cast->is(static_cast<T> (1.0)) && "Expected 1."); // sin(x)^2 + cos(x)^2 -> 1 auto trig2 = graph::sin(var_a)*graph::sin(var_a) + graph::cos(var_a)*graph::cos(var_a); auto trig2_cast = graph::constant_cast(trig2); assert(trig2_cast.get() && "Expected a constant node."); assert(trig2_cast->is(static_cast<T> (1.0)) && "Expected 1."); } //------------------------------------------------------------------------------ Loading Loading
graph_framework/arithmetic.hpp +14 −0 Original line number Diff line number Diff line Loading @@ -520,6 +520,20 @@ namespace graph { pow(prm->get_left(), pl->get_right())); } } // cos(x)^2 + sin(x)^2 -> 1 // sin(x)^2 + cos(x)^2 -> 1 auto plrc = constant_cast(pl->get_right()); if (plrc.get() && plrc->is(static_cast<T> (2.0))) { auto pls = sin_cast(pl->get_left()); auto prc = cos_cast(pr->get_left()); auto plc = cos_cast(pl->get_left()); auto prs = sin_cast(pr->get_left()); if ((pls.get() && prc.get() && pls->get_arg()->is_match(prc->get_arg())) || (plc.get() && prs.get() && plc->get_arg()->is_match(prs->get_arg()))) { return one<T, SAFE_MATH> (); } } } // (a/y)^e + b/y^e -> (a^2 + b)/(y^e) Loading
graph_tests/arithmetic_test.cpp +13 −0 Original line number Diff line number Diff line Loading @@ -433,6 +433,19 @@ template<jit::float_scalar T> void test_add() { assert(common_power_factor4_cast->get_left()->is_match(var_b*var_b + var_c*var_c) && "Expected b^2 + c^2 on the left."); // cos(x)^2 + sin(x)^2 -> 1 auto trig = graph::cos(var_a)*graph::cos(var_a) + graph::sin(var_a)*graph::sin(var_a); auto trig_cast = graph::constant_cast(trig); assert(trig_cast.get() && "Expected a constant node."); assert(trig_cast->is(static_cast<T> (1.0)) && "Expected 1."); // sin(x)^2 + cos(x)^2 -> 1 auto trig2 = graph::sin(var_a)*graph::sin(var_a) + graph::cos(var_a)*graph::cos(var_a); auto trig2_cast = graph::constant_cast(trig2); assert(trig2_cast.get() && "Expected a constant node."); assert(trig2_cast->is(static_cast<T> (1.0)) && "Expected 1."); } //------------------------------------------------------------------------------ Loading