Commit 22ab0b6f authored by cianciosa's avatar cianciosa
Browse files

Combine and reduce more constants in nested fma nodes. This reduces complexity of splines.

parent f5f4b3a5
Loading
Loading
Loading
Loading
+4 −0
Original line number Diff line number Diff line
@@ -1363,6 +1363,7 @@
					"-lLLVMAArch64CodeGen",
					"-lLLVMCGData",
					"-lLLVMSandboxIR",
					"-lLLVMObjectYAML",
					"-lLLVMFrontendAtomic",
					"-lclangFrontend",
					"-lclangBasic",
@@ -1468,6 +1469,7 @@
					"-lLLVMAArch64CodeGen",
					"-lLLVMCGData",
					"-lLLVMSandboxIR",
					"-lLLVMObjectYAML",
					"-lLLVMFrontendAtomic",
					"-lclangFrontend",
					"-lclangBasic",
@@ -1848,6 +1850,7 @@
					"-lLLVMCodeGenData",
					"-lLLVMCGData",
					"-lLLVMSandboxIR",
					"-lLLVMObjectYAML",
					"-lLLVMFrontendAtomic",
					"-lclangFrontend",
					"-lclangBasic",
@@ -1947,6 +1950,7 @@
					"-lLLVMCodeGenData",
					"-lLLVMCGData",
					"-lLLVMSandboxIR",
					"-lLLVMObjectYAML",
					"-lLLVMFrontendAtomic",
					"-lclangFrontend",
					"-lclangBasic",
+143 −3
Original line number Diff line number Diff line
@@ -1903,6 +1903,33 @@ namespace graph {
                if (is_variable_promotable(rm->get_right(), this->left)) {
                    return (this->left*rm->get_left())*rm->get_right();
                }

                auto rmlfma = fma_cast(rm->get_left());
                if (rmlfma.get()) {
                    if (is_constant_combineable(this->left,
                                                rmlfma->get_left()) &&
                        is_constant_combineable(this->left,
                                                rmlfma->get_right())) {
                        return fma(this->left*rmlfma->get_left(),
                                   rmlfma->get_middle(),
                                   this->left*rmlfma->get_right())*rm->get_right();
                    }

                    auto rmlfmalfma = fma_cast(rmlfma->get_left());
                    if (rmlfmalfma.get()) {
                        if (is_constant_combineable(this->left,
                                                    rmlfmalfma->get_left()) &&
                            is_constant_combineable(this->left,
                                                    rmlfmalfma->get_right()) &&
                            is_constant_combineable(this->left, rmlfma->get_right())) {
                            return fma(fma(this->left*rmlfmalfma->get_left(),
                                           rmlfmalfma->get_middle(),
                                           this->left*rmlfmalfma->get_right()),
                                       rmlfma->get_middle(),
                                       this->left*rmlfma->get_right())*rm->get_right();
                        }
                    }
                }
            }

//  v1*(c*v2) -> c*(v1*v2)
@@ -2299,6 +2326,79 @@ namespace graph {
                }
            }

//  c3*fma(c1,a,c2) -> fma(c4,a,c5)
            auto rfma = fma_cast(this->right);
            if (rfma.get()) {
                if (is_constant_combineable(this->left, rfma->get_left()) &&
                    is_constant_combineable(this->left, rfma->get_right())) {
                    return fma(this->left*rfma->get_left(),
                               rfma->get_middle(),
                               this->left*rfma->get_right());
                }

                auto rfmalfma = fma_cast(rfma->get_left());
                if (rfmalfma.get()) {
                    if (is_constant_combineable(this->left, rfmalfma->get_left())  &&
                        is_constant_combineable(this->left, rfmalfma->get_right()) &&
                        is_constant_combineable(this->left, rfma->get_right())) {
                        return fma(fma(this->left*rfmalfma->get_left(),
                                       rfmalfma->get_middle(),
                                       this->left*rfmalfma->get_right()),
                                   rfma->get_middle(),
                                   this->left*rfma->get_right());
                    }
                    
                    auto rfmalfmalfma = fma_cast(rfmalfma->get_left());
                    if (rfmalfmalfma.get()) {
                        if (is_constant_combineable(this->left, rfmalfmalfma->get_left())  &&
                            is_constant_combineable(this->left, rfmalfmalfma->get_right()) &&
                            is_constant_combineable(this->left, rfmalfma->get_right())     &&
                            is_constant_combineable(this->left, rfma->get_right())) {
                            return fma(fma(fma(this->left*rfmalfmalfma->get_left(),
                                               rfmalfmalfma->get_middle(),
                                               this->left*rfmalfmalfma->get_right()),
                                           rfmalfma->get_middle(),
                                           this->left*rfmalfma->get_right()),
                                       rfma->get_middle(),
                                       this->left*rfma->get_right());
                        }
                    }
                }
            }

