Commit 16ff4e20 authored by Cianciosa, Mark's avatar Cianciosa, Mark
Browse files

Fix speed regression and forward declare constants so they can be used...

Fix speed regression and forward declare constants so they can be used directly within constant_node methods.
parent f9035323
Loading
Loading
Loading
Loading
+37 −37
Original line number Diff line number Diff line
@@ -13,7 +13,7 @@

const bool print = true;
const bool write_step = false;
const bool print_expressions = true;
const bool print_expressions = false;

//------------------------------------------------------------------------------
///  @brief Main program of the driver.
@@ -80,22 +80,22 @@ int main(int argc, const char * argv[]) {
                }
            }

            //x->set(static_cast<base> (2.5));
            x->set(static_cast<base> (0.0));
            x->set(static_cast<base> (2.5));
            //x->set(static_cast<base> (9.0));
            y->set(static_cast<base> (0.0));
            z->set(static_cast<base> (0.0));
            //kx->set(static_cast<base> (-600.0));
            kx->set(static_cast<base> (600.0));
            kx->set(static_cast<base> (-600.0));
            //kx->set(static_cast<base> (600.0));
            ky->set(static_cast<base> (0.0));
            kz->set(static_cast<base> (0.0));


            //auto eq = equilibrium::make_efit<base> (NC_FILE, sync);
            auto eq = equilibrium::make_slab_density<base> ();
            auto eq = equilibrium::make_efit<base> (NC_FILE, sync);
            //auto eq = equilibrium::make_slab_density<base> ();
            //auto eq = equilibrium::make_no_magnetic_field<base> ();

            //const base endtime = static_cast<base> (4.0);
            const base endtime = static_cast<base> (10.0);
            const base endtime = static_cast<base> (.0);
            //const base endtime = static_cast<base> (10.0);
            const base dt = endtime/static_cast<base> (num_times);

            //auto dt_var = graph::variable(num_rays, static_cast<base> (dt), "dt");
