Commit a8d259ca authored by Cianciosa, Mark's avatar Cianciosa, Mark
Browse files

Add piecewise reductions.

parent d33632ca
Loading
Loading
Loading
Loading
+6 −5
Original line number Diff line number Diff line
@@ -83,7 +83,8 @@ int main(int argc, const char * argv[]) {
            }

            omega->set(static_cast<base> (500.0));
            x->set(static_cast<base> (-12.0));
            //x->set(static_cast<base> (-12.0));
            x->set(static_cast<base> (2.5));
            //x->set(static_cast<base> (0.0));
            y->set(static_cast<base> (0.0));
            z->set(static_cast<base> (0.0));
@@ -92,13 +93,13 @@ int main(int argc, const char * argv[]) {
            ky->set(static_cast<base> (0.0));
            kz->set(static_cast<base> (10.0));

            //auto eq = equilibrium::make_efit<base, use_safe_math> (NC_FILE, sync);
            auto eq = equilibrium::make_efit<base, use_safe_math> (NC_FILE, sync);
            //auto eq = equilibrium::make_slab_density<base, use_safe_math> ();
            auto eq = equilibrium::make_slab_field<base, use_safe_math> ();
            //auto eq = equilibrium::make_slab_field<base, use_safe_math> ();
            //auto eq = equilibrium::make_no_magnetic_field<base, use_safe_math> ();

            const base endtime = static_cast<base> (10.0);
            //const base endtime = static_cast<base> (0.25);
            //const base endtime = static_cast<base> (10.0);
            const base endtime = static_cast<base> (0.25);
            const base dt = endtime/static_cast<base> (num_times);

            //auto dt_var = graph::variable(num_rays, static_cast<base> (dt), "dt");
+120 −0
Original line number Diff line number Diff line
@@ -80,6 +80,30 @@ namespace graph {
                return this->right + this->left;
            }

            auto pl1 = piecewise_1D_cast(this->left);
            auto pr1 = piecewise_1D_cast(this->right);

            if (pl1.get() && (r.get() || pl1->is_arg_match(this->right))) {
                return piecewise_1D(this->evaluate(), pl1->get_arg());
            } else if (pr1.get() && (l.get() || pr1->is_arg_match(this->left))) {
                return piecewise_1D(this->evaluate(), pr1->get_arg());
            }

            auto pl2 = piecewise_2D_cast(this->left);
            auto pr2 = piecewise_2D_cast(this->right);

            if (pl2.get() && (r.get() || pl2->is_arg_match(this->right))) {
                return piecewise_2D(this->evaluate(),
                                    pl2->get_num_columns(),
                                    pl2->get_left(),
                                    pl2->get_right());
            } else if (pr2.get() && (l.get() || pr2->is_arg_match(this->left))) {
                return piecewise_2D(this->evaluate(),
                                    pr2->get_num_columns(),
                                    pr2->get_left(),
                                    pr2->get_right());
            }

//  Idenity reductions.
            if (this->left->is_match(this->right)) {
                return two<T, SAFE_MATH> ()*this->left;
@@ -462,6 +486,30 @@ namespace graph {
                return constant<T, SAFE_MATH> (this->evaluate());
            }

            auto pl1 = piecewise_1D_cast(this->left);
            auto pr1 = piecewise_1D_cast(this->right);

            if (pl1.get() && (r.get() || pl1->is_arg_match(this->right))) {
                return piecewise_1D(this->evaluate(), pl1->get_arg());
            } else if (pr1.get() && (l.get() || pr1->is_arg_match(this->left))) {
                return piecewise_1D(this->evaluate(), pr1->get_arg());
            }

            auto pl2 = piecewise_2D_cast(this->left);
            auto pr2 = piecewise_2D_cast(this->right);

            if (pl2.get() && (r.get() || pl2->is_arg_match(this->right))) {
                return piecewise_2D(this->evaluate(),
                                    pl2->get_num_columns(),
                                    pl2->get_left(),
                                    pl2->get_right());
            } else if (pr2.get() && (l.get() || pr2->is_arg_match(this->left))) {
                return piecewise_2D(this->evaluate(),
                                    pr2->get_num_columns(),
                                    pr2->get_left(),
                                    pr2->get_right());
            }

//  Common factor reduction. If the left and right are both muliply nodes check
//  for a common factor. So you can change a*b - a*c -> a*(b - c).
            auto lm = multiply_cast(this->left);
@@ -885,6 +933,30 @@ namespace graph {
                return constant<T, SAFE_MATH> (this->evaluate());
            }

            auto pl1 = piecewise_1D_cast(this->left);
            auto pr1 = piecewise_1D_cast(this->right);

            if (pl1.get() && (r.get() || pl1->is_arg_match(this->right))) {
                return piecewise_1D(this->evaluate(), pl1->get_arg());
            } else if (pr1.get() && (l.get() || pr1->is_arg_match(this->left))) {
                return piecewise_1D(this->evaluate(), pr1->get_arg());
            }

            auto pl2 = piecewise_2D_cast(this->left);
            auto pr2 = piecewise_2D_cast(this->right);

            if (pl2.get() && (r.get() || pl2->is_arg_match(this->right))) {
                return piecewise_2D(this->evaluate(),
                                    pl2->get_num_columns(),
                                    pl2->get_left(),
                                    pl2->get_right());
            } else if (pr2.get() && (l.get() || pr2->is_arg_match(this->left))) {
                return piecewise_2D(this->evaluate(),
                                    pr2->get_num_columns(),
                                    pr2->get_left(),
                                    pr2->get_right());
            }

//  Move constants to the left.
            if (r.get() && !l.get()) {
                return this->right*this->left;
@@ -1407,6 +1479,30 @@ namespace graph {
                return constant<T, SAFE_MATH> (this->evaluate());
            }

            auto pl1 = piecewise_1D_cast(this->left);
            auto pr1 = piecewise_1D_cast(this->right);

            if (pl1.get() && (r.get() || pl1->is_arg_match(this->right))) {
                return piecewise_1D(this->evaluate(), pl1->get_arg());
            } else if (pr1.get() && (l.get() || pr1->is_arg_match(this->left))) {
                return piecewise_1D(this->evaluate(), pr1->get_arg());
            }

            auto pl2 = piecewise_2D_cast(this->left);
            auto pr2 = piecewise_2D_cast(this->right);

            if (pl2.get() && (r.get() || pl2->is_arg_match(this->right))) {
                return piecewise_2D(this->evaluate(),
                                    pl2->get_num_columns(),
                                    pl2->get_left(),
                                    pl2->get_right());
            } else if (pr2.get() && (l.get() || pr2->is_arg_match(this->left))) {
                return piecewise_2D(this->evaluate(),
                                    pr2->get_num_columns(),
                                    pr2->get_left(),
                                    pr2->get_right());
            }

            if (this->left->is_match(this->right)) {
                if (l.get() && l->is(1)) {
                    return this->left;
@@ -1791,6 +1887,30 @@ namespace graph {
                return this->right - this->left;
            }

            auto pl1 = piecewise_1D_cast(this->left);
            auto pm1 = piecewise_1D_cast(this->middle);

            if (pl1.get() && (m.get() || pl1->is_arg_match(this->middle))) {
                return piecewise_1D(this->evaluate(), pl1->get_arg()) + this->right;
            } else if (pm1.get() && (m.get() || pm1->is_arg_match(this->left))) {
                return piecewise_1D(this->evaluate(), pm1->get_arg()) + this->right;
            }

            auto pl2 = piecewise_2D_cast(this->left);
            auto pm2 = piecewise_2D_cast(this->middle);

            if (pl2.get() && (m.get() || pl2->is_arg_match(this->right))) {
                return piecewise_2D(this->evaluate(),
                                    pl2->get_num_columns(),
                                    pl2->get_left(),
                                    pl2->get_right()) + this->right;
            } else if (pm2.get() && (l.get() || pm2->is_arg_match(this->left))) {
                return piecewise_2D(this->evaluate(),
                                    pm2->get_num_columns(),
                                    pm2->get_left(),
                                    pm2->get_right()) + this->right;
            }

//  Common factor reduction. If the left and right are both multiply nodes check
//  for a common factor. So you can change a*b + (a*c) -> a*(b + c).
            auto lm = multiply_cast(this->left);
+34 −38
Original line number Diff line number Diff line
@@ -887,8 +887,7 @@ namespace equilibrium {
        get_psi(graph::shared_leaf<T, SAFE_MATH> x,
                graph::shared_leaf<T, SAFE_MATH> y,
                graph::shared_leaf<T, SAFE_MATH> z) {
            auto r = graph::sqrt(x*x + y*y);
            return get_psi(r, z);
            return get_psi(graph::sqrt(x*x + y*y), z);
        }

//------------------------------------------------------------------------------
@@ -954,8 +953,7 @@ namespace equilibrium {
        get_electron_density(graph::shared_leaf<T, SAFE_MATH> x,
                             graph::shared_leaf<T, SAFE_MATH> y,
                             graph::shared_leaf<T, SAFE_MATH> z) {
            auto psi = get_psi(x, y, z);
            auto psi_norm = (psi - psimin)/dpsi;
            auto psi_norm = (get_psi(x, y, z) - psimin)/dpsi;

            auto c0_temp = graph::piecewise_1D(ne_c0, psi_norm);
            auto c1_temp = graph::piecewise_1D(ne_c1, psi_norm);
@@ -997,8 +995,7 @@ namespace equilibrium {
        get_pressure(graph::shared_leaf<T, SAFE_MATH> x,
                     graph::shared_leaf<T, SAFE_MATH> y,
                     graph::shared_leaf<T, SAFE_MATH> z) {
            auto psi = get_psi(x, y, z);
            auto psi_norm = (psi - psimin)/dpsi;
            auto psi_norm = (get_psi(x, y, z) - psimin)/dpsi;

            auto c0_temp = graph::piecewise_1D(pres_c0, psi_norm);
            auto c1_temp = graph::piecewise_1D(pres_c1, psi_norm);
@@ -1023,8 +1020,7 @@ namespace equilibrium {
        get_electron_temperature(graph::shared_leaf<T, SAFE_MATH> x,
                                 graph::shared_leaf<T, SAFE_MATH> y,
                                 graph::shared_leaf<T, SAFE_MATH> z) {
            auto psi = get_psi(x, y, z);
            auto psi_norm = (psi - psimin)/dpsi;
            auto psi_norm = (get_psi(x, y, z) - psimin)/dpsi;

            auto c0_temp = graph::piecewise_1D(te_c0, psi_norm);
            auto c1_temp = graph::piecewise_1D(te_c1, psi_norm);
@@ -1093,7 +1089,7 @@ namespace equilibrium {
            auto r = graph::sqrt(x*x + y*y);
            auto phi = graph::atan(x, y);
            auto none = graph::none<T, SAFE_MATH> ();
            auto psi = get_psi(x, y, z);
            auto psi = get_psi(r, z);

            auto br = psi->df(z)/r;
            auto bp = get_b_phi(r);
@@ -1304,35 +1300,35 @@ namespace equilibrium {

        const auto c00 = backend::buffer(std::vector<T> (psi_c00_buffer.begin(), psi_c00_buffer.end()));
        const auto c01 = backend::buffer(std::vector<T> (psi_c01_buffer.begin(), psi_c01_buffer.end()));
        auto c02 = backend::buffer(std::vector<T> (psi_c02_buffer.begin(), psi_c02_buffer.end()));
        auto c03 = backend::buffer(std::vector<T> (psi_c03_buffer.begin(), psi_c03_buffer.end()));
        auto c10 = backend::buffer(std::vector<T> (psi_c10_buffer.begin(), psi_c10_buffer.end()));
        auto c11 = backend::buffer(std::vector<T> (psi_c11_buffer.begin(), psi_c11_buffer.end()));
        auto c12 = backend::buffer(std::vector<T> (psi_c12_buffer.begin(), psi_c12_buffer.end()));
        auto c13 = backend::buffer(std::vector<T> (psi_c13_buffer.begin(), psi_c13_buffer.end()));
        auto c20 = backend::buffer(std::vector<T> (psi_c20_buffer.begin(), psi_c20_buffer.end()));
        auto c21 = backend::buffer(std::vector<T> (psi_c21_buffer.begin(), psi_c21_buffer.end()));
        auto c22 = backend::buffer(std::vector<T> (psi_c22_buffer.begin(), psi_c22_buffer.end()));
        auto c23 = backend::buffer(std::vector<T> (psi_c23_buffer.begin(), psi_c23_buffer.end()));
        auto c30 = backend::buffer(std::vector<T> (psi_c30_buffer.begin(), psi_c30_buffer.end()));
        auto c31 = backend::buffer(std::vector<T> (psi_c31_buffer.begin(), psi_c31_buffer.end()));
        auto c32 = backend::buffer(std::vector<T> (psi_c32_buffer.begin(), psi_c32_buffer.end()));
        auto c33 = backend::buffer(std::vector<T> (psi_c33_buffer.begin(), psi_c33_buffer.end()));

        auto pres_c0 = backend::buffer(std::vector<T> (pressure_c0_buffer.begin(), pressure_c0_buffer.end()));
        auto pres_c1 = backend::buffer(std::vector<T> (pressure_c1_buffer.begin(), pressure_c1_buffer.end()));
        auto pres_c2 = backend::buffer(std::vector<T> (pressure_c2_buffer.begin(), pressure_c2_buffer.end()));
        auto pres_c3 = backend::buffer(std::vector<T> (pressure_c3_buffer.begin(), pressure_c3_buffer.end()));

        auto te_c0 = backend::buffer(std::vector<T> (te_c0_buffer.begin(), te_c0_buffer.end()));
        auto te_c1 = backend::buffer(std::vector<T> (te_c1_buffer.begin(), te_c1_buffer.end()));
        auto te_c2 = backend::buffer(std::vector<T> (te_c2_buffer.begin(), te_c2_buffer.end()));
        auto te_c3 = backend::buffer(std::vector<T> (te_c3_buffer.begin(), te_c3_buffer.end()));

        auto ne_c0 = backend::buffer(std::vector<T> (ne_c0_buffer.begin(), ne_c0_buffer.end()));
        auto ne_c1 = backend::buffer(std::vector<T> (ne_c1_buffer.begin(), ne_c1_buffer.end()));
        auto ne_c2 = backend::buffer(std::vector<T> (ne_c2_buffer.begin(), ne_c2_buffer.end()));
        auto ne_c3 = backend::buffer(std::vector<T> (ne_c3_buffer.begin(), ne_c3_buffer.end()));
        const auto c02 = backend::buffer(std::vector<T> (psi_c02_buffer.begin(), psi_c02_buffer.end()));
        const auto c03 = backend::buffer(std::vector<T> (psi_c03_buffer.begin(), psi_c03_buffer.end()));
        const auto c10 = backend::buffer(std::vector<T> (psi_c10_buffer.begin(), psi_c10_buffer.end()));
        const auto c11 = backend::buffer(std::vector<T> (psi_c11_buffer.begin(), psi_c11_buffer.end()));
        const auto c12 = backend::buffer(std::vector<T> (psi_c12_buffer.begin(), psi_c12_buffer.end()));
        const auto c13 = backend::buffer(std::vector<T> (psi_c13_buffer.begin(), psi_c13_buffer.end()));
        const auto c20 = backend::buffer(std::vector<T> (psi_c20_buffer.begin(), psi_c20_buffer.end()));
        const auto c21 = backend::buffer(std::vector<T> (psi_c21_buffer.begin(), psi_c21_buffer.end()));
        const auto c22 = backend::buffer(std::vector<T> (psi_c22_buffer.begin(), psi_c22_buffer.end()));
        const auto c23 = backend::buffer(std::vector<T> (psi_c23_buffer.begin(), psi_c23_buffer.end()));
        const auto c30 = backend::buffer(std::vector<T> (psi_c30_buffer.begin(), psi_c30_buffer.end()));
        const auto c31 = backend::buffer(std::vector<T> (psi_c31_buffer.begin(), psi_c31_buffer.end()));
        const auto c32 = backend::buffer(std::vector<T> (psi_c32_buffer.begin(), psi_c32_buffer.end()));
        const auto c33 = backend::buffer(std::vector<T> (psi_c33_buffer.begin(), psi_c33_buffer.end()));

        const auto pres_c0 = backend::buffer(std::vector<T> (pressure_c0_buffer.begin(), pressure_c0_buffer.end()));
        const auto pres_c1 = backend::buffer(std::vector<T> (pressure_c1_buffer.begin(), pressure_c1_buffer.end()));
        const auto pres_c2 = backend::buffer(std::vector<T> (pressure_c2_buffer.begin(), pressure_c2_buffer.end()));
        const auto pres_c3 = backend::buffer(std::vector<T> (pressure_c3_buffer.begin(), pressure_c3_buffer.end()));

        const auto te_c0 = backend::buffer(std::vector<T> (te_c0_buffer.begin(), te_c0_buffer.end()));
        const auto te_c1 = backend::buffer(std::vector<T> (te_c1_buffer.begin(), te_c1_buffer.end()));
        const auto te_c2 = backend::buffer(std::vector<T> (te_c2_buffer.begin(), te_c2_buffer.end()));
        const auto te_c3 = backend::buffer(std::vector<T> (te_c3_buffer.begin(), te_c3_buffer.end()));

        const auto ne_c0 = backend::buffer(std::vector<T> (ne_c0_buffer.begin(), ne_c0_buffer.end()));
        const auto ne_c1 = backend::buffer(std::vector<T> (ne_c1_buffer.begin(), ne_c1_buffer.end()));
        const auto ne_c2 = backend::buffer(std::vector<T> (ne_c2_buffer.begin(), ne_c2_buffer.end()));
        const auto ne_c3 = backend::buffer(std::vector<T> (ne_c3_buffer.begin(), ne_c3_buffer.end()));

        return std::make_shared<efit<T, SAFE_MATH>> (psimin, dpsi,
                                                     te_c0, te_c1, te_c2, te_c3, te_scale,
+105 −18
Original line number Diff line number Diff line
@@ -60,6 +60,7 @@ namespace graph {
//------------------------------------------------------------------------------
        virtual shared_leaf<T, SAFE_MATH> reduce() {
            auto ac = constant_cast(this->arg);
            
            if (ac.get()) {
                if (ac->is(0) || ac->is(1)) {
                    return this->arg;
@@ -67,6 +68,20 @@ namespace graph {
                return constant<T, SAFE_MATH> (this->evaluate());
            }

            auto ap1 = piecewise_1D_cast(this->arg);
            if (ap1.get()) {
                return piecewise_1D(this->evaluate(),
                                    ap1->get_arg());
            }

            auto ap2 = piecewise_2D_cast(this->arg);
            if (ap2.get()) {
                return piecewise_2D(this->evaluate(),
                                    ap2->get_num_columns(),
                                    ap2->get_left(),
                                    ap2->get_right());
            }

//  Handle casses like sqrt(a^b).
            auto ap = pow_cast(this->arg);
            if (ap.get()) {
@@ -85,8 +100,12 @@ namespace graph {
            if (am.get()) {
                if (pow_cast(am->get_left()).get()           ||
                    constant_cast(am->get_left()).get()      ||
                    piecewise_1D_cast(am->get_left()).get()  ||
                    piecewise_2D_cast(am->get_left()).get()  ||
                    pow_cast(am->get_right()).get()          ||
                    constant_cast(am->get_right()).get()) {
                    constant_cast(am->get_right()).get()     ||
                    piecewise_1D_cast(am->get_right()).get() ||
                    piecewise_2D_cast(am->get_right()).get()) {
                    return sqrt(am->get_left()) *
                           sqrt(am->get_right());
                }
@@ -98,8 +117,12 @@ namespace graph {
            if (ad.get()) {
                if (pow_cast(ad->get_left()).get()           ||
                    constant_cast(ad->get_left()).get()      ||
                    piecewise_1D_cast(ad->get_left()).get()  ||
                    piecewise_2D_cast(ad->get_left()).get()  ||
                    pow_cast(ad->get_right()).get()          ||
                    constant_cast(ad->get_right()).get()) {
                    constant_cast(ad->get_right()).get()     ||
                    piecewise_1D_cast(ad->get_right()).get() ||
                    piecewise_2D_cast(ad->get_right()).get()) {
                    return sqrt(ad->get_left()) /
                           sqrt(ad->get_right());
                }
@@ -306,6 +329,20 @@ namespace graph {
                return constant<T, SAFE_MATH> (this->evaluate());
            }

            auto ap1 = piecewise_1D_cast(this->arg);
            if (ap1.get()) {
                return piecewise_1D(this->evaluate(),
                                    ap1->get_arg());
            }

            auto ap2 = piecewise_2D_cast(this->arg);
            if (ap2.get()) {
                return piecewise_2D(this->evaluate(),
                                    ap2->get_num_columns(),
                                    ap2->get_left(),
                                    ap2->get_right());
            }

//  Reduce exp(log(x)) -> x
            auto a = log_cast(this->arg);
            if (a.get()) {
@@ -484,6 +521,20 @@ namespace graph {
                return constant<T, SAFE_MATH> (this->evaluate());
            }

            auto ap1 = piecewise_1D_cast(this->arg);
            if (ap1.get()) {
                return piecewise_1D(this->evaluate(),
                                    ap1->get_arg());
            }

            auto ap2 = piecewise_2D_cast(this->arg);
            if (ap2.get()) {
                return piecewise_2D(this->evaluate(),
                                    ap2->get_num_columns(),
                                    ap2->get_left(),
                                    ap2->get_right());
            }

//  Reduce log(exp(x)) -> x
            auto a = exp_cast(this->arg);
            if (a.get()) {
@@ -678,6 +729,20 @@ namespace graph {
                if (constant_cast(this->left).get()) {
                    return constant<T, SAFE_MATH> (this->evaluate());
                }

                auto pl1 = piecewise_1D_cast(this->left);
                if (pl1.get()) {
                    return piecewise_1D(this->evaluate(),
                                        pl1->get_arg());
                }

                auto pl2 = piecewise_2D_cast(this->left);
                if (pl2.get()) {
                    return piecewise_2D(this->evaluate(),
                                        pl2->get_num_columns(),
                                        pl2->get_left(),
                                        pl2->get_right());
                }
            }

            auto lp = pow_cast(this->left);
@@ -690,6 +755,10 @@ namespace graph {
            if (lm.get()) {
                if (constant_cast(lm->get_left()).get()      ||
                    constant_cast(lm->get_right()).get()     ||
                    piecewise_1D_cast(lm->get_left()).get()  ||
                    piecewise_1D_cast(lm->get_right()).get() ||
                    piecewise_2D_cast(lm->get_left()).get()  ||
                    piecewise_2D_cast(lm->get_right()).get() ||
                    sqrt_cast(lm->get_left()).get()          ||
                    sqrt_cast(lm->get_right()).get()         ||
                    pow_cast(lm->get_left()).get()           ||
@@ -704,6 +773,10 @@ namespace graph {
            if (ld.get()) {
                if (constant_cast(ld->get_left()).get()      ||
                    constant_cast(ld->get_right()).get()     ||
                    piecewise_1D_cast(ld->get_left()).get()  ||
                    piecewise_1D_cast(ld->get_right()).get() ||
                    piecewise_2D_cast(ld->get_left()).get()  ||
                    piecewise_2D_cast(ld->get_right()).get() ||
                    sqrt_cast(ld->get_left()).get()          ||
                    sqrt_cast(ld->get_right()).get()         ||
                    pow_cast(ld->get_left()).get()           ||
@@ -962,6 +1035,20 @@ namespace graph {
                return constant<T, SAFE_MATH> (this->evaluate());
            }

            auto ap1 = piecewise_1D_cast(this->arg);
            if (ap1.get()) {
                return piecewise_1D(this->evaluate(),
                                    ap1->get_arg());
            }

            auto ap2 = piecewise_2D_cast(this->arg);
            if (ap2.get()) {
                return piecewise_2D(this->evaluate(),
                                    ap2->get_num_columns(),
                                    ap2->get_left(),
                                    ap2->get_right());
            }

            return this->shared_from_this();
        }

+27 −1
Original line number Diff line number Diff line
@@ -303,6 +303,17 @@ namespace graph {
        virtual shared_leaf<T, SAFE_MATH> get_power_exponent() const {
            return one<T, SAFE_MATH> ();
        }

//------------------------------------------------------------------------------
///  @brief Check if the args match.
///
///  @param[in] x Node to match.
///  @returns True if the arguments match.
//------------------------------------------------------------------------------
        bool is_arg_match(shared_leaf<T, SAFE_MATH> x) {
            auto temp = piecewise_1D_cast(x);
            return temp.get() && this->arg->is_match(temp->get_arg());
        }
    };

//------------------------------------------------------------------------------
@@ -469,7 +480,7 @@ namespace graph {
///
///  @returns The number of columns in the constant.
//------------------------------------------------------------------------------
        size_t get_num_columns() {
        size_t get_num_columns() const {
            return num_columns;
        }

@@ -498,6 +509,7 @@ namespace graph {
            if (evaluate().is_same()) {
                return constant<T, SAFE_MATH> (evaluate().at(0));
            }

            return this->shared_from_this();
        }

@@ -687,6 +699,20 @@ namespace graph {
        virtual shared_leaf<T, SAFE_MATH> get_power_exponent() const {
            return one<T, SAFE_MATH> ();
        }
        
//------------------------------------------------------------------------------
///  @brief Check if the args match.
///
///  @param[in] x Node to match.
///  @returns True if the arguments match.
//------------------------------------------------------------------------------
        bool is_arg_match(shared_leaf<T, SAFE_MATH> x) {
            auto temp = piecewise_2D_cast(x);
            return temp.get()                               &&
                   this->left->is_match(temp->get_left())   &&
                   this->right->is_match(temp->get_right()) &&
                   (num_columns == this->get_num_columns());
        }
    };

//------------------------------------------------------------------------------
Loading