Commit d20627d2 authored by cianciosa's avatar cianciosa
Browse files

Refactor chained fma reductions and expansion to handle arbitrary length chains.

parent 22ab0b6f
Loading
Loading
Loading
Loading
+2 −0
Original line number Diff line number Diff line
@@ -2530,6 +2530,7 @@
					"-lLLVMCGData",
					"-lLLVMSandboxIR",
					"-lLLVMFrontendAtomic",
					"-lLLVMObjectYAML",
					"-lLLVMAArch64CodeGen",
					"-lclangFrontend",
					"-lclangBasic",
@@ -2630,6 +2631,7 @@
					"-lLLVMCGData",
					"-lLLVMSandboxIR",
					"-lLLVMFrontendAtomic",
					"-lLLVMObjectYAML",
					"-lLLVMAArch64CodeGen",
					"-lclangFrontend",
					"-lclangBasic",
+166 −124
Original line number Diff line number Diff line
@@ -1685,6 +1685,112 @@ namespace graph {
    class multiply_node final : public branch_node<T, SAFE_MATH> {
    private:
//------------------------------------------------------------------------------
///  @brief Try to reduce paterns of constant times nested fma nodes.
///
///  c1*fma(...,x,c2) -> fma(...,x,c3)
///
///  @param[in] trial The fma node to try to reduce.
///  @returns The reduced node or null if it could not reduce the node.
//------------------------------------------------------------------------------
        shared_leaf<T, SAFE_MATH>
        reduce_nested_fma_times_constant(shared_leaf<T, SAFE_MATH> trial) {
            auto temp = fma_cast(trial);
            if (temp.get()) {
                if (is_constant_combineable(this->left, temp->get_left()) &&
                    is_constant_combineable(this->left, temp->get_right())) {
                    return fma(this->left*temp->get_left(),
                               temp->get_middle(),
                               this->left*temp->get_right());
                } else {
                    auto temp2 = reduce_nested_fma_times_constant(temp->get_left());
                    if (temp2.get()) {
                        return fma(temp2,
                                   temp->get_middle(),
                                   this->left*temp->get_right());
                    }
                }
            }
            return null_leaf<T, SAFE_MATH> ();
        }

//------------------------------------------------------------------------------
///  @brief Try to expand nested fma node.
///
///  fma(...,x,c2)*(c3 + x)* -> fma(...,x,c4)
///
///  @param[in] trial The fma node to try to expand.
///  @param[in] add   The add node to try to expand.
///  @returns The expanded node or null if it could not expanded the node.
//------------------------------------------------------------------------------
        shared_leaf<T, SAFE_MATH>
        expand_nested_fma_times_add(shared_leaf<T, SAFE_MATH> trial,
                                    shared_add<T, SAFE_MATH> add) {
            auto temp = fma_cast(trial);
            if (temp.get()) {
                if (add->get_right()->is_match(temp->get_middle())              &&
                    is_constant_combineable(add->get_left(), temp->get_right())) {
                    auto temp2 = expand_nested_fma_times_add2(temp->get_left(),
                                                              temp, add);
                    if (temp2.get()) {
                        return fma(temp2,
                                   add->get_right(),
                                   temp->get_right()*add->get_left());
                    } else if (is_constant_combineable(add->get_left(), temp->get_left())) {
                        return fma(fma(temp->get_left(),
                                       add->get_right(),
                                       add->get_left()*temp->get_left() + temp->get_right()),
                                   add->get_right(),
                                   temp->get_right()*add->get_left());
                    }
                }
            }
            return null_leaf<T, SAFE_MATH> ();
        }

//------------------------------------------------------------------------------
///  @brief Try to expand nested fma node.
///
///  fma(...,x,c2)*(c3 + x)* -> fma(...,x,c4)
///
///  @param[in] trial The fma node to try to reduce.
///  @param[in] last  The last fma node.
///  @param[in] add   The add node to try to expand.
///  @returns The expanded node or null if it could not expanded the node.
//------------------------------------------------------------------------------
        shared_leaf<T, SAFE_MATH>
        expand_nested_fma_times_add2(shared_leaf<T, SAFE_MATH> trial,
                                     shared_leaf<T, SAFE_MATH> last,
                                     shared_add<T, SAFE_MATH> add) {
            auto temp = fma_cast(trial);
            auto temp2 = fma_cast(last);
            assert(temp2.get() && "Assumed a fma node.");
            if (temp.get()) {
                if (add->get_right()->is_match(temp->get_middle())             &&
                    is_constant_combineable(add->get_left(), temp->get_left()) &&
                    is_constant_combineable(add->get_left(), temp->get_right())) {
                    
                    return fma(fma(temp->get_left(),
                                   add->get_right(),
                                   add->get_left()*temp->get_left() +
                                   temp->get_right()),
                               add->get_right(),
                               add->get_left()*temp->get_right() +
                               temp2->get_right());
                } else {
                    auto temp3 = expand_nested_fma_times_add2(temp->get_left(),
                                                              temp, add);
                    if (temp3.get()) {
                        return fma(temp3,
                                   add->get_right(),
                                   add->get_left()*temp->get_right() +
                                   temp2->get_right());
                    }
                }
            }
            return null_leaf<T, SAFE_MATH> ();
        }

//------------------------------------------------------------------------------
///  @brief Convert node pointer to a string.
///
///  @param[in] l Left node pointer.
@@ -1904,31 +2010,13 @@ namespace graph {
                    return (this->left*rm->get_left())*rm->get_right();
                }

                auto rmlfma = fma_cast(rm->get_left());
                if (rmlfma.get()) {
                    if (is_constant_combineable(this->left,
                                                rmlfma->get_left()) &&
                        is_constant_combineable(this->left,
                                                rmlfma->get_right())) {
                        return fma(this->left*rmlfma->get_left(),
                                   rmlfma->get_middle(),
                                   this->left*rmlfma->get_right())*rm->get_right();
                    }

                    auto rmlfmalfma = fma_cast(rmlfma->get_left());
                    if (rmlfmalfma.get()) {
                        if (is_constant_combineable(this->left,
                                                    rmlfmalfma->get_left()) &&
                            is_constant_combineable(this->left,
                                                    rmlfmalfma->get_right()) &&
                            is_constant_combineable(this->left, rmlfma->get_right())) {
                            return fma(fma(this->left*rmlfmalfma->get_left(),
                                           rmlfmalfma->get_middle(),
                                           this->left*rmlfmalfma->get_right()),
                                       rmlfma->get_middle(),
                                       this->left*rmlfma->get_right())*rm->get_right();
                        }
                    }
//  c1*(fma(c2,x,c3)*y)-> fma(c4,x,c5)*y
//  c1*(fma(fma(c2,x,c3),x,c4)*y)-> fma(fma(c5,x,c6),x,c7)*y
//  c1*(fma(fma(fma(c2,x,c3),x,c4),x,c5)*y)-> fma(fma(fma(c6,x,c7),x,c8),x,c9)*y
//  etc...
                auto temp = this->reduce_nested_fma_times_constant(rm->get_left());
                if (temp.get()) {
                    return temp*rm->get_right();
                }
            }

@@ -2326,76 +2414,24 @@ namespace graph {
                }
            }

//  c3*fma(c1,a,c2) -> fma(c4,a,c5)
            auto rfma = fma_cast(this->right);
            if (rfma.get()) {
                if (is_constant_combineable(this->left, rfma->get_left()) &&
                    is_constant_combineable(this->left, rfma->get_right())) {
                    return fma(this->left*rfma->get_left(),
                               rfma->get_middle(),
                               this->left*rfma->get_right());
                }

                auto rfmalfma = fma_cast(rfma->get_left());
                if (rfmalfma.get()) {
                    if (is_constant_combineable(this->left, rfmalfma->get_left())  &&
                        is_constant_combineable(this->left, rfmalfma->get_right()) &&
                        is_constant_combineable(this->left, rfma->get_right())) {
                        return fma(fma(this->left*rfmalfma->get_left(),
                                       rfmalfma->get_middle(),
                                       this->left*rfmalfma->get_right()),
                                   rfma->get_middle(),
                                   this->left*rfma->get_right());
                    }
                    
                    auto rfmalfmalfma = fma_cast(rfmalfma->get_left());
                    if (rfmalfmalfma.get()) {
                        if (is_constant_combineable(this->left, rfmalfmalfma->get_left())  &&
                            is_constant_combineable(this->left, rfmalfmalfma->get_right()) &&
                            is_constant_combineable(this->left, rfmalfma->get_right())     &&
                            is_constant_combineable(this->left, rfma->get_right())) {
                            return fma(fma(fma(this->left*rfmalfmalfma->get_left(),
                                               rfmalfmalfma->get_middle(),
                                               this->left*rfmalfmalfma->get_right()),
                                           rfmalfma->get_middle(),
                                           this->left*rfmalfma->get_right()),
                                       rfma->get_middle(),
                                       this->left*rfma->get_right());
                        }
                    }
                }
//  c1*fma(c2,x,c3) -> fma(c4,x,c5)
//  c1*fma(fma(c2,x,c3),x,c4) -> fma(fma(c5,x,c6),x,c7)
//  c1*fma(fma(fma(c2,x,c3),x,c4),x,c5) -> fma(fma(fma(c6,x,c7),x,c8),x,c9)
//  etc...
            auto fma_reduce = this->reduce_nested_fma_times_constant(this->right);
            if (fma_reduce.get()) {
                return fma_reduce;
            }

//  fma(c1,x,c2)*(c3 + x) -> fma(fma(c1,x,c4),x,c5)
            auto lfma = fma_cast(this->left);
            auto ra = add_cast(this->right);
            if (lfma.get() && ra.get()) {
                if (ra->get_right()->is_match(lfma->get_middle())             &&
                    is_constant_combineable(ra->get_left(), lfma->get_left()) &&
                    is_constant_combineable(ra->get_left(), lfma->get_right())) {
                    return fma(fma(lfma->get_left(),
                                   ra->get_right(),
                                   ra->get_left()*lfma->get_left() + lfma->get_right()),
                               ra->get_right(),
                               lfma->get_right()*ra->get_left());
                }

//  fma(fma(c1,x,c2),x,c3)*(c4 + x) -> fma(fma(fma(c1,x,c5),x,c6),x,c7)
                auto lfmalfma = fma_cast(lfma->get_left());
                if (ra->get_right()->is_match(lfma->get_middle())                  &&
                    ra->get_right()->is_match(lfmalfma->get_middle())              &&
                    is_constant_combineable(ra->get_left(), lfma->get_right())     &&
                    is_constant_combineable(ra->get_left(), lfmalfma->get_right()) &&
                    is_constant_combineable(ra->get_left(), lfmalfma->get_left())) {
                    return fma(fma(fma(lfmalfma->get_left(),
                                       ra->get_right(),
                                       ra->get_left()*lfmalfma->get_left() +
                                       lfmalfma->get_right()),
                                   ra->get_right(),
                                   ra->get_left()*lfmalfma->get_right() +
                                   lfma->get_right()),
                               ra->get_right(),
                               ra->get_left()*lfma->get_right());
//  etc...
            auto ra = add_cast(this->right);
            if (ra.get()) {
                auto fma_expand = this->expand_nested_fma_times_add(this->left,
                                                                    ra);
                if (fma_expand.get()) {
                    return fma_expand;
                }
            }

@@ -3661,6 +3697,42 @@ namespace graph {
    class fma_node final : public triple_node<T, SAFE_MATH> {
    private:
//------------------------------------------------------------------------------
///  @brief Reduced nested fma nodes.
///
///  fma(...,a - c1,c2) -> fma(...,a,c3)
///
///  @param[in] sub   The sub node to try to expand.
///  @returns The reduced node or null if it could not reduce the node.
//------------------------------------------------------------------------------
        shared_leaf<T, SAFE_MATH>
        reduce_nested_fma(shared_subtract<T, SAFE_MATH> sub) {
            auto temp = fma_cast(this->left);
            if (temp.get()) {
                if (is_constant_combineable(sub->get_right(), temp->get_left())  &&
                    is_constant_combineable(sub->get_right(), temp->get_right()) &&
                    is_constant_combineable(this->right, temp->get_right())      &&
                    temp->get_middle()->is_match(sub->get_left())) {
                    return fma(fma(temp->get_left(),
                                   sub->get_left(),
                                   temp->get_right() - temp->get_left()*sub->get_right()),
                               sub->get_left(),
                               this->right - temp->get_right()*sub->get_right());
                } else {
                    if (temp->get_middle()->is_match(sub->get_left()) &&
                        is_constant_combineable(sub->get_right(), this->right)) {
                        auto temp2 = temp->reduce_nested_fma(sub);
                        if (temp2.get()) {
                            return fma(temp2,
                                       sub->get_left(),
                                       this->right - temp->get_right()*sub->get_right());
                        }
                    }
                }
            }
            return this->shared_from_this();
        }
            
//------------------------------------------------------------------------------
///  @brief Convert node pointer to a string.
///
///  @param[in] l Left node pointer.
@@ -3787,39 +3859,9 @@ namespace graph {
                               this->right - this->left*ms->get_right());
                }

                auto lfma = fma_cast(this->left);
                if (lfma.get()) {
                    if (is_constant_combineable(ms->get_right(), lfma->get_left())  &&
                        is_constant_combineable(ms->get_right(), lfma->get_right()) &&
                        is_constant_combineable(this->right, lfma->get_right())     &&
                        lfma->get_middle()->is_match(ms->get_left())) {
                        return fma(fma(lfma->get_left(),
                                       ms->get_left(),
                                       lfma->get_right() - lfma->get_left()*ms->get_right()),
                                   ms->get_left(),
                                   this->right - lfma->get_right()*ms->get_right());
                    }

                    auto lfmalfma = fma_cast(lfma->get_left());
                    if (lfmalfma.get()) {
                        if (lfma->get_middle()->is_match(ms->get_left())                    &&
                            lfmalfma->get_middle()->is_match(ms->get_left())                &&
                            is_constant_combineable(ms->get_right(), lfmalfma->get_left())  &&
                            is_constant_combineable(ms->get_right(), lfmalfma->get_right()) &&
                            is_constant_combineable(ms->get_right(), lfma->get_right())     &&
                            is_constant_combineable(ms->get_right(), this->right)) {
                            return fma(fma(fma(lfmalfma->get_left(),
                                               ms->get_left(),
                                               lfmalfma->get_right() -
                                               lfmalfma->get_left()*ms->get_right()),
                                           ms->get_left(),
                                           lfma->get_right() -
                                           lfmalfma->get_right()*ms->get_right()),
                                       ms->get_left(),
                                       this->right -
                                       lfma->get_right()*ms->get_right());
                        }
                    }
                auto temp = this->reduce_nested_fma(ms);
                if (temp.get() != this) {
                    return temp;
                }
            }

+12 −0
Original line number Diff line number Diff line
@@ -335,6 +335,18 @@ namespace graph {
///  Convenience type alias for shared leaf nodes.
    template<jit::float_scalar T, bool SAFE_MATH=false>
    using shared_leaf = std::shared_ptr<leaf_node<T, SAFE_MATH>>;
//------------------------------------------------------------------------------
///  @brief Create a null leaf.
///
///  @tparam T         Base type of the calculation.
///  @tparam SAFE_MATH Use safe math operations.
///
///  @returns A null leaf.
//------------------------------------------------------------------------------
    template<jit::float_scalar T, bool SAFE_MATH=false>
    constexpr shared_leaf<T, SAFE_MATH> null_leaf() {
        return shared_leaf<T, SAFE_MATH> ();
    }
///  Convenience type alias for a vector of output nodes.
    template<jit::float_scalar T, bool SAFE_MATH=false>
    using output_nodes = std::vector<shared_leaf<T, SAFE_MATH>>;
+57 −0
Original line number Diff line number Diff line
@@ -1967,6 +1967,17 @@ template<jit::float_scalar T> void test_multiply() {
                                        v1,
                                        1.0)) &&
           "Expected fma(fma(fma(2,x,23),x,30,x,1))");

//  c1*(fma(fma(fma(c2,x,c3),x,c4),x,c5)*y) -> fma(fma(fma(c6,x,c7),x,c8),x,c9)*y
    auto consume2 = 10.0*(graph::fma(graph::fma(graph::fma(5.0,v1,0.4),v1,0.3),v1,0.3)*v2);
    assert(consume2->is_match(graph::fma(graph::fma(graph::fma(50.0,
                                                               v1,
                                                               4.0),
                                                    v1,
                                                    3.0),
                                         v1,
                                         3.0)*v2) &&
           "Expected fma(fma(fma(50,x,4),x,3),x,3)*y");
}

//------------------------------------------------------------------------------
@@ -3777,6 +3788,52 @@ template<jit::float_scalar T> void test_fma() {
                                        var_a,
                                        -10.0)) &&
           "Expected fma(fma(fma(2,x,16),x,-10),x,-10)");
/*
//  fma(fma(c1,a,c2),b - c3,fma(c4,a,c5) -> fma(fma(c6,a,c8),b,fma(c9,a,c10))
    auto gather3 = graph::fma(graph::fma(2.0,
                                         var_a,
                                         20.0),
                              var_b - 2.0,
                              graph::fma(2.0,
                                         var_a,
                                         21.0));
    assert(gather3->is_match(graph::fma(graph::fma(2.0,var_a,20.0),var_b,graph::fma(2.0,var_a,-19.0))) &&
           "Expected fma(fma(2,x,20),y,fma(2,x,-19))");

//  fma(fma(fma(fma(c1,a,c2),a,c3),a,c4),b - c5,fma(fma(fma(c6,a,c7),a,c8),a,c9)) ->
//  fma(fma(fma(fma(c10,a,c11),a,c12),a,c13),b,fma(fma(fma(c14,a,c15),a,c16),a,c17))
    auto gather4 = graph::fma(graph::fma(graph::fma(graph::fma(2.0,
                                                               var_a,
                                                               20.0),
                                                    var_a,
                                                    30.0),
                                         var_a,
                                         50.0),
                              var_b - 2.0,
                              graph::fma(graph::fma(graph::fma(2.0,
                                                               var_a,
                                                               21.0),
                                                    var_a,
                                                    31.0),
                                         var_a,
                                         51.0));
    assert(gather3->is_match(graph::fma(graph::fma(graph::fma(graph::fma(2.0,
                                                                         var_a,
                                                                         20.0),
                                                              var_a,
                                                              30.0),
                                                   var_a,
                                                   50.0),
                                        var_b ,
                                        graph::fma(graph::fma(graph::fma(2.0,
                                                                         var_a,
                                                                         -19.0),
                                                              var_a,
                                                              -29.0),
                                                   var_a,
                                                   -49.0))) &&
           "Expected fma(fma(fma(fma(2,x,20),x,30),x,50),b,fma(fma(fma(2,x,-19),-29),-49)");
 */
}

//------------------------------------------------------------------------------