Commit 1459f3ef authored by Cianciosa, Mark's avatar Cianciosa, Mark
Browse files

Merge branch 'powfix' into 'main'

Add trig idenitry reduction.

See merge request !50
parents 77d02d39 e842004b
Loading
Loading
Loading
Loading
+14 −0
Original line number Diff line number Diff line
@@ -520,6 +520,20 @@ namespace graph {
                                pow(prm->get_left(), pl->get_right()));
                    }
                }

//  cos(x)^2 + sin(x)^2 -> 1
//  sin(x)^2 + cos(x)^2 -> 1
                auto plrc = constant_cast(pl->get_right());
                if (plrc.get() && plrc->is(static_cast<T> (2.0))) {
                    auto pls = sin_cast(pl->get_left());
                    auto prc = cos_cast(pr->get_left());
                    auto plc = cos_cast(pl->get_left());
                    auto prs = sin_cast(pr->get_left());
                    if ((pls.get() && prc.get() && pls->get_arg()->is_match(prc->get_arg())) ||
                        (plc.get() && prs.get() && plc->get_arg()->is_match(prs->get_arg()))) {
                        return one<T, SAFE_MATH> ();
                    }
                }
            }

//  (a/y)^e + b/y^e -> (a^2 + b)/(y^e)
+13 −0
Original line number Diff line number Diff line
@@ -433,6 +433,19 @@ template<jit::float_scalar T> void test_add() {
    assert(common_power_factor4_cast->get_left()->is_match(var_b*var_b +
                                                           var_c*var_c) &&
           "Expected b^2 + c^2 on the left.");

//  cos(x)^2 + sin(x)^2 -> 1
    auto trig = graph::cos(var_a)*graph::cos(var_a)
              + graph::sin(var_a)*graph::sin(var_a);
    auto trig_cast = graph::constant_cast(trig);
    assert(trig_cast.get() && "Expected a constant node.");
    assert(trig_cast->is(static_cast<T> (1.0)) && "Expected 1.");
//  sin(x)^2 + cos(x)^2 -> 1
    auto trig2 = graph::sin(var_a)*graph::sin(var_a)
               + graph::cos(var_a)*graph::cos(var_a);
    auto trig2_cast = graph::constant_cast(trig2);
    assert(trig2_cast.get() && "Expected a constant node.");
    assert(trig2_cast->is(static_cast<T> (1.0)) && "Expected 1.");
}

//------------------------------------------------------------------------------