Commit 0d5847be authored by Cianciosa, Mark's avatar Cianciosa, Mark
Browse files

Simplify by moving these out of the constant_node class.

parent 8e3924e0
Loading
Loading
Loading
Loading
+19 −19
Original line number Diff line number Diff line
@@ -106,7 +106,7 @@ namespace graph {
#ifdef USE_REDUCE
//  Idenity reductions.
            if (this->left->is_match(this->right)) {
                return constant_node<typename LN::base>::two()*this->left;
                return two<typename LN::base> ()*this->left;
            }

//  Constant reductions.
@@ -129,7 +129,7 @@ namespace graph {
//  Assume constants are on the left.
//  v1 + -c*v2 -> v1 - c*v2
//  -c*v1 + v2 -> v2 - c*v1
            auto none = constant_node<typename LN::base>::none();
            auto none = graph::none<typename LN::base> ();
            if (rm.get()) {
                auto rmc = constant_cast(rm->get_left());
                if (rmc.get() && rmc->evaluate().is_negative()) {
@@ -269,7 +269,7 @@ namespace graph {
        virtual shared_leaf<typename LN::base>
        df(shared_leaf<typename LN::base> x) final {
            if (this->is_match(x)) {
                return constant_node<typename LN::base>::one();
                return one<typename LN::base> ();
            } else {
                return this->left->df(x) + this->right->df(x);
            }
@@ -438,7 +438,7 @@ namespace graph {
                    return this->left;
                }

                return constant_node<typename LN::base>::zero();
                return zero<typename LN::base> ();
            }

//  Constant reductions.
@@ -446,7 +446,7 @@ namespace graph {
            auto r = constant_cast(this->right);

            if (l.get() && l->is(0)) {
                return constant_node<typename LN::base>::none()*this->right;
                return none<typename LN::base> ()*this->right;
            } else if (r.get() && r->is(0)) {
                return this->left;
            } else if (l.get() && r.get()) {
@@ -463,8 +463,8 @@ namespace graph {
            if (rm.get()) {
                auto rmc = constant_cast(rm->get_left());
                if (rmc.get() && rmc->evaluate().is_negative()) {
                    auto none = constant_node<typename LN::base>::none();
                    return this->left + none*rm->get_left()*rm->get_right();
                    return this->left +
                           none<typename LN::base> ()*rm->get_left()*rm->get_right();
                }
            }

@@ -580,7 +580,7 @@ namespace graph {
        virtual shared_leaf<typename LN::base>
        df(shared_leaf<typename LN::base> x) final {
            if (this->is_match(x)) {
                return constant_node<typename LN::base>::one();
                return one<typename LN::base> ();
            } else {
                return this->left->df(x) - this->right->df(x);
            }
@@ -776,7 +776,7 @@ namespace graph {

//  Reduce x*x to x^2
            if (this->left->is_match(this->right)) {
                return pow(this->left, constant_node<typename LN::base>::two());
                return pow(this->left, two<typename LN::base> ());
            }

//  Gather common terms.
@@ -898,7 +898,7 @@ namespace graph {
//  a^b*a -> a^(b + 1)
                if (lp->get_left()->is_match(this->right)) {
                    return pow(lp->get_left(),
                               lp->get_right() + constant_node<typename LN::base>::one());
                               lp->get_right() + one<typename LN::base> ());
                }

//  a^b*a^c -> a^(b + c)
@@ -924,7 +924,7 @@ namespace graph {
//  a*a^b -> a^(1 + b)
                if (rp->get_left()->is_match(this->left)) {
                    return pow(rp->get_left(),
                               rp->get_right() + constant_node<typename LN::base>::one());
                               rp->get_right() + one<typename LN::base> ());
                }

//  sqrt(a)*a^b -> a^(b + 1)
@@ -955,7 +955,7 @@ namespace graph {
        virtual shared_leaf<typename LN::base>
        df(shared_leaf<typename LN::base> x) final {
            if (this->is_match(x)) {
                return constant_node<typename LN::base>::one();
                return one<typename LN::base> ();
            }

            return this->left->df(x)*this->right +
@@ -1128,12 +1128,12 @@ namespace graph {
                    return this->left;
                }

                return constant_node<typename LN::base>::one();
                return one<typename LN::base> ();
            }

//  Reduce cases of a/c1 -> c2*a
            if (r.get()) {
                return (constant_node<typename LN::base>::one()/this->right) *
                return (one<typename LN::base> ()/this->right) *
                       this->left;
            }

@@ -1215,7 +1215,7 @@ namespace graph {
//  a^b/a -> a^(b - 1)
                if (lp->get_left()->is_match(this->right)) {
                    return pow(lp->get_left(),
                               lp->get_right() - constant_node<typename LN::base>::one());
                               lp->get_right() - one<typename LN::base> ());
                }

//  a^b/a^c -> a^(b - c)
@@ -1241,7 +1241,7 @@ namespace graph {
//  a/a^b -> a^(1 - b)
                if (rp->get_left()->is_match(this->left)) {
                    return pow(rp->get_left(),
                               constant_node<typename LN::base>::one() - rp->get_right());
                               one<typename LN::base> () - rp->get_right());
                }

//  sqrt(a)/a^b -> a^(1/2 - b)
@@ -1254,7 +1254,7 @@ namespace graph {
//  sqrt(a)/a -> 1.0/sqrt(a)
                auto lsq = sqrt_cast(this->left);
                if (lsq.get() && this->right->is_match(lsq->get_arg())) {
                    return constant_node<typename LN::base>::one()/this->left;
                    return one<typename LN::base> ()/this->left;
                }
            }
#endif
@@ -1272,7 +1272,7 @@ namespace graph {
        virtual shared_leaf<typename LN::base>
        df(shared_leaf<typename LN::base> x) final {
            if (this->is_match(x)) {
                return constant_node<typename LN::base>::one();
                return one<typename LN::base> ();
            }

            return this->left->df(x)/this->right -
@@ -1517,7 +1517,7 @@ namespace graph {
        virtual shared_leaf<typename LN::base>
        df(shared_leaf<typename LN::base> x) final {
            if (this->is_match(x)) {
                return constant_node<typename LN::base>::one();
                return one<typename LN::base> ();
            }

            auto temp_right = fma(this->left,
+12 −13
Original line number Diff line number Diff line
@@ -101,7 +101,7 @@ namespace dispersion {
            auto dDdy = this->D->df(y)->reduce();
            auto dDdz = this->D->df(z)->reduce();

            auto neg_one = graph::constant_node<typename DISPERSION_FUNCTION::base>::none();
            auto neg_one = graph::none<typename DISPERSION_FUNCTION::base> ();
            dxdt = neg_one*dDdkx/dDdw;
            dydt = neg_one*dDdky/dDdw;
            dzdt = neg_one*dDdkz/dDdw;
@@ -410,7 +410,7 @@ namespace dispersion {
                                        graph::shared_leaf<T> z,
                                        graph::shared_leaf<T> t,
                                        equilibrium::unique_equilibrium<T> &eq) final {
            auto none = graph::constant_node<T>::none();
            auto none = graph::none<T> ();
            auto c = graph::constant(static_cast<T> (1.0E3));
            return (c*(x - graph::exp(none*t)) - graph::exp(none*t))*kx + w;
        }
@@ -446,7 +446,7 @@ namespace dispersion {
                                        graph::shared_leaf<T> z,
                                        graph::shared_leaf<T> t,
                                        equilibrium::unique_equilibrium<T> &eq) final {
            auto c = graph::constant_node<T>::one();
            auto c = graph::one<T> ();

            auto npar2 = kz*kz*c*c/(w*w);
            auto nperp2 = (kx*kx + ky*ky)*c*c/(w*w);
@@ -470,8 +470,7 @@ namespace dispersion {
///  Electron mass.
        graph::shared_leaf<T> me = graph::constant(static_cast<T> (9.1093837015E-31));
/// Speed of light.
        graph::shared_leaf<T> c = graph::constant_node<T>::one()
                                / graph::sqrt(epsion0*mu0);
        graph::shared_leaf<T> c = graph::one<T> ()/graph::sqrt(epsion0*mu0);
    };

//------------------------------------------------------------------------------
@@ -514,7 +513,7 @@ namespace dispersion {
            auto te = eq->get_electron_temperature(x, y, z);
//  2*1.602176634E-19 to convert eV to J.
            
            auto temp = graph::constant_node<T>::two()*physics<T>::q*te;
            auto temp = graph::two<T> ()*physics<T>::q*te;
            auto vterm2 = temp/(physics<T>::me*physics<T>::c*physics<T>::c);

//  Wave numbers should be parallel to B if there is a magnetic field. Otherwise
@@ -523,7 +522,7 @@ namespace dispersion {
            auto k = graph::vector(kx, ky, kz);
            graph::shared_leaf<T> kpara2;
#ifdef USE_REDUCE
            if (b_vec->length()->is_match(graph::constant_node<T>::zero())) {
            if (b_vec->length()->is_match(graph::zero<T> ())) {
#else
            if (b_vec->length()->evaluate()[0] == static_cast<T> (0.0)) {
#endif
@@ -581,7 +580,7 @@ namespace dispersion {
//  Wave numbers should be parallel to B if there is a magnetic field. Otherwise
//  B should be zero.
#ifdef USE_REDUCE
            assert(eq->get_magnetic_field(x, y, z)->length()->is_match(graph::constant_node<T>::zero()) &&
            assert(eq->get_magnetic_field(x, y, z)->length()->is_match(graph::zero<T> ()) &&
                   "Expected equilibrium with no magnetic field.");
#else
            assert(eq->get_magnetic_field(x, y, z)->length()->evaluate()[0] ==
@@ -643,7 +642,7 @@ namespace dispersion {
            auto k = graph::vector(kx, ky, kz);
            graph::shared_leaf<T> kpara2;
#ifdef USE_REDUCE
            if (b_vec->length()->is_match(graph::constant_node<T>::zero())) {
            if (b_vec->length()->is_match(graph::zero<T> ())) {
#else
            if (b_vec->length()->evaluate()[0] ==
                static_cast<T> (0.0)) {
@@ -689,7 +688,7 @@ namespace dispersion {
                                        graph::shared_leaf<T> z,
                                        graph::shared_leaf<T> t,
                                        equilibrium::unique_equilibrium<T> &eq) final {
            auto c = graph::constant_node<T>::one();
            auto c = graph::one<T> ();
            auto well = c - graph::constant(static_cast<T> (0.5))*exp(graph::constant(static_cast<T> (-1.0))*(x*x + y*y)/graph::constant(static_cast<T> (0.1)));
            auto npar2 = kz*kz*c*c/(w*w);
            auto nperp2 = (kx*kx + ky*ky)*c*c/(w*w);
@@ -793,7 +792,7 @@ namespace dispersion {
                                        graph::shared_leaf<T> t,
                                        equilibrium::unique_equilibrium<T> &eq) final {
//  Constants
            auto one = graph::constant_node<T>::one();
            auto one = graph::one<T> ();

//  Equilibrium quantities.
            auto ne = eq->get_electron_density(x, y, z);
@@ -851,7 +850,7 @@ namespace dispersion {
                                        graph::shared_leaf<T> t,
                                        equilibrium::unique_equilibrium<T> &eq) final {
//  Constants
            auto one = graph::constant_node<T>::one();
            auto one = graph::one<T> ();
            auto none = graph::constant(static_cast<T> (-1.0));
            
//  Equilibrium quantities.
@@ -932,7 +931,7 @@ namespace dispersion {
                                        graph::shared_leaf<T> t,
                                        equilibrium::unique_equilibrium<T> &eq) final {
//  Constants
            auto one = graph::constant_node<T>::one();
            auto one = graph::one<T> ();
            auto none = graph::constant(static_cast<T> (-1.0));

//  Dielectric terms.
+31 −16
Original line number Diff line number Diff line
@@ -170,7 +170,7 @@ namespace equilibrium {
                                                           graph::shared_leaf<T> y,
                                                           graph::shared_leaf<T> z) final {
            return graph::constant(static_cast<T> (1.0E19)) *
                   (graph::constant(static_cast<T> (0.1))*x + graph::constant_node<T>::one());
                   (graph::constant(static_cast<T> (0.1))*x + graph::one<T> ());
        }

//------------------------------------------------------------------------------
@@ -184,7 +184,7 @@ namespace equilibrium {
                                                      graph::shared_leaf<T> y,
                                                      graph::shared_leaf<T> z) final {
            return graph::constant(static_cast<T> (1.0E19)) *
                   (graph::constant(static_cast<T> (0.1))*x + graph::constant_node<T>::one());
                   (graph::constant(static_cast<T> (0.1))*x + graph::one<T> ());
        }

//------------------------------------------------------------------------------
@@ -228,12 +228,16 @@ namespace equilibrium {
        get_magnetic_field(graph::shared_leaf<T> x,
                           graph::shared_leaf<T> y,
                           graph::shared_leaf<T> z) final {
            auto zero = graph::constant_node<T>::zero();
            auto zero = graph::zero<T> ();
            return graph::vector(zero, zero, zero);
        }
    };

///  Convenience type alias for unique equilibria.
//------------------------------------------------------------------------------
///  @brief Convenience function to build a no magnetic field equilibrium.
///
///  @returns A constructed no magnetic field equilibrium.
//------------------------------------------------------------------------------
    template<typename T>
    std::unique_ptr<equilibrium<T>> make_no_magnetic_field() {
        return std::make_unique<no_magnetic_field<T>> ();
@@ -323,13 +327,17 @@ namespace equilibrium {
        get_magnetic_field(graph::shared_leaf<T> x,
                           graph::shared_leaf<T> y,
                           graph::shared_leaf<T> z) final {
            auto zero = graph::constant_node<T>::zero();
            auto zero = graph::zero<T> ();
            return graph::vector(zero, zero,
                                 graph::constant(static_cast<T> (0.1))*x + graph::constant_node<T>::one());
                                 graph::constant(static_cast<T> (0.1))*x + graph::one<T> ());
        }
    };

///  Convenience type alias for unique equilibria.
//------------------------------------------------------------------------------
///  @brief Convenience function to build a slab equilibrium.
///
///  @returns A constructed slab equilibrium.
//------------------------------------------------------------------------------
    template<typename T>
    std::unique_ptr<equilibrium<T>> make_slab() {
        return std::make_unique<slab<T>> ();
@@ -363,7 +371,7 @@ namespace equilibrium {
                                                           graph::shared_leaf<T> y,
                                                           graph::shared_leaf<T> z) final {
            return graph::constant(static_cast<T> (1.0E19)) *
                   (graph::constant(static_cast<T> (0.1))*x + graph::constant_node<T>::one());
                   (graph::constant(static_cast<T> (0.1))*x + graph::one<T> ());
        }

//------------------------------------------------------------------------------
@@ -377,7 +385,7 @@ namespace equilibrium {
                                                      graph::shared_leaf<T> y,
                                                      graph::shared_leaf<T> z) final {
            return graph::constant(static_cast<T> (1.0E19)) *
                   (graph::constant(static_cast<T> (0.1))*x + graph::constant_node<T>::one());
                   (graph::constant(static_cast<T> (0.1))*x + graph::one<T> ());
        }

//------------------------------------------------------------------------------
@@ -421,12 +429,16 @@ namespace equilibrium {
        get_magnetic_field(graph::shared_leaf<T> x,
                           graph::shared_leaf<T> y,
                           graph::shared_leaf<T> z) final {
            auto zero = graph::constant_node<T>::zero();
            return graph::vector(zero, zero, graph::constant_node<T>::one());
            auto zero = graph::zero<T> ();
            return graph::vector(zero, zero, graph::one<T> ());
        }
    };

///  Convenience type alias for unique equilibria.
//------------------------------------------------------------------------------
///  @brief Convenience function to build a slab density equilibrium.
///
///  @returns A constructed slab density equilibrium.
//------------------------------------------------------------------------------
    template<typename T>
    std::unique_ptr<equilibrium<T>> make_slab_density() {
        return std::make_unique<slab_density<T>> ();
@@ -516,13 +528,16 @@ namespace equilibrium {
        get_magnetic_field(graph::shared_leaf<T> x,
                           graph::shared_leaf<T> y,
                           graph::shared_leaf<T> z) final {
            auto zero = graph::constant_node<T>::zero();
            return graph::vector(graph::constant_node<T>::one(),
                                 zero, zero);
            auto zero = graph::zero<T> ();
            return graph::vector(graph::one<T> (), zero, zero);
        }
    };

///  Convenience type alias for unique equilibria.
//------------------------------------------------------------------------------
///  @brief Convenience function to build a guassian density equilibrium.
///
///  @returns A constructed guassian density equilibrium.
//------------------------------------------------------------------------------
    template<typename T>
    std::unique_ptr<equilibrium<T>> make_guassian_density() {
        return std::make_unique<guassian_density<T>> ();
+7 −8
Original line number Diff line number Diff line
@@ -110,10 +110,10 @@ namespace graph {
        virtual shared_leaf<typename N::base>
        df(shared_leaf<typename N::base> x) final {
            if (this->is_match(x)) {
                return constant_node<typename N::base>::one();
                return one<typename N::base> ();
            } else {
                return this->arg->df(x) /
                       (constant_node<typename N::base>::two()*this->shared_from_this());
                       (two<typename N::base> ()*this->shared_from_this());
            }
        }

@@ -251,7 +251,7 @@ namespace graph {
        virtual shared_leaf<typename N::base>
        df(shared_leaf<typename N::base> x) final {
            if (this->is_match(x)) {
                return constant_node<typename N::base>::one();
                return one<typename N::base> ();
            }

            return this->shared_from_this()*this->arg->df(x);
@@ -515,7 +515,7 @@ namespace graph {

            if (rc.get()) {
                if (rc->is(0)) {
                    return constant_node<typename LN::base>::one();
                    return one<typename LN::base> ();
                } else if (rc->is(1)) {
                    return this->left;
                } else if (rc->is(0.5)) {
@@ -568,8 +568,8 @@ namespace graph {
//  Reduce sqrt(a)^b
            auto lsq = sqrt_cast(this->left);
            if (lsq.get()) {
                auto two = constant_node<typename LN::base>::two();
                return pow(lsq->get_arg(), this->right/two);
                return pow(lsq->get_arg(),
                           this->right/two<typename LN::base> ());
            }
#endif
            return this->shared_from_this();
@@ -585,8 +585,7 @@ namespace graph {
//------------------------------------------------------------------------------
        virtual shared_leaf<typename LN::base>
        df(shared_leaf<typename LN::base> x) final {
            auto one = constant_node<typename LN::base>::one();
            return pow(this->left, this->right - one) *
            return pow(this->left, this->right - one<typename LN::base> ()) *
                   (this->right*this->left->df(x) +
                    this->left*log(this->left)*this->right->df(x));
        }
+39 −40
Original line number Diff line number Diff line
@@ -311,11 +311,6 @@ namespace graph {
///  @returns A reduced representation of the node.
//------------------------------------------------------------------------------
        virtual shared_leaf<T> reduce() final {
#ifdef USE_REDUCE
            if (data.size() > 1 && data.is_same()) {
                return std::make_shared<constant_node<T>> (data.at(0));
            }
#endif
            return this->shared_from_this();
        }

@@ -326,7 +321,7 @@ namespace graph {
///  @returns The derivative of the node.
//------------------------------------------------------------------------------
        virtual shared_leaf<T> df(shared_leaf<T> x) final {
            return zero();
            return std::make_shared<constant_node<T>> (static_cast<T> (0.0));
        }

//------------------------------------------------------------------------------
@@ -389,65 +384,69 @@ namespace graph {
        virtual void to_latex() const final {
            std::cout << data.at(0);
        }
    };

//  Define some common constants.
//------------------------------------------------------------------------------
///  @brief Create a zero constant.
///  @brief Construct a constant.
///
///  @returns A zero constant.
///  @param[in] d Scalar data to initalize.
///  @returns A reduced constant node.
//------------------------------------------------------------------------------
        static shared_leaf<T> zero() {
            return std::make_shared<constant_node<T>> (static_cast<T> (0.0));
    template<typename T>
    shared_leaf<T> constant(const T d) {
        return (std::make_shared<constant_node<T>> (d))->reduce();
    }

//------------------------------------------------------------------------------
///  @brief Create a one constant.
///  @brief Construct a constant.
///
///  @returns A one constant.
///  @param[in] d Array buffer.
///  @returns A reduced constant node.
//------------------------------------------------------------------------------
        static shared_leaf<T> one() {
            return std::make_shared<constant_node<T>> (static_cast<T> (1.0));
    template<typename T>
    shared_leaf<T> constant(const backend::cpu<T> &d) {
        return (std::make_shared<constant_node<T>> (d))->reduce();
    }

//  Define some common constants.
//------------------------------------------------------------------------------
///  @brief Create a negative one constant.
///  @brief Create a zero constant.
///
///  @returns A negative one constant.
///  @returns A zero constant.
//------------------------------------------------------------------------------
        static shared_leaf<T> none() {
            return std::make_shared<constant_node<T>> (static_cast<T> (-1.0));
    template<typename T>
    shared_leaf<T> zero() {
        return constant(static_cast<T> (0.0));
    }
        
//------------------------------------------------------------------------------
///  @brief Create a two constant.
///  @brief Create a one constant.
///
///  @returns A two constant.
///  @returns A one constant.
//------------------------------------------------------------------------------
        static shared_leaf<T> two() {
            return std::make_shared<constant_node<T>> (static_cast<T> (2.0));
    template<typename T>
    shared_leaf<T> one() {
        return constant(static_cast<T> (1.0));
    }
    };
        
//------------------------------------------------------------------------------
///  @brief Construct a constant.
///  @brief Create a negative one constant.
///
///  @param[in] d Scalar data to initalize.
///  @returns A reduced constant node.
///  @returns A negative one constant.
//------------------------------------------------------------------------------
    template<typename T>
    shared_leaf<T> constant(const T d) {
        return (std::make_shared<constant_node<T>> (d))->reduce();
    shared_leaf<T> none() {
        return constant(static_cast<T> (-1.0));
    }
        
//------------------------------------------------------------------------------
///  @brief Construct a constant.
///  @brief Create a two constant.
///
///  @param[in] d Array buffer.
///  @returns A reduced constant node.
///  @returns A two constant.
//------------------------------------------------------------------------------
    template<typename T>
    shared_leaf<T> constant(const backend::cpu<T> &d) {
        return (std::make_shared<constant_node<T>> (d))->reduce();
    shared_leaf<T> two() {
        return constant(static_cast<T> (2.0));
    }

///  Convenience type alias for shared constant nodes.
@@ -765,7 +764,7 @@ namespace graph {
//------------------------------------------------------------------------------
        virtual shared_leaf<T> df(shared_leaf<T> x) final {
            if (this->is_match(x)) {
                return graph::constant_node<T>::one();
                return one<T> ();
            } else {
                return this->arg->df(x)->reduce();
            }
Loading