Commit cfe9453a authored by Cianciosa, Mark's avatar Cianciosa, Mark
Browse files

Reorder Exp to allow elimination.

parent 0d7f9814
Loading
Loading
Loading
Loading
+11 −11
Original line number Diff line number Diff line
@@ -13,7 +13,7 @@

const bool print = true;
const bool write_step = false;
const bool print_expressions = false;
const bool print_expressions = true;

//------------------------------------------------------------------------------
///  @brief Main program of the driver.
@@ -27,9 +27,9 @@ int main(int argc, const char * argv[]) {
    std::mutex sync;

    //typedef float base;
    typedef double base;
    //typedef double base;
    //typedef std::complex<float> base;
    //typedef std::complex<double> base;
    typedef std::complex<double> base;

    const timeing::measure_diagnostic total("Total Time");

@@ -90,8 +90,8 @@ int main(int argc, const char * argv[]) {
            kz->set(static_cast<base> (0.0));


            auto eq = equilibrium::make_efit<base> (NC_FILE, sync);
            //auto eq = equilibrium::make_slab_density<base> ();
            //auto eq = equilibrium::make_efit<base> (NC_FILE, sync);
            auto eq = equilibrium::make_slab_density<base> ();
            //auto eq = equilibrium::make_no_magnetic_field<base> ();

            const base endtime = static_cast<base> (1.0);
@@ -109,9 +109,9 @@ int main(int argc, const char * argv[]) {
            //solver::rk4<dispersion::simple<base>>
            //solver::rk4<dispersion::ordinary_wave<base>>
            //solver::rk4<dispersion::extra_ordinary_wave<base>>
            solver::rk4<dispersion::cold_plasma<base>>
            //solver::rk4<dispersion::cold_plasma<base>>
            //solver::adaptive_rk4<dispersion::ordinary_wave<base>>
            //solver::rk4<dispersion::hot_plasma<base, dispersion::z_erfi<base>>>
            solver::rk4<dispersion::hot_plasma<base, dispersion::z_erfi<base>>>
                solve(omega, kx, ky, kz, x, y, z, t, dt, eq,
                      stream.str(), local_num_rays);
                //solve(omega, kx, ky, kz, x, y, z, t, dt_var, eq,
@@ -119,11 +119,11 @@ int main(int argc, const char * argv[]) {
            solve.init(kx);
            solve.compile();
            if (thread_number == 0 && print_expressions) {
                solve.print_dispersion();
                std::cout << std::endl;
                //solve.print_dispersion();
                //std::cout << std::endl;
                solve.print_dkxdt();
                std::cout << std::endl;
                solve.print_dkydt();
                /*solve.print_dkydt();
                std::cout << std::endl;
                solve.print_dkzdt();
                std::cout << std::endl;
@@ -146,7 +146,7 @@ int main(int argc, const char * argv[]) {
                solve.print_ky_next();
                std::cout << std::endl;
                solve.print_kz_next();
                std::cout << std::endl;
                std::cout << std::endl;*/
            }

            const size_t sample = int_dist(engine);
+59 −6
Original line number Diff line number Diff line
@@ -1078,6 +1078,38 @@ namespace graph {
                }
            }

//  Exp(a)*Exp(b) -> Exp(a + b)
            auto le = exp_cast(this->left);
            auto re = exp_cast(this->right);
            if (le.get() && re.get()) {
                return exp(le->get_arg() + re->get_arg());
            }

//  Exp(a)*(Exp(b)*c) -> (Exp(a)*Exp(b))*c
//  Exp(a)*(c*Exp(b)) -> (Exp(a)*Exp(b))*c
            if (le.get() && rm.get()) {
                auto rmle = exp_cast(rm->get_left());
                if (rmle.get()) {
                    return (this->left*rm->get_left())*rm->get_right();
                }
                auto rmre = exp_cast(rm->get_right());
                if (rmre.get()) {
                    return (this->left*rm->get_right())*rm->get_left();
                }
            }
//  (Exp(a)*c)*Exp(b) -> (Exp(a)*Exp(b))*c
//  (c*Exp(a))*Exp(b) -> (Exp(a)*Exp(b))*c
            if (re.get() && lm.get()) {
                auto lmle = exp_cast(lm->get_left());
                if (lmle.get()) {
                    return (this->right*lm->get_left())*lm->get_right();
                }
                auto lmre = exp_cast(rm->get_right());
                if (lmre.get()) {
                    return (this->right*lm->get_right())*lm->get_left();
                }
            }

            return this->shared_from_this();
        }

@@ -1713,15 +1745,36 @@ namespace graph {
                                              (rm->get_left()/lm->get_left())*rm->get_right());
                }
            }
