Commit 4bcadc98 authored by Cianciosa, Mark's avatar Cianciosa, Mark
Browse files

Add reductions for piecewise 2D nodes for with constant arguments.

parent 64f41d1c
Loading
Loading
Loading
Loading
+32 −0
Original line number Diff line number Diff line
@@ -292,6 +292,38 @@ namespace backend {
            return true;
        }

//------------------------------------------------------------------------------
///  @brief Index row.
///
///  @param[in] index       The row index.
///  @param[in] num_columns The number of coils.
///  @returns A buffer containing the row.
//------------------------------------------------------------------------------
        buffer<T> index_row(const size_t index, const size_t num_columns) {
            buffer<T> b(num_columns);
            const size_t num_rows = size()/num_columns;
            for (size_t j = 0; j < num_columns; j++) {
                b[j] = memory[index*num_rows + j];
            }
            return b;
        }

//------------------------------------------------------------------------------
///  @brief Index column.
///
///  @param[in] index       The row index.
///  @param[in] num_columns The number of coils.
///  @returns A buffer containing the row.
//------------------------------------------------------------------------------
        buffer<T> index_column(const size_t index, const size_t num_columns) {
            const size_t num_rows = size()/num_columns;
            buffer<T> b(num_rows);
            for (size_t i = 0; i < num_rows; i++) {
                b[i] = memory[i*num_rows + index];
            }
            return b;
        }

//------------------------------------------------------------------------------
///  @brief Add row operation.
///
+26 −1
Original line number Diff line number Diff line
@@ -191,7 +191,7 @@ void compile_index(std::ostringstream &stream,
//------------------------------------------------------------------------------
        virtual shared_leaf<T, SAFE_MATH> reduce() {
            if (constant_cast(this->arg).get()) {
                const T arg = (this->evaluate().at(0) + offset)/scale;
                const T arg = (this->arg->evaluate().at(0) + offset)/scale;
                const size_t i = std::min(static_cast<size_t> (std::real(arg)),
                                          this->get_size() - 1);
                return constant<T, SAFE_MATH> (leaf_node<T, SAFE_MATH>::caches.backends[data_hash][i]);
@@ -831,6 +831,31 @@ void compile_index(std::ostringstream &stream,
///  @returns A reduced representation of the node.
//------------------------------------------------------------------------------
        virtual shared_leaf<T, SAFE_MATH> reduce() {
            if (constant_cast(this->left).get() &&
                constant_cast(this->right).get()) {
                const T l = (this->left->evaluate().at(0) + x_offset)/x_scale;
                const T r = (this->right->evaluate().at(0) + y_offset)/y_scale;
                const size_t i = std::min(static_cast<size_t> (std::real(l)),
                                          this->get_num_rows() - 1);
                const size_t j = std::min(static_cast<size_t> (std::real(r)),
                                          this->get_num_columns() - 1);
                return constant<T, SAFE_MATH> (leaf_node<T, SAFE_MATH>::caches.backends[data_hash][i*this->get_num_columns() + j]);
            } else if (constant_cast(this->left).get()) {
                const T l = (this->left->evaluate().at(0) + x_offset)/x_scale;
                const size_t i = std::min(static_cast<size_t> (std::real(l)),
                                          this->get_num_rows() - 1);
                
                return piecewise_1D(leaf_node<T, SAFE_MATH>::caches.backends[data_hash].index_row(i, this->get_num_columns()),
                                    this->right, y_scale, y_offset);
            } else if (constant_cast(this->right).get()) {
                const T r = (this->right->evaluate().at(0) + y_offset)/y_scale;
                const size_t j = std::min(static_cast<size_t> (std::real(r)),
                                          this->get_num_columns() - 1);
                
                return piecewise_1D(leaf_node<T, SAFE_MATH>::caches.backends[data_hash].index_column(j, this->get_num_columns()),
                                    this->left, x_scale, x_offset);
            }

            if (evaluate().is_same()) {
                return constant<T, SAFE_MATH> (evaluate().at(0));
            }
+34 −0
Original line number Diff line number Diff line
@@ -313,6 +313,40 @@ template<jit::float_scalar T> void piecewise_2D() {
        static_cast<T> (2.0), static_cast<T> (4.0)
    }), ay, 1.0, 0.0);

    auto cx = graph::constant<T> (static_cast<T> (0.5));
    auto cy = graph::constant<T> (static_cast<T> (1.5));
    auto pconst = graph::piecewise_2D<T> (std::vector<T> ({
        static_cast<T> (2.0), static_cast<T> (4.0),
        static_cast<T> (6.0), static_cast<T> (10.0)
    }), 2, cx, 1.0, 0.0, cy, 1.0, 0.0);
    auto pc_cast = constant_cast(pconst);
    assert(pc_cast.get() && "Expected a constant node.");
    assert(pc_cast->is(4.0) && "Expected a value of 6");

    auto p1const = graph::piecewise_2D<T> (std::vector<T> ({
        static_cast<T> (2.0), static_cast<T> (4.0),
        static_cast<T> (6.0), static_cast<T> (10.0)
    }), 2, cx, 1.0, 0.0, ay, 1.0, 0.0);
    auto p1c_cast = piecewise_1D_cast(p1const);
    assert(p1c_cast.get() && "Expected a piecewise constant.");
    backend::buffer<T> buffer = p1c_cast->evaluate();
    assert(buffer[0] == static_cast<T> (2.0) &&
           "Expected a 2 in the first index.");
    assert(buffer[1] == static_cast<T> (4.0) &&
           "Expected a 4 in the second index.");

    auto p2const = graph::piecewise_2D<T> (std::vector<T> ({
        static_cast<T> (2.0), static_cast<T> (4.0),
        static_cast<T> (6.0), static_cast<T> (10.0)
    }), 2, ax, 1.0, 0.0, cy, 1.0, 0.0);
    auto p2c_cast = piecewise_1D_cast(p2const);
    assert(p2c_cast.get() && "Expected a piecewise constant.");
    buffer = p2c_cast->evaluate();
    assert(buffer[0] == static_cast<T> (4.0) &&
           "Expected a 4 in the first index.");
    assert(buffer[1] == static_cast<T> (10.0) &&
           "Expected a 10 in the second index.");

    assert(graph::constant_cast(p1*0.0).get() &&
           "Expected a constant node.");