//  fma(c1,x,c2)*(c3 + x) -> fma(fma(c1,x,c4),x,c5)
            auto lfma = fma_cast(this->left);
            auto ra = add_cast(this->right);
            if (lfma.get() && ra.get()) {
                if (ra->get_right()->is_match(lfma->get_middle())             &&
                    is_constant_combineable(ra->get_left(), lfma->get_left()) &&
                    is_constant_combineable(ra->get_left(), lfma->get_right())) {
                    return fma(fma(lfma->get_left(),
                                   ra->get_right(),
                                   ra->get_left()*lfma->get_left() + lfma->get_right()),
                               ra->get_right(),
                               lfma->get_right()*ra->get_left());
                }

//  fma(fma(c1,x,c2),x,c3)*(c4 + x) -> fma(fma(fma(c1,x,c5),x,c6),x,c7)
                auto lfmalfma = fma_cast(lfma->get_left());
                if (ra->get_right()->is_match(lfma->get_middle())                  &&
                    ra->get_right()->is_match(lfmalfma->get_middle())              &&
                    is_constant_combineable(ra->get_left(), lfma->get_right())     &&
                    is_constant_combineable(ra->get_left(), lfmalfma->get_right()) &&
                    is_constant_combineable(ra->get_left(), lfmalfma->get_left())) {
                    return fma(fma(fma(lfmalfma->get_left(),
                                       ra->get_right(),
                                       ra->get_left()*lfmalfma->get_left() +
                                       lfmalfma->get_right()),
                                   ra->get_right(),
                                   ra->get_left()*lfmalfma->get_right() +
                                   lfma->get_right()),
                               ra->get_right(),
                               ra->get_left()*lfma->get_right());
                }
            }

//  Cases like
//  (c/exp(a))*(exp(b)/d) -> (c/d)*(exp(b)/exp(a))
//  (c/exp(a))*(d/exp(b)) -> (c*e)/(exp(b)*exp(a))
@@ -3673,13 +3773,53 @@ namespace graph {
                }
            }

//  fma(c1,c2 - a,c3) -> c4 - c5*a
//  fma(c1,c2 - a,c3) -> fma(-c1,a,c1*c2 + c3)
//  fma(c1,a - c2,c3) -> fma(c1,a,c3 - c1*c2)
            auto ms = subtract_cast(this->middle);
            if (ms.get()) {
                if (is_constant_combineable(this->left, ms->get_left()) &&
                    is_constant_combineable(this->left, this->right)) {
                    return fma(this->left, ms->get_left(), this->right) -
                           this->left*ms->get_right();
                    return fma(-this->left, ms->get_right(),
                               this->left*ms->get_left() + this->right);
                } else if (is_constant_combineable(this->left, ms->get_right()) &&
                           is_constant_combineable(this->left, this->right)) {
                    return fma(this->left, ms->get_left(),
                               this->right - this->left*ms->get_right());
                }

                auto lfma = fma_cast(this->left);
                if (lfma.get()) {
                    if (is_constant_combineable(ms->get_right(), lfma->get_left())  &&
                        is_constant_combineable(ms->get_right(), lfma->get_right()) &&
                        is_constant_combineable(this->right, lfma->get_right())     &&
                        lfma->get_middle()->is_match(ms->get_left())) {
                        return fma(fma(lfma->get_left(),
                                       ms->get_left(),
                                       lfma->get_right() - lfma->get_left()*ms->get_right()),
                                   ms->get_left(),
                                   this->right - lfma->get_right()*ms->get_right());
                    }

                    auto lfmalfma = fma_cast(lfma->get_left());
                    if (lfmalfma.get()) {
                        if (lfma->get_middle()->is_match(ms->get_left())                    &&
                            lfmalfma->get_middle()->is_match(ms->get_left())                &&
                            is_constant_combineable(ms->get_right(), lfmalfma->get_left())  &&
                            is_constant_combineable(ms->get_right(), lfmalfma->get_right()) &&
                            is_constant_combineable(ms->get_right(), lfma->get_right())     &&
                            is_constant_combineable(ms->get_right(), this->right)) {
                            return fma(fma(fma(lfmalfma->get_left(),
                                               ms->get_left(),
                                               lfmalfma->get_right() -
                                               lfmalfma->get_left()*ms->get_right()),
                                           ms->get_left(),
                                           lfma->get_right() -
                                           lfmalfma->get_right()*ms->get_right()),
                                       ms->get_left(),
                                       this->right -
                                       lfma->get_right()*ms->get_right());
                        }
                    }
                }
            }

