Commit f9035323 authored by Cianciosa, Mark's avatar Cianciosa, Mark
Browse files

Add a common factor of divide multiply nodes.

parent 702f1c6c
Loading
Loading
Loading
Loading
+26 −3
Original line number Diff line number Diff line
@@ -102,11 +102,34 @@ namespace graph {
            auto ld = divide_cast(this->left);
            auto rd = divide_cast(this->right);

            if (ld.get() && rd.get() &&
                ld->get_right()->is_match(rd->get_right())) {
            if (ld.get() && rd.get()) {
                if (ld->get_right()->is_match(rd->get_right())) {
                    return (ld->get_left() + rd->get_left())/ld->get_right();
                }

//  (a/(c*b) + d/(e*c)) -> (a/b + d/e)/c
//  (a/(b*c) + d/(e*c)) -> (a/b + d/e)/c
//  (a/(c*b) + d/(c*e)) -> (a/b + d/e)/c
//  (a/(b*c) + d/(c*e)) -> (a/b + d/e)/c
                auto ldrm = multiply_cast(ld->get_right());
                auto rdrm = multiply_cast(rd->get_right());
                if (ldrm.get() && rdrm.get()) {
                    if (ldrm->get_right()->is_match(rdrm->get_right())) {
                        return (ld->get_left()/ldrm->get_left() +
                                rd->get_left()/rdrm->get_left())/ldrm->get_right();
                    } else if (ldrm->get_right()->is_match(rdrm->get_left())) {
                        return (ld->get_left()/ldrm->get_left() +
                                rd->get_left()/rdrm->get_right())/ldrm->get_right();
                    } else if (ldrm->get_left()->is_match(rdrm->get_right())) {
                        return (ld->get_left()/ldrm->get_right() +
                                rd->get_left()/rdrm->get_left())/ldrm->get_left();
                    } else if (ldrm->get_left()->is_match(rdrm->get_left())) {
                        return (ld->get_left()/ldrm->get_right() +
                                rd->get_left()/rdrm->get_right())/ldrm->get_left();
                    }
                }
            }

//  Chained addition reductions.
//  a + (a + b) = fma(2,a,b)
//  a + (b + a) = fma(2,a,b)
+17 −0
Original line number Diff line number Diff line
@@ -238,6 +238,23 @@ template<typename T> void test_add() {
           "Expected var_c in the second slot.");
    assert(graph::add_cast(add_fma_cast->get_right()) &&
           "Expected add_node in the third slot.");
    
//  (a/(b*c) + d/(e*c)) -> (a/b + d/e)/c
    auto muliply_divide_factor = var_a/(var_b*var_c) + var_d/(var_e*var_c);
    auto muliply_divide_factor_cast = divide_cast(muliply_divide_factor);
    assert(muliply_divide_factor_cast.get() && "Expected divide node.");
//  (a/(b*c) + d/(c*e)) -> (a/b + d/e)/c
    auto muliply_divide_factor2 = var_a/(var_b*var_c) + var_d/(var_c*var_e);
    auto muliply_divide_factor_cast2 = divide_cast(muliply_divide_factor2);
    assert(muliply_divide_factor_cast2.get() && "Expected divide node.");
//  (a/(c*b) + d/(e*c)) -> (a/b + d/e)/c
    auto muliply_divide_factor3 = var_a/(var_c*var_b) + var_d/(var_e*var_c);
    auto muliply_divide_factor_cast3 = divide_cast(muliply_divide_factor3);
    assert(muliply_divide_factor_cast3.get() && "Expected divide node.");
//  (a/(c*b) + d/(c*e)) -> (a/b + d/e)/c
    auto muliply_divide_factor4 = var_a/(var_c*var_b) + var_d/(var_c*var_e);
    auto muliply_divide_factor_cast4 = divide_cast(muliply_divide_factor4);
    assert(muliply_divide_factor_cast4.get() && "Expected divide node.");
}

//------------------------------------------------------------------------------
+2 −2

File changed.

Contains only whitespace changes.