Commit cd5cfe90 authored by Cianciosa, Mark's avatar Cianciosa, Mark
Browse files

Check if the exponent of pow is a multiply and replace it with explicit multiply code.

parent 0a077ed7
Loading
Loading
Loading
Loading
+14 −12
Original line number Diff line number Diff line
@@ -11,8 +11,8 @@
#include "../graph_framework/solver.hpp"
#include "../graph_framework/timing.hpp"

const bool print = true;
const bool write_step = false;
const bool print = false;
const bool write_step = true;
const bool print_expressions = false;

//------------------------------------------------------------------------------
@@ -26,18 +26,19 @@ int main(int argc, const char * argv[]) {

    std::mutex sync;

    //typedef float base;
    typedef float base;
    //typedef double base;
    //typedef std::complex<float> base;
    typedef std::complex<double> base;
    constexpr bool use_safe_math = true;
    //typedef std::complex<double> base;
    //constexpr bool use_safe_math = true;
    constexpr bool use_safe_math = false;

    const timeing::measure_diagnostic total("Total Time");

    const size_t num_times = 10000;
    const size_t sub_steps = 1;
    const size_t num_steps = num_times/sub_steps;
    const size_t num_rays = 1;//000;//0000;
    const size_t num_rays = 1000000;

    std::vector<std::thread> threads(0);
    if constexpr (jit::use_gpu<base> ()) {
@@ -82,7 +83,7 @@ int main(int argc, const char * argv[]) {
            }

            x->set(static_cast<base> (2.5));
            //x->set(static_cast<base> (9.0));
            //x->set(static_cast<base> (0.0));
            y->set(static_cast<base> (0.0));
            z->set(static_cast<base> (0.0));
            kx->set(static_cast<base> (-600.0));
@@ -91,11 +92,12 @@ int main(int argc, const char * argv[]) {
            kz->set(static_cast<base> (0.0));

            auto eq = equilibrium::make_efit<base, use_safe_math> (NC_FILE, sync);
            //auto eq = equilibrium::make_slab_density<base> ();
            //auto eq = equilibrium::make_no_magnetic_field<base> ();
            //auto eq = equilibrium::make_slab_density<base, use_safe_math> ();
            //auto eq = equilibrium::make_slab_field<base, use_safe_math> ();
            //auto eq = equilibrium::make_no_magnetic_field<base, use_safe_math> ();

            const base endtime = static_cast<base> (1.0);
            //const base endtime = static_cast<base> (10.0);
            //const base endtime = static_cast<base> (0.25);
            const base dt = endtime/static_cast<base> (num_times);

            //auto dt_var = graph::variable(num_rays, static_cast<base> (dt), "dt");
@@ -109,9 +111,9 @@ int main(int argc, const char * argv[]) {
            //solver::rk4<dispersion::simple<base, use_safe_math>>
            //solver::rk4<dispersion::ordinary_wave<base, use_safe_math>>
            //solver::rk4<dispersion::extra_ordinary_wave<base, use_safe_math>>
            //solver::rk4<dispersion::cold_plasma<base, use_safe_math>>
            solver::rk4<dispersion::cold_plasma<base, use_safe_math>>
            //solver::adaptive_rk4<dispersion::ordinary_wave<base, use_safe_math>>
            solver::rk4<dispersion::hot_plasma<base, dispersion::z_erfi<base, use_safe_math>, use_safe_math>>
            //solver::rk4<dispersion::hot_plasma<base, dispersion::z_erfi<base, use_safe_math>, use_safe_math>>
            //solver::rk4<dispersion::hot_plasma_expandion<base, dispersion::z_erfi<base, use_safe_math>, use_safe_math>>
                solve(omega, kx, ky, kz, x, y, z, t, dt, eq,
                      stream.str(), local_num_rays);
+105 −0
Original line number Diff line number Diff line
@@ -460,6 +460,111 @@ namespace equilibrium {
    }

//******************************************************************************
//  Slab field gradient equilibrium.
//******************************************************************************
//------------------------------------------------------------------------------
///  @brief Vary density with uniform magnetic field equilibrium.
//------------------------------------------------------------------------------
    template<typename T, bool SAFE_MATH=false>
    class slab_field : public generic<T, SAFE_MATH> {
    public:
//------------------------------------------------------------------------------
///  @brief Construct a guassian density with uniform magnetic field.
//------------------------------------------------------------------------------
        slab_field() :
        generic<T, SAFE_MATH> ({3.34449469E-27}, {1}) {}

//------------------------------------------------------------------------------
///  @brief Get the electron density.
///
///  @params[in] x X position.
///  @params[in] y Y position.
///  @params[in] z Z position.
///  @returns The electron expression.
//------------------------------------------------------------------------------
        virtual graph::shared_leaf<T, SAFE_MATH>
        get_electron_density(graph::shared_leaf<T, SAFE_MATH> x,
                             graph::shared_leaf<T, SAFE_MATH> y,
                             graph::shared_leaf<T, SAFE_MATH> z) final {
            return graph::constant<T, SAFE_MATH> (static_cast<T> (1.0E19)) *
                   (graph::constant<T, SAFE_MATH> (static_cast<T> (0.1))*x +
                    graph::one<T, SAFE_MATH> ());
        }

//------------------------------------------------------------------------------
///  @brief Get the ion density.
///
///  @params[in] index The species index.
///  @returns The electron expression.
//------------------------------------------------------------------------------
        virtual graph::shared_leaf<T, SAFE_MATH>
        get_ion_density(const size_t index,
                        graph::shared_leaf<T, SAFE_MATH> x,
                        graph::shared_leaf<T, SAFE_MATH> y,
                        graph::shared_leaf<T, SAFE_MATH> z) final {
            return get_electron_density(x, y, z);
        }

//------------------------------------------------------------------------------
///  @brief Get the electron temperature.
///
///  @params[in] x X position.
///  @params[in] y Y position.
///  @params[in] z Z position.
///  @returns The electron expression.
//------------------------------------------------------------------------------
        virtual graph::shared_leaf<T, SAFE_MATH>
        get_electron_temperature(graph::shared_leaf<T, SAFE_MATH> x,
                                 graph::shared_leaf<T, SAFE_MATH> y,
                                 graph::shared_leaf<T, SAFE_MATH> z) final {
            return graph::constant<T, SAFE_MATH> (static_cast<T> (1000.0)) *
                   (graph::constant<T, SAFE_MATH> (static_cast<T> (0.1))*x +
                    graph::one<T, SAFE_MATH> ());
        }

//------------------------------------------------------------------------------
///  @brief Get the ion temperature.
///
///  @params[in] index The species index.
///  @returns The electron expression.
//------------------------------------------------------------------------------
        virtual graph::shared_leaf<T, SAFE_MATH>
        get_ion_temperature(const size_t index,
                            graph::shared_leaf<T, SAFE_MATH> x,
                            graph::shared_leaf<T, SAFE_MATH> y,
                            graph::shared_leaf<T, SAFE_MATH> z) final {
            return get_electron_temperature(x, y, z);
        }
        
//------------------------------------------------------------------------------
///  @brief Get the magnetic field.
///
///  @params[in] x X position.
///  @params[in] y Y position.
///  @params[in] z Z position.
///  @returns Magnetic field expression.
//------------------------------------------------------------------------------
        virtual graph::shared_vector<T, SAFE_MATH>
        get_magnetic_field(graph::shared_leaf<T, SAFE_MATH> x,
                           graph::shared_leaf<T, SAFE_MATH> y,
                           graph::shared_leaf<T, SAFE_MATH> z) final {
            auto zero = graph::zero<T, SAFE_MATH> ();
            return graph::vector(zero, zero,
                                 graph::constant<T, SAFE_MATH> (static_cast<T> (0.1))*x +
                                 graph::one<T, SAFE_MATH> ());
        }
    };

//------------------------------------------------------------------------------
///  @brief Convenience function to build a slab density equilibrium.
///
///  @returns A constructed slab density equilibrium.
//------------------------------------------------------------------------------
    template<typename T, bool SAFE_MATH=false>
    shared<T, SAFE_MATH> make_slab_field() {
        return std::make_shared<slab_field<T, SAFE_MATH>> ();
    }
//******************************************************************************
//  Guassian density with a uniform magnetic field.
//******************************************************************************
//------------------------------------------------------------------------------
+19 −5
Original line number Diff line number Diff line
@@ -750,15 +750,29 @@ namespace graph {
                jit::register_map &registers) {
            if (registers.find(this) == registers.end()) {
                shared_leaf<T, SAFE_MATH> l = this->left->compile(stream, registers);
                shared_leaf<T, SAFE_MATH> r = this->right->compile(stream, registers);
                shared_leaf<T, SAFE_MATH> r;
                auto temp = constant_cast(this->right);
                if (!temp.get() || !temp->is_integer()) {
                    r = this->right->compile(stream, registers);
                }

                registers[this] = jit::to_string('r', this);
                stream << "        const ";
                jit::add_type<T> (stream);
                stream << " " << registers[this] << " = pow("
                stream << " " << registers[this] << " = ";
                if (temp.get() && temp->is_integer()) {
                    stream << registers[l.get()];
                    const size_t end = static_cast<size_t> (std::real(this->right->evaluate().at(0)));
                    for (size_t i = 1; i < end; i++) {
                        stream << "*" << registers[l.get()];
                    }
                    stream << ";";
                } else {
                    stream << "pow("
                           << registers[l.get()] << ", "
                       << registers[r.get()] << ");"
                       << std::endl;
                           << registers[r.get()] << ");";
                }
                stream << std::endl;
            }

            return this->shared_from_this();
+9 −0
Original line number Diff line number Diff line
@@ -364,6 +364,15 @@ namespace graph {
            return data.size() == 1 && data.at(0) == d;
        }

//------------------------------------------------------------------------------
///  @brief Check if the value is an integer.
//------------------------------------------------------------------------------
        bool is_integer() {
            const auto temp = this->evaluate().at(0);
            return std::imag(temp) == 0 &&
                   static_cast<size_t> (std::real(temp))%2 == 0;
        }

//------------------------------------------------------------------------------
///  @brief Convert the node to latex.
//------------------------------------------------------------------------------