+72 −0
Original line number Diff line number Diff line
@@ -1925,6 +1925,48 @@ template<jit::float_scalar T> void test_multiply() {
    assert(gather_power5_cast->get_right()->is_match(graph::pow(var_a*var_b,
                                                                var_c)) &&
           "Expected (a*b)^c.");

//  c3*fma(c1,a,c2) -> fma(c4,a,c5)
    auto constant_reduction = 0.25*fma(2.0, v1, 3.0);
    assert(constant_reduction->is_match(fma(2.0*0.25, v1, 3.0*0.25)) &&
           "Expected (0.5*a + 0.75)");
//  c3*(fma(c1,a,c2)*b) -> fma(c4,a,c5)*b
    auto constant_reduction2 = 0.25*(fma(2.0, v1, 3.0)*v2);
    assert(constant_reduction2->is_match(fma(2.0*0.25, v1, 3.0*0.25)*v2) &&
           "Expected (0.5*a + 0.75)*b");
//  c1*(fma(c2,a,c3)*b + c4) -> fma(c5,a,c6)*b + c7
    auto constant_reduction3 = 0.25*(fma(fma(2.0, v1, 3.0),v2,2.0));
    assert(constant_reduction3->is_match(fma(fma(2.0*0.25, v1, 3.0*0.25),v2,0.5)) &&
           "Expected (0.5*a + 0.75)*b + 0.5");
//  c1*((fma(c2,a,c3)*b + c4)*c) -> (fma(c5,a,c6)*b + c7)*c
    auto constant_reduction4 = 0.25*(fma(fma(2.0, v1, 3.0),v2,2.0)*v1);
    assert(constant_reduction4->is_match(fma(fma(2.0*0.25, v1, 3.0*0.25),v2,0.5)*v1) &&
           "Expected ((0.5*a + 0.75)*b + 0.5)*c");

//  fma(c1,x,c2)*(c3 + x) -> fma(fma(c1,x,c4),x,c5)
    auto expand = graph::fma(0.2, v1, 3.0)*(4.0 + v1);
    assert(expand->is_match(graph::fma(graph::fma(0.2, v1, 3.8), v1, 12.0)));
//  fma(fma(c1,x,c2),x,c3)*(c4 + x) -> fma(fma(fma(c1,x,c5),x,c6),x,c7)
    auto expand2 = graph::fma(fma(0.2,v1,2.3),v1,3.0)*(4.0 + v1);
    assert(expand2->is_match(graph::fma(graph::fma(graph::fma(0.2,
                                                              v1,
                                                              0.2*4.0 + 2.3),
                                                   v1,
                                                   12.2),
                                        v1,
                                        12.0)) &&
           "Exptected (((0.2*x + 3.1)*x + 12.2)*x + 12");

//  c1*fma(fma(fma(c2,x,c3),x,c4),x,c5) -> fma(fma(fma(c6,x,c7),x,c8),x,c9)
    auto consume = 10.0*(graph::fma(graph::fma(graph::fma(0.2,v1,2.3),v1,3.0),v1,0.1));
    assert(consume->is_match(graph::fma(graph::fma(graph::fma(2.0,
                                                              v1,
                                                              23.0),
                                                   v1,
                                                   30.0),
                                        v1,
                                        1.0)) &&
           "Expected fma(fma(fma(2,x,23),x,30,x,1))");
}

//------------------------------------------------------------------------------
@@ -3705,6 +3747,36 @@ template<jit::float_scalar T> void test_fma() {
           "Expected b.");
    assert(factorize6_cast->get_left()->is_match(graph::fma(var_a,var_c,var_d)) &&
           "Expected a*c + d.");

//  fma(c1,a - c2,c3) -> fma(c1,a,c4)
    auto consume = graph::fma(2.0,var_a - 3.0,20.0);
    assert(consume->is_match(graph::fma(2.0,var_a,14.0)) &&
           "Expected fma(2,x,14)");
//  fma(c1,c2 - a,c3) -> fma(-c1,a,c4)
    auto consume2 = graph::fma(2.0,3.0 - var_a,20.0);
    assert(consume2->is_match(graph::fma(-2.0,var_a,26.0)) &&
           "Expected fma(-2,x,26)");

//  fma(fma(c1,a,c2),a - c3,c4) -> fma(fma(c1,x,c5),x,c6)
    auto gather = graph::fma(graph::fma(2.0,var_a,20.0),var_a - 2.0,30.0);
    assert(gather->is_match(graph::fma(graph::fma(2.0,var_a,16.0),var_a,-10.0)) &&
           "Expected fma(fma(2,x,16),x,-10)");
