Loading graph_framework/arithmetic.hpp +2 −2 Original line number Diff line number Diff line Loading @@ -618,8 +618,8 @@ namespace graph { // (a - b*c) - d*e -> a - (b*c + d*e) // (a - b/c) - d/e -> a - (b/c + d/e) auto lsrd = divide_cast(ls->get_right()); if ((multiply_cast(ls->get_right()).get() && rm.get()) || (divide_cast(ls->get_right()).get() && rd.get())) { if ((multiply_cast(ls->get_right()).get() && (rm.get() || rd.get())) || (divide_cast(ls->get_right()).get() && (rm.get() || rd.get()))) { return ls->get_left() - (ls->get_right() + this->right); } } Loading graph_tests/arithmetic_test.cpp +25 −13 Original line number Diff line number Diff line Loading @@ -496,6 +496,31 @@ template<typename T> void test_subtract() { auto common_factor9 = two*(var_b*var_c) - var_b*var_a; assert(graph::multiply_cast(common_factor9).get() && "Expected multiply node."); // (a - b*c) - d*c -> a - (b + d)*c auto chained_subtract_multiply = (var_a - var_b*var_c) - var_d*var_c; auto chained_subtract_multiply_cast = graph::subtract_cast(chained_subtract_multiply); assert(chained_subtract_multiply_cast.get() && "Expected subtract node."); assert(graph::multiply_cast(chained_subtract_multiply_cast->get_right()).get() && "Expected a multiply node on the left."); // (a - b*c) - c/d -> a - (b*c + c/d) auto chained_subtract_multiply2 = (var_a - var_b*var_c) - var_c/var_d; auto chained_subtract_multiply_cast2 = graph::subtract_cast(chained_subtract_multiply2); assert(chained_subtract_multiply_cast2.get() && "Expected subtract node."); assert(graph::fma_cast(chained_subtract_multiply_cast2->get_right()).get() && "Expected a fused multiply add node on the left."); // (a - b/c) - d/c -> a - (b + d)*c auto chained_subtract_divide = (var_a - var_b/var_c) - var_d/var_c; auto chained_subtract_divide_cast = graph::subtract_cast(chained_subtract_divide); assert(chained_subtract_divide_cast.get() && "Expected subtract node."); assert(graph::divide_cast(chained_subtract_divide_cast->get_right()).get() && "Expected a divide node on the left."); // (a - b/c) - d*c -> a - (d*c + b/c) auto chained_subtract_divide2 = (var_a - var_b/var_c) - var_d*var_c; auto chained_subtract_divide_cast2 = graph::subtract_cast(chained_subtract_divide2); assert(chained_subtract_divide_cast2.get() && "Expected subtract node."); assert(graph::fma_cast(chained_subtract_divide_cast2->get_right()).get() && "Expected a fused multiply add node on the left."); } //------------------------------------------------------------------------------ Loading Loading @@ -1542,19 +1567,6 @@ template<typename T> void test_fma() { auto divide_factor4 = graph::fma(var_c/var_b, var_a, var_d/var_b); assert(graph::divide_cast(divide_factor4).get() && "Expetced a divide node."); // (a - b*c) - d*c -> a - (b + d)*c auto chained_subtract_multiply = (var_a - var_b*var_c) - var_d*var_c; auto chained_subtract_multiply_cast = graph::subtract_cast(chained_subtract_multiply); assert(chained_subtract_multiply_cast.get() && "Expected subtract node."); assert(graph::multiply_cast(chained_subtract_multiply_cast->get_right()).get() && "Expected a multiply node on the left."); // (a - b/c) - d/c -> a - (b + d)*c auto chained_subtract_divide = (var_a - var_b/var_c) - var_d/var_c; auto chained_subtract_divide_cast = graph::subtract_cast(chained_subtract_divide); assert(chained_subtract_divide_cast.get() && "Expected subtract node."); assert(graph::divide_cast(chained_subtract_divide_cast->get_right()).get() && "Expected a divide node on the left."); } //------------------------------------------------------------------------------ Loading Loading
graph_framework/arithmetic.hpp +2 −2 Original line number Diff line number Diff line Loading @@ -618,8 +618,8 @@ namespace graph { // (a - b*c) - d*e -> a - (b*c + d*e) // (a - b/c) - d/e -> a - (b/c + d/e) auto lsrd = divide_cast(ls->get_right()); if ((multiply_cast(ls->get_right()).get() && rm.get()) || (divide_cast(ls->get_right()).get() && rd.get())) { if ((multiply_cast(ls->get_right()).get() && (rm.get() || rd.get())) || (divide_cast(ls->get_right()).get() && (rm.get() || rd.get()))) { return ls->get_left() - (ls->get_right() + this->right); } } Loading
graph_tests/arithmetic_test.cpp +25 −13 Original line number Diff line number Diff line Loading @@ -496,6 +496,31 @@ template<typename T> void test_subtract() { auto common_factor9 = two*(var_b*var_c) - var_b*var_a; assert(graph::multiply_cast(common_factor9).get() && "Expected multiply node."); // (a - b*c) - d*c -> a - (b + d)*c auto chained_subtract_multiply = (var_a - var_b*var_c) - var_d*var_c; auto chained_subtract_multiply_cast = graph::subtract_cast(chained_subtract_multiply); assert(chained_subtract_multiply_cast.get() && "Expected subtract node."); assert(graph::multiply_cast(chained_subtract_multiply_cast->get_right()).get() && "Expected a multiply node on the left."); // (a - b*c) - c/d -> a - (b*c + c/d) auto chained_subtract_multiply2 = (var_a - var_b*var_c) - var_c/var_d; auto chained_subtract_multiply_cast2 = graph::subtract_cast(chained_subtract_multiply2); assert(chained_subtract_multiply_cast2.get() && "Expected subtract node."); assert(graph::fma_cast(chained_subtract_multiply_cast2->get_right()).get() && "Expected a fused multiply add node on the left."); // (a - b/c) - d/c -> a - (b + d)*c auto chained_subtract_divide = (var_a - var_b/var_c) - var_d/var_c; auto chained_subtract_divide_cast = graph::subtract_cast(chained_subtract_divide); assert(chained_subtract_divide_cast.get() && "Expected subtract node."); assert(graph::divide_cast(chained_subtract_divide_cast->get_right()).get() && "Expected a divide node on the left."); // (a - b/c) - d*c -> a - (d*c + b/c) auto chained_subtract_divide2 = (var_a - var_b/var_c) - var_d*var_c; auto chained_subtract_divide_cast2 = graph::subtract_cast(chained_subtract_divide2); assert(chained_subtract_divide_cast2.get() && "Expected subtract node."); assert(graph::fma_cast(chained_subtract_divide_cast2->get_right()).get() && "Expected a fused multiply add node on the left."); } //------------------------------------------------------------------------------ Loading Loading @@ -1542,19 +1567,6 @@ template<typename T> void test_fma() { auto divide_factor4 = graph::fma(var_c/var_b, var_a, var_d/var_b); assert(graph::divide_cast(divide_factor4).get() && "Expetced a divide node."); // (a - b*c) - d*c -> a - (b + d)*c auto chained_subtract_multiply = (var_a - var_b*var_c) - var_d*var_c; auto chained_subtract_multiply_cast = graph::subtract_cast(chained_subtract_multiply); assert(chained_subtract_multiply_cast.get() && "Expected subtract node."); assert(graph::multiply_cast(chained_subtract_multiply_cast->get_right()).get() && "Expected a multiply node on the left."); // (a - b/c) - d/c -> a - (b + d)*c auto chained_subtract_divide = (var_a - var_b/var_c) - var_d/var_c; auto chained_subtract_divide_cast = graph::subtract_cast(chained_subtract_divide); assert(chained_subtract_divide_cast.get() && "Expected subtract node."); assert(graph::divide_cast(chained_subtract_divide_cast->get_right()).get() && "Expected a divide node on the left."); } //------------------------------------------------------------------------------ Loading