Commit 596cf102 authored by cianciosa's avatar cianciosa
Browse files

Avoid creating new piecewise constants by prefactoring the scale and offsets out.

parent d20627d2
Loading
Loading
Loading
Loading
+64 −56
Original line number Diff line number Diff line
@@ -193,9 +193,11 @@ namespace graph {
            auto pr1 = piecewise_1D_cast(this->right);

            if (pl1.get() && (r.get() || pl1->is_arg_match(this->right))) {
                return piecewise_1D(this->evaluate(), pl1->get_arg());
                return piecewise_1D(this->evaluate(), pl1->get_arg(),
                                    pl1->get_scale(), pl1->get_offset());
            } else if (pr1.get() && (l.get() || pr1->is_arg_match(this->left))) {
                return piecewise_1D(this->evaluate(), pr1->get_arg());
                return piecewise_1D(this->evaluate(), pr1->get_arg(),
                                    pr1->get_scale(), pr1->get_offset());
            }

            auto pl2 = piecewise_2D_cast(this->left);
@@ -204,13 +206,13 @@ namespace graph {
            if (pl2.get() && (r.get() || pl2->is_arg_match(this->right))) {
                return piecewise_2D(this->evaluate(),
                                    pl2->get_num_columns(),
                                    pl2->get_left(),
                                    pl2->get_right());
                                    pl2->get_left(), pl2->get_x_scale(), pl2->get_x_offset(),
                                    pl2->get_right(), pl2->get_y_scale(), pl2->get_y_offset());
            } else if (pr2.get() && (l.get() || pr2->is_arg_match(this->left))) {
                return piecewise_2D(this->evaluate(),
                                    pr2->get_num_columns(),
                                    pr2->get_left(),
                                    pr2->get_right());
                                    pr2->get_left(), pr2->get_x_scale(), pr2->get_x_offset(),
                                    pr2->get_right(), pr2->get_y_scale(), pr2->get_y_offset());
            }

//  Combine 2D and 1D piecewise constants if a row or column matches.
@@ -219,29 +221,29 @@ namespace graph {
                result.add_row(pr2->evaluate());
                return piecewise_2D(result,
                                    pr2->get_num_columns(),
                                    pr2->get_left(),
                                    pr2->get_right());
                                    pr2->get_left(), pr2->get_x_scale(), pr2->get_x_offset(),
                                    pr2->get_right(), pr2->get_y_scale(), pr2->get_y_offset());
            } else if (pr2.get() && pr2->is_col_match(this->left)) {
                backend::buffer<T> result = pl1->evaluate();
                result.add_col(pr2->evaluate());
                return piecewise_2D(result,
                                    pr2->get_num_columns(),
                                    pr2->get_left(),
                                    pr2->get_right());
                                    pr2->get_left(), pr2->get_x_scale(), pr2->get_x_offset(),
                                    pr2->get_right(), pr2->get_y_scale(), pr2->get_y_offset());
            } else if (pl2.get() && pl2->is_row_match(this->right)) {
                backend::buffer<T> result = pl2->evaluate();
                result.add_row(pr1->evaluate());
                return piecewise_2D(result,
                                    pl2->get_num_columns(),
                                    pl2->get_left(),
                                    pl2->get_right());
                                    pl2->get_left(), pl2->get_x_scale(), pl2->get_x_offset(),
                                    pl2->get_right(), pl2->get_y_scale(), pl2->get_y_offset());
            } else if (pl2.get() && pl2->is_col_match(this->right)) {
                backend::buffer<T> result = pl2->evaluate();
                result.add_col(pr1->evaluate());
                return piecewise_2D(result,
                                    pl2->get_num_columns(),
                                    pl2->get_left(),
                                    pl2->get_right());
                                    pl2->get_left(), pl2->get_x_scale(), pl2->get_x_offset(),
                                    pl2->get_right(), pl2->get_y_scale(), pl2->get_y_offset());
            }

