diff --git a/graph_framework/arithmetic.hpp b/graph_framework/arithmetic.hpp index 01d2ab34d72a5aa3f278dffbe50fd09f8fc2aabf..429d37d361d8010899d6b6c048576834246736ce 100644 --- a/graph_framework/arithmetic.hpp +++ b/graph_framework/arithmetic.hpp @@ -520,6 +520,20 @@ namespace graph { pow(prm->get_left(), pl->get_right())); } } + +// cos(x)^2 + sin(x)^2 -> 1 +// sin(x)^2 + cos(x)^2 -> 1 + auto plrc = constant_cast(pl->get_right()); + if (plrc.get() && plrc->is(static_cast (2.0))) { + auto pls = sin_cast(pl->get_left()); + auto prc = cos_cast(pr->get_left()); + auto plc = cos_cast(pl->get_left()); + auto prs = sin_cast(pr->get_left()); + if ((pls.get() && prc.get() && pls->get_arg()->is_match(prc->get_arg())) || + (plc.get() && prs.get() && plc->get_arg()->is_match(prs->get_arg()))) { + return one (); + } + } } // (a/y)^e + b/y^e -> (a^2 + b)/(y^e) diff --git a/graph_tests/arithmetic_test.cpp b/graph_tests/arithmetic_test.cpp index 3417578940e052bffb84e17b3b0a4ab5f1941a5c..1d25ed429f53c319939edacdfe9349155e39250b 100644 --- a/graph_tests/arithmetic_test.cpp +++ b/graph_tests/arithmetic_test.cpp @@ -433,6 +433,19 @@ template void test_add() { assert(common_power_factor4_cast->get_left()->is_match(var_b*var_b + var_c*var_c) && "Expected b^2 + c^2 on the left."); + +// cos(x)^2 + sin(x)^2 -> 1 + auto trig = graph::cos(var_a)*graph::cos(var_a) + + graph::sin(var_a)*graph::sin(var_a); + auto trig_cast = graph::constant_cast(trig); + assert(trig_cast.get() && "Expected a constant node."); + assert(trig_cast->is(static_cast (1.0)) && "Expected 1."); +// sin(x)^2 + cos(x)^2 -> 1 + auto trig2 = graph::sin(var_a)*graph::sin(var_a) + + graph::cos(var_a)*graph::cos(var_a); + auto trig2_cast = graph::constant_cast(trig2); + assert(trig2_cast.get() && "Expected a constant node."); + assert(trig2_cast->is(static_cast (1.0)) && "Expected 1."); } //------------------------------------------------------------------------------