Commit 702f1c6c authored by Cianciosa, Mark's avatar Cianciosa, Mark
Browse files

Regroup mixed multiply divides in chained subtract nodes.

parent 0e3776b1
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -618,8 +618,8 @@ namespace graph {
//  (a - b*c) - d*e -> a - (b*c + d*e)
//  (a - b/c) - d/e -> a - (b/c + d/e)
                auto lsrd = divide_cast(ls->get_right());
                if ((multiply_cast(ls->get_right()).get() && rm.get()) ||
                    (divide_cast(ls->get_right()).get() && rd.get())) {
                if ((multiply_cast(ls->get_right()).get() && (rm.get() || rd.get())) ||
                    (divide_cast(ls->get_right()).get()   && (rm.get() || rd.get()))) {
                    return ls->get_left() - (ls->get_right() + this->right);
                }
            }
+25 −13
Original line number Diff line number Diff line
@@ -496,6 +496,31 @@ template<typename T> void test_subtract() {
    auto common_factor9 = two*(var_b*var_c) - var_b*var_a;
    assert(graph::multiply_cast(common_factor9).get() &&
           "Expected multiply node.");

//  (a - b*c) - d*c -> a - (b + d)*c
    auto chained_subtract_multiply = (var_a - var_b*var_c) - var_d*var_c;
    auto chained_subtract_multiply_cast = graph::subtract_cast(chained_subtract_multiply);
    assert(chained_subtract_multiply_cast.get() && "Expected subtract node.");
    assert(graph::multiply_cast(chained_subtract_multiply_cast->get_right()).get() &&
           "Expected a multiply node on the left.");
//  (a - b*c) - c/d -> a - (b*c + c/d)
    auto chained_subtract_multiply2 = (var_a - var_b*var_c) - var_c/var_d;
    auto chained_subtract_multiply_cast2 = graph::subtract_cast(chained_subtract_multiply2);
    assert(chained_subtract_multiply_cast2.get() && "Expected subtract node.");
    assert(graph::fma_cast(chained_subtract_multiply_cast2->get_right()).get() &&
           "Expected a fused multiply add node on the left.");
//  (a - b/c) - d/c -> a - (b + d)*c
    auto chained_subtract_divide = (var_a - var_b/var_c) - var_d/var_c;
    auto chained_subtract_divide_cast = graph::subtract_cast(chained_subtract_divide);
    assert(chained_subtract_divide_cast.get() && "Expected subtract node.");
    assert(graph::divide_cast(chained_subtract_divide_cast->get_right()).get() &&
           "Expected a divide node on the left.");
//  (a - b/c) - d*c -> a - (d*c + b/c)
    auto chained_subtract_divide2 = (var_a - var_b/var_c) - var_d*var_c;
    auto chained_subtract_divide_cast2 = graph::subtract_cast(chained_subtract_divide2);
    assert(chained_subtract_divide_cast2.get() && "Expected subtract node.");
    assert(graph::fma_cast(chained_subtract_divide_cast2->get_right()).get() &&
           "Expected a fused multiply add node on the left.");
}

//------------------------------------------------------------------------------
@@ -1542,19 +1567,6 @@ template<typename T> void test_fma() {
    auto divide_factor4 = graph::fma(var_c/var_b, var_a, var_d/var_b);
    assert(graph::divide_cast(divide_factor4).get() &&
           "Expetced a divide node.");

//  (a - b*c) - d*c -> a - (b + d)*c
    auto chained_subtract_multiply = (var_a - var_b*var_c) - var_d*var_c;
    auto chained_subtract_multiply_cast = graph::subtract_cast(chained_subtract_multiply);
    assert(chained_subtract_multiply_cast.get() && "Expected subtract node.");
    assert(graph::multiply_cast(chained_subtract_multiply_cast->get_right()).get() &&
           "Expected a multiply node on the left.");
//  (a - b/c) - d/c -> a - (b + d)*c
    auto chained_subtract_divide = (var_a - var_b/var_c) - var_d/var_c;
    auto chained_subtract_divide_cast = graph::subtract_cast(chained_subtract_divide);
    assert(chained_subtract_divide_cast.get() && "Expected subtract node.");
    assert(graph::divide_cast(chained_subtract_divide_cast->get_right()).get() &&
           "Expected a divide node on the left.");
}

//------------------------------------------------------------------------------