Commit cb2cdc10 authored by Cianciosa, Mark's avatar Cianciosa, Mark
Browse files

Add locks to the netcdf read and write routines since netcdfc is not thread safe.

parent 6f9ffdad
Loading
Loading
Loading
Loading
+3 −5
Original line number Diff line number Diff line
@@ -24,8 +24,6 @@ const bool print_expressions = false;
int main(int argc, const char * argv[]) {
    START_GPU

    std::mutex sync;

    typedef float base;
    //typedef double base;
    //typedef std::complex<float> base;
@@ -45,7 +43,7 @@ int main(int argc, const char * argv[]) {
                                              static_cast<unsigned int> (1)));

    for (size_t i = 0, ie = threads.size(); i < ie; i++) {
        threads[i] = std::thread([num_times, num_rays, &sync] (const size_t thread_number,
        threads[i] = std::thread([num_times, num_rays] (const size_t thread_number,
                                                        const size_t num_threads) -> void {
            const size_t local_num_rays = num_rays/num_threads
                                        + std::min(thread_number, num_rays%num_threads);
@@ -89,7 +87,7 @@ int main(int argc, const char * argv[]) {
            kz->set(static_cast<base> (0.0));
            //kz->set(static_cast<base> (10.0));

            auto eq = equilibrium::make_efit<base, use_safe_math> (NC_FILE, sync);
            auto eq = equilibrium::make_efit<base, use_safe_math> (NC_FILE);
            //auto eq = equilibrium::make_slab_density<base, use_safe_math> ();
            //auto eq = equilibrium::make_slab_field<base, use_safe_math> ();
            //auto eq = equilibrium::make_no_magnetic_field<base, use_safe_math> ();
+4 −4
Original line number Diff line number Diff line
@@ -19,6 +19,9 @@
#include "arithmetic.hpp"

namespace equilibrium {
///  Lock to syncronize netcdf accross threads.
    static std::mutex sync;

//******************************************************************************
//  Equilibrium interface
//******************************************************************************
@@ -1108,13 +1111,10 @@ namespace equilibrium {
///  @brief Convenience function to build an EFIT equilibrium.
///
///  @params[in] spline_file File name of contains the spline functions.
///  @params[in,out] sync    Mutex to ensure the netcdf file is read only by one
///                          thread.
///  @returns A constructed EFIT equilibrium.
//------------------------------------------------------------------------------
    template<typename T, bool SAFE_MATH=false>
    shared<T, SAFE_MATH> make_efit(const std::string spline_file,
                                   std::mutex &sync) {
    shared<T, SAFE_MATH> make_efit(const std::string spline_file) {
        int ncid;
        sync.lock();
        nc_open(spline_file.c_str(), NC_NOWRITE, &ncid);
+18 −1
Original line number Diff line number Diff line
@@ -6,11 +6,16 @@
#ifndef output_h
#define output_h

#include <mutex>

#include <netcdf.h>

#include "jit.hpp"

namespace output {
///  Lock to syncronize netcdf accross threads.
    static std::mutex sync;

//------------------------------------------------------------------------------
///  @brief Class representing a netcdf based output file.
//------------------------------------------------------------------------------
@@ -50,6 +55,7 @@ namespace output {
        result_file(const std::string &filename="",
                    const size_t num_rays=0) : num_rays(num_rays) {
            
            sync.lock();
            nc_create(filename.c_str(),
                      filename.empty() || num_rays == 0 ? NC_DISKLESS : NC_CLOBBER,
                      &ncid);
@@ -63,6 +69,7 @@ namespace output {
                nc_def_dim(ncid, "ray_dim", 1, &ray_dim);
                nc_def_dim(ncid, "num_rays", num_rays*1, &num_rays_dim);
            }
            sync.unlock();
        }

//------------------------------------------------------------------------------
@@ -85,6 +92,7 @@ namespace output {
                             jit::context<T, SAFE_MATH> &context) {
            variable var;
            const std::array<int, 3> dims = {unlimited_dim, num_rays_dim, ray_dim};
            sync.lock();
            if constexpr (jit::is_float<T> ()) {
                nc_def_var(ncid, name.c_str(), NC_FLOAT, dims.size(),
                           dims.data(), &var.id);
@@ -92,6 +100,7 @@ namespace output {
                nc_def_var(ncid, name.c_str(), NC_DOUBLE, dims.size(),
                           dims.data(), &var.id);
            }
            sync.unlock();

            var.buffer = context.get_buffer(node);

@@ -102,7 +111,9 @@ namespace output {
///  @brief End define mode.
//------------------------------------------------------------------------------
        void end_define_mode() const {
            sync.lock();
            nc_enddef(ncid);
            sync.unlock();
        }

//------------------------------------------------------------------------------
@@ -110,9 +121,12 @@ namespace output {
//------------------------------------------------------------------------------
        void write() {
            size_t size;
            sync.lock();
            nc_inq_dimlen(ncid, unlimited_dim, &size);
            sync.unlock();
            const std::array<size_t, 3> start = {size, 0, 0};
            for (variable &var : variables) {
                sync.lock();
                if constexpr (jit::is_float<T> ()) {
                    if constexpr (jit::is_complex<T> ()) {
                        const std::array<size_t, 3> count = {1, num_rays, 2};
@@ -134,9 +148,12 @@ namespace output {
                                           var.buffer);
                    }
                }
                sync.unlock();
            }

            sync.lock();
            nc_sync(ncid);
            sync.unlock();
        }
    };
}
+1 −3
Original line number Diff line number Diff line
@@ -584,9 +584,7 @@ template<typename T> void test_efit() {
    z->set(static_cast<T> (0.0));
    t->set(static_cast<T> (0.0));
    
    std::mutex sync;
    
    auto eq = equilibrium::make_efit<T> (NC_FILE, sync);
    auto eq = equilibrium::make_efit<T> (NC_FILE);
    solver::rk4<dispersion::ordinary_wave<T>>
        solve(omega, kx, ky, kz, x, y, z, t, 0.0001, eq);
    solve.init(kx);