Commit a1e0909b authored by cianciosa's avatar cianciosa
Browse files

Disable constant conbinable reductions when a there is a common factor.

parent a227f6bd
Loading
Loading
Loading
Loading
+11 −3
Original line number Diff line number Diff line
@@ -335,9 +335,13 @@ namespace graph {

//  c1*a/b + c2*a/d = c3*(a/b + c4*a/d)
//  a*b/c + d*b/e -> (a/c + d/e)*b
//  Make sure we prevent combining constants when we just need to factor out a
//  common term.
//  c1*a/b + c2*a/d -> (c1/b + c2/d)*a
                if (ldlm.get() && rdlm.get()) {
                    if (is_constant_combineable(ldlm->get_left(),
                                                rdlm->get_left())) {
                                                rdlm->get_left()) &&
                        !ldlm->get_right()->is_match(rdlm->get_right())) {
                        return (ldlm->get_right()/ld->get_right() +
                                rdlm->get_left()/ldlm->get_left() *
                                rdlm->get_right()/rd->get_right())*ldlm->get_left();
@@ -1134,11 +1138,15 @@ namespace graph {
                    }
                }

//  c1*a/b - c2*a/d = c3*(a/b - c4*a/d)
//  c1*a/b - c2*e/d = c3*(a/b - c4*e/d)
//  a*b/c - d*b/e -> (a/c - d/e)*b
//  Make sure we prevent combining constants when we just need to factor out a
//  common term.
//  c1*a/b - c2*a/d -> (c1/b - c2/d)*a
                if (ldlm.get() && rdlm.get()) {
                    if (is_constant_combineable(ldlm->get_left(),
                                                rdlm->get_left())) {
                                                rdlm->get_left()) &&
                        !ldlm->get_right()->is_match(rdlm->get_right())) {
                        return (ldlm->get_right()/ld->get_right() -
                                rdlm->get_left()/ldlm->get_left() *
                                rdlm->get_right()/rd->get_right())*ldlm->get_left();
+16 −0
Original line number Diff line number Diff line
@@ -959,6 +959,22 @@ namespace graph {
                    return pow(lm->get_left(), this->right) *
                           pow(lm->get_right(), this->right);
                }

//  ((Sqrt(a)*b)*c)^d -> a^(d/2)*(b*c)^d
//  ((b*Sqrt(a))*c)^d -> a^(d/2)*(b*c)^d
                auto lmlm = multiply_cast(lm->get_left());
                if (lmlm.get()) {
                    if (lmlm->get_left()->is_constant()    ||
                        lmlm->get_right()->is_constant()   ||
                        sqrt_cast(lmlm->get_left()).get()  ||
                        sqrt_cast(lmlm->get_right()).get() ||
                        pow_cast(lmlm->get_left()).get()   ||
                        pow_cast(lmlm->get_right()).get()) {
                        return pow(lmlm->get_left(), this->right) *
                               pow(lmlm->get_right(), this->right) *
                               pow(lm->get_right(), this->right);
                    }
                }
            }

            auto ld = divide_cast(this->left);
+4 −0
Original line number Diff line number Diff line
add_tool_target (xplayground)

if (${USE_PCH})
    target_precompile_headers (xplayground REUSE_FROM xrays)
endif ()
+16 −0
Original line number Diff line number Diff line
@@ -384,6 +384,14 @@ template<jit::float_scalar T> void test_add() {
           "Expected var_a");
    assert(common_var4_cast->get_left()->is_match(var_c/var_b + 1.0/var_d) &&
           "Expected c/b + 1/d");

    auto common_var5 = 2.0*var_a/var_b + 3.0*var_a/var_c;
    auto common_var5_cast = graph::multiply_cast(common_var5);
    assert(common_var5_cast.get() && "Expected a multiply node.");
    assert(common_var5_cast->get_right()->is_match(var_a) &&
           "Expected var_a");
    assert(common_var5_cast->get_left()->is_match(2.0/var_b + 3.0/var_c) &&
           "Expected 2/b + 3/c");
}

//------------------------------------------------------------------------------
@@ -855,6 +863,14 @@ template<jit::float_scalar T> void test_subtract() {
           "Expected var_a");
    assert(common_var4_cast->get_left()->is_match(var_c/var_b - 1.0/var_d) &&
           "Expected c/b - 1/d");

    auto common_var5 = 2.0*var_c/var_a - 3.0*var_c/var_b;
    auto common_var5_cast = graph::multiply_cast(common_var5);
    assert(common_var5_cast.get() && "Expected a multiply node.");
    assert(common_var5_cast->get_right()->is_match(var_c) &&
           "Expected var_a");
    assert(common_var5_cast->get_left()->is_match(2.0/var_a - 3.0/var_b) &&
           "Expected 2/a - 3/b");
}

//------------------------------------------------------------------------------
+8 −1
Original line number Diff line number Diff line
@@ -350,7 +350,7 @@ void test_pow() {
           "Expected (a*b/c)^2.");

//  (a^b*c/Sqrt(a)d)^2 -> (a^(2*b - 1)*c^2)/d^2
    auto var_d = graph::variable<T> (1, "d");
    auto var_d = graph::variable<T> (1, "");
    auto pow_combine = graph::pow((graph::pow(var_a, var_b)*var_c) /
                                  (graph::sqrt(var_a)*var_d), 2.0);
    auto pow_combine_cast = graph::divide_cast(pow_combine);
@@ -402,6 +402,13 @@ void test_pow() {
           "Expected (b/c)^2.");
    assert(pow_combine6_cast->get_right()->is_match(graph::pow(expr_a, 4.0)) &&
           "Expected (b/c)^2.");

//  (Sqrt(a)*b*c)^d -> a^(d/2)*(b*c)^d
    auto sqrtpow = graph::pow(var_c*var_d*graph::sqrt(var_a), var_b);
    assert(sqrtpow.get()->is_match(graph::pow(var_a, var_b/2.0) *
                                   graph::pow(var_c, var_b) *
                                   graph::pow(var_d, var_b)) &&
           "Expected a^(d/2)*b^2*c^d.");
}

//------------------------------------------------------------------------------
Loading