diff --git a/graph_c_binding/graph_c_binding.cpp b/graph_c_binding/graph_c_binding.cpp index 54b4d13a12d52e5065d8ad2805976ec6739430f8..8da196f499df550fd8a2e1a2d6fad3555c0aff47 100644 --- a/graph_c_binding/graph_c_binding.cpp +++ b/graph_c_binding/graph_c_binding.cpp @@ -1676,6 +1676,98 @@ extern "C" { } } +//------------------------------------------------------------------------------ +/// @brief Create a 1D index. +/// +/// @param[in] c The graph C context. +/// @param[in] variable The variable to index. +/// @param[in] x_arg The function x argument. +/// @param[in] x_scale Scale factor x argument. +/// @param[in] x_offset Offset factor x argument. +/// @returns A 1D index node. +//------------------------------------------------------------------------------ + graph_node graph_index_1D(STRUCT_TAG graph_c_context *c, + graph_node variable, + graph_node x_arg, + const double x_scale, + const double x_offset) { + switch (c->type) { + case FLOAT: + if (c->safe_math) { + auto d = reinterpret_cast *> (c); + auto temp = graph::index_1D(d->nodes[variable], + d->nodes[x_arg], + static_cast (x_scale), + static_cast (x_offset)); + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + auto d = reinterpret_cast *> (c); + auto temp = graph::index_1D(d->nodes[variable], + d->nodes[x_arg], + static_cast (x_scale), + static_cast (x_offset)); + d->nodes[temp.get()] = temp; + return temp.get(); + } + + case DOUBLE: + if (c->safe_math) { + auto d = reinterpret_cast *> (c); + auto temp = graph::index_1D(d->nodes[variable], + d->nodes[x_arg], + x_scale, x_offset); + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + auto d = reinterpret_cast *> (c); + auto temp = graph::index_1D(d->nodes[variable], + d->nodes[x_arg], + x_scale, x_offset); + d->nodes[temp.get()] = temp; + return temp.get(); + } + + case COMPLEX_FLOAT: + if (c->safe_math) { + auto d = reinterpret_cast, true> *> (c); + auto temp = graph::index_1D(d->nodes[variable], + d->nodes[x_arg], + std::complex (x_scale), + std::complex (x_offset)); + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + auto d = reinterpret_cast> *> (c); + auto temp = graph::index_1D(d->nodes[variable], + d->nodes[x_arg], + std::complex (x_scale), + std::complex (x_offset)); + d->nodes[temp.get()] = temp; + return temp.get(); + } + + case COMPLEX_DOUBLE: + if (c->safe_math) { + auto d = reinterpret_cast, true> *> (c); + auto temp = graph::index_1D(d->nodes[variable], + d->nodes[x_arg], + std::complex (x_scale), + std::complex (x_offset)); + d->nodes[temp.get()] = temp; + return temp.get(); + } else { + auto d = reinterpret_cast> *> (c); + auto temp = graph::index_1D(d->nodes[variable], + d->nodes[x_arg], + std::complex (x_scale), + std::complex (x_offset)); + d->nodes[temp.get()] = temp; + return temp.get(); + } + } + } + //****************************************************************************** // JIT //****************************************************************************** diff --git a/graph_c_binding/graph_c_binding.h b/graph_c_binding/graph_c_binding.h index 8de6ddefea7872ab03210660f82d33647f808b31..39c56588da6122e98e63ee0b74059271143e9cff 100644 --- a/graph_c_binding/graph_c_binding.h +++ b/graph_c_binding/graph_c_binding.h @@ -445,6 +445,22 @@ extern "C" { const void *source, const size_t source_size); +//------------------------------------------------------------------------------ +/// @brief Create a 1D index. +/// +/// @param[in] c The graph C context. +/// @param[in] variable The variable to index. +/// @param[in] x_arg The function x argument. +/// @param[in] x_scale Scale factor x argument. +/// @param[in] x_offset Offset factor x argument. +/// @returns A 1D index node. +//------------------------------------------------------------------------------ + graph_node graph_index_1D(STRUCT_TAG graph_c_context *c, + graph_node variable, + graph_node x_arg, + const double x_scale, + const double x_offset); + //------------------------------------------------------------------------------ /// @brief Create 2D piecewise node with complex arguments. /// diff --git a/graph_fortran_binding/graph_fortran_binding.f90 b/graph_fortran_binding/graph_fortran_binding.f90 index c34de7cee3bf2bb8ec64b1f048b44f28d5683e10..de9c3f459ff9ef2ba42b8381e0243e8753f7054c 100644 --- a/graph_fortran_binding/graph_fortran_binding.f90 +++ b/graph_fortran_binding/graph_fortran_binding.f90 @@ -145,6 +145,7 @@ piecewise_2D_double, & piecewise_2D_cfloat, & piecewise_2D_cdouble + PROCEDURE :: index_1D => graph_context_index_1D PROCEDURE :: get_max_concurrency => graph_context_get_max_concurrency PROCEDURE :: set_device_number => graph_context_set_device_number PROCEDURE :: add_pre_item => graph_context_add_pre_item @@ -593,7 +594,7 @@ END FUNCTION !------------------------------------------------------------------------------- -!> @brief Create 1D piecewise node with complex double buffer. +!> @brief Create 1D piecewise node. !> !> @param[in] c The graph C context. !> @param[in] arg The left operand. @@ -650,6 +651,27 @@ INTEGER(C_LONG), VALUE :: source_size END FUNCTION +!------------------------------------------------------------------------------- +!> @brief Create 1D index node. +!> +!> @param[in] c The graph C context. +!> @param[in] variable The variable to index. +!> @param[in] arg The left operand. +!> @param[in] scale Scale factor argument. +!> @param[in] offset Offset factor argument. +!> @returns A 1D piecewise node. +!------------------------------------------------------------------------------- + TYPE(C_PTR) FUNCTION graph_index_1D(c, variable, arg, scale, offset) & + BIND(C, NAME='graph_index_1D') + USE, INTRINSIC :: ISO_C_BINDING + IMPLICIT NONE + TYPE(C_PTR), VALUE :: c + TYPE(C_PTR), VALUE :: variable + TYPE(C_PTR), VALUE :: arg + REAL(C_DOUBLE), VALUE :: scale + REAL(C_DOUBLE), VALUE :: offset + END FUNCTION + !------------------------------------------------------------------------------- !> @brief Get the maximum number of concurrent devices. !> @@ -1583,7 +1605,7 @@ !> @param[in] scale Scale factor argument. !> @param[in] offset Offset factor argument. !> @param[in] source Source buffer to fill elements. -!> @returns random(state) +!> @returns piecewise_1D node. !------------------------------------------------------------------------------- FUNCTION graph_context_piecewise_1D_float(this, arg, scale, offset, & source) @@ -1613,7 +1635,7 @@ !> @param[in] scale Scale factor argument. !> @param[in] offset Offset factor argument. !> @param[in] source Source buffer to fill elements. -!> @returns random(state) +!> @returns piecewise_1D node. !------------------------------------------------------------------------------- FUNCTION graph_context_piecewise_1D_double(this, arg, scale, offset, & source) @@ -1643,7 +1665,7 @@ !> @param[in] scale Scale factor argument. !> @param[in] offset Offset factor argument. !> @param[in] source Source buffer to fill elements. -!> @returns random(state) +!> @returns piecewise_1D node. !------------------------------------------------------------------------------- FUNCTION graph_context_piecewise_1D_cfloat(this, arg, scale, offset, & source) @@ -1673,7 +1695,7 @@ !> @param[in] scale Scale factor argument. !> @param[in] offset Offset factor argument. !> @param[in] source Source buffer to fill elements. -!> @returns random(state) +!> @returns piecewise_1D node. !------------------------------------------------------------------------------- FUNCTION graph_context_piecewise_1D_cdouble(this, arg, scale, offset, & source) @@ -1706,7 +1728,7 @@ !> @param[in] y_scale Scale factor for y argument. !> @param[in] y_offset Offset factor for y argument. !> @param[in] source Source buffer to fill elements. -!> @returns random(state) +!> @returns piecewise_2D node. !------------------------------------------------------------------------------- FUNCTION graph_context_piecewise_2D_float(this, & x_arg, x_scale, x_offset, & @@ -1747,7 +1769,7 @@ !> @param[in] y_scale Scale factor for y argument. !> @param[in] y_offset Offset factor for y argument. !> @param[in] source Source buffer to fill elements. -!> @returns random(state) +!> @returns piecewise_2D node. !------------------------------------------------------------------------------- FUNCTION graph_context_piecewise_2D_double(this, & x_arg, x_scale, x_offset, & @@ -1788,7 +1810,7 @@ !> @param[in] y_scale Scale factor for y argument. !> @param[in] y_offset Offset factor for y argument. !> @param[in] source Source buffer to fill elements. -!> @returns random(state) +!> @returns piecewise_2D node. !------------------------------------------------------------------------------- FUNCTION graph_context_piecewise_2D_cfloat(this, & x_arg, x_scale, x_offset, & @@ -1829,7 +1851,7 @@ !> @param[in] y_scale Scale factor for y argument. !> @param[in] y_offset Offset factor for y argument. !> @param[in] source Source buffer to fill elements. -!> @returns random(state) +!> @returns piecewise_2D node. !------------------------------------------------------------------------------- FUNCTION graph_context_piecewise_2D_cdouble(this, & x_arg, x_scale, x_offset, & @@ -1859,6 +1881,34 @@ END FUNCTION +!------------------------------------------------------------------------------- +!> @brief Create 1D index node with float buffer. +!> +!> @param[in,out] this @ref graph_context instance. +!> @param[in] variable The variable +!> @param[in] arg The function argument. +!> @param[in] scale Scale factor argument. +!> @param[in] offset Offset factor argument. +!> @returns index_1D node. +!------------------------------------------------------------------------------- + FUNCTION graph_context_index_1D(this, variable, arg, scale, offset) + + IMPLICIT NONE + +! Declare Arguments + TYPE(C_PTR) :: graph_context_index_1D + CLASS(graph_context), INTENT(INOUT) :: this + TYPE(C_PTR), INTENT(IN) :: variable + TYPE(C_PTR), INTENT(IN) :: arg + REAL(C_DOUBLE) :: scale + REAL(C_DOUBLE) :: offset + +! Start of executable. + graph_context_index_1D = & + graph_index_1D(this%c_context, variable, arg, scale, offset) + + END FUNCTION + !******************************************************************************* ! JIT !******************************************************************************* diff --git a/graph_framework/backend.hpp b/graph_framework/backend.hpp index 7ceb09a4cbb4bb2720e245d8d1d95b585804726f..39170d0c7cfc33a054c1733c040ada6d379ac412 100644 --- a/graph_framework/backend.hpp +++ b/graph_framework/backend.hpp @@ -292,6 +292,38 @@ namespace backend { return true; } +//------------------------------------------------------------------------------ +/// @brief Index row. +/// +/// @param[in] index The row index. +/// @param[in] num_columns The number of coils. +/// @returns A buffer containing the row. +//------------------------------------------------------------------------------ + buffer index_row(const size_t index, const size_t num_columns) { + buffer b(num_columns); + const size_t num_rows = size()/num_columns; + for (size_t j = 0; j < num_columns; j++) { + b[j] = memory[index*num_rows + j]; + } + return b; + } + +//------------------------------------------------------------------------------ +/// @brief Index column. +/// +/// @param[in] index The row index. +/// @param[in] num_columns The number of coils. +/// @returns A buffer containing the row. +//------------------------------------------------------------------------------ + buffer index_column(const size_t index, const size_t num_columns) { + const size_t num_rows = size()/num_columns; + buffer b(num_rows); + for (size_t i = 0; i < num_rows; i++) { + b[i] = memory[i*num_rows + index]; + } + return b; + } + //------------------------------------------------------------------------------ /// @brief Add row operation. /// diff --git a/graph_framework/piecewise.hpp b/graph_framework/piecewise.hpp index 7b1b9f9870625158cf54939269e56d2fe232dcc4..f5960a22a3902eaa4fe38ddb077de293511dbc41 100644 --- a/graph_framework/piecewise.hpp +++ b/graph_framework/piecewise.hpp @@ -90,12 +90,12 @@ void compile_index(std::ostringstream &stream, //------------------------------------------------------------------------------ template class piecewise_1D_node final : public straight_node { + private: /// Scale factor for the argument. const T scale; /// Offset factor for the argument. const T offset; - private: //------------------------------------------------------------------------------ /// @brief Convert node pointer to a string. /// @@ -190,6 +190,13 @@ void compile_index(std::ostringstream &stream, /// @returns A reduced representation of the node. //------------------------------------------------------------------------------ virtual shared_leaf reduce() { + if (constant_cast(this->arg).get()) { + const T arg = (this->arg->evaluate().at(0) + offset)/scale; + const size_t i = std::min(static_cast (std::real(arg)), + this->get_size() - 1); + return constant (leaf_node::caches.backends[data_hash][i]); + } + if (evaluate().is_same()) { return constant (evaluate().at(0)); } @@ -291,8 +298,8 @@ void compile_index(std::ostringstream &stream, //------------------------------------------------------------------------------ /// @brief Compile the node. /// -/// This node first evaluates the value of the argument then chooses the correct -/// piecewise index. This assumes that the argument is +/// This node first evaluates the value of the argument then chooses the +/// correct piecewise index. This assumes that the argument is /// /// x' = (x - xmin)/dx (1) /// @@ -824,6 +831,31 @@ void compile_index(std::ostringstream &stream, /// @returns A reduced representation of the node. //------------------------------------------------------------------------------ virtual shared_leaf reduce() { + if (constant_cast(this->left).get() && + constant_cast(this->right).get()) { + const T l = (this->left->evaluate().at(0) + x_offset)/x_scale; + const T r = (this->right->evaluate().at(0) + y_offset)/y_scale; + const size_t i = std::min(static_cast (std::real(l)), + this->get_num_rows() - 1); + const size_t j = std::min(static_cast (std::real(r)), + this->get_num_columns() - 1); + return constant (leaf_node::caches.backends[data_hash][i*this->get_num_columns() + j]); + } else if (constant_cast(this->left).get()) { + const T l = (this->left->evaluate().at(0) + x_offset)/x_scale; + const size_t i = std::min(static_cast (std::real(l)), + this->get_num_rows() - 1); + + return piecewise_1D(leaf_node::caches.backends[data_hash].index_row(i, this->get_num_columns()), + this->right, y_scale, y_offset); + } else if (constant_cast(this->right).get()) { + const T r = (this->right->evaluate().at(0) + y_offset)/y_scale; + const size_t j = std::min(static_cast (std::real(r)), + this->get_num_columns() - 1); + + return piecewise_1D(leaf_node::caches.backends[data_hash].index_column(j, this->get_num_columns()), + this->left, x_scale, x_offset); + } + if (evaluate().is_same()) { return constant (evaluate().at(0)); } @@ -1323,6 +1355,325 @@ void compile_index(std::ostringstream &stream, shared_piecewise_2D piecewise_2D_cast(shared_leaf x) { return std::dynamic_pointer_cast> (x); } + +//****************************************************************************** +// 1D Index node. +//****************************************************************************** +//------------------------------------------------------------------------------ +/// @brief Class representing a 1D index. +/// +/// This class is used to implement indexes into an array. +/// +/// Indicies are selected by +/// +/// x_norm' = (x - xmin)/dx (1) +/// +/// @tparam T Base type of the calculation. +/// @tparam SAFE_MATH Use @ref general_concepts_safe_math operations. +//------------------------------------------------------------------------------ + template + class index_1D_node final : public branch_node { + private: +/// Scale factor for the argument. + const T scale; +/// Offset factor for the argument. + const T offset; + +//------------------------------------------------------------------------------ +/// @brief Convert node pointer to a string with the argument. +/// +/// @param[in] v Value to index. +/// @param[in] x Argument. +/// @return A string rep of the node. +//------------------------------------------------------------------------------ + static std::string to_string(shared_leaf v, + shared_leaf x) { + return jit::format_to_string(v->get_hash()) + "[" + + jit::format_to_string(x->get_hash()) + "]"; + } + + public: +//------------------------------------------------------------------------------ +/// @brief Construct a 1D index. +/// +/// @param[in] var Node to index. +/// @param[in] x Argument. +/// @param[in] scale Scale factor for the argument. +/// @param[in] offset Offset factor for the argument. +//------------------------------------------------------------------------------ + index_1D_node(shared_leaf var, + shared_leaf x, + const T scale, + const T offset) : + branch_node (var, x, index_1D_node::to_string(var, x)), + scale(scale), offset(offset) {} + +//------------------------------------------------------------------------------ +/// @brief Evaluate the results of the piecewise constant. +/// +/// Evaluate functions are only used by the minimization. So this node does not +/// evaluate the argument. Instead this only returns the data as if it were a +/// constant. +/// +/// @returns The evaluated value of the node. +//------------------------------------------------------------------------------ + virtual backend::buffer evaluate() { + return this->right->evaluate(); + } + +//------------------------------------------------------------------------------ +/// @brief Reduction method. +/// +/// If all the values in the data buffer are the same. Reduce to a single +/// constant. +/// +/// @returns A reduced representation of the node. +//------------------------------------------------------------------------------ + virtual shared_leaf reduce() { + return this->shared_from_this(); + } + +//------------------------------------------------------------------------------ +/// @brief Transform node to derivative. +/// +/// @param[in] x The variable to take the derivative to. +/// @return The derivative of the node. +//------------------------------------------------------------------------------ + virtual shared_leaf df(shared_leaf x) { + return constant (static_cast (this->is_match(x))); + } + +//------------------------------------------------------------------------------ +/// @brief the node. +/// +/// This node first evaluates the value of the argument then chooses the +/// correct index of the variable. +/// +/// x' = (x - xmin)/dx (1) +/// +/// @param[in,out] stream String buffer stream. +/// @param[in,out] registers List of defined registers. +/// @param[in,out] indices List of defined indices. +/// @param[in] usage List of register usage count. +/// @returns The current node. +//------------------------------------------------------------------------------ + virtual shared_leaf + compile(std::ostringstream &stream, + jit::register_map ®isters, + jit::register_map &indices, + const jit::register_usage &usage) { + if (registers.find(this) == registers.end()) { +#ifdef USE_INDEX_CACHE + if (indices.find(this->right.get()) == indices.end()) { +#endif + const size_t length = variable_cast(this->left)->size(); + shared_leaf a = this->right->compile(stream, + registers, + indices, + usage); +#ifdef USE_INDEX_CACHE + indices[a.get()] = jit::to_string('i', a.get()); + stream << " const " + << jit::smallest_int_type (length) << " " + << indices[a.get()] << " = "; + compile_index (stream, registers[a.get()], length, + scale, offset); + a->endline(stream, usage); + } +#endif + + registers[this] = jit::to_string('r', this); + stream << " const "; + jit::add_type (stream); + auto var = this->left->compile(stream, + registers, + indices, + usage); + stream << " " << registers[this] << " = " + << jit::to_string('v', var.get()); +#ifdef USE_INDEX_CACHE + stream << "[" << indices[this->right.get()] << "]"; +#else + stream << "["; + compile_index (stream, registers[a.get()], length, + scale, offset); + stream << "]"; +#endif + this->endline(stream, usage); + } + + return this->shared_from_this(); + } + +//------------------------------------------------------------------------------ +/// @brief Convert the node to latex. +//------------------------------------------------------------------------------ + virtual void to_latex() const { + std::cout << "r\\_" << reinterpret_cast (this->left.get()) + << "\\left[i\\_" + << reinterpret_cast (this->right.get()) + << "\\right]"; + } + +//------------------------------------------------------------------------------ +/// @brief Convert the node to vizgraph. +/// +/// @param[in,out] stream String buffer stream. +/// @param[in,out] registers List of defined registers. +/// @returns The current node. +//------------------------------------------------------------------------------ + virtual shared_leaf to_vizgraph(std::stringstream &stream, + jit::register_map ®isters) { + if (registers.find(this) == registers.end()) { + const std::string name = jit::to_string('r', this); + registers[this] = name; + stream << " " << name + << " [label = \"r_" << reinterpret_cast (this->left.get()) + << "\", shape = hexagon, style = filled, fillcolor = black, fontcolor = white];" << std::endl; + + auto l = this->left->to_vizgraph(stream, registers); + stream << " " << name << " -- " << registers[l.get()] << ";" << std::endl; + auto r = this->right->to_vizgraph(stream, registers); + stream << " " << name << " -- " << registers[r.get()] << ";" << std::endl; + } + + return this->shared_from_this(); + } + +//------------------------------------------------------------------------------ +/// @brief Test if node is a constant. +/// +/// @returns True if the node is a constant. +//------------------------------------------------------------------------------ + virtual bool is_constant() const { + return false; + } + +//------------------------------------------------------------------------------ +/// @brief Test if node acts like a variable. +/// +/// @returns True if the node acts like a variable. +//------------------------------------------------------------------------------ + virtual bool is_all_variables() const { + return false; + } + +//------------------------------------------------------------------------------ +/// @brief Test if the node acts like a power of variable. +/// +/// @returns True. +//------------------------------------------------------------------------------ + virtual bool is_power_like() const { + return true; + } + +//------------------------------------------------------------------------------ +/// @brief Get the exponent of a power. +/// +/// @returns The exponent of a power like node. +//------------------------------------------------------------------------------ + virtual shared_leaf get_power_exponent() const { + return one (); + } + +//------------------------------------------------------------------------------ +/// @brief Check if the args match. +/// +/// @param[in] x Node to match. +/// @returns True if the arguments match. +//------------------------------------------------------------------------------ + bool is_arg_match(shared_leaf x) { + auto temp = index_1D_cast(x); + return temp.get() && + this->right->is_match(temp->get_right()) && + (temp->get_size() == this->get_size()) && + (temp->get_scale() == this->scale) && + (temp->get_offset() == this->offset); + } + +//------------------------------------------------------------------------------ +/// @brief Get x argument scale. +/// +/// @returns The scale factor for x. +//------------------------------------------------------------------------------ + T get_scale() const { + return scale; + } + +//------------------------------------------------------------------------------ +/// @brief Get x argument offset. +/// +/// @returns The offset factor for x. +//------------------------------------------------------------------------------ + T get_offset() const { + return offset; + } + +//------------------------------------------------------------------------------ +/// @brief Get the size of the buffer. +/// +/// @returns The size of the buffer. +//------------------------------------------------------------------------------ + size_t get_size() const { + return this->left->size(); + } + }; + +//------------------------------------------------------------------------------ +/// @brief Define index_1D convenience function. +/// +/// @tparam T Base type of the calculation. +/// @tparam SAFE_MATH Use @ref general_concepts_safe_math operations. +/// +/// @param[in] v Variable to index. +/// @param[in] x Argument. +/// @param[in] scale Argument scale factor. +/// @param[in] offset Argument offset factor. +/// @returns A reduced piecewise_1D node. +//------------------------------------------------------------------------------ + template + shared_leaf index_1D(shared_leaf v, + shared_leaf x, + const T scale, + const T offset) { + assert(variable_cast(v) && "index_1D requires a variable node for first arg."); + auto temp = std::make_shared> (v, x, + scale, + offset)->reduce(); +// Test for hash collisions. + for (size_t i = temp->get_hash(); i < std::numeric_limits::max(); i++) { + if (leaf_node::caches.nodes.find(i) == + leaf_node::caches.nodes.end()) { + leaf_node::caches.nodes[i] = temp; + return temp; + } else if (temp->is_match(leaf_node::caches.nodes[i])) { + return leaf_node::caches.nodes[i]; + } + } +#if defined(__clang__) || defined(__GNUC__) + __builtin_unreachable(); +#else + assert(false && "Should never reach."); +#endif + } + +/// Convenience type alias for shared piecewise 1D nodes. + template + using shared_index_1D = std::shared_ptr>; + +//------------------------------------------------------------------------------ +/// @brief Cast to a piecewise 1D node. +/// +/// @tparam T Base type of the calculation. +/// @tparam SAFE_MATH Use @ref general_concepts_safe_math operations. +/// +/// @param[in] x Leaf node to attempt cast. +/// @returns An attempted dynamic cast. +//------------------------------------------------------------------------------ + template + shared_index_1D index_1D_cast(shared_leaf x) { + return std::dynamic_pointer_cast> (x); + } } #endif /* piecewise_h */ diff --git a/graph_tests/c_binding_test.c b/graph_tests/c_binding_test.c index 7d9da5b2e40614109188a7d9a01d5b1f96d35003..a77ac69cac79a76d34f8637922e87e1076ac5b85 100644 --- a/graph_tests/c_binding_test.c +++ b/graph_tests/c_binding_test.c @@ -124,6 +124,7 @@ void run_tests(const enum graph_type type, graph_node p2; graph_node i = graph_variable(c_context, 1, "i"); graph_node j = graph_variable(c_context, 1, "j"); + graph_node variable = graph_variable(c_context, 3, "var"); switch (c_context->type) { case FLOAT: { float value1[3] = {2.0, 4.0, 6.0}; @@ -132,6 +133,7 @@ void run_tests(const enum graph_type type, float value4 = 2.5; graph_set_variable(c_context, i, &value3); graph_set_variable(c_context, j, &value4); + graph_set_variable(c_context, variable, &value1); p1 = graph_piecewise_1D(c_context, i, 1.0, 0.0, value1, 3); p2 = graph_piecewise_2D(c_context, 3, j, 1.0, 0.0, i, 1.0, 0.0, value2, 9); break; @@ -144,6 +146,7 @@ void run_tests(const enum graph_type type, double value4 = 2.5; graph_set_variable(c_context, i, &value3); graph_set_variable(c_context, j, &value4); + graph_set_variable(c_context, variable, &value1); p1 = graph_piecewise_1D(c_context, i, 1.0, 0.0, value1, 3); p2 = graph_piecewise_2D(c_context, 3, j, 1.0, 0.0, i, 1.0, 0.0, value2, 9); break; @@ -158,6 +161,7 @@ void run_tests(const enum graph_type type, float complex value4 = CMPLXF(2.5, 0.0); graph_set_variable(c_context, i, &value3); graph_set_variable(c_context, j, &value4); + graph_set_variable(c_context, variable, &value1); p1 = graph_piecewise_1D(c_context, i, 1.0, 0.0, value1, 3); p2 = graph_piecewise_2D(c_context, 3, j, 1.0, 0.0, i, 1.0, 0.0, value2, 9); break; @@ -172,14 +176,16 @@ void run_tests(const enum graph_type type, double complex value4 = CMPLXF(2.5, 0.0); graph_set_variable(c_context, i, &value3); graph_set_variable(c_context, j, &value4); + graph_set_variable(c_context, variable, &value1); p1 = graph_piecewise_1D(c_context, i, 1.0, 0.0, value1, 3); p2 = graph_piecewise_2D(c_context, 3, j, 1.0, 0.0, i, 1.0, 0.0, value2, 9); break; } } + graph_node i1 = graph_index_1D(c_context, variable, i, 1.0, 0.0); - graph_node inputs2[2] = {i, j}; - graph_node outputs2[2] = {p1, p2}; + graph_node inputs2[3] = {i, j, variable}; + graph_node outputs2[3] = {p1, p2, i1}; graph_node *map_inputs2 = NULL; graph_node *map_outputs2 = NULL; @@ -195,8 +201,8 @@ void run_tests(const enum graph_type type, map_inputs, map_outputs, 0, NULL, "c_binding", 1); graph_add_item(c_context, - inputs2, 2, - outputs2, 2, + inputs2, 3, + outputs2, 3, map_inputs2, map_outputs2, 0, NULL, "c_binding_piecewise", 1); graph_add_converge_item(c_context, &z, 1, @@ -239,7 +245,7 @@ void run_tests(const enum graph_type type, switch (c_context->type) { case FLOAT: { - float value[9]; + float value[10]; graph_copy_to_host(c_context, y, value); graph_copy_to_host(c_context, dydx, value + 1); graph_copy_to_host(c_context, dydm, value + 2); @@ -249,6 +255,7 @@ void run_tests(const enum graph_type type, graph_copy_to_host(c_context, z, value + 6); graph_copy_to_host(c_context, p1, value + 7); graph_copy_to_host(c_context, p2, value + 8); + graph_copy_to_host(c_context, i1, value + 9); assert(value[0] == 0.5f*2.0f + 0.2f && "Value of y does not match."); assert(value[1] == 0.5f && "Value of dydx does not match."); assert(value[2] == 2.0f && "Value of dydm does not match."); @@ -262,11 +269,12 @@ void run_tests(const enum graph_type type, assert(value[6] == 1.0f && "Value of root does not match."); assert(value[7] == 4.0f && "Value of p1 does not match."); assert(value[8] == 8.0f && "Value of p2 does not match."); + assert(value[9] == 4.0f && "Value of i1 does not match."); break; } case DOUBLE: { - double value[9]; + double value[10]; graph_copy_to_host(c_context, y, value); graph_copy_to_host(c_context, dydx, value + 1); graph_copy_to_host(c_context, dydm, value + 2); @@ -276,6 +284,7 @@ void run_tests(const enum graph_type type, graph_copy_to_host(c_context, z, value + 6); graph_copy_to_host(c_context, p1, value + 7); graph_copy_to_host(c_context, p2, value + 8); + graph_copy_to_host(c_context, i1, value + 9); assert(value[0] == 0.5*2.0 + 0.2 && "Value of y does not match."); assert(value[1] == 0.5 && "Value of dydx does not match."); assert(value[2] == 2.0 && "Value of dydm does not match."); @@ -289,11 +298,12 @@ void run_tests(const enum graph_type type, assert(value[6] == 1.0 && "Value of root does not match."); assert(value[7] == 4.0 && "Value of p1 does not match."); assert(value[8] == 8.0 && "Value of p2 does not match."); + assert(value[9] == 4.0 && "Value of i1 does not match."); break; } case COMPLEX_FLOAT: { - float complex value[9]; + float complex value[10]; graph_copy_to_host(c_context, y, value); graph_copy_to_host(c_context, dydx, value + 1); graph_copy_to_host(c_context, dydm, value + 2); @@ -303,6 +313,7 @@ void run_tests(const enum graph_type type, graph_copy_to_host(c_context, z, value + 6); graph_copy_to_host(c_context, p1, value + 7); graph_copy_to_host(c_context, p2, value + 8); + graph_copy_to_host(c_context, i1, value + 9); assert(crealf(value[0]) == 0.5f*2.0f + 0.2f && "Value of y does not match."); assert(crealf(value[1]) == 0.5f && "Value of dydx does not match."); assert(crealf(value[2]) == 2.0f && "Value of dydm does not match."); @@ -316,11 +327,12 @@ void run_tests(const enum graph_type type, assert(crealf(value[6]) == 1.0f && "Value of root does not match."); assert(crealf(value[7]) == 4.0f && "Value of p1 does not match."); assert(crealf(value[8]) == 8.0f && "Value of p2 does not match."); + assert(crealf(value[9]) == 4.0f && "Value of p1 does not match."); break; } case COMPLEX_DOUBLE: { - double complex value[9]; + double complex value[10]; graph_copy_to_host(c_context, y, value); graph_copy_to_host(c_context, dydx, value + 1); graph_copy_to_host(c_context, dydm, value + 2); @@ -330,6 +342,7 @@ void run_tests(const enum graph_type type, graph_copy_to_host(c_context, z, value + 6); graph_copy_to_host(c_context, p1, value + 7); graph_copy_to_host(c_context, p2, value + 8); + graph_copy_to_host(c_context, i1, value + 9); assert(creal(value[0]) == 0.5*2.0 + 0.2 && "Value of y does not match."); assert(creal(value[1]) == 0.5 && "Value of dydx does not match."); assert(creal(value[2]) == 2.0 && "Value of dydm does not match."); @@ -343,6 +356,7 @@ void run_tests(const enum graph_type type, assert(creal(value[6]) == 1.0 && "Value of root does not match."); assert(creal(value[7]) == 4.0 && "Value of p1 does not match."); assert(creal(value[8]) == 8.0 && "Value of p2 does not match."); + assert(creal(value[9]) == 4.0 && "Value of p1 does not match."); break; } } diff --git a/graph_tests/f_binding_test.f90 b/graph_tests/f_binding_test.f90 index 1c3d962155e5bf106f897a63261cf82d9d9463f3..ae850f05c559b5122a704a6c44cf5d2c6bb84641 100644 --- a/graph_tests/f_binding_test.f90 +++ b/graph_tests/f_binding_test.f90 @@ -89,6 +89,8 @@ REAL(C_FLOAT), DIMENSION(3,3) :: buffer2D TYPE(C_PTR) :: p2 TYPE(C_PTR) :: j + TYPE(C_PTR) :: variable + TYPE(C_PTR) :: i1 TYPE(C_PTR) :: z TYPE(C_PTR) :: root TYPE(C_PTR) :: root2 @@ -156,6 +158,10 @@ p2 = graph%piecewise_2D(j, 1.0_C_DOUBLE, 0.0_C_DOUBLE, & i, 1.0_C_DOUBLE, 0.0_C_DOUBLE, buffer2D) + variable = graph%variable(3_C_LONG, 'var' // C_NULL_CHAR) + CALL graph%set_variable(variable, buffer1D) + i1 = graph%index_1D(variable, i, 1.0_C_DOUBLE, 0.0_C_DOUBLE) + z = graph%variable(1_C_LONG, 'z' // C_NULL_CHAR) root = graph%sub(graph%pow(z, graph%constant(3.0_C_DOUBLE)), & graph%pow(z, graph%constant(2.0_C_DOUBLE))) @@ -176,8 +182,9 @@ graph_ptr(dydy) & /), graph_null_array, graph_null_array, C_NULL_PTR, & 'f_binding' // C_NULL_CHAR, 1_C_LONG) - CALL graph%add_item((/ graph_ptr(i), graph_ptr(j) /), & - (/ graph_ptr(p1), graph_ptr(p2) /), & + CALL graph%add_item((/ graph_ptr(i), graph_ptr(j), & + graph_ptr(variable) /), & + (/ graph_ptr(p1), graph_ptr(p2), graph_ptr(i1) /), & graph_null_array, graph_null_array, C_NULL_PTR, & 'c_binding_piecewise' // C_NULL_CHAR, 1_C_LONG) CALL graph%add_converge_item((/ graph_ptr(z) /), (/ graph_ptr(root2) /), & @@ -218,6 +225,8 @@ CALL assert(value(1) .eq. 4.0_C_FLOAT, 'Value of p1 does not match.') CALL graph%copy_to_host(p2, value) CALL assert(value(1) .eq. 8.0_C_FLOAT, 'Value of p2 does not match.') + CALL graph%copy_to_host(i1, value) + CALL assert(value(1) .eq. 4.0_C_FLOAT, 'Value of i1 does not match.') DEALLOCATE(graph) @@ -261,6 +270,8 @@ REAL(C_DOUBLE), DIMENSION(3,3) :: buffer2D TYPE(C_PTR) :: p2 TYPE(C_PTR) :: j + TYPE(C_PTR) :: variable + TYPE(C_PTR) :: i1 TYPE(C_PTR) :: z TYPE(C_PTR) :: root TYPE(C_PTR) :: root2 @@ -328,6 +339,10 @@ p2 = graph%piecewise_2D(j, 1.0_C_DOUBLE, 0.0_C_DOUBLE, & i, 1.0_C_DOUBLE, 0.0_C_DOUBLE, buffer2D) + variable = graph%variable(3_C_LONG, 'var' // C_NULL_CHAR) + CALL graph%set_variable(variable, buffer1D) + i1 = graph%index_1D(variable, i, 1.0_C_DOUBLE, 0.0_C_DOUBLE) + z = graph%variable(1_C_LONG, 'z' // C_NULL_CHAR) root = graph%sub(graph%pow(z, graph%constant(3.0_C_DOUBLE)), & graph%pow(z, graph%constant(2.0_C_DOUBLE))) @@ -348,8 +363,9 @@ graph_ptr(dydy) & /), graph_null_array, graph_null_array, C_NULL_PTR, & 'f_binding' // C_NULL_CHAR, 1_C_LONG) - CALL graph%add_item((/ graph_ptr(i), graph_ptr(j) /), & - (/ graph_ptr(p1), graph_ptr(p2) /), & + CALL graph%add_item((/ graph_ptr(i), graph_ptr(j), & + graph_ptr(variable) /), & + (/ graph_ptr(p1), graph_ptr(p2), graph_ptr(i1) /), & graph_null_array, graph_null_array, C_NULL_PTR, & 'c_binding_piecewise' // C_NULL_CHAR, 1_C_LONG) CALL graph%add_converge_item((/ graph_ptr(z) /), (/ graph_ptr(root2) /), & @@ -390,6 +406,8 @@ CALL assert(value(1) .eq. 4.0_C_DOUBLE, 'Value of p1 does not match.') CALL graph%copy_to_host(p2, value) CALL assert(value(1) .eq. 8.0_C_DOUBLE, 'Value of p2 does not match.') + CALL graph%copy_to_host(i1, value) + CALL assert(value(1) .eq. 4.0_C_DOUBLE, 'Value of i1 does not match.') DEALLOCATE(graph) @@ -433,6 +451,8 @@ COMPLEX(C_FLOAT_COMPLEX), DIMENSION(3,3) :: buffer2D TYPE(C_PTR) :: p2 TYPE(C_PTR) :: j + TYPE(C_PTR) :: variable + TYPE(C_PTR) :: i1 TYPE(C_PTR) :: z TYPE(C_PTR) :: root TYPE(C_PTR) :: root2 @@ -504,6 +524,10 @@ p2 = graph%piecewise_2D(j, 1.0_C_DOUBLE, 0.0_C_DOUBLE, & i, 1.0_C_DOUBLE, 0.0_C_DOUBLE, buffer2D) + variable = graph%variable(3_C_LONG, 'var' // C_NULL_CHAR) + CALL graph%set_variable(variable, buffer1D) + i1 = graph%index_1D(variable, i, 1.0_C_DOUBLE, 0.0_C_DOUBLE) + z = graph%variable(1_C_LONG, 'z' // C_NULL_CHAR) root = graph%sub(graph%pow(z, graph%constant(3.0_C_DOUBLE)), & graph%pow(z, graph%constant(2.0_C_DOUBLE))) @@ -524,8 +548,9 @@ graph_ptr(dydy) & /), graph_null_array, graph_null_array, C_NULL_PTR, & 'f_binding' // C_NULL_CHAR, 1_C_LONG) - CALL graph%add_item((/ graph_ptr(i), graph_ptr(j) /), & - (/ graph_ptr(p1), graph_ptr(p2) /), & + CALL graph%add_item((/ graph_ptr(i), graph_ptr(j), & + graph_ptr(variable) /), & + (/ graph_ptr(p1), graph_ptr(p2), graph_ptr(i1) /), & graph_null_array, graph_null_array, C_NULL_PTR, & 'c_binding_piecewise' // C_NULL_CHAR, 1_C_LONG) CALL graph%add_converge_item((/ graph_ptr(z) /), (/ graph_ptr(root2) /), & @@ -573,6 +598,9 @@ CALL graph%copy_to_host(p2, value) CALL assert(REAL(value(1)) .eq. 8.0_C_FLOAT, & 'Value of p2 does not match.') + CALL graph%copy_to_host(i1, value) + CALL assert(REAL(value(1)) .eq. 4.0_C_FLOAT, & + 'Value of i1 does not match.') DEALLOCATE(graph) @@ -616,6 +644,8 @@ COMPLEX(C_DOUBLE_COMPLEX), DIMENSION(3,3) :: buffer2D TYPE(C_PTR) :: p2 TYPE(C_PTR) :: j + TYPE(C_PTR) :: variable + TYPE(C_PTR) :: i1 TYPE(C_PTR) :: z TYPE(C_PTR) :: root TYPE(C_PTR) :: root2 @@ -679,6 +709,7 @@ CMPLX(4.0, 0.0, KIND=C_DOUBLE), & CMPLX(6.0, 0.0, KIND=C_DOUBLE) & /) + p1 = graph%piecewise_1D(i, 1.0_C_DOUBLE, 0.0_C_DOUBLE, buffer1D) j = graph%variable(1_C_LONG, 'j' // C_NULL_CHAR) @@ -698,6 +729,10 @@ p2 = graph%piecewise_2D(j, 1.0_C_DOUBLE, 0.0_C_DOUBLE, & i, 1.0_C_DOUBLE, 0.0_C_DOUBLE, buffer2D) + variable = graph%variable(3_C_LONG, 'var' // C_NULL_CHAR) + CALL graph%set_variable(variable, buffer1D) + i1 = graph%index_1D(variable, i, 1.0_C_DOUBLE, 0.0_C_DOUBLE) + z = graph%variable(1_C_LONG, 'z' // C_NULL_CHAR) root = graph%sub(graph%pow(z, graph%constant(3.0_C_DOUBLE)), & graph%pow(z, graph%constant(2.0_C_DOUBLE))) @@ -718,8 +753,9 @@ graph_ptr(dydy) & /), graph_null_array, graph_null_array, C_NULL_PTR, & 'f_binding' // C_NULL_CHAR, 1_C_LONG) - CALL graph%add_item((/ graph_ptr(i), graph_ptr(j) /), & - (/ graph_ptr(p1), graph_ptr(p2) /), & + CALL graph%add_item((/ graph_ptr(i), graph_ptr(j), & + graph_ptr(variable) /), & + (/ graph_ptr(p1), graph_ptr(p2), graph_ptr(i1) /), & graph_null_array, graph_null_array, C_NULL_PTR, & 'c_binding_piecewise' // C_NULL_CHAR, 1_C_LONG) CALL graph%add_converge_item((/ graph_ptr(z) /), (/ graph_ptr(root2) /), & @@ -768,6 +804,9 @@ CALL graph%copy_to_host(p2, value) CALL assert(DBLE(value(1)) .eq. 8.0_C_DOUBLE, & 'Value of p2 does not match.') + CALL graph%copy_to_host(i1, value) + CALL assert(DBLE(value(1)) .eq. 4.0_C_DOUBLE, & + 'Value of i1 does not match.') DEALLOCATE(graph) diff --git a/graph_tests/piecewise_test.cpp b/graph_tests/piecewise_test.cpp index 904187381099071a5093e76294bbd53b1a574430..5cff52b132c9ddfd101541b58131fa2f4e2191a0 100644 --- a/graph_tests/piecewise_test.cpp +++ b/graph_tests/piecewise_test.cpp @@ -93,6 +93,15 @@ template void piecewise_1D() { static_cast (6.0)}), a, 1.0, 0.0); + auto c = graph::constant (static_cast (2.5)); + auto pconst = graph::piecewise_1D (std::vector ({static_cast (2.0), + static_cast (4.0), + static_cast (6.0)}), + c, 1.0, 0.0); + auto pc_cast = constant_cast(pconst); + assert(pc_cast.get() && "Expected a constant node."); + assert(pc_cast->is(6.0) && "Expected a value of 6"); + assert(graph::constant_cast(p1*0.0).get() && "Expected a constant node."); @@ -304,6 +313,40 @@ template void piecewise_2D() { static_cast (2.0), static_cast (4.0) }), ay, 1.0, 0.0); + auto cx = graph::constant (static_cast (0.5)); + auto cy = graph::constant (static_cast (1.5)); + auto pconst = graph::piecewise_2D (std::vector ({ + static_cast (2.0), static_cast (4.0), + static_cast (6.0), static_cast (10.0) + }), 2, cx, 1.0, 0.0, cy, 1.0, 0.0); + auto pc_cast = constant_cast(pconst); + assert(pc_cast.get() && "Expected a constant node."); + assert(pc_cast->is(4.0) && "Expected a value of 6"); + + auto p1const = graph::piecewise_2D (std::vector ({ + static_cast (2.0), static_cast (4.0), + static_cast (6.0), static_cast (10.0) + }), 2, cx, 1.0, 0.0, ay, 1.0, 0.0); + auto p1c_cast = piecewise_1D_cast(p1const); + assert(p1c_cast.get() && "Expected a piecewise constant."); + backend::buffer buffer = p1c_cast->evaluate(); + assert(buffer[0] == static_cast (2.0) && + "Expected a 2 in the first index."); + assert(buffer[1] == static_cast (4.0) && + "Expected a 4 in the second index."); + + auto p2const = graph::piecewise_2D (std::vector ({ + static_cast (2.0), static_cast (4.0), + static_cast (6.0), static_cast (10.0) + }), 2, ax, 1.0, 0.0, cy, 1.0, 0.0); + auto p2c_cast = piecewise_1D_cast(p2const); + assert(p2c_cast.get() && "Expected a piecewise constant."); + buffer = p2c_cast->evaluate(); + assert(buffer[0] == static_cast (4.0) && + "Expected a 4 in the first index."); + assert(buffer[1] == static_cast (10.0) && + "Expected a 10 in the second index."); + assert(graph::constant_cast(p1*0.0).get() && "Expected a constant node."); @@ -687,6 +730,26 @@ template void piecewise_2D() { "Expected p1 - p3 on the left."); } +//------------------------------------------------------------------------------ +/// @brief Tests for 1D index nodes. +/// +/// @tparam T Base type of the calculation. +//------------------------------------------------------------------------------ +template void index_1D() { + auto variable = graph::variable (11, ""); + auto arg = graph::variable (1, ""); + + auto index = graph::index_1D (variable, arg, 1.0, 0.0); + + variable->set({0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0}); + arg->set(static_cast (3.5)); + + compile ({graph::variable_cast(variable), + graph::variable_cast(arg)}, + {index}, {}, + static_cast (3.0), 0.0); +} + //------------------------------------------------------------------------------ /// @brief Run tests with a specified backend. /// @@ -695,6 +758,7 @@ template void piecewise_2D() { template void run_tests() { piecewise_1D (); piecewise_2D (); + index_1D (); } //------------------------------------------------------------------------------