//  fma(c1*a,b,c2/d) -> c1*(a*b + c1/(c2*d))
//  fma(c1*a,b,d/c2) -> c1*(a*b + d/(c1*c2))

//  Move constant multiplies to the left.
            if (lm.get()) {
                auto lmc = constant_cast(lm->get_left());
                if (lmc.get()) {
                    return fma(lm->get_left(),
                               lm->get_right()*this->middle,
                               this->right);
                }
            } else if (mm.get()) {
                auto mmc = constant_cast(mm->get_left());
                if (mmc.get() && !l.get()) {
                    return fma(mm->get_left(),
                               this->left*mm->get_right(),
                               this->right);
                } else if (mmc.get() && l.get()) {
                    return fma(this->left*mm->get_left(),
                               mm->get_right(),
                               this->right);
                }
            }

//  fma(c1,a,c2/b) -> c1*(a + c1/(c2*b))
//  fma(c1,a,b/c2) -> c1*(a + b/(c1*c2))
            auto rd = divide_cast(this->right);
            if (lm.get() && rd.get()) {
            if (l.get() && rd.get()) {
                if (constant_cast(rd->get_left()).get() ||
                    constant_cast(rd->get_right()).get()) {
                    return lm->get_left()*fma(lm->get_right(),
                                              this->middle,
                                              rd->get_left()/(lm->get_left()*rd->get_right()));
                    return this->left*(this->middle +
                                       rd->get_left()/(this->left*rd->get_right()));
                }
            }

+2 −6
Original line number Diff line number Diff line
@@ -1031,8 +1031,6 @@ namespace dispersion {
            auto p_func = one - P;

            auto gamma5 = P*(n2*npara2 - (one - q)*n_func + q_func);
            //auto gamma4 = P*(two*q_func - n_func);
            //auto gamma3 = P*w*w/(four*ec*ec)*nperp2/npara2*(n_func - two*q_func);
            auto gamma2 = P*w/ec*(n2nperp2 - q_func*nperp2)
                        + P*P*w*w/(four*ec*ec)*(n_func - two*q_func)*nperp2/npara2;
            auto gamma1 = (one - q)*n2nperp2 + p_func*n2*npara2
@@ -1041,9 +1039,7 @@ namespace dispersion {

            auto zeta_func = one + zeta*Z_func;

            return isigma*gamma0 + gamma1 + gamma2*zeta_func +
//                   gamma3*zeta*Z_func*zeta_func + gamma4*isigma*F +
                   gamma5*F;
            return isigma*gamma0 + gamma1 + gamma2*zeta_func + gamma5*F;
        }
    };

+1 −1
Original line number Diff line number Diff line
@@ -1003,7 +1003,7 @@ namespace graph {
        virtual void to_latex() const {
            std::cout << "erfi\\left(";
            this->arg->to_latex();
            std::cout << "\\right)}";
            std::cout << "\\right)";
        }

//------------------------------------------------------------------------------
+3 −3
Original line number Diff line number Diff line
@@ -1202,9 +1202,9 @@ namespace graph {
///  @brief Convert the node to latex.
//------------------------------------------------------------------------------
        virtual void to_latex() const {
            std::cout << "(";
            std::cout << "\\left(";
            this->arg->to_latex();
            std::cout << ")";
            std::cout << "\\right)";
        }

//------------------------------------------------------------------------------
@@ -1284,7 +1284,7 @@ namespace graph {
///  @returns An attemped dynamic case.
//------------------------------------------------------------------------------
    template<typename T>
    shared_pseudo_variable<T> pseudo_variable_cast(shared_leaf<T> x) {
    shared_pseudo_variable<T> pseudo_variable_cast(shared_leaf<T> &x) {
        return std::dynamic_pointer_cast<pseudo_variable_node<T>> (x);
    }
}
Loading