Commit c29e0f38 authored by Cianciosa, Mark's avatar Cianciosa, Mark
Browse files

Merge branch 'arthamtic_reduction' into 'main'

Reduced chained constant subtractions.

See merge request !47
parents 8475dfe5 7a485d11
Loading
Loading
Loading
Loading
+39 −4
Original line number Diff line number Diff line
@@ -919,6 +919,42 @@ namespace graph {
                                    pl2->get_left(),
                                    pl2->get_right());
            }
// (c1 + a) - c2 -> c3 + a
// c1 - (c2 + a) -> c3 + a
            auto la = add_cast(this->left);
            if (la.get()) {
                if (is_constant_combineable(la->get_left(), this->right)) {
                    return (la->get_left() - this->right) + la->get_right();
                }
            }
            auto ra = add_cast(this->right);
            if (ra.get()) {
                if (is_constant_combineable(this->left, ra->get_left())) {
                    return (this->left - ra->get_left()) + ra->get_right();
                }
            }

// (c1 - a) - c2 -> c3 - a
// (a - c3) - c2 -> a + c3
            auto ls = subtract_cast(this->left);
            if (ls.get()) {
                if (is_constant_combineable(ls->get_left(), this->right)) {
                    return (ls->get_left() - this->right) - ls->get_right();
                } else if (is_constant_combineable(ls->get_right(),
                                                   this->right)) {
                    return -(ls->get_right() + this->right) - ls->get_left();
                }
            }
// c1 - (c2 - a) -> c3 - a
// c1 - (a - c2) -> c3 - a
            auto rs = subtract_cast(this->right);
            if (rs.get()) {
                if (is_constant_combineable(this->left, rs->get_left())) {
                    return (this->left - rs->get_left()) - rs->get_right();
                } else if (is_constant_combineable(this->left, rs->get_right())) {
                    return (this->left + rs->get_right()) - rs->get_left();
                }
            }

//  Common factor reduction. If the left and right are both muliply nodes check
//  for a common factor. So you can change a*b - a*c -> a*(b - c).
@@ -938,7 +974,7 @@ namespace graph {
                                   lm->get_left()*lmra->get_left() - this->right);
                    }
                }

//  c1*(c2 - a) - c3 -> c4 - c1*a
                auto lmrs = subtract_cast(lm->get_right());
                if (lmrs.get()) {
                    if (is_constant_combineable(lm->get_left(),
@@ -1077,7 +1113,6 @@ namespace graph {
            }

//  Chained subtraction reductions.
            auto ls = subtract_cast(this->left);
            if (ls.get()) {
                auto lrm = multiply_cast(ls->get_right());
                if (lrm.get() && rm.get()) {
@@ -4666,8 +4701,8 @@ namespace graph {
                this->left->to_latex();
            }
            std::cout << " ";
            if (add_cast(this->right).get() ||
                subtract_cast(this->right).get()) {
            if (add_cast(this->middle).get() ||
                subtract_cast(this->middle).get()) {
                std::cout << "\\left(";
                this->middle->to_latex();
                std::cout << "\\right)";
+43 −0
Original line number Diff line number Diff line
@@ -872,6 +872,49 @@ template<jit::float_scalar T> void test_subtract() {
           "Expected var_a");
    assert(common_var5_cast->get_left()->is_match(2.0/var_a - 3.0/var_b) &&
           "Expected 2/a - 3/b");

    auto constant_combine = (1.0 - var_a) - 2.0;
    auto constant_combine_cast = graph::subtract_cast(constant_combine);
    assert(constant_combine_cast.get() && "Expected a subtract node.");
    assert(constant_combine_cast->get_left()->evaluate().at(0) == static_cast<T> (-1.0) &&
           "Expected -1 on the left.");
    assert(constant_combine_cast->get_right()->is_match(var_a) &&
           "Expected a on the right.");
    auto constant_combine2 = (1.0 + var_a) - 2.0;
    auto constant_combine_cast2 = graph::add_cast(constant_combine2);
    assert(constant_combine_cast2.get() && "Expected a add node.");
    assert(constant_combine_cast2->get_left()->evaluate().at(0) == static_cast<T> (-1.0) &&
           "Expected -1 on the left.");
    assert(constant_combine_cast2->get_right()->is_match(var_a) &&
           "Expected a on the right.");
    auto constant_combine3 = (var_a - 1.0) - 2.0;
    auto constant_combine3_cast = graph::subtract_cast(constant_combine3);
    assert(constant_combine3_cast.get() && "Expected a subtract node.");
    assert(constant_combine3_cast->get_left()->evaluate().at(0) == static_cast<T> (-3.0) &&
           "Expected -1 on the left.");
    assert(constant_combine3_cast->get_right()->is_match(var_a) &&
           "Expected a on the right.");
    auto constant_combine4 = 2.0 - (1.0 - var_a);
    auto constant_combine4_cast = graph::subtract_cast(constant_combine4);
    assert(constant_combine4_cast.get() && "Expected a subtract node.");
    assert(constant_combine4_cast->get_left()->evaluate().at(0) == static_cast<T> (1.0) &&
           "Expected 1 on the left.");
    assert(constant_combine4_cast->get_right()->is_match(var_a) &&
           "Expected a on the right.");
    auto constant_combine5 = 2.0 - (1.0 + var_a);
    auto constant_combine5_cast = graph::add_cast(constant_combine5);
    assert(constant_combine5_cast.get() && "Expected an add node.");
    assert(constant_combine5_cast->get_left()->evaluate().at(0) == static_cast<T> (1.0) &&
           "Expected 1 on the left.");
    assert(constant_combine5_cast->get_right()->is_match(var_a) &&
           "Expected a on the right.");
    auto constant_combine6 = 2.0 - (var_a - 1.0);
    auto constant_combine6_cast = graph::subtract_cast(constant_combine6);
    assert(constant_combine6_cast.get() && "Expected a subtract node.");
    assert(constant_combine6_cast->get_left()->evaluate().at(0) == static_cast<T> (3.0) &&
           "Expected 3 on the left.");
    assert(constant_combine6_cast->get_right()->is_match(var_a) &&
           "Expected a on the right.");
}

//------------------------------------------------------------------------------