Loading graph_framework.xcodeproj/project.pbxproj +2 −0 Original line number Diff line number Diff line Loading @@ -2530,6 +2530,7 @@ "-lLLVMCGData", "-lLLVMSandboxIR", "-lLLVMFrontendAtomic", "-lLLVMObjectYAML", "-lLLVMAArch64CodeGen", "-lclangFrontend", "-lclangBasic", Loading Loading @@ -2630,6 +2631,7 @@ "-lLLVMCGData", "-lLLVMSandboxIR", "-lLLVMFrontendAtomic", "-lLLVMObjectYAML", "-lLLVMAArch64CodeGen", "-lclangFrontend", "-lclangBasic", Loading graph_framework/arithmetic.hpp +166 −124 Original line number Diff line number Diff line Loading @@ -1685,6 +1685,112 @@ namespace graph { class multiply_node final : public branch_node<T, SAFE_MATH> { private: //------------------------------------------------------------------------------ /// @brief Try to reduce paterns of constant times nested fma nodes. /// /// c1*fma(...,x,c2) -> fma(...,x,c3) /// /// @param[in] trial The fma node to try to reduce. /// @returns The reduced node or null if it could not reduce the node. //------------------------------------------------------------------------------ shared_leaf<T, SAFE_MATH> reduce_nested_fma_times_constant(shared_leaf<T, SAFE_MATH> trial) { auto temp = fma_cast(trial); if (temp.get()) { if (is_constant_combineable(this->left, temp->get_left()) && is_constant_combineable(this->left, temp->get_right())) { return fma(this->left*temp->get_left(), temp->get_middle(), this->left*temp->get_right()); } else { auto temp2 = reduce_nested_fma_times_constant(temp->get_left()); if (temp2.get()) { return fma(temp2, temp->get_middle(), this->left*temp->get_right()); } } } return null_leaf<T, SAFE_MATH> (); } //------------------------------------------------------------------------------ /// @brief Try to expand nested fma node. /// /// fma(...,x,c2)*(c3 + x)* -> fma(...,x,c4) /// /// @param[in] trial The fma node to try to expand. /// @param[in] add The add node to try to expand. /// @returns The expanded node or null if it could not expanded the node. //------------------------------------------------------------------------------ shared_leaf<T, SAFE_MATH> expand_nested_fma_times_add(shared_leaf<T, SAFE_MATH> trial, shared_add<T, SAFE_MATH> add) { auto temp = fma_cast(trial); if (temp.get()) { if (add->get_right()->is_match(temp->get_middle()) && is_constant_combineable(add->get_left(), temp->get_right())) { auto temp2 = expand_nested_fma_times_add2(temp->get_left(), temp, add); if (temp2.get()) { return fma(temp2, add->get_right(), temp->get_right()*add->get_left()); } else if (is_constant_combineable(add->get_left(), temp->get_left())) { return fma(fma(temp->get_left(), add->get_right(), add->get_left()*temp->get_left() + temp->get_right()), add->get_right(), temp->get_right()*add->get_left()); } } } return null_leaf<T, SAFE_MATH> (); } //------------------------------------------------------------------------------ /// @brief Try to expand nested fma node. /// /// fma(...,x,c2)*(c3 + x)* -> fma(...,x,c4) /// /// @param[in] trial The fma node to try to reduce. /// @param[in] last The last fma node. /// @param[in] add The add node to try to expand. /// @returns The expanded node or null if it could not expanded the node. //------------------------------------------------------------------------------ shared_leaf<T, SAFE_MATH> expand_nested_fma_times_add2(shared_leaf<T, SAFE_MATH> trial, shared_leaf<T, SAFE_MATH> last, shared_add<T, SAFE_MATH> add) { auto temp = fma_cast(trial); auto temp2 = fma_cast(last); assert(temp2.get() && "Assumed a fma node."); if (temp.get()) { if (add->get_right()->is_match(temp->get_middle()) && is_constant_combineable(add->get_left(), temp->get_left()) && is_constant_combineable(add->get_left(), temp->get_right())) { return fma(fma(temp->get_left(), add->get_right(), add->get_left()*temp->get_left() + temp->get_right()), add->get_right(), add->get_left()*temp->get_right() + temp2->get_right()); } else { auto temp3 = expand_nested_fma_times_add2(temp->get_left(), temp, add); if (temp3.get()) { return fma(temp3, add->get_right(), add->get_left()*temp->get_right() + temp2->get_right()); } } } return null_leaf<T, SAFE_MATH> (); } //------------------------------------------------------------------------------ /// @brief Convert node pointer to a string. /// /// @param[in] l Left node pointer. Loading Loading @@ -1904,31 +2010,13 @@ namespace graph { return (this->left*rm->get_left())*rm->get_right(); } auto rmlfma = fma_cast(rm->get_left()); if (rmlfma.get()) { if (is_constant_combineable(this->left, rmlfma->get_left()) && is_constant_combineable(this->left, rmlfma->get_right())) { return fma(this->left*rmlfma->get_left(), rmlfma->get_middle(), this->left*rmlfma->get_right())*rm->get_right(); } auto rmlfmalfma = fma_cast(rmlfma->get_left()); if (rmlfmalfma.get()) { if (is_constant_combineable(this->left, rmlfmalfma->get_left()) && is_constant_combineable(this->left, rmlfmalfma->get_right()) && is_constant_combineable(this->left, rmlfma->get_right())) { return fma(fma(this->left*rmlfmalfma->get_left(), rmlfmalfma->get_middle(), this->left*rmlfmalfma->get_right()), rmlfma->get_middle(), this->left*rmlfma->get_right())*rm->get_right(); } } // c1*(fma(c2,x,c3)*y)-> fma(c4,x,c5)*y // c1*(fma(fma(c2,x,c3),x,c4)*y)-> fma(fma(c5,x,c6),x,c7)*y // c1*(fma(fma(fma(c2,x,c3),x,c4),x,c5)*y)-> fma(fma(fma(c6,x,c7),x,c8),x,c9)*y // etc... auto temp = this->reduce_nested_fma_times_constant(rm->get_left()); if (temp.get()) { return temp*rm->get_right(); } } Loading Loading @@ -2326,76 +2414,24 @@ namespace graph { } } // c3*fma(c1,a,c2) -> fma(c4,a,c5) auto rfma = fma_cast(this->right); if (rfma.get()) { if (is_constant_combineable(this->left, rfma->get_left()) && is_constant_combineable(this->left, rfma->get_right())) { return fma(this->left*rfma->get_left(), rfma->get_middle(), this->left*rfma->get_right()); } auto rfmalfma = fma_cast(rfma->get_left()); if (rfmalfma.get()) { if (is_constant_combineable(this->left, rfmalfma->get_left()) && is_constant_combineable(this->left, rfmalfma->get_right()) && is_constant_combineable(this->left, rfma->get_right())) { return fma(fma(this->left*rfmalfma->get_left(), rfmalfma->get_middle(), this->left*rfmalfma->get_right()), rfma->get_middle(), this->left*rfma->get_right()); } auto rfmalfmalfma = fma_cast(rfmalfma->get_left()); if (rfmalfmalfma.get()) { if (is_constant_combineable(this->left, rfmalfmalfma->get_left()) && is_constant_combineable(this->left, rfmalfmalfma->get_right()) && is_constant_combineable(this->left, rfmalfma->get_right()) && is_constant_combineable(this->left, rfma->get_right())) { return fma(fma(fma(this->left*rfmalfmalfma->get_left(), rfmalfmalfma->get_middle(), this->left*rfmalfmalfma->get_right()), rfmalfma->get_middle(), this->left*rfmalfma->get_right()), rfma->get_middle(), this->left*rfma->get_right()); } } } // c1*fma(c2,x,c3) -> fma(c4,x,c5) // c1*fma(fma(c2,x,c3),x,c4) -> fma(fma(c5,x,c6),x,c7) // c1*fma(fma(fma(c2,x,c3),x,c4),x,c5) -> fma(fma(fma(c6,x,c7),x,c8),x,c9) // etc... auto fma_reduce = this->reduce_nested_fma_times_constant(this->right); if (fma_reduce.get()) { return fma_reduce; } // fma(c1,x,c2)*(c3 + x) -> fma(fma(c1,x,c4),x,c5) auto lfma = fma_cast(this->left); auto ra = add_cast(this->right); if (lfma.get() && ra.get()) { if (ra->get_right()->is_match(lfma->get_middle()) && is_constant_combineable(ra->get_left(), lfma->get_left()) && is_constant_combineable(ra->get_left(), lfma->get_right())) { return fma(fma(lfma->get_left(), ra->get_right(), ra->get_left()*lfma->get_left() + lfma->get_right()), ra->get_right(), lfma->get_right()*ra->get_left()); } // fma(fma(c1,x,c2),x,c3)*(c4 + x) -> fma(fma(fma(c1,x,c5),x,c6),x,c7) auto lfmalfma = fma_cast(lfma->get_left()); if (ra->get_right()->is_match(lfma->get_middle()) && ra->get_right()->is_match(lfmalfma->get_middle()) && is_constant_combineable(ra->get_left(), lfma->get_right()) && is_constant_combineable(ra->get_left(), lfmalfma->get_right()) && is_constant_combineable(ra->get_left(), lfmalfma->get_left())) { return fma(fma(fma(lfmalfma->get_left(), ra->get_right(), ra->get_left()*lfmalfma->get_left() + lfmalfma->get_right()), ra->get_right(), ra->get_left()*lfmalfma->get_right() + lfma->get_right()), ra->get_right(), ra->get_left()*lfma->get_right()); // etc... auto ra = add_cast(this->right); if (ra.get()) { auto fma_expand = this->expand_nested_fma_times_add(this->left, ra); if (fma_expand.get()) { return fma_expand; } } Loading Loading @@ -3661,6 +3697,42 @@ namespace graph { class fma_node final : public triple_node<T, SAFE_MATH> { private: //------------------------------------------------------------------------------ /// @brief Reduced nested fma nodes. /// /// fma(...,a - c1,c2) -> fma(...,a,c3) /// /// @param[in] sub The sub node to try to expand. /// @returns The reduced node or null if it could not reduce the node. //------------------------------------------------------------------------------ shared_leaf<T, SAFE_MATH> reduce_nested_fma(shared_subtract<T, SAFE_MATH> sub) { auto temp = fma_cast(this->left); if (temp.get()) { if (is_constant_combineable(sub->get_right(), temp->get_left()) && is_constant_combineable(sub->get_right(), temp->get_right()) && is_constant_combineable(this->right, temp->get_right()) && temp->get_middle()->is_match(sub->get_left())) { return fma(fma(temp->get_left(), sub->get_left(), temp->get_right() - temp->get_left()*sub->get_right()), sub->get_left(), this->right - temp->get_right()*sub->get_right()); } else { if (temp->get_middle()->is_match(sub->get_left()) && is_constant_combineable(sub->get_right(), this->right)) { auto temp2 = temp->reduce_nested_fma(sub); if (temp2.get()) { return fma(temp2, sub->get_left(), this->right - temp->get_right()*sub->get_right()); } } } } return this->shared_from_this(); } //------------------------------------------------------------------------------ /// @brief Convert node pointer to a string. /// /// @param[in] l Left node pointer. Loading Loading @@ -3787,39 +3859,9 @@ namespace graph { this->right - this->left*ms->get_right()); } auto lfma = fma_cast(this->left); if (lfma.get()) { if (is_constant_combineable(ms->get_right(), lfma->get_left()) && is_constant_combineable(ms->get_right(), lfma->get_right()) && is_constant_combineable(this->right, lfma->get_right()) && lfma->get_middle()->is_match(ms->get_left())) { return fma(fma(lfma->get_left(), ms->get_left(), lfma->get_right() - lfma->get_left()*ms->get_right()), ms->get_left(), this->right - lfma->get_right()*ms->get_right()); } auto lfmalfma = fma_cast(lfma->get_left()); if (lfmalfma.get()) { if (lfma->get_middle()->is_match(ms->get_left()) && lfmalfma->get_middle()->is_match(ms->get_left()) && is_constant_combineable(ms->get_right(), lfmalfma->get_left()) && is_constant_combineable(ms->get_right(), lfmalfma->get_right()) && is_constant_combineable(ms->get_right(), lfma->get_right()) && is_constant_combineable(ms->get_right(), this->right)) { return fma(fma(fma(lfmalfma->get_left(), ms->get_left(), lfmalfma->get_right() - lfmalfma->get_left()*ms->get_right()), ms->get_left(), lfma->get_right() - lfmalfma->get_right()*ms->get_right()), ms->get_left(), this->right - lfma->get_right()*ms->get_right()); } } auto temp = this->reduce_nested_fma(ms); if (temp.get() != this) { return temp; } } Loading graph_framework/node.hpp +12 −0 Original line number Diff line number Diff line Loading @@ -335,6 +335,18 @@ namespace graph { /// Convenience type alias for shared leaf nodes. template<jit::float_scalar T, bool SAFE_MATH=false> using shared_leaf = std::shared_ptr<leaf_node<T, SAFE_MATH>>; //------------------------------------------------------------------------------ /// @brief Create a null leaf. /// /// @tparam T Base type of the calculation. /// @tparam SAFE_MATH Use safe math operations. /// /// @returns A null leaf. //------------------------------------------------------------------------------ template<jit::float_scalar T, bool SAFE_MATH=false> constexpr shared_leaf<T, SAFE_MATH> null_leaf() { return shared_leaf<T, SAFE_MATH> (); } /// Convenience type alias for a vector of output nodes. template<jit::float_scalar T, bool SAFE_MATH=false> using output_nodes = std::vector<shared_leaf<T, SAFE_MATH>>; Loading graph_tests/arithmetic_test.cpp +57 −0 Original line number Diff line number Diff line Loading @@ -1967,6 +1967,17 @@ template<jit::float_scalar T> void test_multiply() { v1, 1.0)) && "Expected fma(fma(fma(2,x,23),x,30,x,1))"); // c1*(fma(fma(fma(c2,x,c3),x,c4),x,c5)*y) -> fma(fma(fma(c6,x,c7),x,c8),x,c9)*y auto consume2 = 10.0*(graph::fma(graph::fma(graph::fma(5.0,v1,0.4),v1,0.3),v1,0.3)*v2); assert(consume2->is_match(graph::fma(graph::fma(graph::fma(50.0, v1, 4.0), v1, 3.0), v1, 3.0)*v2) && "Expected fma(fma(fma(50,x,4),x,3),x,3)*y"); } //------------------------------------------------------------------------------ Loading Loading @@ -3777,6 +3788,52 @@ template<jit::float_scalar T> void test_fma() { var_a, -10.0)) && "Expected fma(fma(fma(2,x,16),x,-10),x,-10)"); /* // fma(fma(c1,a,c2),b - c3,fma(c4,a,c5) -> fma(fma(c6,a,c8),b,fma(c9,a,c10)) auto gather3 = graph::fma(graph::fma(2.0, var_a, 20.0), var_b - 2.0, graph::fma(2.0, var_a, 21.0)); assert(gather3->is_match(graph::fma(graph::fma(2.0,var_a,20.0),var_b,graph::fma(2.0,var_a,-19.0))) && "Expected fma(fma(2,x,20),y,fma(2,x,-19))"); // fma(fma(fma(fma(c1,a,c2),a,c3),a,c4),b - c5,fma(fma(fma(c6,a,c7),a,c8),a,c9)) -> // fma(fma(fma(fma(c10,a,c11),a,c12),a,c13),b,fma(fma(fma(c14,a,c15),a,c16),a,c17)) auto gather4 = graph::fma(graph::fma(graph::fma(graph::fma(2.0, var_a, 20.0), var_a, 30.0), var_a, 50.0), var_b - 2.0, graph::fma(graph::fma(graph::fma(2.0, var_a, 21.0), var_a, 31.0), var_a, 51.0)); assert(gather3->is_match(graph::fma(graph::fma(graph::fma(graph::fma(2.0, var_a, 20.0), var_a, 30.0), var_a, 50.0), var_b , graph::fma(graph::fma(graph::fma(2.0, var_a, -19.0), var_a, -29.0), var_a, -49.0))) && "Expected fma(fma(fma(fma(2,x,20),x,30),x,50),b,fma(fma(fma(2,x,-19),-29),-49)"); */ } //------------------------------------------------------------------------------ Loading Loading
graph_framework.xcodeproj/project.pbxproj +2 −0 Original line number Diff line number Diff line Loading @@ -2530,6 +2530,7 @@ "-lLLVMCGData", "-lLLVMSandboxIR", "-lLLVMFrontendAtomic", "-lLLVMObjectYAML", "-lLLVMAArch64CodeGen", "-lclangFrontend", "-lclangBasic", Loading Loading @@ -2630,6 +2631,7 @@ "-lLLVMCGData", "-lLLVMSandboxIR", "-lLLVMFrontendAtomic", "-lLLVMObjectYAML", "-lLLVMAArch64CodeGen", "-lclangFrontend", "-lclangBasic", Loading
graph_framework/arithmetic.hpp +166 −124 Original line number Diff line number Diff line Loading @@ -1685,6 +1685,112 @@ namespace graph { class multiply_node final : public branch_node<T, SAFE_MATH> { private: //------------------------------------------------------------------------------ /// @brief Try to reduce paterns of constant times nested fma nodes. /// /// c1*fma(...,x,c2) -> fma(...,x,c3) /// /// @param[in] trial The fma node to try to reduce. /// @returns The reduced node or null if it could not reduce the node. //------------------------------------------------------------------------------ shared_leaf<T, SAFE_MATH> reduce_nested_fma_times_constant(shared_leaf<T, SAFE_MATH> trial) { auto temp = fma_cast(trial); if (temp.get()) { if (is_constant_combineable(this->left, temp->get_left()) && is_constant_combineable(this->left, temp->get_right())) { return fma(this->left*temp->get_left(), temp->get_middle(), this->left*temp->get_right()); } else { auto temp2 = reduce_nested_fma_times_constant(temp->get_left()); if (temp2.get()) { return fma(temp2, temp->get_middle(), this->left*temp->get_right()); } } } return null_leaf<T, SAFE_MATH> (); } //------------------------------------------------------------------------------ /// @brief Try to expand nested fma node. /// /// fma(...,x,c2)*(c3 + x)* -> fma(...,x,c4) /// /// @param[in] trial The fma node to try to expand. /// @param[in] add The add node to try to expand. /// @returns The expanded node or null if it could not expanded the node. //------------------------------------------------------------------------------ shared_leaf<T, SAFE_MATH> expand_nested_fma_times_add(shared_leaf<T, SAFE_MATH> trial, shared_add<T, SAFE_MATH> add) { auto temp = fma_cast(trial); if (temp.get()) { if (add->get_right()->is_match(temp->get_middle()) && is_constant_combineable(add->get_left(), temp->get_right())) { auto temp2 = expand_nested_fma_times_add2(temp->get_left(), temp, add); if (temp2.get()) { return fma(temp2, add->get_right(), temp->get_right()*add->get_left()); } else if (is_constant_combineable(add->get_left(), temp->get_left())) { return fma(fma(temp->get_left(), add->get_right(), add->get_left()*temp->get_left() + temp->get_right()), add->get_right(), temp->get_right()*add->get_left()); } } } return null_leaf<T, SAFE_MATH> (); } //------------------------------------------------------------------------------ /// @brief Try to expand nested fma node. /// /// fma(...,x,c2)*(c3 + x)* -> fma(...,x,c4) /// /// @param[in] trial The fma node to try to reduce. /// @param[in] last The last fma node. /// @param[in] add The add node to try to expand. /// @returns The expanded node or null if it could not expanded the node. //------------------------------------------------------------------------------ shared_leaf<T, SAFE_MATH> expand_nested_fma_times_add2(shared_leaf<T, SAFE_MATH> trial, shared_leaf<T, SAFE_MATH> last, shared_add<T, SAFE_MATH> add) { auto temp = fma_cast(trial); auto temp2 = fma_cast(last); assert(temp2.get() && "Assumed a fma node."); if (temp.get()) { if (add->get_right()->is_match(temp->get_middle()) && is_constant_combineable(add->get_left(), temp->get_left()) && is_constant_combineable(add->get_left(), temp->get_right())) { return fma(fma(temp->get_left(), add->get_right(), add->get_left()*temp->get_left() + temp->get_right()), add->get_right(), add->get_left()*temp->get_right() + temp2->get_right()); } else { auto temp3 = expand_nested_fma_times_add2(temp->get_left(), temp, add); if (temp3.get()) { return fma(temp3, add->get_right(), add->get_left()*temp->get_right() + temp2->get_right()); } } } return null_leaf<T, SAFE_MATH> (); } //------------------------------------------------------------------------------ /// @brief Convert node pointer to a string. /// /// @param[in] l Left node pointer. Loading Loading @@ -1904,31 +2010,13 @@ namespace graph { return (this->left*rm->get_left())*rm->get_right(); } auto rmlfma = fma_cast(rm->get_left()); if (rmlfma.get()) { if (is_constant_combineable(this->left, rmlfma->get_left()) && is_constant_combineable(this->left, rmlfma->get_right())) { return fma(this->left*rmlfma->get_left(), rmlfma->get_middle(), this->left*rmlfma->get_right())*rm->get_right(); } auto rmlfmalfma = fma_cast(rmlfma->get_left()); if (rmlfmalfma.get()) { if (is_constant_combineable(this->left, rmlfmalfma->get_left()) && is_constant_combineable(this->left, rmlfmalfma->get_right()) && is_constant_combineable(this->left, rmlfma->get_right())) { return fma(fma(this->left*rmlfmalfma->get_left(), rmlfmalfma->get_middle(), this->left*rmlfmalfma->get_right()), rmlfma->get_middle(), this->left*rmlfma->get_right())*rm->get_right(); } } // c1*(fma(c2,x,c3)*y)-> fma(c4,x,c5)*y // c1*(fma(fma(c2,x,c3),x,c4)*y)-> fma(fma(c5,x,c6),x,c7)*y // c1*(fma(fma(fma(c2,x,c3),x,c4),x,c5)*y)-> fma(fma(fma(c6,x,c7),x,c8),x,c9)*y // etc... auto temp = this->reduce_nested_fma_times_constant(rm->get_left()); if (temp.get()) { return temp*rm->get_right(); } } Loading Loading @@ -2326,76 +2414,24 @@ namespace graph { } } // c3*fma(c1,a,c2) -> fma(c4,a,c5) auto rfma = fma_cast(this->right); if (rfma.get()) { if (is_constant_combineable(this->left, rfma->get_left()) && is_constant_combineable(this->left, rfma->get_right())) { return fma(this->left*rfma->get_left(), rfma->get_middle(), this->left*rfma->get_right()); } auto rfmalfma = fma_cast(rfma->get_left()); if (rfmalfma.get()) { if (is_constant_combineable(this->left, rfmalfma->get_left()) && is_constant_combineable(this->left, rfmalfma->get_right()) && is_constant_combineable(this->left, rfma->get_right())) { return fma(fma(this->left*rfmalfma->get_left(), rfmalfma->get_middle(), this->left*rfmalfma->get_right()), rfma->get_middle(), this->left*rfma->get_right()); } auto rfmalfmalfma = fma_cast(rfmalfma->get_left()); if (rfmalfmalfma.get()) { if (is_constant_combineable(this->left, rfmalfmalfma->get_left()) && is_constant_combineable(this->left, rfmalfmalfma->get_right()) && is_constant_combineable(this->left, rfmalfma->get_right()) && is_constant_combineable(this->left, rfma->get_right())) { return fma(fma(fma(this->left*rfmalfmalfma->get_left(), rfmalfmalfma->get_middle(), this->left*rfmalfmalfma->get_right()), rfmalfma->get_middle(), this->left*rfmalfma->get_right()), rfma->get_middle(), this->left*rfma->get_right()); } } } // c1*fma(c2,x,c3) -> fma(c4,x,c5) // c1*fma(fma(c2,x,c3),x,c4) -> fma(fma(c5,x,c6),x,c7) // c1*fma(fma(fma(c2,x,c3),x,c4),x,c5) -> fma(fma(fma(c6,x,c7),x,c8),x,c9) // etc... auto fma_reduce = this->reduce_nested_fma_times_constant(this->right); if (fma_reduce.get()) { return fma_reduce; } // fma(c1,x,c2)*(c3 + x) -> fma(fma(c1,x,c4),x,c5) auto lfma = fma_cast(this->left); auto ra = add_cast(this->right); if (lfma.get() && ra.get()) { if (ra->get_right()->is_match(lfma->get_middle()) && is_constant_combineable(ra->get_left(), lfma->get_left()) && is_constant_combineable(ra->get_left(), lfma->get_right())) { return fma(fma(lfma->get_left(), ra->get_right(), ra->get_left()*lfma->get_left() + lfma->get_right()), ra->get_right(), lfma->get_right()*ra->get_left()); } // fma(fma(c1,x,c2),x,c3)*(c4 + x) -> fma(fma(fma(c1,x,c5),x,c6),x,c7) auto lfmalfma = fma_cast(lfma->get_left()); if (ra->get_right()->is_match(lfma->get_middle()) && ra->get_right()->is_match(lfmalfma->get_middle()) && is_constant_combineable(ra->get_left(), lfma->get_right()) && is_constant_combineable(ra->get_left(), lfmalfma->get_right()) && is_constant_combineable(ra->get_left(), lfmalfma->get_left())) { return fma(fma(fma(lfmalfma->get_left(), ra->get_right(), ra->get_left()*lfmalfma->get_left() + lfmalfma->get_right()), ra->get_right(), ra->get_left()*lfmalfma->get_right() + lfma->get_right()), ra->get_right(), ra->get_left()*lfma->get_right()); // etc... auto ra = add_cast(this->right); if (ra.get()) { auto fma_expand = this->expand_nested_fma_times_add(this->left, ra); if (fma_expand.get()) { return fma_expand; } } Loading Loading @@ -3661,6 +3697,42 @@ namespace graph { class fma_node final : public triple_node<T, SAFE_MATH> { private: //------------------------------------------------------------------------------ /// @brief Reduced nested fma nodes. /// /// fma(...,a - c1,c2) -> fma(...,a,c3) /// /// @param[in] sub The sub node to try to expand. /// @returns The reduced node or null if it could not reduce the node. //------------------------------------------------------------------------------ shared_leaf<T, SAFE_MATH> reduce_nested_fma(shared_subtract<T, SAFE_MATH> sub) { auto temp = fma_cast(this->left); if (temp.get()) { if (is_constant_combineable(sub->get_right(), temp->get_left()) && is_constant_combineable(sub->get_right(), temp->get_right()) && is_constant_combineable(this->right, temp->get_right()) && temp->get_middle()->is_match(sub->get_left())) { return fma(fma(temp->get_left(), sub->get_left(), temp->get_right() - temp->get_left()*sub->get_right()), sub->get_left(), this->right - temp->get_right()*sub->get_right()); } else { if (temp->get_middle()->is_match(sub->get_left()) && is_constant_combineable(sub->get_right(), this->right)) { auto temp2 = temp->reduce_nested_fma(sub); if (temp2.get()) { return fma(temp2, sub->get_left(), this->right - temp->get_right()*sub->get_right()); } } } } return this->shared_from_this(); } //------------------------------------------------------------------------------ /// @brief Convert node pointer to a string. /// /// @param[in] l Left node pointer. Loading Loading @@ -3787,39 +3859,9 @@ namespace graph { this->right - this->left*ms->get_right()); } auto lfma = fma_cast(this->left); if (lfma.get()) { if (is_constant_combineable(ms->get_right(), lfma->get_left()) && is_constant_combineable(ms->get_right(), lfma->get_right()) && is_constant_combineable(this->right, lfma->get_right()) && lfma->get_middle()->is_match(ms->get_left())) { return fma(fma(lfma->get_left(), ms->get_left(), lfma->get_right() - lfma->get_left()*ms->get_right()), ms->get_left(), this->right - lfma->get_right()*ms->get_right()); } auto lfmalfma = fma_cast(lfma->get_left()); if (lfmalfma.get()) { if (lfma->get_middle()->is_match(ms->get_left()) && lfmalfma->get_middle()->is_match(ms->get_left()) && is_constant_combineable(ms->get_right(), lfmalfma->get_left()) && is_constant_combineable(ms->get_right(), lfmalfma->get_right()) && is_constant_combineable(ms->get_right(), lfma->get_right()) && is_constant_combineable(ms->get_right(), this->right)) { return fma(fma(fma(lfmalfma->get_left(), ms->get_left(), lfmalfma->get_right() - lfmalfma->get_left()*ms->get_right()), ms->get_left(), lfma->get_right() - lfmalfma->get_right()*ms->get_right()), ms->get_left(), this->right - lfma->get_right()*ms->get_right()); } } auto temp = this->reduce_nested_fma(ms); if (temp.get() != this) { return temp; } } Loading
graph_framework/node.hpp +12 −0 Original line number Diff line number Diff line Loading @@ -335,6 +335,18 @@ namespace graph { /// Convenience type alias for shared leaf nodes. template<jit::float_scalar T, bool SAFE_MATH=false> using shared_leaf = std::shared_ptr<leaf_node<T, SAFE_MATH>>; //------------------------------------------------------------------------------ /// @brief Create a null leaf. /// /// @tparam T Base type of the calculation. /// @tparam SAFE_MATH Use safe math operations. /// /// @returns A null leaf. //------------------------------------------------------------------------------ template<jit::float_scalar T, bool SAFE_MATH=false> constexpr shared_leaf<T, SAFE_MATH> null_leaf() { return shared_leaf<T, SAFE_MATH> (); } /// Convenience type alias for a vector of output nodes. template<jit::float_scalar T, bool SAFE_MATH=false> using output_nodes = std::vector<shared_leaf<T, SAFE_MATH>>; Loading
graph_tests/arithmetic_test.cpp +57 −0 Original line number Diff line number Diff line Loading @@ -1967,6 +1967,17 @@ template<jit::float_scalar T> void test_multiply() { v1, 1.0)) && "Expected fma(fma(fma(2,x,23),x,30,x,1))"); // c1*(fma(fma(fma(c2,x,c3),x,c4),x,c5)*y) -> fma(fma(fma(c6,x,c7),x,c8),x,c9)*y auto consume2 = 10.0*(graph::fma(graph::fma(graph::fma(5.0,v1,0.4),v1,0.3),v1,0.3)*v2); assert(consume2->is_match(graph::fma(graph::fma(graph::fma(50.0, v1, 4.0), v1, 3.0), v1, 3.0)*v2) && "Expected fma(fma(fma(50,x,4),x,3),x,3)*y"); } //------------------------------------------------------------------------------ Loading Loading @@ -3777,6 +3788,52 @@ template<jit::float_scalar T> void test_fma() { var_a, -10.0)) && "Expected fma(fma(fma(2,x,16),x,-10),x,-10)"); /* // fma(fma(c1,a,c2),b - c3,fma(c4,a,c5) -> fma(fma(c6,a,c8),b,fma(c9,a,c10)) auto gather3 = graph::fma(graph::fma(2.0, var_a, 20.0), var_b - 2.0, graph::fma(2.0, var_a, 21.0)); assert(gather3->is_match(graph::fma(graph::fma(2.0,var_a,20.0),var_b,graph::fma(2.0,var_a,-19.0))) && "Expected fma(fma(2,x,20),y,fma(2,x,-19))"); // fma(fma(fma(fma(c1,a,c2),a,c3),a,c4),b - c5,fma(fma(fma(c6,a,c7),a,c8),a,c9)) -> // fma(fma(fma(fma(c10,a,c11),a,c12),a,c13),b,fma(fma(fma(c14,a,c15),a,c16),a,c17)) auto gather4 = graph::fma(graph::fma(graph::fma(graph::fma(2.0, var_a, 20.0), var_a, 30.0), var_a, 50.0), var_b - 2.0, graph::fma(graph::fma(graph::fma(2.0, var_a, 21.0), var_a, 31.0), var_a, 51.0)); assert(gather3->is_match(graph::fma(graph::fma(graph::fma(graph::fma(2.0, var_a, 20.0), var_a, 30.0), var_a, 50.0), var_b , graph::fma(graph::fma(graph::fma(2.0, var_a, -19.0), var_a, -29.0), var_a, -49.0))) && "Expected fma(fma(fma(fma(2,x,20),x,30),x,50),b,fma(fma(fma(2,x,-19),-29),-49)"); */ } //------------------------------------------------------------------------------ Loading