Commit 25b1ac85 authored by Cianciosa, Mark's avatar Cianciosa, Mark
Browse files

Add cuda support for float and double. Complex support is WIP.

parent 3e08ad13
Loading
Loading
Loading
Loading
+32 −24
Original line number Diff line number Diff line
@@ -60,7 +60,6 @@ namespace graph {
        return is_variable_like(a) &&
               is_variable_like(b) &&
               get_argument(a)->is_match(get_argument(b));
        
    }

//******************************************************************************
@@ -289,7 +288,7 @@ namespace graph {
            if (registers.find(this) == registers.end()) {
                registers[this] = jit::to_string('r', this);
                stream << "        const ";
                jit::add_type<LN> (stream);
                jit::add_type<typename LN::backend> (stream);
                stream << " " << registers[this] << " = "
                       << registers[l.get()] << " + "
                       << registers[r.get()] << ";"
@@ -600,7 +599,7 @@ namespace graph {
            if (registers.find(this) == registers.end()) {
                registers[this] = jit::to_string('r', this);
                stream << "        const ";
                jit::add_type<LN> (stream);
                jit::add_type<typename LN::backend> (stream);
                stream << " " << registers[this] << " = "
                       << registers[l.get()] << " - "
                       << registers[r.get()] << ";"
@@ -976,7 +975,7 @@ namespace graph {
            if (registers.find(this) == registers.end()) {
                registers[this] = jit::to_string('r', this);
                stream << "        const ";
                jit::add_type<LN> (stream);
                jit::add_type<typename LN::backend> (stream);
                stream << " " << registers[this] << " = "
                       << registers[l.get()] << "*"
                       << registers[r.get()] << ";"
@@ -1293,7 +1292,7 @@ namespace graph {
            if (registers.find(this) == registers.end()) {
                registers[this] = jit::to_string('r', this);
                stream << "        const ";
                jit::add_type<LN> (stream);
                jit::add_type<typename LN::backend> (stream);
                //std::cout << ((registers.find(r.get()) == registers.end()) ? "True" : registers[r.get()])
                //          << std::endl;
                stream << " " << registers[this] << " = "
@@ -1544,13 +1543,22 @@ namespace graph {
            if (registers.find(this) == registers.end()) {
                registers[this] = jit::to_string('r', this);
                stream << "        const ";
                jit::add_type<LN> (stream);
                stream << " " << registers[this] << " = fma("
                jit::add_type<typename LN::backend> (stream);
                stream << " " << registers[this] << " = ";
                if constexpr (std::is_same<std::complex<float>, typename LN::backend::base>::value ||
                              std::is_same<std::complex<double>, typename LN::backend::base>::value) {
                    stream << registers[l.get()] << "*"
                           << registers[m.get()] << " + "
                           << registers[r.get()] << ";"
                           << std::endl;
                } else {
                    stream << "fma("
                           << registers[l.get()] << ", "
                           << registers[m.get()] << ", "
                           << registers[r.get()] << ");"
                           << std::endl;
                }
            }

            return this->shared_from_this();
        }
+132 −42
Original line number Diff line number Diff line
@@ -31,29 +31,42 @@ namespace gpu {
        CUcontext context;
///  The cuda code library.
        CUmodule module;
///  The cuda kernel;
///  The cuda kernel.
        CUfunction function;
///  The cuda max reduction kernel.
        CUfunction max_function;
///  The cuda library.
        nvrtcProgram kernel_program;
///  Buffer objects.
        std::vector<CUdeviceptr> buffers;
///  Result buffer.
        CUdeviceptr result_buffer;
///  Cuda stream.
        CUstream stream;
///  Number of thread groups.
        unsigned int thread_groups;
///  Number of threads in a group.
        unsigned int threads_per_group;
///  Index offset.
        size_t buffer_offset;
///  Buffer element size.
        size_t buffer_element_size;
///  Time offset.
        size_t time_offset;
///  Result buffer size;
        size_t result_size;
///  Kernel arguments.
        std::vector<void *> kernel_arguments;
///  Max kernel arguments.
        std::vector<void *> max_kernel_arguments;

//------------------------------------------------------------------------------
///  @brief  Check results of realtime compile.
///  @param[in] name   Name of the operation.
//------------------------------------------------------------------------------
        void check_nvrtc_error(nvrtcResult result,
                               const std::string &name) {
#ifndef NDEBUG
            std::cout << name << " " << result << " "
                      << nvrtcGetErrorString(result) << std::endl;
            assert(result == NVRTC_SUCCESS && "NVTRC Error");
#endif
        }

//------------------------------------------------------------------------------
///  @brief  Check Results of cuda functions.
///  @brief  Check results of cuda functions.
///
///  @param[in] result Result code of the operation.
///  @param[in] name   Name of the operation.
@@ -65,6 +78,7 @@ namespace gpu {
            cuGetErrorString(result, &error);
            std::cout << name << " "
                      << result << " " << error << std::endl;
            assert(result == CUDA_SUCCESS && "Cuda Error");
#endif
        }

@@ -87,7 +101,7 @@ namespace gpu {
//------------------------------------------------------------------------------
///  @brief Cuda context constructor.
//------------------------------------------------------------------------------
        cuda_context() {
        cuda_context() : result_buffer(0) {
            check_error(cuDeviceGet(&device, 0), "cuDeviceGet");
            check_error(cuDevicePrimaryCtxRetain(&context, device), "cuDevicePrimaryCtxRetain");
            check_error(cuCtxSetCurrent(context), "cuCtxSetCurrent");
@@ -104,6 +118,10 @@ namespace gpu {
                check_error(cuMemFree(ptr), "cuMemFree");
            }

            check_nvrtc_error(nvrtcDestroyProgram(&kernel_program),
                              "nvrtcDestroyProgram");
            check_error(cuMemFree(result_buffer), "cuMemFree");

            check_error(cuStreamDestroy(stream), "cuStreamDestroy");
            check_error(cuDevicePrimaryCtxRelease(device), "cuDevicePrimaryCtxRelease");
        }
@@ -116,8 +134,8 @@ namespace gpu {
///  @param[in] inputs        Input nodes of the kernel.
///  @param[in] outputs       Output nodes of the kernel.
///  @param[in] num_rays      Number of rays to trace.
///  @param[in] num_times     Number of times to record.
///  @param[in] ray_index     Index of the ray to save.
///  @param[in] add_reduction Optional argument to generate the reduction
///                           kernel.
//------------------------------------------------------------------------------
        template<class BACKEND>
        void create_pipeline(const std::string kernel_source,
@@ -125,14 +143,23 @@ namespace gpu {
                             graph::input_nodes<BACKEND> inputs,
                             graph::output_nodes<BACKEND> outputs,
                             const size_t num_rays,
                             const size_t num_times,
                             const size_t ray_index) {
            nvrtcProgram kernel_program;
            nvrtcCreateProgram(&kernel_program,
                             const bool add_reduction=false) {
//            std::vector<const char *> headers_path({"/opt/nvidia/hpc_sdk/Linux_x86_64/22.5/cuda/11.7/include/cuda/std/"});
//            std::vector<const char *> headers({"complex"});
            check_nvrtc_error(nvrtcCreateProgram(&kernel_program,
                                                 kernel_source.c_str(),
                               NULL, 0, NULL, NULL);
                                                 NULL, 0, NULL, NULL),
                              "nvrtcCreateProgram");

            check_nvrtc_error(nvrtcAddNameExpression(kernel_program,
                                                     kernel_name.c_str()),
                              "nvrtcAddNameExpression");

            nvrtcAddNameExpression(kernel_program, kernel_name.c_str());
            if (add_reduction) {
                check_nvrtc_error(nvrtcAddNameExpression(kernel_program,
                                                         "max_reduction"),
                                  "nvrtcAddNameExpression");
            }

            int compute_version;
            check_error(cuDeviceGetAttribute(&compute_version,
@@ -152,25 +179,29 @@ namespace gpu {
            std::cout << "  Device name              : " << device_name << std::endl;

//  FIXME: Hardcoded for ada gpus for now.
            std::array<const char *, 2> options({
            std::array<const char *, 3> options({
                "--gpu-architecture=compute_80",
                "--std=c++17"
                "--std=c++17",
                "--include-path=/opt/nvidia/hpc_sdk/Linux_x86_64/22.5/cuda/11.7/include"
            });

            if (nvrtcCompileProgram(kernel_program, 2, options.data())) {
            if (nvrtcCompileProgram(kernel_program, options.size(), options.data())) {
                size_t log_size;
                nvrtcGetProgramLogSize(kernel_program, &log_size);
                check_nvrtc_error(nvrtcGetProgramLogSize(kernel_program, &log_size),
                                  "nvrtcGetProgramLogSize");

                char *log = static_cast<char *> (malloc(log_size));
                nvrtcGetProgramLog(kernel_program, log);
                check_nvrtc_error(nvrtcGetProgramLog(kernel_program, log),
                                  "nvrtcGetProgramLog");
                std::cout << log << std::endl;
                free(log);
            }

            const char *mangled_kernel_name;
            nvrtcGetLoweredName(kernel_program,
            check_nvrtc_error(nvrtcGetLoweredName(kernel_program,
                                                  kernel_name.c_str(),
                                &mangled_kernel_name);
                                                  &mangled_kernel_name),
                              "nvrtcGetLoweredName");

            std::cout << "  Mangled Kernel Name      : " << mangled_kernel_name << std::endl;

@@ -180,10 +211,11 @@ namespace gpu {
            std::cout << "  Managed Memory           : " << compute_version << std::endl;

            size_t ptx_size;
            nvrtcGetPTXSize(kernel_program, &ptx_size);
            check_nvrtc_error(nvrtcGetPTXSize(kernel_program, &ptx_size),
                              "nvrtcGetPTXSize");

            char *ptx = static_cast<char *> (malloc(ptx_size));
            nvrtcGetPTX(kernel_program, ptx);
            check_nvrtc_error(nvrtcGetPTX(kernel_program, ptx), "nvrtcGetPTX");

            check_error(cuModuleLoadDataEx(&module, ptx, 0, NULL, NULL), "cuModuleLoadDataEx");
            check_error(cuModuleGetFunction(&function, module, mangled_kernel_name), "cuModuleGetFunction");
@@ -192,22 +224,25 @@ namespace gpu {

            buffers.resize(inputs.size() + outputs.size());

            buffer_element_size = sizeof(typename BACKEND::base);
            buffer_offset = ray_index;
            time_offset = 0;
            result_size = num_times*buffer_element_size;
            const size_t buffer_element_size = sizeof(typename BACKEND::base);
            for (size_t i = 0, ie = inputs.size(); i < ie; i++) {
                const BACKEND backend = inputs[i]->evaluate();

                check_error(cuMemAllocManaged(&buffers[i], backend.size()*buffer_element_size, CU_MEM_ATTACH_GLOBAL), "cuMemAllocManaged");
                check_error(cuMemcpyHtoD(buffers[i], &backend[0], backend.size()*buffer_element_size), "cuMemcpyHtoD");
                check_error(cuMemAllocManaged(&buffers[i], backend.size()*buffer_element_size,
                                              CU_MEM_ATTACH_GLOBAL),
                            "cuMemAllocManaged");
                check_error(cuMemcpyHtoD(buffers[i], &backend[0], backend.size()*buffer_element_size), 
                            "cuMemcpyHtoD");
                kernel_arguments.push_back(reinterpret_cast<void *> (&buffers[i]));
            }
            for (size_t i = inputs.size(), ie = buffers.size(), j = 0; i < ie; i++, j++) {
                const BACKEND backend = outputs[j]->evaluate();

                check_error(cuMemAllocManaged(&buffers[i], backend.size()*buffer_element_size, CU_MEM_ATTACH_GLOBAL), "cuMemAllocManaged");
                check_error(cuMemcpyHtoD(buffers[i], &backend[0], backend.size()*buffer_element_size), "cuMemcpyHtoD");
                check_error(cuMemAllocManaged(&buffers[i], backend.size()*buffer_element_size,
                                              CU_MEM_ATTACH_GLOBAL), 
                            "cuMemAllocManaged");
                check_error(cuMemcpyHtoD(buffers[i], &backend[0], backend.size()*buffer_element_size), 
                                         "cuMemcpyHtoD");
                kernel_arguments.push_back(reinterpret_cast<void *> (&buffers[i]));
            }

@@ -221,19 +256,60 @@ namespace gpu {
            std::cout << "  Total problem size       : " << threads_per_group*thread_groups << std::endl;
        }

//------------------------------------------------------------------------------
///  @brief Create a max compute pipeline.
//------------------------------------------------------------------------------
        template<class BACKEND>
        void create_max_pipeline() {
            const char *mangled_kernel_name;
            check_nvrtc_error(nvrtcGetLoweredName(kernel_program,
                                                  "max_reduction",
                                                  &mangled_kernel_name),
                              "nvrtcGetLoweredName");

            std::cout << "  Mangled Kernel Name      : " << mangled_kernel_name << std::endl;

            check_error(cuMemAllocManaged(&result_buffer, sizeof(typename BACKEND::base),
                                          CU_MEM_ATTACH_GLOBAL),
                        "cuMemAllocManaged");

            max_kernel_arguments.push_back(reinterpret_cast<void *> (&buffers.back()));
            max_kernel_arguments.push_back(reinterpret_cast<void *> (&result_buffer));

            check_error(cuModuleGetFunction(&max_function, module, mangled_kernel_name),
                        "cuModuleGetFunction");
        }

//------------------------------------------------------------------------------
///  @brief Perform a time step.
///
///  This calls dispatches a kernel instance to the command buffer and the
///  commits the job. This method is asynchronous.
//------------------------------------------------------------------------------
        void step() {
        void run() {
            check_error_async(cuLaunchKernel(function, thread_groups, 1, 1,
                                             threads_per_group, 1, 1, 0, stream,
                                             kernel_arguments.data(), NULL),
                              "cuLaunchKernel");
        }

//------------------------------------------------------------------------------
///  @brief Compute the max reduction.
///
///  @returns The maximum value from the input buffer.
//------------------------------------------------------------------------------
        template<class BACKEND>
        typename BACKEND::base max_reduction() {
            run();
            check_error_async(cuLaunchKernel(max_function, 1, 1, 1,
                                             threads_per_group, 1, 1, 0, stream,
                                             max_kernel_arguments.data(), NULL),
                              "cuLaunchKernel");
            wait();

            return reinterpret_cast<typename BACKEND::base *> (result_buffer)[0];
        }

//------------------------------------------------------------------------------
///  @brief Hold the current thread until the stream has completed.
//------------------------------------------------------------------------------
@@ -255,6 +331,20 @@ namespace gpu {
            }
            std::cout << std::endl;
        }

//------------------------------------------------------------------------------
///  @brief Copy buffer contents.
///
///  @param[in]     source_index Index of the GPU buffer.
///  @param[in,out] destination  Host side buffer to copy to.
//------------------------------------------------------------------------------
        template<typename BASE>
        void copy_buffer(const size_t source_index,
			 BASE *destination) {
	    size_t size;
	    check_error(cuMemGetAddressRange(NULL, &size, buffers[source_index]), "cuMemGetAddressRange");
            check_error_async(cuMemcpyDtoHAsync(destination, buffers[source_index], size, stream), "cuMemcpyDtoHAsync");
        }
    };
}

+4 −4
Original line number Diff line number Diff line
@@ -154,7 +154,7 @@ namespace dispersion {
                                                                                               setters);
                source->add_max_reduction(x_var);

                source->compile("loss_kernel", inputs, outputs, x_var->size());
                source->compile("loss_kernel", inputs, outputs, x_var->size(), true);
                source->compile_max();

                max_residule = source->max_reduction();
+67 −22
Original line number Diff line number Diff line
@@ -142,24 +142,64 @@ namespace jit {
            source_buffer << "    uint i [[thread_position_in_grid]]," << std::endl;
            source_buffer << "    uint j [[simdgroup_index_in_threadgroup]]," << std::endl;
            source_buffer << "    uint k [[thread_index_in_simdgroup]]) {" << std::endl;
            
#elif defined(USE_CUDA)
            source_buffer << ") {" << std::endl;
            source_buffer << "    const unsigned int i = threadIdx.x;" << std::endl;
            source_buffer << "    const unsigned int j = threadIdx.x/32;" << std::endl;
            source_buffer << "    const unsigned int k = threadIdx.x%32;" << std::endl;
#endif
            source_buffer << "    if (i < " << input->size() << ") {" << std::endl;
            source_buffer << "        float sub_max = input[i];" << std::endl;
            source_buffer << "        for (size_t index = i + 1024; index < " << input->size() << "; index += 1024) {" << std::endl;
            if constexpr (std::is_same<std::complex<float>, typename BACKEND::base>::value ||
                          std::is_same<std::complex<double>, typename BACKEND::base>:: value) {
                source_buffer << "            sub_max = max(abs(sub_max), abs(input[index]));" << std::endl;
            } else {
                source_buffer << "            sub_max = max(sub_max, input[index]);" << std::endl;
            }
            source_buffer << "        }" << std::endl;

            source_buffer << "        threadgroup float thread_max[32];" << std::endl;
#ifdef USE_METAL
            source_buffer << "        threadgroup ";
#elif defined(USE_CUDA)
            source_buffer << "        __shared__ ";
#endif
            add_type<BACKEND> (source_buffer);
            source_buffer << " thread_max[32];" << std::endl;
#ifdef USE_METAL
            source_buffer << "        thread_max[j] = simd_max(sub_max);" << std::endl;

            source_buffer << "        threadgroup_barrier(mem_flags::mem_threadgroup);" << std::endl;
#elif defined(USE_CUDA)
            source_buffer << "        for (int index = 16; index > 0; index /= 2) {" << std::endl;
            if constexpr (std::is_same<std::complex<float>, typename BACKEND::base>::value ||
                          std::is_same<std::complex<double>, typename BACKEND::base>:: value) {
                source_buffer << "            sub_max = max(abs(sub_max), abs(__shfl_down_sync(__activemask(), sub_max, index)));" << std::endl;
            } else {
                source_buffer << "            sub_max = max(sub_max, __shfl_down_sync(__activemask(), sub_max, index));" << std::endl;
            }
            source_buffer << "        }" << std::endl;
            source_buffer << "        thread_max[j] = sub_max;" << std::endl;

            source_buffer << "        __syncthreads();" << std::endl;
#endif
            source_buffer << "        if (j == 0) {"  << std::endl;
#ifdef USE_METAL
            source_buffer << "            *result = simd_max(thread_max[k]);"  << std::endl;
#elif defined(USE_CUDA)
            source_buffer << "            for (int index = 16; index > 0; index /= 2) {" << std::endl;
            if constexpr (std::is_same<std::complex<float>, typename BACKEND::base>::value ||
                          std::is_same<std::complex<double>, typename BACKEND::base>:: value) {
                source_buffer << "                thread_max[k] = max(abs(thread_max[k]), abs(__shfl_down_sync(__activemask(), thread_max[k], index)));" << std::endl;
            } else {
                source_buffer << "                thread_max[k] = max(thread_max[k], __shfl_down_sync(__activemask(), thread_max[k], index));" << std::endl;
            }
            source_buffer << "            }" << std::endl;
            source_buffer << "            *result = thread_max[0];" << std::endl;
#endif
            source_buffer << "        }"  << std::endl;
            source_buffer << "    }"  << std::endl;
            source_buffer << "}" << std::endl << std::endl;
#endif
        }

//------------------------------------------------------------------------------
@@ -184,6 +224,8 @@ namespace jit {
            source_buffer << "using namespace metal;" << std::endl
                          << "kernel ";
#else
            source_buffer << "#include <cuda/std/complex>" << std::endl;
//            source_buffer << "#include <complex>" << std::endl;
            source_buffer << "extern \"C\" __global__ ";
#endif
            source_buffer << "void " << name << "(";
@@ -217,12 +259,11 @@ namespace jit {
#ifdef USE_METAL
            source_buffer << "device ";
#endif
            add_type<graph::leaf_node<BACKEND>> (source_buffer);
            add_type<BACKEND> (source_buffer);
            source_buffer << " *" << name;
#ifdef USE_METAL
            source_buffer << " [[buffer("<< index <<")]]";
#endif
            
        }

//------------------------------------------------------------------------------
@@ -255,7 +296,7 @@ namespace jit {
        void load_variable(graph::variable_node<BACKEND> *pointer) {
            registers[pointer] = to_string('r', pointer);
            source_buffer << "        const ";
            add_type<graph::leaf_node<BACKEND>> (source_buffer);
            add_type<BACKEND> (source_buffer);
            source_buffer << " " << registers[pointer] << " = "
                          << to_string('v', pointer) << "[index];"
                          << std::endl;
@@ -299,14 +340,18 @@ namespace jit {
///  @param[in] inputs        Input variables of the kernel.
///  @param[in] outputs       Output nodes to calculate results of.
///  @param[in] num_rays      Number of rays.
///  @param[in] add_reduction Optional argument to generate the reduction
///                           kernel.
//------------------------------------------------------------------------------
        void compile(const std::string name,
                     graph::input_nodes<BACKEND> inputs,
                     graph::output_nodes<BACKEND> outputs,
                     const size_t num_rays) {
                     const size_t num_rays,
                     const bool add_reduction=false) {
#ifdef USE_GPU
            context.create_pipeline(source_buffer.str(), name,
                                    inputs, outputs, num_rays);
                                    inputs, outputs, num_rays, 
                                    add_reduction);
#endif
        }

+12 −12
Original line number Diff line number Diff line
@@ -130,7 +130,7 @@ namespace graph {
            if (registers.find(this) == registers.end()) {
                registers[this] = jit::to_string('r', this);
                stream << "        const ";
                jit::add_type<N> (stream);
                jit::add_type<typename N::backend> (stream);
                stream << " " << registers[this] << " = sqrt("
                       << registers[a.get()] << ");"
                       << std::endl;
@@ -270,7 +270,7 @@ namespace graph {
            if (registers.find(this) == registers.end()) {
                registers[this] = jit::to_string('r', this);
                stream << "        const ";
                jit::add_type<N> (stream);
                jit::add_type<typename N::backend> (stream);
                stream << " " << registers[this] << " = exp("
                       << registers[a.get()] << ");"
                       << std::endl;
@@ -406,7 +406,7 @@ namespace graph {
            if (registers.find(this) == registers.end()) {
                registers[this] = jit::to_string('r', this);
                stream << "        const ";
                jit::add_type<N> (stream);
                jit::add_type<typename N::backend> (stream);
                stream << " " << registers[this] << " = log("
                       << registers[a.get()] << ");"
                       << std::endl;
@@ -605,7 +605,7 @@ namespace graph {
            if (registers.find(this) == registers.end()) {
                registers[this] = jit::to_string('r', this);
                stream << "        const ";
                jit::add_type<LN> (stream);
                jit::add_type<typename LN::backend> (stream);
                stream << " " << registers[this] << " = pow("
                       << registers[l.get()] << ", "
                       << registers[r.get()] << ");"
Loading