//  Idenity reductions.
@@ -916,9 +918,11 @@ namespace graph {
            auto pr1 = piecewise_1D_cast(this->right);

            if (pl1.get() && (r.get() || pl1->is_arg_match(this->right))) {
                return piecewise_1D(this->evaluate(), pl1->get_arg());
                return piecewise_1D(this->evaluate(), pl1->get_arg(),
                                    pl1->get_scale(), pl1->get_offset());
            } else if (pr1.get() && (l.get() || pr1->is_arg_match(this->left))) {
                return piecewise_1D(this->evaluate(), pr1->get_arg());
                return piecewise_1D(this->evaluate(), pr1->get_arg(),
                                    pr1->get_scale(), pr1->get_offset());
            }

            auto pl2 = piecewise_2D_cast(this->left);
@@ -927,13 +931,13 @@ namespace graph {
            if (pl2.get() && (r.get() || pl2->is_arg_match(this->right))) {
                return piecewise_2D(this->evaluate(),
                                    pl2->get_num_columns(),
                                    pl2->get_left(),
                                    pl2->get_right());
                                    pl2->get_left(), pl2->get_x_scale(), pl2->get_x_offset(),
                                    pl2->get_right(), pl2->get_y_scale(), pl2->get_y_offset());
            } else if (pr2.get() && (l.get() || pr2->is_arg_match(this->left))) {
                return piecewise_2D(this->evaluate(),
                                    pr2->get_num_columns(),
                                    pr2->get_left(),
                                    pr2->get_right());
                                    pr2->get_left(), pr2->get_x_scale(), pr2->get_x_offset(),
                                    pr2->get_right(), pr2->get_y_scale(), pr2->get_y_offset());
            }

//  Combine 2D and 1D piecewise constants if a row or column matches.
@@ -942,29 +946,29 @@ namespace graph {
                result.subtract_row(pr2->evaluate());
                return piecewise_2D(result,
                                    pr2->get_num_columns(),
                                    pr2->get_left(),
                                    pr2->get_right());
                                    pr2->get_left(), pr2->get_x_scale(), pr2->get_x_offset(),
                                    pr2->get_right(), pr2->get_y_scale(), pr2->get_y_offset());
            } else if (pr2.get() && pr2->is_col_match(this->left)) {
                backend::buffer<T> result = pl1->evaluate();
                result.subtract_col(pr2->evaluate());
                return piecewise_2D(result,
                                    pr2->get_num_columns(),
                                    pr2->get_left(),
                                    pr2->get_right());
                                    pr2->get_left(), pr2->get_x_scale(), pr2->get_x_offset(),
                                    pr2->get_right(), pr2->get_y_scale(), pr2->get_y_offset());
            } else if (pl2.get() && pl2->is_row_match(this->right)) {
                backend::buffer<T> result = pl2->evaluate();
                result.subtract_row(pr1->evaluate());
                return piecewise_2D(result,
                                    pl2->get_num_columns(),
                                    pl2->get_left(),
                                    pl2->get_right());
                                    pl2->get_left(), pl2->get_x_scale(), pl2->get_x_offset(),
                                    pl2->get_right(), pl2->get_y_scale(), pl2->get_y_offset());
            } else if (pl2.get() && pl2->is_col_match(this->right)) {
                backend::buffer<T> result = pl2->evaluate();
                result.subtract_col(pr1->evaluate());
                return piecewise_2D(result,
                                    pl2->get_num_columns(),
                                    pl2->get_left(),
                                    pl2->get_right());
                                    pl2->get_left(), pl2->get_x_scale(), pl2->get_x_offset(),
                                    pl2->get_right(), pl2->get_y_scale(), pl2->get_y_offset());
            }
// (c1 + a) - c2 -> c3 + a
// c1 - (c2 + a) -> c3 + a
@@ -1866,9 +1870,11 @@ namespace graph {
            auto pr1 = piecewise_1D_cast(this->right);

            if (pl1.get() && (r.get() || pl1->is_arg_match(this->right))) {
                return piecewise_1D(this->evaluate(), pl1->get_arg());
                return piecewise_1D(this->evaluate(), pl1->get_arg(),
                                    pl1->get_scale(), pl1->get_offset());
            } else if (pr1.get() && (l.get() || pr1->is_arg_match(this->left))) {
                return piecewise_1D(this->evaluate(), pr1->get_arg());
                return piecewise_1D(this->evaluate(), pr1->get_arg(),
                                    pr1->get_scale(), pr1->get_offset());
            }

            auto pl2 = piecewise_2D_cast(this->left);
@@ -1877,13 +1883,13 @@ namespace graph {
            if (pl2.get() && (r.get() || pl2->is_arg_match(this->right))) {
                return piecewise_2D(this->evaluate(),
                                    pl2->get_num_columns(),
                                    pl2->get_left(),
                                    pl2->get_right());
                                    pl2->get_left(), pl2->get_x_scale(), pl2->get_x_offset(),
                                    pl2->get_right(), pl2->get_y_scale(), pl2->get_y_offset());
            } else if (pr2.get() && (l.get() || pr2->is_arg_match(this->left))) {
                return piecewise_2D(this->evaluate(),
                                    pr2->get_num_columns(),
                                    pr2->get_left(),
                                    pr2->get_right());
                                    pr2->get_left(), pr2->get_x_scale(), pr2->get_x_offset(),
                                    pr2->get_right(), pr2->get_y_scale(), pr2->get_y_offset());
            }

