Commit f0f88666 authored by cianciosa's avatar cianciosa
Browse files

Add exp reductions WIP.

parent 86c52be9
Loading
Loading
Loading
Loading
+3 −3
Original line number Diff line number Diff line
@@ -331,13 +331,13 @@ void calculate_power(const size_t num_times,
            //auto eq = equilibrium::make_slab_field<T, SAFE_MATH> ();
            //auto eq = equilibrium::make_no_magnetic_field<T, SAFE_MATH> ();

            //absorption::root_finder<dispersion::hot_plasma<T, dispersion::z_erfi<T, SAFE_MATH>, SAFE_MATH>>
            absorption::weak_damping<T, SAFE_MATH>
            absorption::root_finder<dispersion::hot_plasma<T, dispersion::z_erfi<T, SAFE_MATH>, SAFE_MATH>>
            //absorption::weak_damping<T, SAFE_MATH>
                power(kamp, omega, kx, ky, kz, x, y, z, t, eq,
                      stream.str(), local_num_rays, thread_number);
            power.compile();

            for (size_t j = 0, je = num_steps + 1; j < je; j++) {
            for (size_t j = 120, je = num_steps + 1; j < je; j++) {
                power.run(j);
            }
        }, i, threads.size());
+358 −9
Original line number Diff line number Diff line
@@ -575,6 +575,25 @@ namespace graph {
                }
            }

//  a*v - v = (a - 1)*v
//  v*a - v = (a - 1)*v
            if (lm.get()) {
                if (this->right->is_match(lm->get_right())) {
                    return (lm->get_left() - one<T, SAFE_MATH> ())*this->right;
                } else if (this->right->is_match(lm->get_left())) {
                    return (lm->get_right() - one<T, SAFE_MATH> ())*this->right;
                }
            }
