Commit 980a1a86 authored by Cianciosa, Mark's avatar Cianciosa, Mark
Browse files

Reduce piecewise nodes inside fma nodes by making it an explicit multiply.

parent a8d259ca
Loading
Loading
Loading
Loading
+10 −8
Original line number Diff line number Diff line
@@ -27,11 +27,11 @@ int main(int argc, const char * argv[]) {
    std::mutex sync;

    //typedef float base;
    //typedef double base;
    typedef double base;
    //typedef std::complex<float> base;
    typedef std::complex<double> base;
    constexpr bool use_safe_math = true;
    //constexpr bool use_safe_math = false;
    //typedef std::complex<double> base;
    //constexpr bool use_safe_math = true;
    constexpr bool use_safe_math = false;

    const timeing::measure_diagnostic total("Total Time");

@@ -91,15 +91,17 @@ int main(int argc, const char * argv[]) {
            kx->set(static_cast<base> (-600));
            //kx->set(static_cast<base> (600.0));
            ky->set(static_cast<base> (0.0));
            kz->set(static_cast<base> (10.0));
            kz->set(static_cast<base> (0.0));
            //kz->set(static_cast<base> (10.0));

            auto eq = equilibrium::make_efit<base, use_safe_math> (NC_FILE, sync);
            //auto eq = equilibrium::make_slab_density<base, use_safe_math> ();
            //auto eq = equilibrium::make_slab_field<base, use_safe_math> ();
            //auto eq = equilibrium::make_no_magnetic_field<base, use_safe_math> ();

            const base endtime = static_cast<base> (1.0);
            //const base endtime = static_cast<base> (10.0);
            const base endtime = static_cast<base> (0.25);
            //const base endtime = static_cast<base> (0.25);
            const base dt = endtime/static_cast<base> (num_times);

            //auto dt_var = graph::variable(num_rays, static_cast<base> (dt), "dt");
@@ -113,9 +115,9 @@ int main(int argc, const char * argv[]) {
            //solver::rk4<dispersion::simple<base, use_safe_math>>
            //solver::rk4<dispersion::ordinary_wave<base, use_safe_math>>
            //solver::rk4<dispersion::extra_ordinary_wave<base, use_safe_math>>
            //solver::rk4<dispersion::cold_plasma<base, use_safe_math>>
            solver::rk4<dispersion::cold_plasma<base, use_safe_math>>
            //solver::adaptive_rk4<dispersion::ordinary_wave<base, use_safe_math>>
            solver::rk4<dispersion::hot_plasma<base, dispersion::z_erfi<base, use_safe_math>, use_safe_math>>
            //solver::rk4<dispersion::hot_plasma<base, dispersion::z_erfi<base, use_safe_math>, use_safe_math>>
            //solver::rk4<dispersion::hot_plasma_expandion<base, dispersion::z_erfi<base, use_safe_math>, use_safe_math>>
                solve(omega, kx, ky, kz, x, y, z, t, dt, eq,
                      stream.str(), local_num_rays);
+17 −25
Original line number Diff line number Diff line
@@ -1889,26 +1889,14 @@ namespace graph {

            auto pl1 = piecewise_1D_cast(this->left);
            auto pm1 = piecewise_1D_cast(this->middle);

            if (pl1.get() && (m.get() || pl1->is_arg_match(this->middle))) {
                return piecewise_1D(this->evaluate(), pl1->get_arg()) + this->right;
            } else if (pm1.get() && (m.get() || pm1->is_arg_match(this->left))) {
                return piecewise_1D(this->evaluate(), pm1->get_arg()) + this->right;
            }

            auto pl2 = piecewise_2D_cast(this->left);
            auto pm2 = piecewise_2D_cast(this->middle);

            if (pl2.get() && (m.get() || pl2->is_arg_match(this->right))) {
                return piecewise_2D(this->evaluate(),
                                    pl2->get_num_columns(),
                                    pl2->get_left(),
                                    pl2->get_right()) + this->right;
            } else if (pm2.get() && (l.get() || pm2->is_arg_match(this->left))) {
                return piecewise_2D(this->evaluate(),
                                    pm2->get_num_columns(),
                                    pm2->get_left(),
                                    pm2->get_right()) + this->right;
            if ((pl1.get() && (m.get() || pl1->is_arg_match(this->middle))) ||
                (pm1.get() && (l.get() || pm1->is_arg_match(this->left)))   ||
                (pl2.get() && (m.get() || pl2->is_arg_match(this->middle))) ||
                (pm2.get() && (l.get() || pm2->is_arg_match(this->left)))) {
                return (this->left*this->middle) + this->right;
            }

//  Common factor reduction. If the left and right are both multiply nodes check
@@ -1966,14 +1954,18 @@ namespace graph {
                }
            } else if (mm.get()) {
                auto mmc = constant_cast(mm->get_left());
                if (mmc.get() && !l.get()) {
                    return fma(mm->get_left(),
                               this->left*mm->get_right(),
                               this->right);
                } else if (mmc.get() && l.get()) {
                auto mmpw1c = piecewise_1D_cast(mm->get_left());
                auto mmpw2c = piecewise_2D_cast(mm->get_left());
                if (mmc.get() || mmpw1c.get() || mmpw2c.get()) {
                    if (l.get() || pl1.get() || pl2.get()) {
                        return fma(this->left*mm->get_left(),
                                   mm->get_right(),
                                   this->right);
                    } else {
                        return fma(mm->get_left(),
                                   this->left*mm->get_right(),
                                   this->right);
                    }
                }
            }

+20 −0
Original line number Diff line number Diff line
@@ -1770,6 +1770,26 @@ template<typename T> void test_fma() {
    assert(constant_move2_cast.get() && "Expected an fma cast");
    assert(graph::constant_cast(constant_move2_cast->get_left()) &&
           "Expected a constant on the left.");
    
//  fma(c, pwc*v, d) -> fma(pwc, v, d)
    auto piecewise1 = graph::fma<T> (two,
                                     graph::piecewise_1D<T> (std::vector<T> ({static_cast<T> (1.0),
                                                                              static_cast<T> (2.0)}),
                                                             var_a)*var_a,
                                     var_b);
    auto piecewise1_cast = graph::fma_cast(piecewise1);
    assert(piecewise1_cast.get() && "Expected a fma node.");
    assert(graph::piecewise_1D_cast(piecewise1_cast->get_left()) &&
           "Expected a piecewise_1D node.");
    auto piecewise2 = graph::fma<T> (two,
                                     graph::piecewise_2D<T> (std::vector<T> ({static_cast<T> (1.0),
                                                                              static_cast<T> (2.0)}),
                                                             1, var_a, var_b)*var_a,
                                     var_b);
    auto piecewise2_cast = graph::fma_cast(piecewise2);
    assert(piecewise2_cast.get() && "Expected a fma node.");
    assert(graph::piecewise_2D_cast(piecewise2_cast->get_left()) &&
           "Expected a piecewise_2D node.");
}

//------------------------------------------------------------------------------