//  Combine 2D and 1D piecewise constants if a row or column matches.
@@ -1892,29 +1898,29 @@ namespace graph {
                result.multiply_row(pr2->evaluate());
                return piecewise_2D(result,
                                    pr2->get_num_columns(),
                                    pr2->get_left(),
                                    pr2->get_right());
                                    pr2->get_left(), pr2->get_x_scale(), pr2->get_x_offset(),
                                    pr2->get_right(), pr2->get_y_scale(), pr2->get_y_offset());
            } else if (pr2.get() && pr2->is_col_match(this->left)) {
                backend::buffer<T> result = pl1->evaluate();
                result.multiply_col(pr2->evaluate());
                return piecewise_2D(result,
                                    pr2->get_num_columns(),
                                    pr2->get_left(),
                                    pr2->get_right());
                                    pr2->get_left(), pr2->get_x_scale(), pr2->get_x_offset(),
                                    pr2->get_right(), pr2->get_y_scale(), pr2->get_y_offset());
            } else if (pl2.get() && pl2->is_row_match(this->right)) {
                backend::buffer<T> result = pl2->evaluate();
                result.multiply_row(pr1->evaluate());
                return piecewise_2D(result,
                                    pl2->get_num_columns(),
                                    pl2->get_left(),
                                    pl2->get_right());
                                    pl2->get_left(), pl2->get_x_scale(), pl2->get_x_offset(),
                                    pl2->get_right(), pl2->get_y_scale(), pl2->get_y_offset());
            } else if (pl2.get() && pl2->is_col_match(this->right)) {
                backend::buffer<T> result = pl2->evaluate();
                result.multiply_col(pr1->evaluate());
                return piecewise_2D(result,
                                    pl2->get_num_columns(),
                                    pl2->get_left(),
                                    pl2->get_right());
                                    pl2->get_left(), pl2->get_x_scale(), pl2->get_x_offset(),
                                    pl2->get_right(), pl2->get_y_scale(), pl2->get_y_offset());
            }

//  Move constants to the left.
@@ -2798,9 +2804,11 @@ namespace graph {
            auto pr1 = piecewise_1D_cast(this->right);

            if (pl1.get() && (r.get() || pl1->is_arg_match(this->right))) {
                return piecewise_1D(this->evaluate(), pl1->get_arg());
                return piecewise_1D(this->evaluate(), pl1->get_arg(),
                                    pl1->get_scale(), pl1->get_offset());
            } else if (pr1.get() && (l.get() || pr1->is_arg_match(this->left))) {
                return piecewise_1D(this->evaluate(), pr1->get_arg());
                return piecewise_1D(this->evaluate(), pr1->get_arg(),
                                    pr1->get_scale(), pr1->get_offset());
            }

            auto pl2 = piecewise_2D_cast(this->left);
@@ -2809,13 +2817,13 @@ namespace graph {
            if (pl2.get() && (r.get() || pl2->is_arg_match(this->right))) {
                return piecewise_2D(this->evaluate(),
                                    pl2->get_num_columns(),
                                    pl2->get_left(),
                                    pl2->get_right());
                                    pl2->get_left(), pl2->get_x_scale(), pl2->get_x_offset(),
                                    pl2->get_right(), pl2->get_y_scale(), pl2->get_y_offset());
            } else if (pr2.get() && (l.get() || pr2->is_arg_match(this->left))) {
                return piecewise_2D(this->evaluate(),
                                    pr2->get_num_columns(),
                                    pr2->get_left(),
                                    pr2->get_right());
                                    pr2->get_left(), pr2->get_x_scale(), pr2->get_x_offset(),
                                    pr2->get_right(), pr2->get_y_scale(), pr2->get_y_offset());
            }

