Commit 7a0e8078 authored by cianciosa's avatar cianciosa
Browse files

Add template parameter documentation.

parent 1fbe71c8
Loading
Loading
Loading
Loading
+5 −0
Original line number Diff line number Diff line
@@ -11,6 +11,11 @@

//------------------------------------------------------------------------------
///  @brief Bench runner.
///
///  @tparam T         Base type of the calculation.
///  @tparam NUM_TIMES Total number of times steps.
///  @tparam SUB_STEPS Number of substeps.
///  @tparam NUM_RAYS  Number of rays.
//------------------------------------------------------------------------------
template<typename T, size_t NUM_TIMES, size_t SUB_STEPS, size_t NUM_RAYS>
void bench_runner() {
+89 −75
Original line number Diff line number Diff line
@@ -16,31 +16,21 @@ const bool print_expressions = false;
const bool verbose = true;

//------------------------------------------------------------------------------
///  @brief Main program of the driver.
///  @brief Trace the rays.
///
///  @params[in] argc Number of commandline arguments.
///  @params[in] argv Array of commandline arguments.
///  @tparam T         Base type of the calculation.
///  @tparam SAFE_MATH Use safe math operations.
///
///  @params[in] num_times Total number of time steps.
///  @params[in] sub_steps Number of substeps to push the rays.
///  @params[in] num_rays  Number of rays to trace.
//------------------------------------------------------------------------------
int main(int argc, const char * argv[]) {
    START_GPU

    jit::verbose = verbose;

    //typedef float base;
    typedef double base;
    //typedef std::complex<float> base;
    //typedef std::complex<double> base;
    //constexpr bool use_safe_math = true;
    constexpr bool use_safe_math = false;
template<typename T, bool SAFE_MATH=false>
void trace_ray(const size_t num_times,
               const size_t sub_steps,
               const size_t num_rays) {

    const timeing::measure_diagnostic total("Total Time");

    const size_t num_times = 200000;
    const size_t sub_steps = 100;
    const size_t num_steps = num_times/sub_steps;
    const size_t num_rays = 100000;

    std::vector<std::thread> threads(std::max(std::min(static_cast<unsigned int> (jit::context<base, use_safe_math>::max_concurrency()),
    std::vector<std::thread> threads(std::max(std::min(static_cast<unsigned int> (jit::context<T, SAFE_MATH>::max_concurrency()),
                                                       static_cast<unsigned int> (num_rays)),
                                              static_cast<unsigned int> (1)));

@@ -48,86 +38,91 @@ int main(int argc, const char * argv[]) {
    const size_t extra = num_rays%threads.size();

    for (size_t i = 0, ie = threads.size(); i < ie; i++) {
        threads[i] = std::thread([num_times, num_rays, batch, extra] (const size_t thread_number,
        threads[i] = std::thread([num_times, sub_steps, num_rays, batch, extra] (const size_t thread_number,
                                                                                 const size_t num_threads) -> void {

            const size_t num_steps = num_times/sub_steps;
            const size_t local_num_rays = batch
                                        + (extra > thread_number ? 1 : 0);

            std::mt19937_64 engine((thread_number + 1)*static_cast<uint64_t> (std::chrono::system_clock::to_time_t(std::chrono::system_clock::now())));
            std::uniform_int_distribution<size_t> int_dist(0, local_num_rays - 1);

            auto omega = graph::variable<base, use_safe_math> (local_num_rays, "\\omega");
            auto kx    = graph::variable<base, use_safe_math> (local_num_rays, "k_{x}");
            auto ky    = graph::variable<base, use_safe_math> (local_num_rays, "k_{y}");
            auto kz    = graph::variable<base, use_safe_math> (local_num_rays, "k_{z}");
            auto x     = graph::variable<base, use_safe_math> (local_num_rays, "x");
            auto y     = graph::variable<base, use_safe_math> (local_num_rays, "y");
            auto z     = graph::variable<base, use_safe_math> (local_num_rays, "z");
            auto t     = graph::variable<base, use_safe_math> (local_num_rays, "t");
            auto omega = graph::variable<T, SAFE_MATH> (local_num_rays, "\\omega");
            auto kx    = graph::variable<T, SAFE_MATH> (local_num_rays, "k_{x}");
            auto ky    = graph::variable<T, SAFE_MATH> (local_num_rays, "k_{y}");
            auto kz    = graph::variable<T, SAFE_MATH> (local_num_rays, "k_{z}");
            auto x     = graph::variable<T, SAFE_MATH> (local_num_rays, "x");
            auto y     = graph::variable<T, SAFE_MATH> (local_num_rays, "y");
            auto z     = graph::variable<T, SAFE_MATH> (local_num_rays, "z");
            auto t     = graph::variable<T, SAFE_MATH> (local_num_rays, "t");

            t->set(static_cast<base> (0.0));
            t->set(static_cast<T> (0.0));

//  Inital conditions.
            if constexpr (jit::is_float<base> ()) {
                std::normal_distribution<float> norm_dist(static_cast<float> (700.0), static_cast<float> (10.0));
                std::normal_distribution<float> norm_dist2(static_cast<float> (0.0), static_cast<float> (0.05));
                std::normal_distribution<float> norm_dist3(static_cast<float> (-100.0), static_cast<float> (10.0));
                std::normal_distribution<float> norm_dist4(static_cast<float> (0.0), static_cast<float> (10.0));
            if constexpr (jit::is_float<T> ()) {
                std::normal_distribution<float> norm_dist1(static_cast<float> (700.0),
                                                           static_cast<float> (10.0));
                std::normal_distribution<float> norm_dist2(static_cast<float> (0.0),
                                                           static_cast<float> (0.05));
                std::normal_distribution<float> norm_dist3(static_cast<float> (-100.0),
                                                           static_cast<float> (10.0));
                std::normal_distribution<float> norm_dist4(static_cast<float> (0.0),
                                                           static_cast<float> (10.0));

                for (size_t j = 0; j < local_num_rays; j++) {
                    omega->set(j, static_cast<base> (norm_dist(engine)));
                    y->set(j, static_cast<base> (norm_dist2(engine)));
                    z->set(j, static_cast<base> (norm_dist2(engine)));
                    ky->set(j, static_cast<base> (norm_dist3(engine)));
                    kz->set(j, static_cast<base> (norm_dist4(engine)));
                    omega->set(j, static_cast<T> (norm_dist1(engine)));
                    y->set(j, static_cast<T> (norm_dist2(engine)));
                    z->set(j, static_cast<T> (norm_dist2(engine)));
                    ky->set(j, static_cast<T> (norm_dist3(engine)));
                    kz->set(j, static_cast<T> (norm_dist4(engine)));
                }
            } else {
                std::normal_distribution<float> norm_dist(static_cast<double> (700.0), static_cast<double> (10.0));
                std::normal_distribution<float> norm_dist2(static_cast<double> (0.0), static_cast<double> (0.05));
                std::normal_distribution<float> norm_dist3(static_cast<float> (-100.0), static_cast<float> (10.0));
                std::normal_distribution<float> norm_dist4(static_cast<float> (0.0), static_cast<float> (10.0));
                std::normal_distribution<double> norm_dist1(static_cast<double> (700.0),
                                                            static_cast<double> (10.0));
                std::normal_distribution<double> norm_dist2(static_cast<double> (0.0),
                                                            static_cast<double> (0.05));
                std::normal_distribution<double> norm_dist3(static_cast<double> (-100.0),
                                                            static_cast<double> (10.0));
                std::normal_distribution<double> norm_dist4(static_cast<double> (0.0),
                                                            static_cast<double> (10.0));

                for (size_t j = 0; j < local_num_rays; j++) {
                    omega->set(j, static_cast<base> (norm_dist(engine)));
                    y->set(j, static_cast<base> (norm_dist2(engine)));
                    z->set(j, static_cast<base> (norm_dist2(engine)));
                    ky->set(j, static_cast<base> (norm_dist3(engine)));
                    kz->set(j, static_cast<base> (norm_dist4(engine)));
                    omega->set(j, static_cast<T> (norm_dist1(engine)));
                    y->set(j, static_cast<T> (norm_dist2(engine)));
                    z->set(j, static_cast<T> (norm_dist2(engine)));
                    ky->set(j, static_cast<T> (norm_dist3(engine)));
                    kz->set(j, static_cast<T> (norm_dist4(engine)));
                }
            }
            x->set(static_cast<base> (2.5));
            kx->set(static_cast<base> (-700));
            x->set(static_cast<T> (2.5));
            kx->set(static_cast<T> (-700));

            auto eq = equilibrium::make_efit<base, use_safe_math> (NC_FILE);
            //auto eq = equilibrium::make_slab_density<base, use_safe_math> ();
            //auto eq = equilibrium::make_slab_field<base, use_safe_math> ();
            //auto eq = equilibrium::make_no_magnetic_field<base, use_safe_math> ();
            auto eq = equilibrium::make_efit<T, SAFE_MATH> (NC_FILE);
            //auto eq = equilibrium::make_slab_density<T, SAFE_MATH> ();
            //auto eq = equilibrium::make_slab_field<T, SAFE_MATH> ();
            //auto eq = equilibrium::make_no_magnetic_field<T, SAFE_MATH> ();

            const base endtime = static_cast<base> (1.5);
            //const base endtime = static_cast<base> (10.0);
            //const base endtime = static_cast<base> (0.25);
            const base dt = endtime/static_cast<base> (num_times);
            const T endtime = static_cast<T> (1.5);
            const T dt = endtime/static_cast<T> (num_times);

            //auto dt_var = graph::variable(num_rays, static_cast<base> (dt), "dt");
            //auto dt_var = graph::variable(num_rays, static_cast<T> (dt), "dt");

            std::ostringstream stream;
            stream << "result" << thread_number << ".nc";

            //solver::split_simplextic<dispersion::bohm_gross<base, use_safe_math>>
            //solver::rk4<dispersion::bohm_gross<base, use_safe_math>>
            //solver::adaptive_rk4<dispersion::bohm_gross<base, use_safe_math>>
            //solver::rk4<dispersion::simple<base, use_safe_math>>
            solver::rk4<dispersion::ordinary_wave<base, use_safe_math>>
            //solver::rk4<dispersion::extra_ordinary_wave<base, use_safe_math>>
            //solver::rk4<dispersion::cold_plasma<base, use_safe_math>>
            //solver::adaptive_rk4<dispersion::ordinary_wave<base, use_safe_math>>
            //solver::rk4<dispersion::hot_plasma<base, dispersion::z_erfi<base, use_safe_math>, use_safe_math>>
            //solver::rk4<dispersion::hot_plasma_expandion<base, dispersion::z_erfi<base, use_safe_math>, use_safe_math>>
            //solver::split_simplextic<dispersion::bohm_gross<T, SAFE_MATH>>
            //solver::rk4<dispersion::bohm_gross<T, SAFE_MATH>>
            //solver::adaptive_rk4<dispersion::bohm_gross<T, SAFE_MATH>>
            //solver::rk4<dispersion::simple<T, SAFE_MATH>>
            solver::rk4<dispersion::ordinary_wave<T, SAFE_MATH>>
            //solver::rk4<dispersion::extra_ordinary_wave<T, SAFE_MATH>>
            //solver::rk4<dispersion::cold_plasma<T, SAFE_MATH>>
            //solver::adaptive_rk4<dispersion::ordinary_wave<T, SAFE_MATH>>
            //solver::rk4<dispersion::hot_plasma<base, dispersion::z_erfi<T, SAFE_MATH>, use_safe_math>>
            //solver::rk4<dispersion::hot_plasma_expandion<base, dispersion::z_erfi<T, SAFE_MATH>, use_safe_math>>
                solve(omega, kx, ky, kz, x, y, z, t, dt, eq,
                      stream.str(), local_num_rays, thread_number);
                //solve(omega, kx, ky, kz, x, y, z, t, dt_var, eq,
                //      stream.str(), local_num_rays, thread_number);
            solve.init(kx);
            solve.compile();
            if (thread_number == 0 && print_expressions) {
@@ -193,6 +188,25 @@ int main(int argc, const char * argv[]) {
    for (std::thread &t : threads) {
        t.join();
    }
}

//------------------------------------------------------------------------------
///  @brief Main program of the driver.
///
///  @params[in] argc Number of commandline arguments.
///  @params[in] argv Array of commandline arguments.
//------------------------------------------------------------------------------
int main(int argc, const char * argv[]) {
    START_GPU
    const timeing::measure_diagnostic total("Total Time");

    jit::verbose = verbose;

    const size_t num_times = 100000;
    const size_t sub_steps = 100;
    const size_t num_rays = 100000;

    trace_ray<double> (num_times, sub_steps, num_rays);

    std::cout << std::endl << "Timing:" << std::endl;
    total.print();
+57 −0
Original line number Diff line number Diff line
@@ -18,6 +18,9 @@ namespace graph {
///  @brief An addition node.
///
///  Note use templates here to defer this so it can use the operator functions.
///
///  @tparam T         Base type of the calculation.
///  @tparam SAFE_MATH Use safe math operations.
//------------------------------------------------------------------------------
    template<typename T, bool SAFE_MATH=false>
    class add_node final : public branch_node<T, SAFE_MATH> {
@@ -358,6 +361,9 @@ namespace graph {
///  Note use templates here to defer this so it can be used in the above
///  classes.
///
///  @tparam T         Base type of the calculation.
///  @tparam SAFE_MATH Use safe math operations.
///
///  @params[in] l Left branch.
///  @params[in] r Right branch.
//------------------------------------------------------------------------------
@@ -385,6 +391,9 @@ namespace graph {
///  Note use templates here to defer this so it can be used in the above
///  classes.
///
///  @tparam T         Base type of the calculation.
///  @tparam SAFE_MATH Use safe math operations.
///
///  @params[in] l Left branch.
///  @params[in] r Right branch.
//------------------------------------------------------------------------------
@@ -401,6 +410,9 @@ namespace graph {
//------------------------------------------------------------------------------
///  @brief Cast to a add node.
///
///  @tparam T         Base type of the calculation.
///  @tparam SAFE_MATH Use safe math operations.
///
///  @params[in] x Leaf node to attempt cast.
///  @returns An attemped dynamic case.
//------------------------------------------------------------------------------
@@ -416,6 +428,9 @@ namespace graph {
///  @brief A subtraction node.
///
///  Note use templates here to defer this so it can use the operator functions.
///
///  @tparam T         Base type of the calculation.
///  @tparam SAFE_MATH Use safe math operations.
//------------------------------------------------------------------------------
    template<typename T, bool SAFE_MATH=false>
    class subtract_node final : public branch_node<T, SAFE_MATH> {
@@ -804,6 +819,9 @@ namespace graph {
//------------------------------------------------------------------------------
///  @brief Build subtract node from two leaves.
///
///  @tparam T         Base type of the calculation.
///  @tparam SAFE_MATH Use safe math operations.
///
///  @params[in] l Left branch.
///  @params[in] r Right branch.
//------------------------------------------------------------------------------
@@ -828,6 +846,9 @@ namespace graph {
//------------------------------------------------------------------------------
///  @brief Build subtract operator from two leaves.
///
///  @tparam T         Base type of the calculation.
///  @tparam SAFE_MATH Use safe math operations.
///
///  @params[in] l Left branch.
///  @params[in] r Right branch.
//------------------------------------------------------------------------------
@@ -844,6 +865,9 @@ namespace graph {
//------------------------------------------------------------------------------
///  @brief Cast to a subtract node.
///
///  @tparam T         Base type of the calculation.
///  @tparam SAFE_MATH Use safe math operations.
///
///  @params[in] x Leaf node to attempt cast.
///  @returns An attemped dynamic case.
//------------------------------------------------------------------------------
@@ -857,6 +881,9 @@ namespace graph {
//******************************************************************************
//------------------------------------------------------------------------------
///  @brief A multiplcation node.
///
///  @tparam T         Base type of the calculation.
///  @tparam SAFE_MATH Use safe math operations.
//------------------------------------------------------------------------------
    template<typename T, bool SAFE_MATH=false>
    class multiply_node final : public branch_node<T, SAFE_MATH> {
@@ -1359,6 +1386,9 @@ namespace graph {
//------------------------------------------------------------------------------
///  @brief Build multiply node from two leaves.
///
///  @tparam T         Base type of the calculation.
///  @tparam SAFE_MATH Use safe math operations.
///
///  @params[in] l Left branch.
///  @params[in] r Right branch.
//------------------------------------------------------------------------------
@@ -1383,6 +1413,9 @@ namespace graph {
//------------------------------------------------------------------------------
///  @brief Build multiply operator from two leaves.
///
///  @tparam T         Base type of the calculation.
///  @tparam SAFE_MATH Use safe math operations.
///
///  @params[in] l Left branch.
///  @params[in] r Right branch.
//------------------------------------------------------------------------------
@@ -1399,6 +1432,9 @@ namespace graph {
//------------------------------------------------------------------------------
///  @brief Cast to a multiply node.
///
///  @tparam T         Base type of the calculation.
///  @tparam SAFE_MATH Use safe math operations.
///
///  @params[in] x Leaf node to attempt cast.
///  @returns An attemped dynamic case.
//------------------------------------------------------------------------------
@@ -1412,6 +1448,9 @@ namespace graph {
//******************************************************************************
//------------------------------------------------------------------------------
///  @brief A division node.
///
///  @tparam T         Base type of the calculation.
///  @tparam SAFE_MATH Use safe math operations.
//------------------------------------------------------------------------------
    template<typename T, bool SAFE_MATH=false>
    class divide_node final : public branch_node<T, SAFE_MATH> {
@@ -1751,6 +1790,9 @@ namespace graph {
//------------------------------------------------------------------------------
///  @brief Build divide node from two leaves.
///
///  @tparam T         Base type of the calculation.
///  @tparam SAFE_MATH Use safe math operations.
///
///  @params[in] l Left branch.
///  @params[in] r Right branch.
//------------------------------------------------------------------------------
@@ -1775,6 +1817,9 @@ namespace graph {
//------------------------------------------------------------------------------
///  @brief Build divide operator from two leaves.
///
///  @tparam T         Base type of the calculation.
///  @tparam SAFE_MATH Use safe math operations.
///
///  @params[in] l Left branch.
///  @params[in] r Right branch.
//------------------------------------------------------------------------------
@@ -1791,6 +1836,9 @@ namespace graph {
//------------------------------------------------------------------------------
///  @brief Cast to a divide node.
///
///  @tparam T         Base type of the calculation.
///  @tparam SAFE_MATH Use safe math operations.
///
///  @params[in] x Leaf node to attempt cast.
///  @returns An attemped dynamic case.
//------------------------------------------------------------------------------
@@ -1806,6 +1854,9 @@ namespace graph {
///  @brief A fused multiply add node.
///
///  Note use templates here to defer this so it can use the operator functions.
///
///  @tparam T         Base type of the calculation.
///  @tparam SAFE_MATH Use safe math operations.
//------------------------------------------------------------------------------
    template<typename T, bool SAFE_MATH=false>
    class fma_node final : public triple_node<T, SAFE_MATH> {
@@ -2262,6 +2313,9 @@ namespace graph {
//------------------------------------------------------------------------------
///  @brief Build fused multiply add node.
///
///  @tparam T         Base type of the calculation.
///  @tparam SAFE_MATH Use safe math operations.
///
///  @params[in] l Left branch.
///  @params[in] m Middle branch.
///  @params[in] r Right branch.
@@ -2292,6 +2346,9 @@ namespace graph {
//------------------------------------------------------------------------------
///  @brief Cast to a fma node.
///
///  @tparam T         Base type of the calculation.
///  @tparam SAFE_MATH Use safe math operations.
///
///  @params[in] x Leaf node to attempt cast.
///  @returns An attemped dynamic case.
//------------------------------------------------------------------------------
+18 −0
Original line number Diff line number Diff line
@@ -20,6 +20,8 @@ namespace backend {
//******************************************************************************
//------------------------------------------------------------------------------
///  @brief Class representing a generic buffer.
///
///  @tparam T Base type of the calculation.
//------------------------------------------------------------------------------
    template<typename T>
    class buffer {
@@ -261,6 +263,8 @@ namespace backend {
//------------------------------------------------------------------------------
///  @brief Add operation.
///
///  @tparam T Base type of the calculation.
///
///  @params[in] a Left operand.
///  @params[in] b Right operand.
///  @returns a + b.
@@ -293,6 +297,8 @@ namespace backend {
//------------------------------------------------------------------------------
///  @brief Equal operation.
///
///  @tparam T Base type of the calculation.
///
///  @params[in] a Left operand.
///  @params[in] b Right operand.
///  @returns a == b.
@@ -315,6 +321,8 @@ namespace backend {
//------------------------------------------------------------------------------
///  @brief Subtract operation.
///
///  @tparam T Base type of the calculation.
///
///  @params[in] a Left operand.
///  @params[in] b Right operand.
///  @returns a - b.
@@ -347,6 +355,8 @@ namespace backend {
//------------------------------------------------------------------------------
///  @brief Multiply operation.
///
///  @tparam T Base type of the calculation.
///
///  @params[in] a Left operand.
///  @params[in] b Right operand.
///  @returns a * b.
@@ -379,6 +389,8 @@ namespace backend {
//------------------------------------------------------------------------------
///  @brief Divide operation.
///
///  @tparam T Base type of the calculation.
///
///  @params[in] a Numerator.
///  @params[in] b Denominator.
///  @returns a / b.
@@ -411,6 +423,8 @@ namespace backend {
//------------------------------------------------------------------------------
///  @brief Fused multiply add operation.
///
///  @tparam T Base type of the calculation.
///
///  @params[in] a Left operand.
///  @params[in] b Middle operand.
///  @params[in] c Right operand.
@@ -517,6 +531,8 @@ namespace backend {
//------------------------------------------------------------------------------
///  @brief Take the power.
///
///  @tparam T Base type of the calculation.
///
///  @params[in] base     Base to raise to the power of.
///  @params[in] exponent Power to apply to the base.
///  @returns base^exponent.
@@ -588,6 +604,8 @@ namespace backend {
//------------------------------------------------------------------------------
///  @brief Take the inverse tangent.
///
///  @tparam T Base type of the calculation.
///
///  @params[in] x X argument.
///  @params[in] y Y argument.
///  @returns atan2(y, x)
+3 −0
Original line number Diff line number Diff line
@@ -20,6 +20,9 @@
namespace gpu {
//------------------------------------------------------------------------------
///  @brief Class representing a cpu context.
///
///  @tparam T         Base type of the calculation.
///  @tparam SAFE_MATH Use safe math operations.
//------------------------------------------------------------------------------
    template<typename T, bool SAFE_MATH=false>
    class cpu_context {
Loading