//  v - a*v = (1 - a)*v
//  v - v*a = (1 - a)*v
            if (rm.get()) {
                if (this->left->is_match(rm->get_right())) {
                    return (one<T, SAFE_MATH> () - rm->get_left())*this->left;
                } else if (this->left->is_match(rm->get_left())) {
                    return (one<T, SAFE_MATH> () - rm->get_right())*this->left;
                }
            }

            if (lm.get() && rm.get()) {
                if (lm->get_left()->is_match(rm->get_left())) {
//  a*b - a*c -> a*(b - c)
@@ -1335,38 +1354,166 @@ namespace graph {
                }
            }

//  Exp(a)*Exp(b) -> Exp(a + b)
//  exp(a)*exp(b) -> exp(a + b)
            auto le = exp_cast(this->left);
            auto re = exp_cast(this->right);
            if (le.get() && re.get()) {
                return exp(le->get_arg() + re->get_arg());
            }

//  Exp(a)*(Exp(b)*c) -> (Exp(a)*Exp(b))*c
//  Exp(a)*(c*Exp(b)) -> (Exp(a)*Exp(b))*c
//  exp(a)*(exp(b)*c) -> c*(exp(a)*exp(b))
//  exp(a)*(c*exp(b)) -> c*(exp(a)*exp(b))
            if (le.get() && rm.get()) {
                auto rmle = exp_cast(rm->get_left());
                if (rmle.get()) {
                    return (this->left*rm->get_left())*rm->get_right();
                    return rm->get_right()*(this->left*rm->get_left());
                }
                auto rmre = exp_cast(rm->get_right());
                if (rmre.get()) {
                    return (this->left*rm->get_right())*rm->get_left();
                    return rm->get_left()*(this->left*rm->get_right());
                }
            }
//  (Exp(a)*c)*Exp(b) -> (Exp(a)*Exp(b))*c
//  (c*Exp(a))*Exp(b) -> (Exp(a)*Exp(b))*c
//  (exp(a)*c)*exp(b) -> c*(exp(a)*exp(b))
//  (c*exp(a))*exp(b) -> c*(exp(a)*exp(b))
            if (re.get() && lm.get()) {
                auto lmle = exp_cast(lm->get_left());
                if (lmle.get()) {
                    return (this->right*lm->get_left())*lm->get_right();
                    return lm->get_right()*(this->right*lm->get_left());
                }
                auto lmre = exp_cast(lm->get_right());
                if (lmre.get()) {
                    return (this->right*lm->get_right())*lm->get_left();
                    return lm->get_left()*(this->right*lm->get_right());
                }
            }
//  (exp(a)*c)*(exp(b)*d) -> (c*d)*(exp(a)*exp(b))
//  (exp(a)*c)*(d*exp(b)) -> (c*d)*(exp(a)*exp(b))
//  (c*exp(a))*(exp(b)*d) -> (c*d)*(exp(a)*exp(b))
//  (c*exp(a))*(d*exp(b)) -> (c*d)*(exp(a)*exp(b))
            if (lm.get() && rm.get()) {
                auto lmle = exp_cast(lm->get_left());
                if (lmle.get()) {
                    auto rmle = exp_cast(rm->get_left());
                    if (rmle.get()) {
                        return (lm->get_right()*rm->get_right()) *
                               (lm->get_left()*rm->get_left());
                    }
                    auto rmre = exp_cast(rm->get_right());
                    if (rmre.get()) {
                        return (lm->get_right()*rm->get_left()) *
                               (lm->get_left()*rm->get_right());
                    }
                }
                auto lmre = exp_cast(lm->get_right());
                if (lmre.get()) {
                    auto rmle = exp_cast(rm->get_left());
                    if (rmle.get()) {
                        return (lm->get_left()*rm->get_right()) *
                               (lm->get_right()*rm->get_left());
                    }
                    auto rmre = exp_cast(rm->get_right());
                    if (rmre.get()) {
                        return (lm->get_left()*rm->get_left()) *
                               (lm->get_right()*rm->get_right());
                    }
                }
            }

            if (ld.get() && re.get()) {
//  (c/exp(a))*exp(b) -> c*(exp(b)/exp(a))
                auto ldre = exp_cast(ld->get_right());
                if (ldre.get()) {
                    return ld->get_left()*(this->right/ld->get_right());
                }
//  (exp(a)/c)*exp(b) -> (exp(a)*exp(b))/c
                auto ldle = exp_cast(ld->get_left());
                if (ldle.get()) {
                    return (ld->get_left()*this->right)/ld->get_right();
                }
            }
            if (rd.get() && le.get()) {
//  exp(a)*(c/exp(a)) -> c*(exp(a)/exp(b))
                auto rdre = exp_cast(rd->get_right());
                if (rdre.get()) {
                    return rd->get_left()*(this->left/rd->get_right());
                }
//  exp(a)*(exp(b)/c) -> (exp(a)*exp(b))/c
                auto rdle = exp_cast(rd->get_left());
                if (rdle.get()) {
                    return (this->left*rd->get_left())/rd->get_right();
                }
            }

            if (ld.get() && rm.get()) {
                auto rmle = exp_cast(rm->get_left());
                if (rmle.get()) {
//  (c/exp(a))*(exp(b)*d) -> (c*d)*(exp(b)/exp(a))
                    auto ldre = exp_cast(ld->get_right());
                    if (ldre.get()) {
                        return (ld->get_left()*rm->get_right()) *
                               (rm->get_left()/ld->get_right());
                    }
//  (exp(a)/c)*(exp(b)*d) -> (d/c)*(exp(a)*exp(b))
                    auto ldle = exp_cast(ld->get_left());
                    if (ldle.get()) {
                        return (rm->get_right()/ld->get_right()) *
                               (ld->get_left()*rm->get_left());
                    }
                }
                auto rmre = exp_cast(rm->get_right());
                if (rmre.get()) {
//  (c/exp(a))*(d*exp(b)) -> (c*d)*(exp(b)/exp(a))
                    auto ldre = exp_cast(ld->get_right());
                    if (ldre.get()) {
                        return (ld->get_left()*rm->get_left()) *
                               (rm->get_right()/ld->get_right());
                    }
//  (exp(a)/c)*(d*exp(b)) -> (d/c)*(exp(a)*exp(b))
                    auto ldle = exp_cast(ld->get_left());
                    if (ldle.get()) {
                        return (rm->get_left()/ld->get_right()) *
                               (ld->get_left()*rm->get_right());
                    }
                }
            } else if (rd.get() && lm.get()) {
                auto lmre = exp_cast(lm->get_right());
                if (lmre.get()) {
//  (c*exp(a))*(exp(b)/d) -> (c/d)*(exp(a)*exp(b))
                    auto rdre = exp_cast(rd->get_left());
                    if (rdre.get()) {
                        return (lm->get_left()/rd->get_right()) *
                               (lm->get_right()*rd->get_left());
                    }
//  (c*exp(a))*(d/exp(b)) -> (c*d)*(exp(a)/exp(b))
                    auto rdle = exp_cast(rd->get_right());
                    if (rdle.get()) {
                        return (lm->get_left()*rd->get_left()) *
                               (lm->get_right()/rd->get_right());
                    }
                }
                auto lmle = exp_cast(lm->get_left());
                if (lmle.get()) {
//  (exp(a)*c)*(d/exp(b)) -> (c*d)*(exp(a)/exp(b))
                    auto rdle = exp_cast(rd->get_right());
                    if (rdle.get()) {
                        return (lm->get_right()*rd->get_left()) *
                               (lm->get_left()/rd->get_right());
                    }
//  (exp(a)*c)*(exp(b)/d) -> (c/d)*(exp(a)*exp(b))
                    auto rdre = exp_cast(rd->get_left());
                    if (rdre.get()) {
                        return (lm->get_right()/rd->get_right()) *
                               (lm->get_left()*rd->get_left());
                    }
                }
            }

//  Cases like
//  (c/exp(a))*(exp(b)/d) -> (c/d)*(exp(b)/exp(a))
//  (c/exp(a))*(d/exp(b)) -> (c*e)/(exp(b)*exp(a))
//  (exp(a)/c)*(d/exp(b)) -> (d/c)*(exp(a)/exp(b))
//  (exp(a)/c)*(exp(b)/d) -> (exp(a)*exp(b))/(c*d)
//  Are taken care of by (a/b)*(c/d) -> (a*c)/(b*d) conversion above.

            return this->shared_from_this();
        }

