From 7a485d11cf7fbc9d2f0aba04242e1ea5d60923b8 Mon Sep 17 00:00:00 2001 From: cianciosa Date: Fri, 3 Jan 2025 16:57:36 -0500 Subject: [PATCH] Reduced chained constant subtractions. --- graph_framework/arithmetic.hpp | 43 ++++++++++++++++++++++++++++++--- graph_tests/arithmetic_test.cpp | 43 +++++++++++++++++++++++++++++++++ 2 files changed, 82 insertions(+), 4 deletions(-) diff --git a/graph_framework/arithmetic.hpp b/graph_framework/arithmetic.hpp index b124d8b..37d06b7 100644 --- a/graph_framework/arithmetic.hpp +++ b/graph_framework/arithmetic.hpp @@ -919,6 +919,42 @@ namespace graph { pl2->get_left(), pl2->get_right()); } +// (c1 + a) - c2 -> c3 + a +// c1 - (c2 + a) -> c3 + a + auto la = add_cast(this->left); + if (la.get()) { + if (is_constant_combineable(la->get_left(), this->right)) { + return (la->get_left() - this->right) + la->get_right(); + } + } + auto ra = add_cast(this->right); + if (ra.get()) { + if (is_constant_combineable(this->left, ra->get_left())) { + return (this->left - ra->get_left()) + ra->get_right(); + } + } + +// (c1 - a) - c2 -> c3 - a +// (a - c3) - c2 -> a + c3 + auto ls = subtract_cast(this->left); + if (ls.get()) { + if (is_constant_combineable(ls->get_left(), this->right)) { + return (ls->get_left() - this->right) - ls->get_right(); + } else if (is_constant_combineable(ls->get_right(), + this->right)) { + return -(ls->get_right() + this->right) - ls->get_left(); + } + } +// c1 - (c2 - a) -> c3 - a +// c1 - (a - c2) -> c3 - a + auto rs = subtract_cast(this->right); + if (rs.get()) { + if (is_constant_combineable(this->left, rs->get_left())) { + return (this->left - rs->get_left()) - rs->get_right(); + } else if (is_constant_combineable(this->left, rs->get_right())) { + return (this->left + rs->get_right()) - rs->get_left(); + } + } // Common factor reduction. If the left and right are both muliply nodes check // for a common factor. So you can change a*b - a*c -> a*(b - c). @@ -938,7 +974,7 @@ namespace graph { lm->get_left()*lmra->get_left() - this->right); } } - +// c1*(c2 - a) - c3 -> c4 - c1*a auto lmrs = subtract_cast(lm->get_right()); if (lmrs.get()) { if (is_constant_combineable(lm->get_left(), @@ -1077,7 +1113,6 @@ namespace graph { } // Chained subtraction reductions. - auto ls = subtract_cast(this->left); if (ls.get()) { auto lrm = multiply_cast(ls->get_right()); if (lrm.get() && rm.get()) { @@ -4666,8 +4701,8 @@ namespace graph { this->left->to_latex(); } std::cout << " "; - if (add_cast(this->right).get() || - subtract_cast(this->right).get()) { + if (add_cast(this->middle).get() || + subtract_cast(this->middle).get()) { std::cout << "\\left("; this->middle->to_latex(); std::cout << "\\right)"; diff --git a/graph_tests/arithmetic_test.cpp b/graph_tests/arithmetic_test.cpp index 6b66d38..90915b9 100644 --- a/graph_tests/arithmetic_test.cpp +++ b/graph_tests/arithmetic_test.cpp @@ -872,6 +872,49 @@ template void test_subtract() { "Expected var_a"); assert(common_var5_cast->get_left()->is_match(2.0/var_a - 3.0/var_b) && "Expected 2/a - 3/b"); + + auto constant_combine = (1.0 - var_a) - 2.0; + auto constant_combine_cast = graph::subtract_cast(constant_combine); + assert(constant_combine_cast.get() && "Expected a subtract node."); + assert(constant_combine_cast->get_left()->evaluate().at(0) == static_cast (-1.0) && + "Expected -1 on the left."); + assert(constant_combine_cast->get_right()->is_match(var_a) && + "Expected a on the right."); + auto constant_combine2 = (1.0 + var_a) - 2.0; + auto constant_combine_cast2 = graph::add_cast(constant_combine2); + assert(constant_combine_cast2.get() && "Expected a add node."); + assert(constant_combine_cast2->get_left()->evaluate().at(0) == static_cast (-1.0) && + "Expected -1 on the left."); + assert(constant_combine_cast2->get_right()->is_match(var_a) && + "Expected a on the right."); + auto constant_combine3 = (var_a - 1.0) - 2.0; + auto constant_combine3_cast = graph::subtract_cast(constant_combine3); + assert(constant_combine3_cast.get() && "Expected a subtract node."); + assert(constant_combine3_cast->get_left()->evaluate().at(0) == static_cast (-3.0) && + "Expected -1 on the left."); + assert(constant_combine3_cast->get_right()->is_match(var_a) && + "Expected a on the right."); + auto constant_combine4 = 2.0 - (1.0 - var_a); + auto constant_combine4_cast = graph::subtract_cast(constant_combine4); + assert(constant_combine4_cast.get() && "Expected a subtract node."); + assert(constant_combine4_cast->get_left()->evaluate().at(0) == static_cast (1.0) && + "Expected 1 on the left."); + assert(constant_combine4_cast->get_right()->is_match(var_a) && + "Expected a on the right."); + auto constant_combine5 = 2.0 - (1.0 + var_a); + auto constant_combine5_cast = graph::add_cast(constant_combine5); + assert(constant_combine5_cast.get() && "Expected an add node."); + assert(constant_combine5_cast->get_left()->evaluate().at(0) == static_cast (1.0) && + "Expected 1 on the left."); + assert(constant_combine5_cast->get_right()->is_match(var_a) && + "Expected a on the right."); + auto constant_combine6 = 2.0 - (var_a - 1.0); + auto constant_combine6_cast = graph::subtract_cast(constant_combine6); + assert(constant_combine6_cast.get() && "Expected a subtract node."); + assert(constant_combine6_cast->get_left()->evaluate().at(0) == static_cast (3.0) && + "Expected 3 on the left."); + assert(constant_combine6_cast->get_right()->is_match(var_a) && + "Expected a on the right."); } //------------------------------------------------------------------------------ -- GitLab