Loading graph_framework/arithmetic.hpp +39 −4 Original line number Diff line number Diff line Loading @@ -919,6 +919,42 @@ namespace graph { pl2->get_left(), pl2->get_right()); } // (c1 + a) - c2 -> c3 + a // c1 - (c2 + a) -> c3 + a auto la = add_cast(this->left); if (la.get()) { if (is_constant_combineable(la->get_left(), this->right)) { return (la->get_left() - this->right) + la->get_right(); } } auto ra = add_cast(this->right); if (ra.get()) { if (is_constant_combineable(this->left, ra->get_left())) { return (this->left - ra->get_left()) + ra->get_right(); } } // (c1 - a) - c2 -> c3 - a // (a - c3) - c2 -> a + c3 auto ls = subtract_cast(this->left); if (ls.get()) { if (is_constant_combineable(ls->get_left(), this->right)) { return (ls->get_left() - this->right) - ls->get_right(); } else if (is_constant_combineable(ls->get_right(), this->right)) { return -(ls->get_right() + this->right) - ls->get_left(); } } // c1 - (c2 - a) -> c3 - a // c1 - (a - c2) -> c3 - a auto rs = subtract_cast(this->right); if (rs.get()) { if (is_constant_combineable(this->left, rs->get_left())) { return (this->left - rs->get_left()) - rs->get_right(); } else if (is_constant_combineable(this->left, rs->get_right())) { return (this->left + rs->get_right()) - rs->get_left(); } } // Common factor reduction. If the left and right are both muliply nodes check // for a common factor. So you can change a*b - a*c -> a*(b - c). Loading @@ -938,7 +974,7 @@ namespace graph { lm->get_left()*lmra->get_left() - this->right); } } // c1*(c2 - a) - c3 -> c4 - c1*a auto lmrs = subtract_cast(lm->get_right()); if (lmrs.get()) { if (is_constant_combineable(lm->get_left(), Loading Loading @@ -1077,7 +1113,6 @@ namespace graph { } // Chained subtraction reductions. auto ls = subtract_cast(this->left); if (ls.get()) { auto lrm = multiply_cast(ls->get_right()); if (lrm.get() && rm.get()) { Loading Loading @@ -4666,8 +4701,8 @@ namespace graph { this->left->to_latex(); } std::cout << " "; if (add_cast(this->right).get() || subtract_cast(this->right).get()) { if (add_cast(this->middle).get() || subtract_cast(this->middle).get()) { std::cout << "\\left("; this->middle->to_latex(); std::cout << "\\right)"; Loading graph_tests/arithmetic_test.cpp +43 −0 Original line number Diff line number Diff line Loading @@ -872,6 +872,49 @@ template<jit::float_scalar T> void test_subtract() { "Expected var_a"); assert(common_var5_cast->get_left()->is_match(2.0/var_a - 3.0/var_b) && "Expected 2/a - 3/b"); auto constant_combine = (1.0 - var_a) - 2.0; auto constant_combine_cast = graph::subtract_cast(constant_combine); assert(constant_combine_cast.get() && "Expected a subtract node."); assert(constant_combine_cast->get_left()->evaluate().at(0) == static_cast<T> (-1.0) && "Expected -1 on the left."); assert(constant_combine_cast->get_right()->is_match(var_a) && "Expected a on the right."); auto constant_combine2 = (1.0 + var_a) - 2.0; auto constant_combine_cast2 = graph::add_cast(constant_combine2); assert(constant_combine_cast2.get() && "Expected a add node."); assert(constant_combine_cast2->get_left()->evaluate().at(0) == static_cast<T> (-1.0) && "Expected -1 on the left."); assert(constant_combine_cast2->get_right()->is_match(var_a) && "Expected a on the right."); auto constant_combine3 = (var_a - 1.0) - 2.0; auto constant_combine3_cast = graph::subtract_cast(constant_combine3); assert(constant_combine3_cast.get() && "Expected a subtract node."); assert(constant_combine3_cast->get_left()->evaluate().at(0) == static_cast<T> (-3.0) && "Expected -1 on the left."); assert(constant_combine3_cast->get_right()->is_match(var_a) && "Expected a on the right."); auto constant_combine4 = 2.0 - (1.0 - var_a); auto constant_combine4_cast = graph::subtract_cast(constant_combine4); assert(constant_combine4_cast.get() && "Expected a subtract node."); assert(constant_combine4_cast->get_left()->evaluate().at(0) == static_cast<T> (1.0) && "Expected 1 on the left."); assert(constant_combine4_cast->get_right()->is_match(var_a) && "Expected a on the right."); auto constant_combine5 = 2.0 - (1.0 + var_a); auto constant_combine5_cast = graph::add_cast(constant_combine5); assert(constant_combine5_cast.get() && "Expected an add node."); assert(constant_combine5_cast->get_left()->evaluate().at(0) == static_cast<T> (1.0) && "Expected 1 on the left."); assert(constant_combine5_cast->get_right()->is_match(var_a) && "Expected a on the right."); auto constant_combine6 = 2.0 - (var_a - 1.0); auto constant_combine6_cast = graph::subtract_cast(constant_combine6); assert(constant_combine6_cast.get() && "Expected a subtract node."); assert(constant_combine6_cast->get_left()->evaluate().at(0) == static_cast<T> (3.0) && "Expected 3 on the left."); assert(constant_combine6_cast->get_right()->is_match(var_a) && "Expected a on the right."); } //------------------------------------------------------------------------------ Loading Loading
graph_framework/arithmetic.hpp +39 −4 Original line number Diff line number Diff line Loading @@ -919,6 +919,42 @@ namespace graph { pl2->get_left(), pl2->get_right()); } // (c1 + a) - c2 -> c3 + a // c1 - (c2 + a) -> c3 + a auto la = add_cast(this->left); if (la.get()) { if (is_constant_combineable(la->get_left(), this->right)) { return (la->get_left() - this->right) + la->get_right(); } } auto ra = add_cast(this->right); if (ra.get()) { if (is_constant_combineable(this->left, ra->get_left())) { return (this->left - ra->get_left()) + ra->get_right(); } } // (c1 - a) - c2 -> c3 - a // (a - c3) - c2 -> a + c3 auto ls = subtract_cast(this->left); if (ls.get()) { if (is_constant_combineable(ls->get_left(), this->right)) { return (ls->get_left() - this->right) - ls->get_right(); } else if (is_constant_combineable(ls->get_right(), this->right)) { return -(ls->get_right() + this->right) - ls->get_left(); } } // c1 - (c2 - a) -> c3 - a // c1 - (a - c2) -> c3 - a auto rs = subtract_cast(this->right); if (rs.get()) { if (is_constant_combineable(this->left, rs->get_left())) { return (this->left - rs->get_left()) - rs->get_right(); } else if (is_constant_combineable(this->left, rs->get_right())) { return (this->left + rs->get_right()) - rs->get_left(); } } // Common factor reduction. If the left and right are both muliply nodes check // for a common factor. So you can change a*b - a*c -> a*(b - c). Loading @@ -938,7 +974,7 @@ namespace graph { lm->get_left()*lmra->get_left() - this->right); } } // c1*(c2 - a) - c3 -> c4 - c1*a auto lmrs = subtract_cast(lm->get_right()); if (lmrs.get()) { if (is_constant_combineable(lm->get_left(), Loading Loading @@ -1077,7 +1113,6 @@ namespace graph { } // Chained subtraction reductions. auto ls = subtract_cast(this->left); if (ls.get()) { auto lrm = multiply_cast(ls->get_right()); if (lrm.get() && rm.get()) { Loading Loading @@ -4666,8 +4701,8 @@ namespace graph { this->left->to_latex(); } std::cout << " "; if (add_cast(this->right).get() || subtract_cast(this->right).get()) { if (add_cast(this->middle).get() || subtract_cast(this->middle).get()) { std::cout << "\\left("; this->middle->to_latex(); std::cout << "\\right)"; Loading
graph_tests/arithmetic_test.cpp +43 −0 Original line number Diff line number Diff line Loading @@ -872,6 +872,49 @@ template<jit::float_scalar T> void test_subtract() { "Expected var_a"); assert(common_var5_cast->get_left()->is_match(2.0/var_a - 3.0/var_b) && "Expected 2/a - 3/b"); auto constant_combine = (1.0 - var_a) - 2.0; auto constant_combine_cast = graph::subtract_cast(constant_combine); assert(constant_combine_cast.get() && "Expected a subtract node."); assert(constant_combine_cast->get_left()->evaluate().at(0) == static_cast<T> (-1.0) && "Expected -1 on the left."); assert(constant_combine_cast->get_right()->is_match(var_a) && "Expected a on the right."); auto constant_combine2 = (1.0 + var_a) - 2.0; auto constant_combine_cast2 = graph::add_cast(constant_combine2); assert(constant_combine_cast2.get() && "Expected a add node."); assert(constant_combine_cast2->get_left()->evaluate().at(0) == static_cast<T> (-1.0) && "Expected -1 on the left."); assert(constant_combine_cast2->get_right()->is_match(var_a) && "Expected a on the right."); auto constant_combine3 = (var_a - 1.0) - 2.0; auto constant_combine3_cast = graph::subtract_cast(constant_combine3); assert(constant_combine3_cast.get() && "Expected a subtract node."); assert(constant_combine3_cast->get_left()->evaluate().at(0) == static_cast<T> (-3.0) && "Expected -1 on the left."); assert(constant_combine3_cast->get_right()->is_match(var_a) && "Expected a on the right."); auto constant_combine4 = 2.0 - (1.0 - var_a); auto constant_combine4_cast = graph::subtract_cast(constant_combine4); assert(constant_combine4_cast.get() && "Expected a subtract node."); assert(constant_combine4_cast->get_left()->evaluate().at(0) == static_cast<T> (1.0) && "Expected 1 on the left."); assert(constant_combine4_cast->get_right()->is_match(var_a) && "Expected a on the right."); auto constant_combine5 = 2.0 - (1.0 + var_a); auto constant_combine5_cast = graph::add_cast(constant_combine5); assert(constant_combine5_cast.get() && "Expected an add node."); assert(constant_combine5_cast->get_left()->evaluate().at(0) == static_cast<T> (1.0) && "Expected 1 on the left."); assert(constant_combine5_cast->get_right()->is_match(var_a) && "Expected a on the right."); auto constant_combine6 = 2.0 - (var_a - 1.0); auto constant_combine6_cast = graph::subtract_cast(constant_combine6); assert(constant_combine6_cast.get() && "Expected a subtract node."); assert(constant_combine6_cast->get_left()->evaluate().at(0) == static_cast<T> (3.0) && "Expected 3 on the left."); assert(constant_combine6_cast->get_right()->is_match(var_a) && "Expected a on the right."); } //------------------------------------------------------------------------------ Loading