@@ -1828,6 +1975,94 @@ namespace graph {
                }
            }

//  exp(a)/exp(b) -> exp(a - b)
            auto lexp = exp_cast(this->left);
            auto rexp = exp_cast(this->right);
            if (lexp.get() && rexp.get()) {
                return exp(lexp->get_arg() - rexp->get_arg());
            }

//  (c*exp(a))/exp(b) -> c*(exp(a)/exp(b))
//  (exp(a)*c)/exp(b) -> c*(exp(a)/exp(b))
            if (rexp.get() && lm.get()) {
                auto lmre = exp_cast(lm->get_right());
                if (lmre.get()) {
                    return lm->get_left()*(lm->get_right()/this->right);
                }
                auto lmle = exp_cast(lm->get_left());
                if (lmle.get()) {
                    return lm->get_right()*(lm->get_left()/this->right);
                }
            }
//  exp(a)/(c*exp(b)) -> (exp(a)/exp(b))/c
//  exp(a)/(exp(b)*c) -> (exp(a)/exp(b))/c
            if (lexp.get() && rm.get()) {
                auto rmre = exp_cast(rm->get_right());
                if (rmre.get()) {
                    return (this->left/rm->get_right())/rm->get_left();
                }
                auto rmle = exp_cast(rm->get_left());
                if (rmle.get()) {
                    return (this->left/rm->get_left())/rm->get_right();
                }
            }

//  (c*exp(a))/(d*exp(b)) -> (c/d)*(exp(a)/exp(b))
//  (c*exp(a))/(exp(b)*d) -> (c/d)*(exp(a)/exp(b))
//  (exp(a)*c)/(d*exp(b)) -> (c/d)*(exp(a)/exp(b))
//  (exp(a)*c)/(exp(b)*d) -> (c/d)*(exp(a)/exp(b))
            if (lm.get() && rm.get()) {
                auto lmre = exp_cast(lm->get_right());
                if (lmre.get()) {
                    auto rmre = exp_cast(rm->get_right());
                    if (rmre.get()) {
                        return (lm->get_left()/rm->get_left()) *
                               (lm->get_right()/rm->get_right());
                    }
                    auto rmle = exp_cast(rm->get_left());
                    if (rmle.get()) {
                        return (lm->get_left()/rm->get_right()) *
                               (lm->get_right()/rm->get_left());
                    }
                }
                auto lmle = exp_cast(lm->get_left());
                if (lmle.get()) {
                    auto rmre = exp_cast(rm->get_right());
                    if (rmre.get()) {
                        return (lm->get_right()/rm->get_left()) *
                               (lm->get_left()/rm->get_right());
                    }
                    auto rmle = exp_cast(rm->get_left());
                    if (rmle.get()) {
                        return (lm->get_right()/rm->get_right()) *
                               (lm->get_left()/rm->get_left());
                    }
                }
            }

