Commit 4cbf417d authored by Cianciosa, Mark's avatar Cianciosa, Mark
Browse files

Refactor to simplify remove_pseudo.

parent 44f32a03
Loading
Loading
Loading
Loading
+21 −36
Original line number Diff line number Diff line
@@ -354,13 +354,11 @@ namespace graph {
///  @returns A tree without variable nodes.
//------------------------------------------------------------------------------
        virtual shared_leaf<T, SAFE_MATH> remove_pseudo() {
            auto l = this->left->remove_pseudo();
            auto r = this->right->remove_pseudo();
            if (l->is_match(this->left) &&
                r->is_match(this->right)) {
                return this->shared_from_this();
            if (this->has_pseudo()) {
                return this->left->remove_pseudo() +
                       this->right->remove_pseudo();
            }
            return l + r;
            return this->shared_from_this();
        }

//------------------------------------------------------------------------------
@@ -875,14 +873,11 @@ namespace graph {
///  @returns A tree without variable nodes.
//------------------------------------------------------------------------------
        virtual shared_leaf<T, SAFE_MATH> remove_pseudo() {
            auto l = this->left->remove_pseudo();
            auto r = this->right->remove_pseudo();
            if (l->is_match(this->left) &&
                r->is_match(this->right)) {
                return this->shared_from_this();
            } else {
                return l - r;
            if (this->has_pseudo()) {
                return this->left->remove_pseudo() -
                       this->right->remove_pseudo();
            }
            return this->shared_from_this();
        }

//------------------------------------------------------------------------------
@@ -1502,14 +1497,11 @@ namespace graph {
///  @returns A tree without variable nodes.
//------------------------------------------------------------------------------
        virtual shared_leaf<T, SAFE_MATH> remove_pseudo() {
            auto l = this->left->remove_pseudo();
            auto r = this->right->remove_pseudo();
            if (l->is_match(this->left) &&
                r->is_match(this->right)) {
                return this->shared_from_this();
            } else {
                return l*r;
            if (this->has_pseudo()) {
                return this->left->remove_pseudo() *
                       this->right->remove_pseudo();
            }
            return this->shared_from_this();
        }

//------------------------------------------------------------------------------
@@ -1940,14 +1932,11 @@ namespace graph {
///  @returns A tree without variable nodes.
//------------------------------------------------------------------------------
        virtual shared_leaf<T, SAFE_MATH> remove_pseudo() {
            auto l = this->left->remove_pseudo();
            auto r = this->right->remove_pseudo();
            if (l->is_match(this->left) &&
                r->is_match(this->right)) {
                return this->shared_from_this();
            } else {
                return l/r;
            if (this->has_pseudo()) {
                return this->left->remove_pseudo() /
                       this->right->remove_pseudo();
            }
            return this->shared_from_this();
        }

//------------------------------------------------------------------------------
@@ -2610,16 +2599,12 @@ namespace graph {
///  @returns A tree without variable nodes.
//------------------------------------------------------------------------------
        virtual shared_leaf<T, SAFE_MATH> remove_pseudo() {
            auto l = this->left->remove_pseudo();
            auto m = this->middle->remove_pseudo();
            auto r = this->right->remove_pseudo();
            if (l->is_match(this->left)   &&
                m->is_match(this->middle) &&
                r->is_match(this->right)) {
                return this->shared_from_this();
            } else {
                return fma(l, m, r);
            if (this->has_pseudo()) {
                return fma(this->left->remove_pseudo(),
                           this->middle->remove_pseudo(),
                           this->right->remove_pseudo());
            }
            return this->shared_from_this();
        }

//------------------------------------------------------------------------------
+16 −27
Original line number Diff line number Diff line
@@ -239,12 +239,10 @@ namespace graph {
///  @returns A tree without variable nodes.
//------------------------------------------------------------------------------
        virtual shared_leaf<T, SAFE_MATH> remove_pseudo() {
            auto temp = this->arg->remove_pseudo();
            if (temp->is_match(this->arg)) {
                return this->shared_from_this();
            } else {
                return sqrt(temp);
            if (this->has_pseudo()) {
                return sqrt(this->arg->remove_pseudo());
            }
            return this->shared_from_this();
        }

//------------------------------------------------------------------------------
@@ -471,12 +469,10 @@ namespace graph {
///  @returns A tree without variable nodes.
//------------------------------------------------------------------------------
        virtual shared_leaf<T, SAFE_MATH> remove_pseudo() {
            auto temp = this->arg->remove_pseudo();
            if (temp->is_match(this->arg)) {
                return this->shared_from_this();
            } else {
                return exp(temp);
            if (this->has_pseudo()) {
                return exp(this->arg->remove_pseudo());
            }
            return this->shared_from_this();
        }

//------------------------------------------------------------------------------
@@ -703,12 +699,10 @@ namespace graph {
///  @returns A tree without variable nodes.
//------------------------------------------------------------------------------
        virtual shared_leaf<T, SAFE_MATH> remove_pseudo() {
            auto temp = this->arg->remove_pseudo();
            if (temp->is_match(this->arg)) {
                return this->shared_from_this();
            } else {
                return log(temp);
            if (this->has_pseudo()) {
                return log(this->arg->remove_pseudo());
            }
            return this->shared_from_this();
        }

//------------------------------------------------------------------------------
@@ -1092,14 +1086,11 @@ namespace graph {
///  @returns A tree without variable nodes.
//------------------------------------------------------------------------------
        virtual shared_leaf<T, SAFE_MATH> remove_pseudo() {
            auto l = this->left->remove_pseudo();
            auto r = this->right->remove_pseudo();
            if (l->is_match(this->left) &&
                r->is_match(this->right)) {
                return this->shared_from_this();
            } else {
                return pow(l, r);
            if (this->has_pseudo()) {
                return pow(this->left->remove_pseudo(),
                           this->right->remove_pseudo());
            }
            return this->shared_from_this();
        }
    };

@@ -1297,12 +1288,10 @@ namespace graph {
///  @returns A tree without variable nodes.
//------------------------------------------------------------------------------
        virtual shared_leaf<T, SAFE_MATH> remove_pseudo() {
            auto temp = this->arg->remove_pseudo();
            if (temp->is_match(this->arg)) {
                return this->shared_from_this();
            } else {
                return erfi(temp);
            if (this->has_pseudo()) {
                return erfi(this->arg->remove_pseudo());
            }
            return this->shared_from_this();
        }

//------------------------------------------------------------------------------
+47 −26
Original line number Diff line number Diff line
@@ -35,6 +35,8 @@ namespace graph {
        const size_t complexity;
///  Cache derivative terms.
        std::map<size_t, std::shared_ptr<leaf_node<T, SAFE_MATH>>> df_cache;
///  Node contains pseudo variables.
        const bool contains_pseudo;

    public:
//------------------------------------------------------------------------------
@@ -42,11 +44,13 @@ namespace graph {
///
///  @params[in] s      Node string to hash.
///  @params[in] count  Number of nodes in the subgraph.
///  @params[in] pseudo Node contains pseudo variable.
//------------------------------------------------------------------------------
        leaf_node(const std::string s,
                  const size_t count) :
                  const size_t count,
                  const bool pseudo) :
        hash(std::hash<std::string>{} (s)),
        complexity(count) {}
        complexity(count), contains_pseudo(pseudo) {}

//------------------------------------------------------------------------------
///  @brief Destructor
@@ -230,6 +234,15 @@ namespace graph {
            return complexity;
        }

//------------------------------------------------------------------------------
///  @brief Query if the node contains pseudo variables.
///
///  @return True if the node contains pseudo variables.
//------------------------------------------------------------------------------
        virtual bool has_pseudo() const {
            return contains_pseudo;
        }

//------------------------------------------------------------------------------
///  @brief Remove pseudo variable nodes.
///
@@ -315,7 +328,7 @@ namespace graph {
///  @params[in] d Array buffer.
//------------------------------------------------------------------------------
        constant_node(const backend::buffer<T> &d) :
        leaf_node<T, SAFE_MATH> (constant_node::to_string(d.at(0)), 1), data(d) {
        leaf_node<T, SAFE_MATH> (constant_node::to_string(d.at(0)), 1, false), data(d) {
            assert(d.size() == 1 && "Constants need to be scalar functions.");
        }

@@ -672,15 +685,8 @@ namespace graph {
//------------------------------------------------------------------------------
        straight_node(shared_leaf<T, SAFE_MATH> a,
                      const std::string s) :
        leaf_node<T, SAFE_MATH> (s, a->get_complexity() + 1), arg(a) {}

//------------------------------------------------------------------------------
///  @brief Construct a straight node with defered argument.
///
///  @params[in] s Node string to hash.
//------------------------------------------------------------------------------
        straight_node(const std::string s) :
        leaf_node<T, SAFE_MATH> (s) {}
        leaf_node<T, SAFE_MATH> (s, a->get_complexity() + 1, a->has_pseudo()),
        arg(a) {}

//------------------------------------------------------------------------------
///  @brief Evaluate method.
@@ -786,7 +792,8 @@ namespace graph {
        branch_node(shared_leaf<T, SAFE_MATH> l,
                    shared_leaf<T, SAFE_MATH> r,
                    const std::string s) :
        leaf_node<T, SAFE_MATH> (s, l->get_complexity() + r->get_complexity() + 1),
        leaf_node<T, SAFE_MATH> (s, l->get_complexity() + r->get_complexity() + 1,
                                 l->has_pseudo() || r->has_pseudo()),
        left(l), right(r) {}

//------------------------------------------------------------------------------
@@ -796,12 +803,14 @@ namespace graph {
///  @params[in] r     Right branch.
///  @params[in] s     Node string to hash.
///  @params[in] count Number of nodes in the subgraph.
///  @params[in] pseudo Node contains pseudo variable.
//------------------------------------------------------------------------------
        branch_node(shared_leaf<T, SAFE_MATH> l,
                    shared_leaf<T, SAFE_MATH> r,
                    const std::string s,
                            const size_t count) :
                leaf_node<T, SAFE_MATH> (s, count),
                    const size_t count,
                    const bool pseudo) :
        leaf_node<T, SAFE_MATH> (s, count, pseudo),
        left(l), right(r) {}

//------------------------------------------------------------------------------
@@ -901,7 +910,10 @@ namespace graph {
        branch_node<T, SAFE_MATH> (l, r, s,
                                   l->get_complexity() +
                                   m->get_complexity() +
                                   r->get_complexity()),
                                   r->get_complexity(),
                                   l->has_pseudo() ||
                                   m->has_pseudo() ||
                                   r->has_pseudo()),
        middle(m) {}

//------------------------------------------------------------------------------
@@ -988,7 +1000,7 @@ namespace graph {
//------------------------------------------------------------------------------
        variable_node(const size_t s,
                      const std::string &symbol) :
        leaf_node<T, SAFE_MATH> (variable_node::to_string(this), 1),
        leaf_node<T, SAFE_MATH> (variable_node::to_string(this), 1, false),
        buffer(s), symbol(symbol) {}

//------------------------------------------------------------------------------
@@ -1000,7 +1012,7 @@ namespace graph {
//------------------------------------------------------------------------------
        variable_node(const size_t s, const T d,
                      const std::string &symbol) :
        leaf_node<T, SAFE_MATH> (variable_node::to_string(this), 1),
        leaf_node<T, SAFE_MATH> (variable_node::to_string(this), 1, false),
        buffer(s, d), symbol(symbol) {}

//------------------------------------------------------------------------------
@@ -1011,7 +1023,7 @@ namespace graph {
//------------------------------------------------------------------------------
        variable_node(const std::vector<T> &d,
                      const std::string &symbol) :
        leaf_node<T, SAFE_MATH> (variable_node::to_string(this), 1),
        leaf_node<T, SAFE_MATH> (variable_node::to_string(this), 1, false),
        buffer(d), symbol(symbol) {}

//------------------------------------------------------------------------------
@@ -1022,7 +1034,7 @@ namespace graph {
//------------------------------------------------------------------------------
        variable_node(const backend::buffer<T> &d,
                      const std::string &symbol) :
        leaf_node<T, SAFE_MATH> (variable_node::to_string(this), 1),
        leaf_node<T, SAFE_MATH> (variable_node::to_string(this), 1, false),
        buffer(d), symbol(symbol) {}

//------------------------------------------------------------------------------
@@ -1417,6 +1429,15 @@ namespace graph {
            return this->arg->get_power_exponent();
        }

//------------------------------------------------------------------------------
///  @brief Query if the node contains pseudo variables.
///
///  @return True if the node contains pseudo variables.
//------------------------------------------------------------------------------
        virtual bool has_pseudo() const {
            return true;
        }

//------------------------------------------------------------------------------
///  @brief Remove pseudo variable nodes.
///
+10 −17
Original line number Diff line number Diff line
@@ -178,12 +178,10 @@ namespace graph {
///  @returns A tree without variable nodes.
//------------------------------------------------------------------------------
        virtual shared_leaf<T, SAFE_MATH> remove_pseudo() {
            auto temp = this->arg->remove_pseudo();
            if (temp->is_match(this->arg)) {
                return this->shared_from_this();
            } else {
                return sin(temp);
            if (this->has_pseudo()) {
                return sin(this->arg->remove_pseudo());
            }
            return this->shared_from_this();
        }

//------------------------------------------------------------------------------
@@ -419,12 +417,10 @@ namespace graph {
///  @returns A tree without variable nodes.
//------------------------------------------------------------------------------
        virtual shared_leaf<T, SAFE_MATH> remove_pseudo() {
            auto temp = this->arg->remove_pseudo();
            if (temp->is_match(this->arg)) {
                return this->shared_from_this();
            } else {
                return cos(temp);
            if (this->has_pseudo()) {
                return cos(this->arg->remove_pseudo());
            }
            return this->shared_from_this();
        }

//------------------------------------------------------------------------------
@@ -668,14 +664,11 @@ namespace graph {
///  @returns A tree without variable nodes.
//------------------------------------------------------------------------------
        virtual shared_leaf<T, SAFE_MATH> remove_pseudo() {
            auto l = this->left->remove_pseudo();
            auto r = this->right->remove_pseudo();
            if (l->is_match(this->left) &&
                r->is_match(this->right)) {
                return this->shared_from_this();
            } else {
                return atan(l, r);
            if (this->has_pseudo()) {
                return atan(this->left->remove_pseudo(),
                            this->right->remove_pseudo());
            }
            return this->shared_from_this();
        }

//------------------------------------------------------------------------------