Commit 5603e27d authored by cianciosa's avatar cianciosa
Browse files

Initial absoption calculation. WIP

parent edd49ce7
Loading
Loading
Loading
Loading
+39 −9
Original line number Diff line number Diff line
@@ -9,6 +9,8 @@

#include "../graph_framework/solver.hpp"
#include "../graph_framework/timing.hpp"
#include "../graph_framework/output.hpp"
#include "../graph_framework/absorption.hpp"

const bool print = false;
const bool write_step = true;
@@ -119,8 +121,8 @@ void trace_ray(const size_t num_times,
            //solver::rk4<dispersion::extra_ordinary_wave<T, SAFE_MATH>>
            //solver::rk4<dispersion::cold_plasma<T, SAFE_MATH>>
            //solver::adaptive_rk4<dispersion::ordinary_wave<T, SAFE_MATH>>
            //solver::rk4<dispersion::hot_plasma<base, dispersion::z_erfi<T, SAFE_MATH>, use_safe_math>>
            //solver::rk4<dispersion::hot_plasma_expandion<base, dispersion::z_erfi<T, SAFE_MATH>, use_safe_math>>
            //solver::rk4<dispersion::hot_plasma<T, dispersion::z_erfi<T, SAFE_MATH>, use_safe_math>>
            //solver::rk4<dispersion::hot_plasma_expandion<T, dispersion::z_erfi<T, SAFE_MATH>, use_safe_math>>
                solve(omega, kx, ky, kz, x, y, z, t, dt, eq,
                      stream.str(), local_num_rays, thread_number);
            solve.init(kx);
@@ -210,6 +212,8 @@ void calculate_power(const size_t num_times,
    for (size_t i = 0, ie = threads.size(); i < ie; i++) {
        threads[i] = std::thread([num_times, sub_steps, num_rays, batch, extra] (const size_t thread_number,
                                                                                 const size_t num_threads) -> void {
            std::ostringstream stream;
            stream << "result" << thread_number << ".nc";

            const size_t num_steps = num_times/sub_steps;
            const size_t local_num_rays = batch
@@ -223,7 +227,31 @@ void calculate_power(const size_t num_times,
            auto y     = graph::variable<T, SAFE_MATH> (local_num_rays, "y");
            auto z     = graph::variable<T, SAFE_MATH> (local_num_rays, "z");
            auto t     = graph::variable<T, SAFE_MATH> (local_num_rays, "t");
            auto kamp  = graph::variable<T, SAFE_MATH> (local_num_rays, "kamp");

            omega->set(static_cast<T> (0.0));
            graph::shared_variable<T, SAFE_MATH> omega_var = graph::variable_cast(omega);
            graph::shared_variable<T, SAFE_MATH> kx_var = graph::variable_cast(kx);
            graph::shared_variable<T, SAFE_MATH> ky_var = graph::variable_cast(ky);
            graph::shared_variable<T, SAFE_MATH> kz_var = graph::variable_cast(kz);
            graph::shared_variable<T, SAFE_MATH> x_var = graph::variable_cast(x);
            graph::shared_variable<T, SAFE_MATH> y_var = graph::variable_cast(y);
            graph::shared_variable<T, SAFE_MATH> z_var = graph::variable_cast(z);
            graph::shared_variable<T, SAFE_MATH> t_var = graph::variable_cast(t);

            auto eq = equilibrium::make_efit<T, SAFE_MATH> (NC_FILE);
            //auto eq = equilibrium::make_slab_density<T, SAFE_MATH> ();
            //auto eq = equilibrium::make_slab_field<T, SAFE_MATH> ();
            //auto eq = equilibrium::make_no_magnetic_field<T, SAFE_MATH> ();

            absorption::root_finder<dispersion::hot_plasma<T, dispersion::z_erfi<T, SAFE_MATH>, SAFE_MATH>>
                root(kamp, omega, kx, ky, kz, x, y, z, t, eq,
                     stream.str(), local_num_rays, thread_number);
            root.compile();
            
            for (size_t i = 0; i < num_steps; i++) {
                root.run(i);
            }
        }, i, threads.size());
    }

@@ -246,12 +274,14 @@ int main(int argc, const char * argv[]) {

    const size_t num_times = 100000;
    const size_t sub_steps = 100;
    const size_t num_rays = 100000;
    const size_t num_rays = 36; //100000;

    const bool use_safe_math = true;

    trace_ray<double> (num_times, sub_steps, num_rays);
    calculate_power<std::complex<double>, use_safe_math> (num_times,
    typedef double base;

    trace_ray<base> (num_times, sub_steps, num_rays);
    calculate_power<std::complex<base>, use_safe_math> (num_times,
                                                        sub_steps,
                                                        num_rays);

+163 −0
Original line number Diff line number Diff line
@@ -8,6 +8,8 @@
#ifndef absorption_h
#define absorption_h

#include "newton.hpp"

namespace absorption {
//******************************************************************************
//  Root finder.
@@ -19,6 +21,167 @@ namespace absorption {
//------------------------------------------------------------------------------
    template<class DISPERSION_FUNCTION>
    class root_finder {
    private:
///  kamp variable.
        graph::shared_leaf<typename DISPERSION_FUNCTION::base,
                           DISPERSION_FUNCTION::safe_math> kamp;

///  w variable.
        graph::shared_leaf<typename DISPERSION_FUNCTION::base,
                                   DISPERSION_FUNCTION::safe_math> w;
///  kx variable.
        graph::shared_leaf<typename DISPERSION_FUNCTION::base,
                           DISPERSION_FUNCTION::safe_math> kx;
///  ky variable.
        graph::shared_leaf<typename DISPERSION_FUNCTION::base,
                           DISPERSION_FUNCTION::safe_math> ky;
///  kz variable.
        graph::shared_leaf<typename DISPERSION_FUNCTION::base,
                           DISPERSION_FUNCTION::safe_math> kz;
///  x variable.
        graph::shared_leaf<typename DISPERSION_FUNCTION::base,
                           DISPERSION_FUNCTION::safe_math> x;
///  y variable.
        graph::shared_leaf<typename DISPERSION_FUNCTION::base,
                           DISPERSION_FUNCTION::safe_math> y;
///  z variable.
        graph::shared_leaf<typename DISPERSION_FUNCTION::base,
                           DISPERSION_FUNCTION::safe_math> z;
///  t variable.
        graph::shared_leaf<typename DISPERSION_FUNCTION::base,
                           DISPERSION_FUNCTION::safe_math> t;

///  Residule.
        graph::shared_leaf<typename DISPERSION_FUNCTION::base,
                           DISPERSION_FUNCTION::safe_math> residule;

///  Workflow manager.
        workflow::manager<typename DISPERSION_FUNCTION::base,
                          DISPERSION_FUNCTION::safe_math> work;
///  Concurrent index.
        const size_t index;

///  Output file.
        output::result_file file;
///  Output dataset.
        output::data_set<typename DISPERSION_FUNCTION::base> dataset;

    public:
//------------------------------------------------------------------------------
///  @brief Constructor for root finding.
///
///  @params[in] kamp     Inital kamp.
///  @params[in] w        Inital w.
///  @params[in] kx       Inital kx.
///  @params[in] ky       Inital ky.
///  @params[in] kz       Inital kz.
///  @params[in] x        Inital x.
///  @params[in] y        Inital y.
///  @params[in] z        Inital z.
///  @params[in] t        Inital t.
///  @params[in] eq       The plasma equilibrium.
///  @params[in] filename Result filename, empty names will be blank.
///  @params[in] num_rays Number of rays to write.
///  @params[in] index    Concurrent index.
//------------------------------------------------------------------------------
        root_finder(graph::shared_leaf<typename DISPERSION_FUNCTION::base,
                                       DISPERSION_FUNCTION::safe_math> kamp,
                    graph::shared_leaf<typename DISPERSION_FUNCTION::base,
                                       DISPERSION_FUNCTION::safe_math> w,
                    graph::shared_leaf<typename DISPERSION_FUNCTION::base,
                                       DISPERSION_FUNCTION::safe_math> kx,
                    graph::shared_leaf<typename DISPERSION_FUNCTION::base,
                                       DISPERSION_FUNCTION::safe_math> ky,
                    graph::shared_leaf<typename DISPERSION_FUNCTION::base,
                                       DISPERSION_FUNCTION::safe_math> kz,
                    graph::shared_leaf<typename DISPERSION_FUNCTION::base,
                                       DISPERSION_FUNCTION::safe_math> x,
                    graph::shared_leaf<typename DISPERSION_FUNCTION::base,
                                       DISPERSION_FUNCTION::safe_math> y,
                    graph::shared_leaf<typename DISPERSION_FUNCTION::base,
                                       DISPERSION_FUNCTION::safe_math> z,
                    graph::shared_leaf<typename DISPERSION_FUNCTION::base,
                                       DISPERSION_FUNCTION::safe_math> t,
                    equilibrium::shared<typename DISPERSION_FUNCTION::base,
                                        DISPERSION_FUNCTION::safe_math> &eq,
                    const std::string &filename="",
                    const size_t num_rays=0,
                    const size_t index=0) :
        kamp(kamp), w(w), kx(kx), ky(ky), kz(kz), x(x), y(y), z(z), t(t),
        file(filename), dataset(file), index(index), work(index) {
            auto kvec = graph::vector(kx, ky, kz);
            auto kunit = kvec->unit();
            auto klen = kvec->length();

            auto kx_amp = kamp*kvec->get_x();
            auto ky_amp = kamp*kvec->get_y();
            auto kz_amp = kamp*kvec->get_z();

            dispersion::dispersion_interface<DISPERSION_FUNCTION> D(w, kx_amp, ky_amp, kz_amp, x, y, z, t, eq);

            graph::input_nodes<typename DISPERSION_FUNCTION::base,
                               DISPERSION_FUNCTION::safe_math> inputs = {
                graph::variable_cast(this->kamp),
                graph::variable_cast(this->kx),
                graph::variable_cast(this->ky),
                graph::variable_cast(this->kz)
            };

            graph::map_nodes<typename DISPERSION_FUNCTION::base,
                             DISPERSION_FUNCTION::safe_math> setters = {
                {klen, graph::variable_cast(this->kamp)}
            };

            work.add_item(inputs, {}, setters, "root_find_init_kernel");

            inputs.push_back(graph::variable_cast(this->x));
            inputs.push_back(graph::variable_cast(this->y));
            inputs.push_back(graph::variable_cast(this->z));
            inputs.push_back(graph::variable_cast(this->t));
            inputs.push_back(graph::variable_cast(this->w));

            solver::newton(work, {kamp}, inputs, {D.get_d()*D.get_d()});
        }

//------------------------------------------------------------------------------
///  @brief Compile the workitems.
//------------------------------------------------------------------------------
        void compile() {
            work.compile();

            dataset.create_variable(file, "kamp", this->kamp, work.get_context());
            
            dataset.reference_variable(file, "w",    graph::variable_cast(this->w));
            dataset.reference_variable(file, "kx",   graph::variable_cast(this->kx));
            dataset.reference_variable(file, "ky",   graph::variable_cast(this->ky));
            dataset.reference_variable(file, "kz",   graph::variable_cast(this->kz));
            dataset.reference_variable(file, "x",    graph::variable_cast(this->x));
            dataset.reference_variable(file, "y",    graph::variable_cast(this->y));
            dataset.reference_variable(file, "z",    graph::variable_cast(this->z));
            dataset.reference_variable(file, "time", graph::variable_cast(this->t));
            file.end_define_mode();
        }

//------------------------------------------------------------------------------
///  @brief Run the workflow.
///
///  @params[in] time_index The time index to run the case for.
//------------------------------------------------------------------------------
        void run(const size_t time_index) {
            dataset.read(file, time_index);
            work.copy_to_device(w,  graph::variable_cast(this->w)->data());
            work.copy_to_device(kx, graph::variable_cast(this->kx)->data());
            work.copy_to_device(ky, graph::variable_cast(this->ky)->data());
            work.copy_to_device(kz, graph::variable_cast(this->kz)->data());
            work.copy_to_device(x,  graph::variable_cast(this->x)->data());
            work.copy_to_device(y,  graph::variable_cast(this->y)->data());
            work.copy_to_device(z,  graph::variable_cast(this->z)->data());
            work.copy_to_device(t,  graph::variable_cast(this->t)->data());

            work.run();
            work.wait();
            dataset.write(file);
        }
    };
}

+4 −6
Original line number Diff line number Diff line
@@ -18,14 +18,14 @@
namespace gpu {
//------------------------------------------------------------------------------
///  @brief  Check results of realtime compile.
///
///  @params[in] result Result code of the operation.
///  @params[in] name   Name of the operation.
//------------------------------------------------------------------------------
    static void check_nvrtc_error(nvrtcResult result,
                                  const std::string &name) {
#ifndef NDEBUG
        std::cout << name << " " << result << " "
                  << nvrtcGetErrorString(result) << std::endl;
        assert(result == NVRTC_SUCCESS && "NVTRC Error");
        assert(result == NVRTC_SUCCESS && nvrtcGetErrorString(result));
#endif
    }

@@ -40,9 +40,7 @@ namespace gpu {
#ifndef NDEBUG
        const char *error;
        cuGetErrorString(result, &error);
        std::cout << name << " "
                  << result << " " << error << std::endl;
        assert(result == CUDA_SUCCESS && "Cuda Error");
        assert(result == CUDA_SUCCESS && error);
#endif
    }

+105 −66

File changed.

Preview size limit exceeded, changes collapsed.

+8 −12
Original line number Diff line number Diff line
@@ -25,8 +25,8 @@ namespace jit {
//------------------------------------------------------------------------------
///  @brief Test if a type is complex.
///
///  @tparam T         Base type of the calculation.
///  @tparam SAFE_MATH Use safe math operations.
///  @tparam BASE Base type.
///  @tparam T    Type to check against.
///
///  @returns A constant expression true or false type.
//------------------------------------------------------------------------------
@@ -38,8 +38,8 @@ namespace jit {
//------------------------------------------------------------------------------
///  @brief Test if the base type is float.
///
///  @tparam T         Base type of the calculation.
///  @tparam SAFE_MATH Use safe math operations.
///  @tparam BASE Base type.
///  @tparam T    Type to check against.
///
///  @returns A constant expression true or false type.
//------------------------------------------------------------------------------
@@ -52,7 +52,6 @@ namespace jit {
///  @brief Test if the base type is float.
///
///  @tparam T Base type of the calculation.
///  @tparam SAFE_MATH Use safe math operations.
///
///  @returns A constant expression true or false type.
//------------------------------------------------------------------------------
@@ -65,7 +64,6 @@ namespace jit {
///  @brief Test if the base type is double.
///
///  @tparam T Base type of the calculation.
///  @tparam SAFE_MATH Use safe math operations.
///
///  @returns A constant expression true or false type.
//------------------------------------------------------------------------------
@@ -78,7 +76,6 @@ namespace jit {
///  @brief Test if a type is complex.
///
///  @tparam T Base type of the calculation.
///  @tparam SAFE_MATH Use safe math operations.
///
///  @returns A constant expression true or false type.
//------------------------------------------------------------------------------
@@ -92,7 +89,6 @@ namespace jit {
///  @brief Convert a base type to a string.
///
///  @tparam T Base type of the calculation.
///  @tparam SAFE_MATH Use safe math operations.
///
///  @returns A constant string literal of the type.
//------------------------------------------------------------------------------
Loading