@@ -119,34 +119,34 @@ int main(int argc, const char * argv[]) {
            solve.init(kx);
            solve.compile();
            if (thread_number == 0 && print_expressions) {
                //solve.print_dispersion();
                //std::cout << std::endl;
                //solve.print_dkxdt();
                //std::cout << std::endl;
                //solve.print_dkydt();
                //std::cout << std::endl;
                //solve.print_dkzdt();
                //std::cout << std::endl;
                //solve.print_dxdt();
                //std::cout << std::endl;
                //solve.print_dydt();
                //std::cout << std::endl;
                //solve.print_dzdt();
                //std::cout << std::endl;
                //solve.print_residule();
                //std::cout << std::endl;
                //solve.print_x_next();
                //std::cout << std::endl;
                //solve.print_y_next();
                //std::cout << std::endl;
                //solve.print_z_next();
                //std::cout << std::endl;
                //solve.print_kx_next();
                //std::cout << std::endl;
                //solve.print_ky_next();
                //std::cout << std::endl;
                //solve.print_kz_next();
                //std::cout << std::endl;
                solve.print_dispersion();
                std::cout << std::endl;
                solve.print_dkxdt();
                std::cout << std::endl;
                solve.print_dkydt();
                std::cout << std::endl;
                solve.print_dkzdt();
                std::cout << std::endl;
                solve.print_dxdt();
                std::cout << std::endl;
                solve.print_dydt();
                std::cout << std::endl;
                solve.print_dzdt();
                std::cout << std::endl;
                solve.print_residule();
                std::cout << std::endl;
                solve.print_x_next();
                std::cout << std::endl;
                solve.print_y_next();
                std::cout << std::endl;
                solve.print_z_next();
                std::cout << std::endl;
                solve.print_kx_next();
                std::cout << std::endl;
                solve.print_ky_next();
                std::cout << std::endl;
                solve.print_kz_next();
                std::cout << std::endl;
            }

            const size_t sample = int_dist(engine);
+14 −37
Original line number Diff line number Diff line
@@ -240,6 +240,13 @@ namespace graph {
    template<typename T>
    using output_nodes = std::vector<shared_leaf<T>>;

///  Forward declare for zero.
    template<typename T>
    constexpr shared_leaf<T> zero();
///  Forward declare for one.
    template<typename T>
    constexpr shared_leaf<T> one();

//******************************************************************************
//  Constant node.
//******************************************************************************
@@ -264,14 +271,6 @@ namespace graph {

    public:
//------------------------------------------------------------------------------
///  @brief Construct a constant node from a scalar.
///
///  @params[in] d Scalar data to initalize.
//------------------------------------------------------------------------------
        constant_node(const T d) :
        leaf_node<T> (constant_node<T>::to_string(d), 1), data(1, d) {}

//------------------------------------------------------------------------------
///  @brief Construct a constant node from a vector.
///
///  @params[in] d Array buffer.
@@ -308,18 +307,7 @@ namespace graph {
///  @returns The derivative of the node.
//------------------------------------------------------------------------------
        virtual shared_leaf<T> df(shared_leaf<T> x) {
            auto zero = std::make_shared<constant_node<T>> (static_cast<T> (0.0));
//  Test for hash collisions.
            for (size_t i = zero->get_hash(); i < std::numeric_limits<size_t>::max(); i++) {
                if (leaf_node<T>::cache.find(i) ==
                    leaf_node<T>::cache.end()) {
                    leaf_node<T>::cache[i] = zero;
                    return zero;
                } else if (zero->is_match(leaf_node<T>::cache[i])) {
                    return leaf_node<T>::cache[i];
                }
            }
            assert(false && "Should never reach.");
            return zero<T> ();
        }

//------------------------------------------------------------------------------
@@ -422,18 +410,18 @@ namespace graph {
///  @returns The exponent of a power like node.
//------------------------------------------------------------------------------
        virtual shared_leaf<T> get_power_exponent() const {
            return std::make_shared<constant_node<T>> (static_cast<T> (1.0));
            return one<T> ();
        }
    };

//------------------------------------------------------------------------------
///  @brief Construct a constant.
///
///  @params[in] d Scalar data to initalize.
///  @params[in] d Array buffer.
///  @returns A reduced constant node.
//------------------------------------------------------------------------------
    template<typename T>
    shared_leaf<T> constant(const T d) {
    shared_leaf<T> constant(const backend::buffer<T> &d) {
        auto temp = std::make_shared<constant_node<T>> (d);
//  Test for hash collisions.
        for (size_t i = temp->get_hash(); i < std::numeric_limits<size_t>::max(); i++) {
@@ -451,23 +439,12 @@ namespace graph {
//------------------------------------------------------------------------------
///  @brief Construct a constant.
///
///  @params[in] d Array buffer.
///  @params[in] d Scalar data to initalize.
///  @returns A reduced constant node.
//------------------------------------------------------------------------------
    template<typename T>
    shared_leaf<T> constant(const backend::buffer<T> &d) {
        auto temp = std::make_shared<constant_node<T>> (d);
//  Test for hash collisions.
        for (size_t i = temp->get_hash(); i < std::numeric_limits<size_t>::max(); i++) {
            if (leaf_node<T>::cache.find(i) ==
                leaf_node<T>::cache.end()) {
                leaf_node<T>::cache[i] = temp;
                return temp;
            } else if (temp->is_match(leaf_node<T>::cache[i])) {
                return leaf_node<T>::cache[i];
            }
        }
        assert(false && "Should never reach.");
    shared_leaf<T> constant(const T d) {
        return constant(backend::buffer<T> (1, d));
    }

//  Define some common constants.
+7 −58
Original line number Diff line number Diff line
@@ -238,14 +238,10 @@ namespace graph {
//------------------------------------------------------------------------------
        virtual bool is_match(shared_leaf<T> x) {
            auto x_cast = piecewise_1D_cast(x);

            if (x_cast.get()) {
                if (this->arg.get() && x_cast->get_arg().get()) {
                    return this->evaluate() == x->evaluate() &&
                return this->data_hash == x_cast->data_hash &&
                       this->arg->is_match(x_cast->get_arg());
                } else {
                    return this->arg.get() == x_cast->get_arg().get() &&
                           this->evaluate() == x->evaluate();
                }
            }

            return false;
@@ -342,24 +338,6 @@ namespace graph {
        return std::dynamic_pointer_cast<piecewise_1D_node<T>> (x);
    }

//------------------------------------------------------------------------------
///  @brief Set the argument of a piecewise\_1D node.
///
///  Piecewise functions could be reduced to a single constant so we need to
///  check if it can be cast.
///
///  @params[in] p Existing piecewise constant.
///  @params[in] x Argument.
///  @returns The 1D piecewise constant with the argument set.
//------------------------------------------------------------------------------
    template<typename T> shared_leaf<T> piecewise_1D(shared_leaf<T> p,
                                                     shared_leaf<T> x) {
        if (piecewise_1D_cast(p).get()) {
            return piecewise_1D(p->evaluate(), x);
        }
        return p;
    }

//******************************************************************************
//  2D Piecewise node.
//******************************************************************************
@@ -639,17 +617,11 @@ namespace graph {
//------------------------------------------------------------------------------
        virtual bool is_match(shared_leaf<T> x) {
            auto x_cast = piecewise_2D_cast(x);

            if (x_cast.get()) {
                if (this->left.get() &&
                    x_cast->get_left().get()) {
                    return this->evaluate() == x->evaluate()        &&
                return this->data_hash == x_cast->data_hash     &&
                       this->left->is_match(x_cast->get_left()) &&
                       this->right->is_match(x_cast->get_right());
                } else {
                    return this->left.get() == x_cast->get_left().get()   &&
                           this->right.get() == x_cast->get_right().get() &&
                           this->evaluate() == x->evaluate();
                }
            }

            return false;
@@ -751,29 +723,6 @@ namespace graph {
    shared_piecewise_2D<T> piecewise_2D_cast(shared_leaf<T> x) {
        return std::dynamic_pointer_cast<piecewise_2D_node<T>> (x);
    }

//------------------------------------------------------------------------------
///  @brief Set the argument of a piecewise\_1D node.
///
///  Piecewise functions could be reduced to a single constant so we need to
///  check if it can be cast.
///
///  @params[in] p Existing piecewise constant.
///  @params[in] x X Argument.
///  @params[in] y Y Argument.
///  @returns The 1D piecewise constant with the argument set.
//------------------------------------------------------------------------------
    template<typename T> shared_leaf<T> piecewise_2D(shared_leaf<T> p,
                                                     shared_leaf<T> x,
                                                     shared_leaf<T> y) {
        auto temp = piecewise_2D_cast(p);
        if (temp.get()) {
            return piecewise_2D(temp->evaluate(),
                                temp->get_num_columns(),
                                x, y);
        }
        return p;
    }
}

#endif /* piecewise_h */
+111 −54
Original line number Diff line number Diff line
@@ -140,20 +140,10 @@ template<typename T> void test_add() {
                         graph::pow(var_d/var_b,var_c);
    assert(graph::divide_cast(common_power3) && "Expected Divide node.");

//  v1 + -c*v2 -> v1 - c*v2
//    auto negate = var_a + graph::constant(static_cast<T> (-2.0))*var_b;
//    assert(graph::add_cast(negate).get() && "Expected add node.");

//  v1 + -1*v2 -> v1 - v2
    auto add_neg = var_a + graph::none<T> ()*var_b;
    assert(graph::subtract_cast(add_neg).get() && "Expected subtract node.");

//  -c1*v1 + v2 -> v2 - c*v1
//    auto negate2 = graph::constant(static_cast<T> (-2.0))*var_a + var_b;
//    auto negate2_cast = graph::subtract_cast(negate2);
//    assert(negate2_cast.get() && "Expected subtract node.");
//    assert(negate2_cast->get_left()->is_match(var_b) && "Expected var_b.");

//  (c1*v1 + c2) + (c3*v1 + c4) -> c5*v1 + c6
    auto var_e = graph::variable<T> (1, "");
    auto addfma1 = graph::fma(var_b, var_a, var_d)
@@ -255,6 +245,24 @@ template<typename T> void test_add() {
    auto muliply_divide_factor4 = var_a/(var_c*var_b) + var_d/(var_c*var_e);
    auto muliply_divide_factor_cast4 = divide_cast(muliply_divide_factor4);
    assert(muliply_divide_factor_cast4.get() && "Expected divide node.");

//  Test node properties.
    assert(three->is_constant_like() && "Expected a constant.");
    assert(!three->is_all_variables() && "Did not expect a variable.");
    assert(three->is_power_like() && "Expected a power like.");
    auto constant_add = three + graph::piecewise_1D<T> (std::vector<T> ({static_cast<T> (1.0),
                                                                         static_cast<T> (2.0)}), var_a);
    assert(constant_add->is_constant_like() && "Expected a constant.");
    assert(!constant_add->is_all_variables() && "Did not expect a variable.");
    assert(!constant_add->is_power_like() && "Expected a power like.");
    auto constant_var_add = three + var_a;
    assert(!constant_var_add->is_constant_like() && "Did not expect a constant.");
    assert(!constant_var_add->is_all_variables() && "Did not expect a variable.");
    assert(!constant_var_add->is_power_like() && "Did not expect a power like.");
    auto var_var_add = var_a + variable;
    assert(!var_var_add->is_constant_like() && "Did not expect a constant.");
    assert(var_var_add->is_all_variables() && "Expected a variable.");
    assert(!var_var_add->is_power_like() && "Did not expect a power like.");
}

//------------------------------------------------------------------------------
@@ -538,6 +546,24 @@ template<typename T> void test_subtract() {
    assert(chained_subtract_divide_cast2.get() && "Expected subtract node.");
    assert(graph::fma_cast(chained_subtract_divide_cast2->get_right()).get() &&
           "Expected a fused multiply add node on the left.");

//  Test node properties.
    assert(zero->is_constant_like() && "Expected a constant.");
    assert(!zero->is_all_variables() && "Did not expect a variable.");
    assert(zero->is_power_like() && "Expected a power like.");
    auto constant_sub = one - graph::piecewise_1D<T> (std::vector<T> ({static_cast<T> (1.0),
                                                                       static_cast<T> (2.0)}), var_a);
    assert(constant_sub->is_constant_like() && "Expected a constant.");
    assert(!constant_sub->is_all_variables() && "Did not expect a variable.");
    assert(!constant_sub->is_power_like() && "Expected a power like.");
    auto constant_var_sub = one - var_a;
    assert(!constant_var_sub->is_constant_like() && "Did not expect a constant.");
    assert(!constant_var_sub->is_all_variables() && "Did not expect a variable.");
    assert(!constant_var_sub->is_power_like() && "Did not expect a power like.");
    auto var_var_sub = var_a - var_b;
    assert(!var_var_sub->is_constant_like() && "Did not expect a constant.");
    assert(var_var_sub->is_all_variables() && "Expected a variable.");
    assert(!var_var_sub->is_power_like() && "Did not expect a power like.");
}

//------------------------------------------------------------------------------
@@ -948,6 +974,24 @@ template<typename T> void test_multiply() {
           "Expected add cast on the left.");
    assert(graph::pow_cast(common_base_cast4->get_right()).get() &&
           "Expected power cast on the right.");

//  Test node properties.
    assert(two_times_three->is_constant_like() && "Expected a constant.");
    assert(!two_times_three->is_all_variables() && "Did not expect a variable.");
    assert(two_times_three->is_power_like() && "Expected a power like.");
    auto constant_mul = three*graph::piecewise_1D<T> (std::vector<T> ({static_cast<T> (1.0),
                                                                       static_cast<T> (2.0)}), variable);
    assert(constant_mul->is_constant_like() && "Expected a constant.");
    assert(!constant_mul->is_all_variables() && "Did not expect a variable.");
    assert(!constant_mul->is_power_like() && "Expected a power like.");
    auto constant_var_mul = three*variable;
    assert(!constant_var_mul->is_constant_like() && "Did not expect a constant.");
    assert(!constant_var_mul->is_all_variables() && "Did not expect a variable.");
    assert(!constant_var_mul->is_power_like() && "Did not expect a power like.");
    auto var_var_mul = variable*a;
    assert(!var_var_mul->is_constant_like() && "Did not expect a constant.");
    assert(var_var_mul->is_all_variables() && "Expected a variable.");
    assert(!var_var_mul->is_power_like() && "Did not expect a power like.");
}

//------------------------------------------------------------------------------
@@ -1347,6 +1391,24 @@ template<typename T> void test_divide() {
    auto common_power2 = (graph::pow(a, three)*variable)/graph::pow(a, two);
    assert(graph::multiply_cast(common_power2).get() &&
           "Expected a multiply node.");

//  Test node properties.
    assert(two_divided_three->is_constant_like() && "Expected a constant.");
    assert(!two_divided_three->is_all_variables() && "Did not expect a variable.");
    assert(two_divided_three->is_power_like() && "Expected a power like.");
    auto constant_div = two_divided_three/graph::piecewise_1D<T> (std::vector<T> ({static_cast<T> (1.0),
                                                                                   static_cast<T> (2.0)}), variable);
    assert(constant_div->is_constant_like() && "Expected a constant.");
    assert(!constant_div->is_all_variables() && "Did not expect a variable.");
    assert(!constant_div->is_power_like() && "Expected a power like.");
    auto constant_var_div = two_divided_three/variable;
    assert(!constant_var_div->is_constant_like() && "Did not expect a constant.");
    assert(!constant_var_div->is_all_variables() && "Did not expect a variable.");
    assert(!constant_var_div->is_power_like() && "Did not expect a power like.");
    auto var_var_div = variable/a;
    assert(!var_var_div->is_constant_like() && "Did not expect a constant.");
    assert(var_var_div->is_all_variables() && "Expected a variable.");
    assert(!var_var_div->is_power_like() && "Did not expect a power like.");
}

//------------------------------------------------------------------------------
@@ -1584,34 +1646,29 @@ template<typename T> void test_fma() {
    auto divide_factor4 = graph::fma(var_c/var_b, var_a, var_d/var_b);
    assert(graph::divide_cast(divide_factor4).get() &&
           "Expetced a divide node.");
}

//------------------------------------------------------------------------------
///  @brief Tests function for variable like expressions.
//------------------------------------------------------------------------------
template<typename T> void test_variable_like() {
    auto a = graph::variable<T> (1, "");
    auto c = graph::one<T> ();
    
    assert(a->is_all_variables() && "Expected a to be variable like.");
    assert(graph::sqrt(a)->is_all_variables() &&
           "Expected sqrt(a) to be variable like.");
    assert(graph::pow(a, c)->is_all_variables() &&
           "Expected a^c to be variable like.");
    
    assert(!c->is_all_variables() &&
           "Expected c to not be variable like.");
    assert(!graph::sqrt(c)->is_all_variables() &&
           "Expected sqrt(c) to not be variable like.");
    assert(!graph::pow(c, a)->is_all_variables() &&
           "Expected c^a to not be variable like.");
//  Test node properties.
    assert(one_two_three->is_constant_like() && "Expected a constant.");
    assert(!one_two_three->is_all_variables() && "Did not expect a variable.");
    assert(one_two_three->is_power_like() && "Expected a power like.");
    auto constant_fma = graph::fma(one_two_three, graph::piecewise_1D<T> (std::vector<T> ({static_cast<T> (1.0),
                                                                                           static_cast<T> (2.0)}), var_a), one);
    assert(!constant_fma->is_all_variables() && "Did not expect a variable.");
    assert(!constant_fma->is_power_like() && "Expected a power like.");
    auto constant_var_fma = graph::fma(var_a, var_b, one);
    assert(!constant_var_fma->is_constant_like() && "Did not expect a constant.");
    assert(!constant_var_fma->is_all_variables() && "Did not expect a variable.");
    assert(!constant_var_fma->is_power_like() && "Did not expect a power like.");
    auto var_var_fma = graph::fma(var_a, var_b, var_c);
    assert(!var_var_fma->is_constant_like() && "Did not expect a constant.");
    assert(var_var_fma->is_all_variables() && "Expected a variable.");
    assert(!var_var_fma->is_power_like() && "Did not expect a power like.");
}

//------------------------------------------------------------------------------
///  @brief Run tests with a specified backend.
//------------------------------------------------------------------------------
template<typename T> void run_tests() {
    test_variable_like<T> ();
    test_add<T> ();
    test_subtract<T> ();
    test_multiply<T> ();
+63 −0

File changed.

Preview size limit exceeded, changes collapsed.

Loading