Commit 739a5e15 authored by Cianciosa, Mark's avatar Cianciosa, Mark
Browse files

Add remaining permutations to fma node reductions of exp.

parent f0f88666
Loading
Loading
Loading
Loading
+74 −0
Original line number Diff line number Diff line
@@ -2811,6 +2811,80 @@ namespace graph {
                }
            }

//  fma(exp(a)/c, exp(b)*d, e) -> fma(exp(a)*exp(b), d/c, e)
//  fma(exp(a)/c, d*exp(b), e) -> fma(exp(a)*exp(b), d/c, e)
//  fma(c/exp(a), exp(b)*d, e) -> fma(exp(b)/exp(a), c*d, e)
//  fma(c/exp(a), d*exp(b), e) -> fma(exp(b)/exp(a), c*d, e)
            if (ld.get() && mm.get()) {
                auto ldle = exp_cast(ld->get_left());
                if (ldle.get()) {
                    auto mmle = exp_cast(mm->get_left());
                    if (mmle.get()) {
                        return fma(ld->get_left()*mm->get_left(),
                                   mm->get_right()/ld->get_right(),
                                   this->right);
                    }
                    auto mmre = exp_cast(mm->get_right());
                    if (mmre.get()) {
                        return fma(ld->get_left()*mm->get_right(),
                                   mm->get_left()/ld->get_right(),
                                   this->right);
                    }
                }
                auto ldre = exp_cast(ld->get_right());
                if (ldre.get()) {
                    auto mmle = exp_cast(mm->get_left());
                    if (mmle.get()) {
                        return fma(mm->get_left()/ld->get_right(),
                                   ld->get_left()*mm->get_right(),
                                   this->right);
                    }
                    auto mmre = exp_cast(mm->get_right());
                    if (mmre.get()) {
                        return fma(mm->get_right()/ld->get_right(),
                                   ld->get_left()*mm->get_left(),
                                   this->right);
                    }
                }
            }

//  fma(exp(a)/c, exp(b)/d, e) -> (exp(a)*exp(b))/(c*d) + e
//  fma(exp(a)/c, d/exp(b), e) -> fma(exp(a)/exp(b), d/c, e)
//  fma(c/exp(a), exp(b)/d, e) -> fma(exp(b)/exp(a), c/d, e)
//  fma(c/exp(a), d/exp(b), e) -> (c*d)/(exp(a)*exp(b)) + e
            if (ld.get() && md.get()) {
                auto ldle = exp_cast(ld->get_left());
                if (ldle.get()) {
                    auto mdle = exp_cast(md->get_left());
                    if (mdle.get()) {
                        return ((ld->get_left()*md->get_left()) /
                                (ld->get_right()*md->get_right())) +
                               this->right;
                    }
                    auto mdre = exp_cast(md->get_right());
                    if (mdre.get()) {
                        return fma(ld->get_left()/md->get_right(),
                                   md->get_left()/ld->get_right(),
                                   this->right);
                    }
                }
                auto ldre = exp_cast(ld->get_right());
                if (ldre.get()) {
                    auto mdle = exp_cast(md->get_left());
                    if (mdle.get()) {
                        return fma(md->get_left()/ld->get_right(),
                                   ld->get_left()/md->get_right(),
                                   this->right);
                    }
                    auto mdre = exp_cast(md->get_right());
                    if (mdre.get()) {
                        return ((ld->get_left()*md->get_left()) /
                                (ld->get_right()*md->get_right())) +
                               this->right;
                    }
                }
            }

            return this->shared_from_this();
        }

+50 −0
Original line number Diff line number Diff line
@@ -2392,6 +2392,56 @@ template<jit::float_scalar T> void test_fma() {
    assert(fmaexp13_cast.get() && "Expected a fma node.");
    assert(graph::exp_cast(fmaexp13_cast->get_left()).get() &&
           "Expected a exp node on the left.");

//  fma(exp(a)/c, exp(b)*d, e) -> fma(exp(a + b), d/c, e)
    auto fmaexp14 = graph::fma(expa/exp_c, expb*exp_d, exp_e);
    auto fmaexp14_cast = graph::fma_cast(fmaexp14);
    assert(fmaexp14_cast.get() && "Expected a fma node.");
    assert(graph::exp_cast(fmaexp14_cast->get_left()).get() &&
           "Expected a exp node on the left.");
//  fma(exp(a)/c, d*exp(b), e) -> fma(exp(a + b), d/c, e)
    auto fmaexp15 = graph::fma(expa/exp_c, exp_d*expb, exp_e);
    auto fmaexp15_cast = graph::fma_cast(fmaexp15);
    assert(fmaexp15_cast.get() && "Expected a fma node.");
    assert(graph::exp_cast(fmaexp15_cast->get_left()).get() &&
           "Expected a exp node on the left.");
//  fma(c/exp(a), exp(b)*d, e) -> fma(exp(b - a), c*d, e)
    auto fmaexp16 = graph::fma(exp_c/expa, expb*exp_d, exp_e);
    auto fmaexp16_cast = graph::fma_cast(fmaexp16);
    assert(fmaexp16_cast.get() && "Expected a fma node.");
    assert(graph::exp_cast(fmaexp16_cast->get_left()).get() &&
           "Expected a exp node on the left.");
//  fma(c/exp(a), d*exp(b), e) -> fma(exp(b - a), c*d, e)
    auto fmaexp17 = graph::fma(exp_c/expa, exp_d*expb, exp_e);
    auto fmaexp17_cast = graph::fma_cast(fmaexp17);
    assert(fmaexp17_cast.get() && "Expected a fma node.");
    assert(graph::exp_cast(fmaexp17_cast->get_left()).get() &&
           "Expected a exp node on the left.");

//  fma(exp(a)/c, exp(b)/d, e) -> exp(a + b)/(c*d) + e
    auto fmaexp18 = graph::fma(expa/exp_c, expb/exp_d, exp_e);
    auto fmaexp18_cast = graph::add_cast(fmaexp18);
    assert(fmaexp18_cast.get() && "Expected an add node.");
    assert(graph::divide_cast(fmaexp18_cast->get_left()).get() &&
           "Expected a divide node on the left.");
//  fma(exp(a)/c, d/exp(b), e) -> fma(exp(a - b), d/c, e)
    auto fmaexp19 = graph::fma(expa/exp_c, exp_d/expb, exp_e);
    auto fmaexp19_cast = graph::fma_cast(fmaexp19);
    assert(fmaexp19_cast.get() && "Expected a fma node.");
    assert(graph::exp_cast(fmaexp19_cast->get_left()).get() &&
           "Expected a exp node on the left.");
//  fma(c/exp(a), exp(b)/d, e) -> fma(exp(b - a), c/d, e)
    auto fmaexp20 = graph::fma(exp_c/expa, expb/exp_d, exp_e);
    auto fmaexp20_cast = graph::fma_cast(fmaexp20);
    assert(fmaexp20_cast.get() && "Expected a fma node.");
    assert(graph::exp_cast(fmaexp20_cast->get_left()).get() &&
           "Expected a exp node on the left.");
//  fma(c/exp(a), d/exp(b), e) -> (c*d)/exp(a + b) + e
    auto fmaexp21 = graph::fma(exp_c/expa, exp_d/expb, exp_e);
    auto fmaexp21_cast = graph::add_cast(fmaexp21);
    assert(fmaexp21_cast.get() && "Expected an add node.");
    assert(graph::divide_cast(fmaexp21_cast->get_left()).get() &&
           "Expected a dive node on the left.");
}

//------------------------------------------------------------------------------