//  Combine 2D and 1D piecewise constants if a row or column matches.
@@ -2824,29 +2832,29 @@ namespace graph {
                result.divide_row(pr2->evaluate());
                return piecewise_2D(result,
                                    pr2->get_num_columns(),
                                    pr2->get_left(),
                                    pr2->get_right());
                                    pr2->get_left(), pr2->get_x_scale(), pr2->get_x_offset(),
                                    pr2->get_right(), pr2->get_y_scale(), pr2->get_y_offset());
            } else if (pr2.get() && pr2->is_col_match(this->left)) {
                backend::buffer<T> result = pl1->evaluate();
                result.divide_col(pr2->evaluate());
                return piecewise_2D(result,
                                    pr2->get_num_columns(),
                                    pr2->get_left(),
                                    pr2->get_right());
                                    pr2->get_left(), pr2->get_x_scale(), pr2->get_x_offset(),
                                    pr2->get_right(), pr2->get_y_scale(), pr2->get_y_offset());
            } else if (pl2.get() && pl2->is_row_match(this->right)) {
                backend::buffer<T> result = pl2->evaluate();
                result.divide_row(pr1->evaluate());
                return piecewise_2D(result,
                                    pl2->get_num_columns(),
                                    pl2->get_left(),
                                    pl2->get_right());
                                    pl2->get_left(), pl2->get_x_scale(), pl2->get_x_offset(),
                                    pl2->get_right(), pl2->get_y_scale(), pl2->get_y_offset());
            } else if (pl2.get() && pl2->is_col_match(this->right)) {
                backend::buffer<T> result = pl2->evaluate();
                result.divide_col(pr1->evaluate());
                return piecewise_2D(result,
                                    pl2->get_num_columns(),
                                    pl2->get_left(),
                                    pl2->get_right());
                                    pl2->get_left(), pl2->get_x_scale(), pl2->get_x_offset(),
                                    pl2->get_right(), pl2->get_y_scale(), pl2->get_y_offset());
            }

            if (this->left->is_match(this->right)) {
+143 −119

File changed.

Preview size limit exceeded, changes collapsed.

+36 −26
Original line number Diff line number Diff line
@@ -75,15 +75,17 @@ namespace graph {
            auto ap1 = piecewise_1D_cast(this->arg);
            if (ap1.get()) {
                return piecewise_1D(this->evaluate(),
                                    ap1->get_arg());
                                    ap1->get_arg(),
                                    ap1->get_scale(),
                                    ap1->get_offset());
            }

            auto ap2 = piecewise_2D_cast(this->arg);
            if (ap2.get()) {
                return piecewise_2D(this->evaluate(),
                                    ap2->get_num_columns(),
                                    ap2->get_left(),
                                    ap2->get_right());
                                    ap2->get_left(), ap2->get_x_scale(), ap2->get_x_offset(),
                                    ap2->get_right(), ap2->get_y_scale(), ap2->get_y_offset());
            }

//  Handle casses like sqrt(c*x) where c is constant or cases like
@@ -371,15 +373,17 @@ namespace graph {
            auto ap1 = piecewise_1D_cast(this->arg);
            if (ap1.get()) {
                return piecewise_1D(this->evaluate(),
                                    ap1->get_arg());
                                    ap1->get_arg(),
                                    ap1->get_scale(),
                                    ap1->get_offset());
            }

            auto ap2 = piecewise_2D_cast(this->arg);
            if (ap2.get()) {
                return piecewise_2D(this->evaluate(),
                                    ap2->get_num_columns(),
                                    ap2->get_left(),
                                    ap2->get_right());
                                    ap2->get_left(), ap2->get_x_scale(), ap2->get_x_offset(),
                                    ap2->get_right(), ap2->get_y_scale(), ap2->get_y_offset());
            }

//  Reduce exp(log(x)) -> x
@@ -638,15 +642,17 @@ namespace graph {
            auto ap1 = piecewise_1D_cast(this->arg);
            if (ap1.get()) {
                return piecewise_1D(this->evaluate(),
                                    ap1->get_arg());
                                    ap1->get_arg(),
                                    ap1->get_scale(),
                                    ap1->get_offset());
            }

            auto ap2 = piecewise_2D_cast(this->arg);
            if (ap2.get()) {
                return piecewise_2D(this->evaluate(),
                                    ap2->get_num_columns(),
                                    ap2->get_left(),
                                    ap2->get_right());
                                    ap2->get_left(), ap2->get_x_scale(), ap2->get_x_offset(),
                                    ap2->get_right(), ap2->get_y_scale(), ap2->get_y_offset());
            }