//  fma(fma(fma(c1,a,c2),a,c3),a - c4,c5) -> fma(fma(c1,x,c6),x,c6),x,c8)
    auto gather2 = graph::fma(graph::fma(graph::fma(2.0,
                                                    var_a,
                                                    20.0),
                                         var_a,
                                         30.0),
                              var_a - 2.0,
                              50.0);
    assert(gather2->is_match(graph::fma(graph::fma(graph::fma(2.0,
                                                              var_a,
                                                              16.0),
                                                   var_a,
                                                   -10.0),
                                        var_a,
                                        -10.0)) &&
           "Expected fma(fma(fma(2,x,16),x,-10),x,-10)");
}

//------------------------------------------------------------------------------
+3 −3
Original line number Diff line number Diff line
@@ -157,11 +157,11 @@ void run_test() {
                    "Expected a match in bx.");
        check_error(work.check_value(i, bvec->get_y()), gold.by_grid[i], 1.0E-20,
                    "Expected a match in by.");
        check_error(work.check_value(i, bvec->get_z()), gold.bz_grid[i], 4.0E-12,
        check_error(work.check_value(i, bvec->get_z()), gold.bz_grid[i], 5.0E-12,
                    "Expected a match in bz.");
        check_error(work.check_value(i, ne), gold.ne_grid[i], 8.0E-13,
        check_error(work.check_value(i, ne), gold.ne_grid[i], 2.1E-12,
                    "Expected a match in ne.");
        check_error(work.check_value(i, te), gold.te_grid[i], 8.0E-13,
        check_error(work.check_value(i, te), gold.te_grid[i], 2.1E-12,
                    "Expected a match in te.");
    }
}
+16 −12
Original line number Diff line number Diff line
@@ -241,12 +241,14 @@ template<jit::float_scalar T> void piecewise_1D() {
           "Expected p1 + p3 on the right.");
//  fma(p1,c1 - a,p2) -> p3 - p1*a
    auto fma_combine2 = fma(p1,1.0 - a,p3);
    auto fma_combine2_cast = graph::subtract_cast(fma_combine2);
    assert(fma_combine2_cast.get() && "Expected an subtract node.");
    assert(fma_combine2_cast->get_right()->is_match(p1*a) &&
           "Expected p1*a on the right.");
    assert(fma_combine2_cast->get_left()->is_match(p1 + p3) &&
           "Expected p1 + p3 on the left.");
    auto fma_combine2_cast = graph::fma_cast(fma_combine2);
    assert(fma_combine2_cast.get() && "Expected a fma node.");
    assert(fma_combine2_cast->get_left()->is_match(-p1) &&
           "Expected -p1 on the left.");
    assert(fma_combine2_cast->get_middle()->is_match(a) &&
           "Expected a in the middle.");
    assert(fma_combine2_cast->get_right()->is_match(p1 + p3) &&
           "Expected p1 + p3 on the right.");
//  p1*(c1 + a) - p2 -> fma(p1,a,p3)
    auto fma_combine3 = p1*(1.0 + a) - p3;
    auto fma_combine3_cast = graph::fma_cast(fma_combine3);
@@ -651,12 +653,14 @@ template<jit::float_scalar T> void piecewise_2D() {
           "Expected p1 + p3 on the right.");
//  fma(p1,c1 - a,p2) -> p3 - p1*a
    auto fma_combine2 = fma(p1,1.0 - ax,p3);
    auto fma_combine2_cast = graph::subtract_cast(fma_combine2);
    assert(fma_combine2_cast.get() && "Expected an subtract node.");
    assert(fma_combine2_cast->get_right()->is_match(p1*ax) &&
           "Expected p1*a on the right.");
    assert(fma_combine2_cast->get_left()->is_match(p1 + p3) &&
           "Expected p1 + p3 on the left.");
    auto fma_combine2_cast = graph::fma_cast(fma_combine2);
    assert(fma_combine2_cast.get() && "Expected a fma node.");
    assert(fma_combine2_cast->get_left()->is_match(-p1) &&
           "Expected -p1 on the right.");
    assert(fma_combine2_cast->get_middle()->is_match(ax) &&
           "Expected a in the middle.");
    assert(fma_combine2_cast->get_right()->is_match(p1 + p3) &&
           "Expected p1 + p3 on the right.");
//  p1*(c1 + a) - p2 -> fma(p1,a,p3)
    auto fma_combine3 = p1*(1.0 + ax) - p3;
    auto fma_combine3_cast = graph::fma_cast(fma_combine3);
+1 −1

File changed.

Contains only whitespace changes.

Loading