Commit 7bca5f2c authored by Cianciosa, Mark's avatar Cianciosa, Mark
Browse files

Refactor to add functions to check if types are complex and deduce the basic types.

parent 25b1ac85
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -9,6 +9,7 @@
#define arithmetic_h

#include "node.hpp"
#include "register.hpp"

namespace graph {
//------------------------------------------------------------------------------
@@ -1545,8 +1546,7 @@ namespace graph {
                stream << "        const ";
                jit::add_type<typename LN::backend> (stream);
                stream << " " << registers[this] << " = ";
                if constexpr (std::is_same<std::complex<float>, typename LN::backend::base>::value ||
                              std::is_same<std::complex<double>, typename LN::backend::base>::value) {
                if constexpr (jit::is_complex<typename LN::backend::base> ()) {
                    stream << registers[l.get()] << "*"
                           << registers[m.get()] << " + "
                           << registers[r.get()] << ";"
+2 −2
Original line number Diff line number Diff line
@@ -14,6 +14,7 @@
#include <complex>

#include "backend_protocall.hpp"
#include "register.hpp"

namespace backend {
//******************************************************************************
@@ -114,8 +115,7 @@ namespace backend {
///  @returns The maximum value.
//------------------------------------------------------------------------------
        virtual BASE max() const final {
            if constexpr (std::is_same<BASE, std::complex<float>>::value ||
                          std::is_same<BASE, std::complex<double>>::value) {
            if constexpr (jit::is_complex<BASE> ()) {
                return *std::max_element(buffer.cbegin(), buffer.cend(),
                                         [] (const BASE a, const BASE b) {
                    return std::abs(a) < std::abs(b);
+1 −3
Original line number Diff line number Diff line
@@ -144,8 +144,6 @@ namespace gpu {
                             graph::output_nodes<BACKEND> outputs,
                             const size_t num_rays,
                             const bool add_reduction=false) {
//            std::vector<const char *> headers_path({"/opt/nvidia/hpc_sdk/Linux_x86_64/22.5/cuda/11.7/include/cuda/std/"});
//            std::vector<const char *> headers({"complex"});
            check_nvrtc_error(nvrtcCreateProgram(&kernel_program,
                                                 kernel_source.c_str(),
                                                 NULL, 0, NULL, NULL),
@@ -178,7 +176,7 @@ namespace gpu {
            check_error(cuDeviceGetName(device_name, 100, device), "cuDeviceGetName");
            std::cout << "  Device name              : " << device_name << std::endl;

//  FIXME: Hardcoded for ada gpus for now.
//  FIXME: Hardcoded for ada gpus for now. Also hardcoded for perlmutter.
            std::array<const char *, 3> options({
                "--gpu-architecture=compute_80",
                "--std=c++17",
+4 −7
Original line number Diff line number Diff line
@@ -149,10 +149,9 @@ namespace jit {
            source_buffer << "    const unsigned int k = threadIdx.x%32;" << std::endl;
#endif
            source_buffer << "    if (i < " << input->size() << ") {" << std::endl;
            source_buffer << "        float sub_max = input[i];" << std::endl;
            source_buffer << "        " << jit::type_to_string<typename BACKEND::base> () << " sub_max = input[i];" << std::endl;
            source_buffer << "        for (size_t index = i + 1024; index < " << input->size() << "; index += 1024) {" << std::endl;
            if constexpr (std::is_same<std::complex<float>, typename BACKEND::base>::value ||
                          std::is_same<std::complex<double>, typename BACKEND::base>:: value) {
            if constexpr (jit::is_complex<typename BACKEND::base> ()) {
                source_buffer << "            sub_max = max(abs(sub_max), abs(input[index]));" << std::endl;
            } else {
                source_buffer << "            sub_max = max(sub_max, input[index]);" << std::endl;
@@ -172,8 +171,7 @@ namespace jit {
            source_buffer << "        threadgroup_barrier(mem_flags::mem_threadgroup);" << std::endl;
#elif defined(USE_CUDA)
            source_buffer << "        for (int index = 16; index > 0; index /= 2) {" << std::endl;
            if constexpr (std::is_same<std::complex<float>, typename BACKEND::base>::value ||
                          std::is_same<std::complex<double>, typename BACKEND::base>:: value) {
            if constexpr (jit::is_complex<typename BACKEND::base> ()) {
                source_buffer << "            sub_max = max(abs(sub_max), abs(__shfl_down_sync(__activemask(), sub_max, index)));" << std::endl;
            } else {
                source_buffer << "            sub_max = max(sub_max, __shfl_down_sync(__activemask(), sub_max, index));" << std::endl;
@@ -188,8 +186,7 @@ namespace jit {
            source_buffer << "            *result = simd_max(thread_max[k]);"  << std::endl;
#elif defined(USE_CUDA)
            source_buffer << "            for (int index = 16; index > 0; index /= 2) {" << std::endl;
            if constexpr (std::is_same<std::complex<float>, typename BACKEND::base>::value ||
                          std::is_same<std::complex<double>, typename BACKEND::base>:: value) {
            if constexpr (jit::is_complex<typename BACKEND::base> ()) {
                source_buffer << "                thread_max[k] = max(abs(thread_max[k]), abs(__shfl_down_sync(__activemask(), thread_max[k], index)));" << std::endl;
            } else {
                source_buffer << "                thread_max[k] = max(thread_max[k], __shfl_down_sync(__activemask(), thread_max[k], index));" << std::endl;
+0 −1
Original line number Diff line number Diff line
@@ -99,7 +99,6 @@ namespace gpu {
                }

                const size_t buffer_element_size = sizeof(typename BACKEND::base);
                time_offset = 0;
                for (graph::shared_variable<BACKEND> &input : inputs) {
                    BACKEND buffer = input->evaluate();
                    buffers.push_back([device newBufferWithBytes:buffer.data()
Loading