Loading graph_framework/arithmetic.hpp +74 −0 Original line number Diff line number Diff line Loading @@ -2811,6 +2811,80 @@ namespace graph { } } // fma(exp(a)/c, exp(b)*d, e) -> fma(exp(a)*exp(b), d/c, e) // fma(exp(a)/c, d*exp(b), e) -> fma(exp(a)*exp(b), d/c, e) // fma(c/exp(a), exp(b)*d, e) -> fma(exp(b)/exp(a), c*d, e) // fma(c/exp(a), d*exp(b), e) -> fma(exp(b)/exp(a), c*d, e) if (ld.get() && mm.get()) { auto ldle = exp_cast(ld->get_left()); if (ldle.get()) { auto mmle = exp_cast(mm->get_left()); if (mmle.get()) { return fma(ld->get_left()*mm->get_left(), mm->get_right()/ld->get_right(), this->right); } auto mmre = exp_cast(mm->get_right()); if (mmre.get()) { return fma(ld->get_left()*mm->get_right(), mm->get_left()/ld->get_right(), this->right); } } auto ldre = exp_cast(ld->get_right()); if (ldre.get()) { auto mmle = exp_cast(mm->get_left()); if (mmle.get()) { return fma(mm->get_left()/ld->get_right(), ld->get_left()*mm->get_right(), this->right); } auto mmre = exp_cast(mm->get_right()); if (mmre.get()) { return fma(mm->get_right()/ld->get_right(), ld->get_left()*mm->get_left(), this->right); } } } // fma(exp(a)/c, exp(b)/d, e) -> (exp(a)*exp(b))/(c*d) + e // fma(exp(a)/c, d/exp(b), e) -> fma(exp(a)/exp(b), d/c, e) // fma(c/exp(a), exp(b)/d, e) -> fma(exp(b)/exp(a), c/d, e) // fma(c/exp(a), d/exp(b), e) -> (c*d)/(exp(a)*exp(b)) + e if (ld.get() && md.get()) { auto ldle = exp_cast(ld->get_left()); if (ldle.get()) { auto mdle = exp_cast(md->get_left()); if (mdle.get()) { return ((ld->get_left()*md->get_left()) / (ld->get_right()*md->get_right())) + this->right; } auto mdre = exp_cast(md->get_right()); if (mdre.get()) { return fma(ld->get_left()/md->get_right(), md->get_left()/ld->get_right(), this->right); } } auto ldre = exp_cast(ld->get_right()); if (ldre.get()) { auto mdle = exp_cast(md->get_left()); if (mdle.get()) { return fma(md->get_left()/ld->get_right(), ld->get_left()/md->get_right(), this->right); } auto mdre = exp_cast(md->get_right()); if (mdre.get()) { return ((ld->get_left()*md->get_left()) / (ld->get_right()*md->get_right())) + this->right; } } } return this->shared_from_this(); } Loading graph_tests/arithmetic_test.cpp +50 −0 Original line number Diff line number Diff line Loading @@ -2392,6 +2392,56 @@ template<jit::float_scalar T> void test_fma() { assert(fmaexp13_cast.get() && "Expected a fma node."); assert(graph::exp_cast(fmaexp13_cast->get_left()).get() && "Expected a exp node on the left."); // fma(exp(a)/c, exp(b)*d, e) -> fma(exp(a + b), d/c, e) auto fmaexp14 = graph::fma(expa/exp_c, expb*exp_d, exp_e); auto fmaexp14_cast = graph::fma_cast(fmaexp14); assert(fmaexp14_cast.get() && "Expected a fma node."); assert(graph::exp_cast(fmaexp14_cast->get_left()).get() && "Expected a exp node on the left."); // fma(exp(a)/c, d*exp(b), e) -> fma(exp(a + b), d/c, e) auto fmaexp15 = graph::fma(expa/exp_c, exp_d*expb, exp_e); auto fmaexp15_cast = graph::fma_cast(fmaexp15); assert(fmaexp15_cast.get() && "Expected a fma node."); assert(graph::exp_cast(fmaexp15_cast->get_left()).get() && "Expected a exp node on the left."); // fma(c/exp(a), exp(b)*d, e) -> fma(exp(b - a), c*d, e) auto fmaexp16 = graph::fma(exp_c/expa, expb*exp_d, exp_e); auto fmaexp16_cast = graph::fma_cast(fmaexp16); assert(fmaexp16_cast.get() && "Expected a fma node."); assert(graph::exp_cast(fmaexp16_cast->get_left()).get() && "Expected a exp node on the left."); // fma(c/exp(a), d*exp(b), e) -> fma(exp(b - a), c*d, e) auto fmaexp17 = graph::fma(exp_c/expa, exp_d*expb, exp_e); auto fmaexp17_cast = graph::fma_cast(fmaexp17); assert(fmaexp17_cast.get() && "Expected a fma node."); assert(graph::exp_cast(fmaexp17_cast->get_left()).get() && "Expected a exp node on the left."); // fma(exp(a)/c, exp(b)/d, e) -> exp(a + b)/(c*d) + e auto fmaexp18 = graph::fma(expa/exp_c, expb/exp_d, exp_e); auto fmaexp18_cast = graph::add_cast(fmaexp18); assert(fmaexp18_cast.get() && "Expected an add node."); assert(graph::divide_cast(fmaexp18_cast->get_left()).get() && "Expected a divide node on the left."); // fma(exp(a)/c, d/exp(b), e) -> fma(exp(a - b), d/c, e) auto fmaexp19 = graph::fma(expa/exp_c, exp_d/expb, exp_e); auto fmaexp19_cast = graph::fma_cast(fmaexp19); assert(fmaexp19_cast.get() && "Expected a fma node."); assert(graph::exp_cast(fmaexp19_cast->get_left()).get() && "Expected a exp node on the left."); // fma(c/exp(a), exp(b)/d, e) -> fma(exp(b - a), c/d, e) auto fmaexp20 = graph::fma(exp_c/expa, expb/exp_d, exp_e); auto fmaexp20_cast = graph::fma_cast(fmaexp20); assert(fmaexp20_cast.get() && "Expected a fma node."); assert(graph::exp_cast(fmaexp20_cast->get_left()).get() && "Expected a exp node on the left."); // fma(c/exp(a), d/exp(b), e) -> (c*d)/exp(a + b) + e auto fmaexp21 = graph::fma(exp_c/expa, exp_d/expb, exp_e); auto fmaexp21_cast = graph::add_cast(fmaexp21); assert(fmaexp21_cast.get() && "Expected an add node."); assert(graph::divide_cast(fmaexp21_cast->get_left()).get() && "Expected a dive node on the left."); } //------------------------------------------------------------------------------ Loading Loading
graph_framework/arithmetic.hpp +74 −0 Original line number Diff line number Diff line Loading @@ -2811,6 +2811,80 @@ namespace graph { } } // fma(exp(a)/c, exp(b)*d, e) -> fma(exp(a)*exp(b), d/c, e) // fma(exp(a)/c, d*exp(b), e) -> fma(exp(a)*exp(b), d/c, e) // fma(c/exp(a), exp(b)*d, e) -> fma(exp(b)/exp(a), c*d, e) // fma(c/exp(a), d*exp(b), e) -> fma(exp(b)/exp(a), c*d, e) if (ld.get() && mm.get()) { auto ldle = exp_cast(ld->get_left()); if (ldle.get()) { auto mmle = exp_cast(mm->get_left()); if (mmle.get()) { return fma(ld->get_left()*mm->get_left(), mm->get_right()/ld->get_right(), this->right); } auto mmre = exp_cast(mm->get_right()); if (mmre.get()) { return fma(ld->get_left()*mm->get_right(), mm->get_left()/ld->get_right(), this->right); } } auto ldre = exp_cast(ld->get_right()); if (ldre.get()) { auto mmle = exp_cast(mm->get_left()); if (mmle.get()) { return fma(mm->get_left()/ld->get_right(), ld->get_left()*mm->get_right(), this->right); } auto mmre = exp_cast(mm->get_right()); if (mmre.get()) { return fma(mm->get_right()/ld->get_right(), ld->get_left()*mm->get_left(), this->right); } } } // fma(exp(a)/c, exp(b)/d, e) -> (exp(a)*exp(b))/(c*d) + e // fma(exp(a)/c, d/exp(b), e) -> fma(exp(a)/exp(b), d/c, e) // fma(c/exp(a), exp(b)/d, e) -> fma(exp(b)/exp(a), c/d, e) // fma(c/exp(a), d/exp(b), e) -> (c*d)/(exp(a)*exp(b)) + e if (ld.get() && md.get()) { auto ldle = exp_cast(ld->get_left()); if (ldle.get()) { auto mdle = exp_cast(md->get_left()); if (mdle.get()) { return ((ld->get_left()*md->get_left()) / (ld->get_right()*md->get_right())) + this->right; } auto mdre = exp_cast(md->get_right()); if (mdre.get()) { return fma(ld->get_left()/md->get_right(), md->get_left()/ld->get_right(), this->right); } } auto ldre = exp_cast(ld->get_right()); if (ldre.get()) { auto mdle = exp_cast(md->get_left()); if (mdle.get()) { return fma(md->get_left()/ld->get_right(), ld->get_left()/md->get_right(), this->right); } auto mdre = exp_cast(md->get_right()); if (mdre.get()) { return ((ld->get_left()*md->get_left()) / (ld->get_right()*md->get_right())) + this->right; } } } return this->shared_from_this(); } Loading
graph_tests/arithmetic_test.cpp +50 −0 Original line number Diff line number Diff line Loading @@ -2392,6 +2392,56 @@ template<jit::float_scalar T> void test_fma() { assert(fmaexp13_cast.get() && "Expected a fma node."); assert(graph::exp_cast(fmaexp13_cast->get_left()).get() && "Expected a exp node on the left."); // fma(exp(a)/c, exp(b)*d, e) -> fma(exp(a + b), d/c, e) auto fmaexp14 = graph::fma(expa/exp_c, expb*exp_d, exp_e); auto fmaexp14_cast = graph::fma_cast(fmaexp14); assert(fmaexp14_cast.get() && "Expected a fma node."); assert(graph::exp_cast(fmaexp14_cast->get_left()).get() && "Expected a exp node on the left."); // fma(exp(a)/c, d*exp(b), e) -> fma(exp(a + b), d/c, e) auto fmaexp15 = graph::fma(expa/exp_c, exp_d*expb, exp_e); auto fmaexp15_cast = graph::fma_cast(fmaexp15); assert(fmaexp15_cast.get() && "Expected a fma node."); assert(graph::exp_cast(fmaexp15_cast->get_left()).get() && "Expected a exp node on the left."); // fma(c/exp(a), exp(b)*d, e) -> fma(exp(b - a), c*d, e) auto fmaexp16 = graph::fma(exp_c/expa, expb*exp_d, exp_e); auto fmaexp16_cast = graph::fma_cast(fmaexp16); assert(fmaexp16_cast.get() && "Expected a fma node."); assert(graph::exp_cast(fmaexp16_cast->get_left()).get() && "Expected a exp node on the left."); // fma(c/exp(a), d*exp(b), e) -> fma(exp(b - a), c*d, e) auto fmaexp17 = graph::fma(exp_c/expa, exp_d*expb, exp_e); auto fmaexp17_cast = graph::fma_cast(fmaexp17); assert(fmaexp17_cast.get() && "Expected a fma node."); assert(graph::exp_cast(fmaexp17_cast->get_left()).get() && "Expected a exp node on the left."); // fma(exp(a)/c, exp(b)/d, e) -> exp(a + b)/(c*d) + e auto fmaexp18 = graph::fma(expa/exp_c, expb/exp_d, exp_e); auto fmaexp18_cast = graph::add_cast(fmaexp18); assert(fmaexp18_cast.get() && "Expected an add node."); assert(graph::divide_cast(fmaexp18_cast->get_left()).get() && "Expected a divide node on the left."); // fma(exp(a)/c, d/exp(b), e) -> fma(exp(a - b), d/c, e) auto fmaexp19 = graph::fma(expa/exp_c, exp_d/expb, exp_e); auto fmaexp19_cast = graph::fma_cast(fmaexp19); assert(fmaexp19_cast.get() && "Expected a fma node."); assert(graph::exp_cast(fmaexp19_cast->get_left()).get() && "Expected a exp node on the left."); // fma(c/exp(a), exp(b)/d, e) -> fma(exp(b - a), c/d, e) auto fmaexp20 = graph::fma(exp_c/expa, expb/exp_d, exp_e); auto fmaexp20_cast = graph::fma_cast(fmaexp20); assert(fmaexp20_cast.get() && "Expected a fma node."); assert(graph::exp_cast(fmaexp20_cast->get_left()).get() && "Expected a exp node on the left."); // fma(c/exp(a), d/exp(b), e) -> (c*d)/exp(a + b) + e auto fmaexp21 = graph::fma(exp_c/expa, exp_d/expb, exp_e); auto fmaexp21_cast = graph::add_cast(fmaexp21); assert(fmaexp21_cast.get() && "Expected an add node."); assert(graph::divide_cast(fmaexp21_cast->get_left()).get() && "Expected a dive node on the left."); } //------------------------------------------------------------------------------ Loading