//  exp(a)/(c/exp(b)) -> (exp(a)*exp(b))/c
//  exp(a)/(exp(b)/c) -> c*(exp(a)/exp(b))
            auto rd = divide_cast(this->right);
            if (rd.get() && lexp.get()) {
                auto rdre = exp_cast(rd->get_right());
                if (rdre.get()) {
                    return (this->left*rd->get_right())/rd->get_left();
                }
                auto rdle = exp_cast(rd->get_left());
                if (rdle.get()) {
                    return rd->get_right()*(this->left/rd->get_left());
                }
            }

//  (c/exp(a))/exp(b) -> c/(exp(a)*exp(b))
//  (exp(a)/c)/exp(b) -> exp(a)/(c*exp(b))
//  (c/exp(a))/(d/exp(b)) -> (c*exp(b))/(d*exp(a))
//  (c/exp(a))/(exp(b)/d) -> (c*d)/(exp(b)*exp(a))
//  (exp(a)/c)/(d/exp(b)) -> (exp(a)*exp(b))/(d*c)
//  (exp(a)/c)/(exp(b)/d) -> (exp(a)*d)/(exp(b)*c)
//  Note cases like this are already transformed by the (a/b)/c -> a/(b*c)
//  above.

            return this->shared_from_this();
        }

@@ -2462,6 +2697,120 @@ namespace graph {
                }
            }

//  fma(exp(a), exp(b), c) -> exp(a + b) + c
            auto le = exp_cast(this->left);
            auto me = exp_cast(this->middle);
            if (le.get() && me.get()) {
                return exp(le->get_arg() + me->get_arg()) + this->right;
            }

//  fma(exp(a), exp(b)*c, d) -> fma(exp(a)*exp(b), c, d)
//  fma(exp(a), c*exp(b), d) -> fma(exp(a)*exp(b), c, d)
            if (mm.get() && le.get()) {
                auto mmle = exp_cast(mm->get_left());
                if (mmle.get()) {
                    return fma(this->left*mm->get_left(), 
                               mm->get_right(),
                               this->right);
                }
                auto mmre = exp_cast(mm->get_right());
                if (mmre.get()) {
                    return fma(this->left*mm->get_right(), 
                               mm->get_left(),
                               this->right);
                }
            }
//  fma(exp(a)*c, exp(b), d) -> fma(exp(a)*exp(b), c, d)
//  fma(c*exp(a), exp(b), d) -> fma(exp(a)*exp(b), c, d)
            if (lm.get() && me.get()) {
                auto lmle = exp_cast(lm->get_left());
                if (lmle.get()) {
                    return fma(lm->get_left()*this->middle, 
                               lm->get_right(),
                               this->right);
                }
                auto lmre = exp_cast(lm->get_right());
                if (lmre.get()) {
                    return fma(lm->get_right()*this->middle, 
                               lm->get_left(),
                               this->right);
                }
            }

//  fma(exp(a)*c, exp(b)*d, e) -> fma(exp(a)*exp(b), c*d, e)
//  fma(exp(a)*c, d*exp(b), e) -> fma(exp(a)*exp(b), c*d, e)
//  fma(c*exp(a), exp(b)*d, e) -> fma(exp(a)*exp(b), c*d, e)
//  fma(c*exp(a), d*exp(b), e) -> fma(exp(a)*exp(b), c*d, e)
            if (lm.get() && mm.get()) {
                auto lmle = exp_cast(lm->get_left());
                if (lmle.get()) {
                    auto mmle = exp_cast(mm->get_left());
                    if (mmle.get()) {
                        return fma(lm->get_left()*mm->get_left(),
                                   lm->get_right()*mm->get_right(),
                                   this->right);
                    }
                    auto mmre = exp_cast(mm->get_right());
                    if (mmre.get()) {
                        return fma(lm->get_left()*mm->get_right(),
                                   lm->get_right()*mm->get_left(),
                                   this->right);
                    }
                }
                auto lmre = exp_cast(lm->get_right());
                if (lmre.get()) {
                    auto mmle = exp_cast(mm->get_left());
                    if (mmle.get()) {
                        return fma(lm->get_right()*mm->get_left(),
                                   lm->get_left()*mm->get_right(),
                                   this->right);
                    }
                    auto mmre = exp_cast(mm->get_right());
                    if (mmre.get()) {
                        return fma(lm->get_right()*mm->get_right(),
                                   lm->get_left()*mm->get_left(),
                                   this->right);
                    }
                }
            }

