Commit 4edde634 authored by Cianciosa, Mark's avatar Cianciosa, Mark
Browse files

Add an index node. This can be used a potential code to index into a variable....

Add an index node. This can be used a potential code to index into a variable. Indexing is read only currently. Add the ability to reduce piece wise nodes with constant arguments. Not this can also be done for 2D but that will require updating the backend operations to extract 1D indexes for cases when only one index is constant.
parent c138bdba
Loading
Loading
Loading
Loading
+329 −3
Original line number Diff line number Diff line
@@ -90,12 +90,12 @@ void compile_index(std::ostringstream &stream,
//------------------------------------------------------------------------------
    template<jit::float_scalar T, bool SAFE_MATH=false>
    class piecewise_1D_node final : public straight_node<T, SAFE_MATH> {
    private:
///  Scale factor for the argument.
        const T scale;
///  Offset factor for the argument.
        const T offset;

    private:
//------------------------------------------------------------------------------
///  @brief Convert node pointer to a string.
///
@@ -190,6 +190,13 @@ void compile_index(std::ostringstream &stream,
///  @returns A reduced representation of the node.
//------------------------------------------------------------------------------
        virtual shared_leaf<T, SAFE_MATH> reduce() {
            if (constant_cast(this->arg).get()) {
                const T arg = (this->evaluate().at(0) + offset)/scale;
                const size_t i = std::min(static_cast<size_t> (std::real(arg)),
                                          this->get_size() - 1);
                return constant<T, SAFE_MATH> (leaf_node<T, SAFE_MATH>::caches.backends[data_hash][i]);
            }

            if (evaluate().is_same()) {
                return constant<T, SAFE_MATH> (evaluate().at(0));
            }
@@ -291,8 +298,8 @@ void compile_index(std::ostringstream &stream,
//------------------------------------------------------------------------------
///  @brief Compile the node.
///
///  This node first evaluates the value of the argument then chooses the correct
///  piecewise index. This assumes that the argument is
///  This node first evaluates the value of the argument then chooses the
///  correct piecewise index. This assumes that the argument is
///
///    x' = (x - xmin)/dx                                                    (1)
///
@@ -1323,6 +1330,325 @@ void compile_index(std::ostringstream &stream,
    shared_piecewise_2D<T, SAFE_MATH> piecewise_2D_cast(shared_leaf<T, SAFE_MATH> x) {
        return std::dynamic_pointer_cast<piecewise_2D_node<T, SAFE_MATH>> (x);
    }

//******************************************************************************
//  1D Index node.
//******************************************************************************
//------------------------------------------------------------------------------
///  @brief Class representing a 1D index.
///
///  This class is used to implement indexes into an array.
///
///  Indicies are selected by
///
///    x_norm' = (x - xmin)/dx                                               (1)
///
///  @tparam T         Base type of the calculation.
///  @tparam SAFE_MATH Use @ref general_concepts_safe_math operations.
//------------------------------------------------------------------------------
    template<jit::float_scalar T, bool SAFE_MATH=false>
    class index_1D_node final : public branch_node<T, SAFE_MATH> {
    private:
///  Scale factor for the argument.
        const T scale;
///  Offset factor for the argument.
        const T offset;

//------------------------------------------------------------------------------
///  @brief Convert node pointer to a string with the argument.
///
///  @param[in] v Value to index.
///  @param[in] x Argument.
///  @return A string rep of the node.
//------------------------------------------------------------------------------
        static std::string to_string(shared_leaf<T, SAFE_MATH> v,
                                     shared_leaf<T, SAFE_MATH> x) {
            return jit::format_to_string(v->get_hash()) + "[" +
                   jit::format_to_string(x->get_hash()) + "]";
        }

    public:
//------------------------------------------------------------------------------
///  @brief Construct a 1D index.
///
///  @param[in] var    Node to index.
///  @param[in] x      Argument.
///  @param[in] scale  Scale factor for the argument.
///  @param[in] offset Offset factor for the argument.
//------------------------------------------------------------------------------
        index_1D_node(shared_leaf<T, SAFE_MATH> var,
                      shared_leaf<T, SAFE_MATH> x,
                      const T scale,
                      const T offset) :
        branch_node<T, SAFE_MATH> (var, x, index_1D_node::to_string(var, x)),
        scale(scale), offset(offset) {}

//------------------------------------------------------------------------------
///  @brief Evaluate the results of the piecewise constant.
///
///  Evaluate functions are only used by the minimization. So this node does not
///  evaluate the argument. Instead this only returns the data as if it were a
///  constant.
///
///  @returns The evaluated value of the node.
//------------------------------------------------------------------------------
        virtual backend::buffer<T> evaluate() {
            return this->right->evaluate();
        }

//------------------------------------------------------------------------------
///  @brief Reduction method.
///
///  If all the values in the data buffer are the same. Reduce to a single
///  constant.
///
///  @returns A reduced representation of the node.
//------------------------------------------------------------------------------
        virtual shared_leaf<T, SAFE_MATH> reduce() {
            return this->shared_from_this();
        }

//------------------------------------------------------------------------------
///  @brief Transform node to derivative.
///
///  @param[in] x The variable to take the derivative to.
///  @return The derivative of the node.
//------------------------------------------------------------------------------
        virtual shared_leaf<T, SAFE_MATH> df(shared_leaf<T, SAFE_MATH> x) {
            return constant<T, SAFE_MATH> (static_cast<T> (this->is_match(x)));
        }

//------------------------------------------------------------------------------
///  @brief the node.
///
///  This node first evaluates the value of the argument then chooses the
///  correct index of the variable.
///
///    x' = (x - xmin)/dx                                                    (1)
///
///  @param[in,out] stream    String buffer stream.
///  @param[in,out] registers List of defined registers.
///  @param[in,out] indices   List of defined indices.
///  @param[in]     usage     List of register usage count.
///  @returns The current node.
//------------------------------------------------------------------------------
        virtual shared_leaf<T, SAFE_MATH>
        compile(std::ostringstream &stream,
                jit::register_map &registers,
                jit::register_map &indices,
                const jit::register_usage &usage) {
            if (registers.find(this) == registers.end()) {
#ifdef USE_INDEX_CACHE
                if (indices.find(this->right.get()) == indices.end()) {
#endif
                    const size_t length = variable_cast(this->left)->size();
                    shared_leaf<T, SAFE_MATH> a = this->right->compile(stream,
                                                                       registers,
                                                                       indices,
                                                                       usage);
#ifdef USE_INDEX_CACHE
                    indices[a.get()] = jit::to_string('i', a.get());
                    stream << "        const "
                           << jit::smallest_int_type<T> (length) << " "
                           << indices[a.get()] << " = ";
                    compile_index<T> (stream, registers[a.get()], length,
                                      scale, offset);
                    a->endline(stream, usage);
                }
#endif

                registers[this] = jit::to_string('r', this);
                stream << "        const ";
                jit::add_type<T> (stream);
                auto var = this->left->compile(stream,
                                                registers,
                                                indices,
                                                usage);
                stream << " " << registers[this] << " = "
                       << jit::to_string('v', var.get());
#ifdef USE_INDEX_CACHE
                stream << "[" << indices[this->right.get()] << "]";
#else
                stream << "[";
                compile_index<T> (stream, registers[a.get()], length,
                                  scale, offset);
                stream << "]";
#endif
                this->endline(stream, usage);
            }

            return this->shared_from_this();
        }

//------------------------------------------------------------------------------
///  @brief Convert the node to latex.
//------------------------------------------------------------------------------
        virtual void to_latex() const {
            std::cout << "r\\_" << reinterpret_cast<size_t> (this->left.get())
                      << "\\left[i\\_"
                      << reinterpret_cast<size_t> (this->right.get())
                      << "\\right]";
        }

//------------------------------------------------------------------------------
///  @brief Convert the node to vizgraph.
///
///  @param[in,out] stream    String buffer stream.
///  @param[in,out] registers List of defined registers.
///  @returns The current node.
//------------------------------------------------------------------------------
        virtual shared_leaf<T, SAFE_MATH> to_vizgraph(std::stringstream &stream,
                                                      jit::register_map &registers) {
            if (registers.find(this) == registers.end()) {
                const std::string name = jit::to_string('r', this);
                registers[this] = name;
                stream << "    " << name
                       << " [label = \"r_" << reinterpret_cast<size_t> (this->left.get())
                       << "\", shape = hexagon, style = filled, fillcolor = black, fontcolor = white];" << std::endl;

                auto l = this->left->to_vizgraph(stream, registers);
                stream << "    " << name << " -- " << registers[l.get()] << ";" << std::endl;
                auto r = this->right->to_vizgraph(stream, registers);
                stream << "    " << name << " -- " << registers[r.get()] << ";" << std::endl;
            }

            return this->shared_from_this();
        }

//------------------------------------------------------------------------------
///  @brief Test if node is a constant.
///
///  @returns True if the node is a constant.
//------------------------------------------------------------------------------
        virtual bool is_constant() const {
            return false;
        }

//------------------------------------------------------------------------------
///  @brief Test if node acts like a variable.
///
///  @returns True if the node acts like a variable.
//------------------------------------------------------------------------------
        virtual bool is_all_variables() const {
            return false;
        }

//------------------------------------------------------------------------------
///  @brief Test if the node acts like a power of variable.
///
///  @returns True.
//------------------------------------------------------------------------------
        virtual bool is_power_like() const {
            return true;
        }

//------------------------------------------------------------------------------
///  @brief Get the exponent of a power.
///
///  @returns The exponent of a power like node.
//------------------------------------------------------------------------------
        virtual shared_leaf<T, SAFE_MATH> get_power_exponent() const {
            return one<T, SAFE_MATH> ();
        }

//------------------------------------------------------------------------------
///  @brief Check if the args match.
///
///  @param[in] x Node to match.
///  @returns True if the arguments match.
//------------------------------------------------------------------------------
        bool is_arg_match(shared_leaf<T, SAFE_MATH> x) {
            auto temp = index_1D_cast(x);
            return temp.get()                               &&
                   this->right->is_match(temp->get_right()) &&
                   (temp->get_size() == this->get_size())   &&
                   (temp->get_scale() == this->scale)       &&
                   (temp->get_offset() == this->offset);
        }

//------------------------------------------------------------------------------
///  @brief Get x argument scale.
///
///  @returns The scale factor for x.
//------------------------------------------------------------------------------
        T get_scale() const {
            return scale;
        }

//------------------------------------------------------------------------------
///  @brief Get x argument offset.
///
///  @returns The offset factor for x.
//------------------------------------------------------------------------------
        T get_offset() const {
            return offset;
        }

//------------------------------------------------------------------------------
///  @brief Get the size of the buffer.
///
///  @returns The size of the buffer.
//------------------------------------------------------------------------------
        size_t get_size() const {
            return this->left->size();
        }
    };

//------------------------------------------------------------------------------
///  @brief Define index_1D convenience function.
///
///  @tparam T         Base type of the calculation.
///  @tparam SAFE_MATH Use @ref general_concepts_safe_math operations.
///
///  @param[in] v      Variable to index.
///  @param[in] x      Argument.
///  @param[in] scale  Argument scale factor.
///  @param[in] offset Argument offset factor.
///  @returns A reduced piecewise_1D node.
//------------------------------------------------------------------------------
    template<jit::float_scalar T, bool SAFE_MATH=false>
    shared_leaf<T, SAFE_MATH> index_1D(shared_leaf<T, SAFE_MATH> v,
                                       shared_leaf<T, SAFE_MATH> x,
                                       const T scale,
                                       const T offset) {
        assert(variable_cast(v) && "index_1D requires a variable node for first arg.");
        auto temp = std::make_shared<index_1D_node<T, SAFE_MATH>> (v, x,
                                                                   scale,
                                                                   offset)->reduce();
//  Test for hash collisions.
        for (size_t i = temp->get_hash(); i < std::numeric_limits<size_t>::max(); i++) {
            if (leaf_node<T, SAFE_MATH>::caches.nodes.find(i) ==
                leaf_node<T, SAFE_MATH>::caches.nodes.end()) {
                leaf_node<T, SAFE_MATH>::caches.nodes[i] = temp;
                return temp;
            } else if (temp->is_match(leaf_node<T, SAFE_MATH>::caches.nodes[i])) {
                return leaf_node<T, SAFE_MATH>::caches.nodes[i];
            }
        }
#if defined(__clang__) || defined(__GNUC__)
        __builtin_unreachable();
#else
        assert(false && "Should never reach.");
#endif
    }

///  Convenience type alias for shared piecewise 1D nodes.
    template<jit::float_scalar T, bool SAFE_MATH=false>
    using shared_index_1D = std::shared_ptr<index_1D_node<T, SAFE_MATH>>;

//------------------------------------------------------------------------------
///  @brief Cast to a piecewise 1D node.
///
///  @tparam T         Base type of the calculation.
///  @tparam SAFE_MATH Use @ref general_concepts_safe_math operations.
///
///  @param[in] x Leaf node to attempt cast.
///  @returns An attempted dynamic cast.
//------------------------------------------------------------------------------
    template<jit::float_scalar T, bool SAFE_MATH=false>
    shared_index_1D<T, SAFE_MATH> index_1D_cast(shared_leaf<T, SAFE_MATH> x) {
        return std::dynamic_pointer_cast<index_1D_node<T, SAFE_MATH>> (x);
    }
}

#endif /* piecewise_h */
+30 −0
Original line number Diff line number Diff line
@@ -93,6 +93,15 @@ template<jit::float_scalar T> void piecewise_1D() {
                                                       static_cast<T> (6.0)}),
                                      a, 1.0, 0.0);

    auto c = graph::constant<T> (static_cast<T> (2.5));
    auto pconst = graph::piecewise_1D<T> (std::vector<T> ({static_cast<T> (2.0),
                                                           static_cast<T> (4.0),
                                                           static_cast<T> (6.0)}),
                                          c, 1.0, 0.0);
    auto pc_cast = constant_cast(pconst);
    assert(pc_cast.get() && "Expected a constant node.");
    assert(pc_cast->is(6.0) && "Expected a value of 6");

    assert(graph::constant_cast(p1*0.0).get() &&
           "Expected a constant node.");

@@ -687,6 +696,26 @@ template<jit::float_scalar T> void piecewise_2D() {
           "Expected p1 - p3 on the left.");
}

//------------------------------------------------------------------------------
///  @brief Tests for 1D index nodes.
///
///  @tparam T Base type of the calculation.
//------------------------------------------------------------------------------
template<jit::float_scalar T> void index_1D() {
    auto variable = graph::variable<T> (11, "");
    auto arg = graph::variable<T> (1, "");

    auto index = graph::index_1D<T> (variable, arg, 1.0, 0.0);

    variable->set({0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0});
    arg->set(static_cast<T> (3.5));

    compile<T> ({graph::variable_cast(variable),
                 graph::variable_cast(arg)},
                {index}, {},
                static_cast<T> (3.0), 0.0);
}

//------------------------------------------------------------------------------
///  @brief Run tests with a specified backend.
///
@@ -695,6 +724,7 @@ template<jit::float_scalar T> void piecewise_2D() {
template<jit::float_scalar T> void run_tests() {
    piecewise_1D<T> ();
    piecewise_2D<T> ();
    index_1D<T> ();
}

//------------------------------------------------------------------------------