//  Reduce log(exp(x)) -> x
@@ -900,9 +906,11 @@ namespace graph {
            auto pl1 = piecewise_1D_cast(this->left);
            auto pr1 = piecewise_1D_cast(this->right);
            if (pl1.get() && (rc.get() || pl1->is_arg_match(this->right))) {
                return piecewise_1D(this->evaluate(), pl1->get_arg());
                return piecewise_1D(this->evaluate(), pl1->get_arg(),
                                    pl1->get_scale(), pl1->get_offset());
            } else if (pr1.get() && (lc.get() || pr1->is_arg_match(this->left))) {
                return piecewise_1D(this->evaluate(), pr1->get_arg());
                return piecewise_1D(this->evaluate(), pr1->get_arg(),
                                    pr1->get_scale(), pr1->get_offset());
            }
            
            auto pl2 = piecewise_2D_cast(this->left);
@@ -910,13 +918,13 @@ namespace graph {
            if (pl2.get() && (rc.get() || pl2->is_arg_match(this->right))) {
                return piecewise_2D(this->evaluate(),
                                    pl2->get_num_columns(),
                                    pl2->get_left(),
                                    pl2->get_right());
                                    pl2->get_left(), pl2->get_x_scale(), pl2->get_x_offset(),
                                    pl2->get_right(), pl2->get_y_scale(), pl2->get_y_offset());
            } else if (pr2.get() && (lc.get() || pr2->is_arg_match(this->left))) {
                return piecewise_2D(this->evaluate(),
                                    pr2->get_num_columns(),
                                    pr2->get_left(),
                                    pr2->get_right());
                                    pr2->get_left(), pr2->get_x_scale(), pr2->get_x_offset(),
                                    pr2->get_right(), pr2->get_y_scale(), pr2->get_y_offset());
            }

//  Combine 2D and 1D piecewise constants if a row or column matches.
@@ -925,29 +933,29 @@ namespace graph {
                result.pow_row(pr2->evaluate());
                return piecewise_2D(result,
                                    pr2->get_num_columns(),
                                    pr2->get_left(),
                                    pr2->get_right());
                                    pr2->get_left(), pr2->get_x_scale(), pr2->get_x_offset(),
                                    pr2->get_right(), pr2->get_y_scale(), pr2->get_y_offset());
            } else if (pr2.get() && pr2->is_col_match(this->left)) {
                backend::buffer<T> result = pl1->evaluate();
                result.pow_col(pr2->evaluate());
                return piecewise_2D(result,
                                    pr2->get_num_columns(),
                                    pr2->get_left(),
                                    pr2->get_right());
                                    pr2->get_left(), pr2->get_x_scale(), pr2->get_x_offset(),
                                    pr2->get_right(), pr2->get_y_scale(), pr2->get_y_offset());
            } else if (pl2.get() && pl2->is_row_match(this->right)) {
                backend::buffer<T> result = pl2->evaluate();
                result.pow_row(pr1->evaluate());
                return piecewise_2D(result,
                                    pl2->get_num_columns(),
                                    pl2->get_left(),
                                    pl2->get_right());
                                    pl2->get_left(), pl2->get_x_scale(), pl2->get_x_offset(),
                                    pl2->get_right(), pl2->get_y_scale(), pl2->get_y_offset());
            } else if (pl2.get() && pl2->is_col_match(this->right)) {
                backend::buffer<T> result = pl2->evaluate();
                result.pow_col(pr1->evaluate());
                return piecewise_2D(result,
                                    pl2->get_num_columns(),
                                    pl2->get_left(),
                                    pl2->get_right());
                                    pl2->get_left(), pl2->get_x_scale(), pl2->get_x_offset(),
                                    pl2->get_right(), pl2->get_y_scale(), pl2->get_y_offset());
            }

            auto lp = pow_cast(this->left);
@@ -1472,15 +1480,17 @@ namespace graph {
            auto ap1 = piecewise_1D_cast(this->arg);
            if (ap1.get()) {
                return piecewise_1D(this->evaluate(),
                                    ap1->get_arg());
                                    ap1->get_arg(),
                                    ap1->get_scale(),
                                    ap1->get_offset());
            }

            auto ap2 = piecewise_2D_cast(this->arg);
            if (ap2.get()) {
                return piecewise_2D(this->evaluate(),
                                    ap2->get_num_columns(),
                                    ap2->get_left(),
                                    ap2->get_right());
                                    ap2->get_left(), ap2->get_x_scale(), ap2->get_x_offset(),
                                    ap2->get_right(), ap2->get_y_scale(), ap2->get_y_offset());
            }

            return this->shared_from_this();
+178 −48

File changed.

Preview size limit exceeded, changes collapsed.

+26 −20

File changed.

Preview size limit exceeded, changes collapsed.

Loading