//  fma(exp(a)*c, exp(b)/d, e) -> fma(exp(a)*exp(b), c/d, e)
//  fma(exp(a)*c, d/exp(b), e) -> fma(exp(a)/exp(b), c*d, e)
//  fma(c*exp(a), exp(b)/d, e) -> fma(exp(a)*exp(b), c/d, e)
//  fma(c*exp(a), d/exp(b), e) -> fma(exp(a)/exp(b), c*d, e)
            if (lm.get() && md.get()) {
                auto lmle = exp_cast(lm->get_left());
                if (lmle.get()) {
                    auto mdle = exp_cast(md->get_left());
                    if (mdle.get()) {
                        return fma(lm->get_left()*md->get_left(),
                                   lm->get_right()/md->get_right(),
                                   this->right);
                    }
                    auto mdre = exp_cast(md->get_right());
                    if (mdre.get()) {
                        return fma(lm->get_left()/md->get_right(),
                                   lm->get_right()*md->get_left(),
                                   this->right);
                    }
                }
                auto lmre = exp_cast(lm->get_right());
                if (lmre.get()) {
                    auto mdle = exp_cast(md->get_left());
                    if (mdle.get()) {
                        return fma(lm->get_right()*md->get_left(),
                                   lm->get_left()/md->get_right(),
                                   this->right);
                    }
                    auto mdre = exp_cast(md->get_right());
                    if (mdre.get()) {
                        return fma(lm->get_right()/md->get_right(),
                                   lm->get_left()*md->get_left(),
                                   this->right);
                    }
                }
            }

            return this->shared_from_this();
        }

+6 −0
Original line number Diff line number Diff line
@@ -938,6 +938,12 @@ namespace graph {
                           this->right/two<T, SAFE_MATH> ());
            }

//  Reduce exp(x)^n -> exp(n*x) when x is an integer.
            auto temp = exp_cast(this->left);
            if (temp.get() && rc.get() && rc->is_integer()) {
                return exp(this->right*temp->get_arg());
            }

            return this->shared_from_this();
        }

+368 −16

File changed.

Preview size limit exceeded, changes collapsed.

+20 −3
Original line number Diff line number Diff line
@@ -291,18 +291,35 @@ void test_pow() {
    auto powpow_int = graph::pow(graph::pow(var_a, var_b),
                                 graph::constant<T> (static_cast<T> (3.0)));
    auto powpow_int_cast = graph::pow_cast(powpow_int);
    assert(graph::multiply_cast(powpow_int_cast->get_right()) &&
    assert((powpow_int_cast.get() &&
            graph::multiply_cast(powpow_int_cast->get_right())) &&
           "Expected multiply node.");
    auto powpow_float =  graph::pow(graph::pow(var_a, var_b),
                                    graph::constant<T> (static_cast<T> (1.5)));
    auto powpow_float_cast = graph::pow_cast(powpow_float);
    assert(!graph::multiply_cast(powpow_float_cast->get_right()) &&
    assert((powpow_int_cast.get() &&
            !graph::multiply_cast(powpow_float_cast->get_right())) &&
           "Did not expect multiply node.");
    auto powpow_var =  graph::pow(graph::pow(var_a, var_b),
                                  ten);
    auto powpow_var_cast = graph::pow_cast(powpow_var);
    assert(!graph::multiply_cast(powpow_var_cast->get_right()) &&
    assert((powpow_int_cast.get() &&
            !graph::multiply_cast(powpow_var_cast->get_right())) &&
           "Did not expect multiply node.");

//  Test pow of exp
//  Exp[x]^n -> Exp[n*x] when n is an integer.
    auto powexp_int = graph::pow(graph::exp(var_a), 
                                 graph::constant<T> (static_cast<T> (3.0)));
    auto powexp_int_cast = graph::exp_cast(powexp_int);
    assert((powexp_int_cast.get() &&
            graph::multiply_cast(powexp_int_cast->get_arg())) &&
           "Expected multiply node in exp argument.");
    auto powexp_float = graph::pow(graph::exp(var_a),
                                   graph::constant<T> (static_cast<T> (1.5)));
    auto powexp_float_cast = graph::pow_cast(powexp_float);
    assert(powexp_float_cast.get() &&
           "Expected power cast.");
}

//------------------------------------------------------------------------------