Commit 35615635 authored by Cianciosa, Mark's avatar Cianciosa, Mark
Browse files

Enable mutliple GPU support.

parent 1183ac78
Loading
Loading
Loading
Loading
+16 −21
Original line number Diff line number Diff line
@@ -11,8 +11,8 @@
#include "../graph_framework/solver.hpp"
#include "../graph_framework/timing.hpp"

const bool print = true;
const bool write_step = false;
const bool print = false;
const bool write_step = true;
const bool print_expressions = false;

//------------------------------------------------------------------------------
@@ -26,8 +26,8 @@ int main(int argc, const char * argv[]) {

    std::mutex sync;

    //typedef float base;
    typedef double base;
    typedef float base;
    //typedef double base;
    //typedef std::complex<float> base;
    //typedef std::complex<double> base;
    //constexpr bool use_safe_math = true;
@@ -38,16 +38,11 @@ int main(int argc, const char * argv[]) {
    const size_t num_times = 100000;
    const size_t sub_steps = 10;
    const size_t num_steps = num_times/sub_steps;
    const size_t num_rays = 1;
    const size_t num_rays = 100000;

    std::vector<std::thread> threads(0);
    if constexpr (jit::use_gpu<base> ()) {
        threads.resize(1);
    } else {
        threads.resize(std::max(std::min(std::thread::hardware_concurrency(),
    std::vector<std::thread> threads(std::max(std::min(static_cast<unsigned int> (jit::context<base, use_safe_math>::max_concurrency),
                                                       static_cast<unsigned int> (num_rays)),
                                              static_cast<unsigned int> (1)));
    }

    for (size_t i = 0, ie = threads.size(); i < ie; i++) {
        threads[i] = std::thread([num_times, num_rays, &sync] (const size_t thread_number,
+5 −1
Original line number Diff line number Diff line
@@ -1047,6 +1047,7 @@
				GCC_OPTIMIZATION_LEVEL = 0;
				GCC_PREPROCESSOR_DEFINITIONS = (
					USE_METAL,
					"CXX_FLAGS=\\\"-g\\ -fsanitize=undefined\\ -fsanitize=float-divide-by-zero\\ -fsanitize-trap=all\\\"",
					"CXX=\\\"c++\\ -I/Users/m4c/Projects/graph_framework/graph_framework\\ -std=gnu++2a\\\"",
					"DEBUG=1",
					"$(inherited)",
@@ -1154,7 +1155,10 @@
				EXECUTABLE_PREFIX = lib;
				GCC_PREPROCESSOR_DEFINITIONS = (
					USE_METAL,
					"CXX=\\\"c++\\ -I/Users/m4c/Projects/graph_framework/graph_framework\\ -std=gnu++2a\\ -g\" DEBUG=1 $(inherited)",
					"CXX_FLAGS=\\\"-g\\ -fsanitize=undefined\\ -fsanitize=float-divide-by-zero\\ -fsanitize-trap=all\\\"",
					"CXX=\\\"c++\\ -I/Users/m4c/Projects/graph_framework/graph_framework\\ -std=gnu++2a\\\"",
					"DEBUG=1",
					"$(inherited)",
				);
				MACOSX_DEPLOYMENT_TARGET = 13.3;
				OTHER_CPLUSPLUSFLAGS = "$(OTHER_CFLAGS)";
+13 −1
Original line number Diff line number Diff line
@@ -11,6 +11,7 @@
#include <fstream>
#include <cstdlib>
#include <cstring>
#include <thread>

#include <dlfcn.h>

@@ -34,9 +35,20 @@ namespace gpu {

    public:
//------------------------------------------------------------------------------
///  @brief Get the maximum number of concurrent instances.
///
///  @returns The maximum available concurrency.
//------------------------------------------------------------------------------
        static size_t max_concurrency() {
            return std::thread::hardware_concurrency();
        }

//------------------------------------------------------------------------------
///  @brief Construct a cpu context.
///
///  @params[in] index Concurrent index. Not used.
//------------------------------------------------------------------------------
        cpu_context() {}
        cpu_context(const size_t index) {}

//------------------------------------------------------------------------------
///  @brief Destruct a cpu context.
+14 −1
Original line number Diff line number Diff line
@@ -85,10 +85,23 @@ namespace gpu {

    public:
//------------------------------------------------------------------------------
///  @brief Get the maximum number of concurrent instances.
///
///  @returns The maximum available concurrency.
//------------------------------------------------------------------------------
        static size_t max_concurrency() {
            int count;
            check_error(cuDeviceGetCount(&count), "cuDeviceGetCount");
            return count;
        }

//------------------------------------------------------------------------------
///  @brief Cuda context constructor.
///
///  @params[in] index Concurrent index.
//------------------------------------------------------------------------------
        cuda_context() : result_buffer(0), module(0) {
            check_error(cuDeviceGet(&device, 0), "cuDeviceGet");
            check_error(cuDeviceGet(&device, index), "cuDeviceGet");
            check_error(cuDevicePrimaryCtxRetain(&context, device), "cuDevicePrimaryCtxRetain");
            check_error(cuCtxSetCurrent(context), "cuCtxSetCurrent");
            check_error(cuStreamCreate(&stream, CU_STREAM_DEFAULT), "cuStreamCreate");
+3 −1
Original line number Diff line number Diff line
@@ -135,6 +135,7 @@ namespace dispersion {
///
///  @params[in,out] x              The unknown to solver for.
///  @params[in]     inputs         Inputs for jit compile.
///  @params[in]     index          Concurrent index.
///  @params[in]     tolarance      Tolarance to solve the dispersion function
///                                 to.
///  @params[in]     max_iterations Maximum number of iterations before giving
@@ -147,13 +148,14 @@ namespace dispersion {
                                 DISPERSION_FUNCTION::safe_math> x,
              graph::input_nodes<typename DISPERSION_FUNCTION::base,
                                 DISPERSION_FUNCTION::safe_math> inputs,
              const size_t index=0,
              const typename DISPERSION_FUNCTION::base tolarance = 1.0E-30,
              const size_t max_iterations = 1000) {
            auto loss = D*D;
            auto x_var = graph::variable_cast(x);

            workflow::manager<typename DISPERSION_FUNCTION::base,
                              DISPERSION_FUNCTION::safe_math> work;
                              DISPERSION_FUNCTION::safe_math> work(index);

            solver::newton(work, {x}, inputs, loss, tolarance, max_iterations);

Loading