Commit 042f98bb authored by cianciosa's avatar cianciosa Committed by Cianciosa, Mark
Browse files

Inital support for complex metal kernels.

parent 513bea6a
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -26,6 +26,7 @@ target_compile_definitions (graph_framework
                            $<$<BOOL:${SHOW_USE_COUNT}>:SHOW_USE_COUNT>
                            $<$<BOOL:${USE_INDEX_CACHE}>:USE_INDEX_CACHE>
                            $<IF:$<BOOL:${USE_VERBOSE}>,USE_VERBOSE=true,USE_VERBOSE=false>
                            $<$<BOOL:${USE_METAL}>:HEADER_DIR="$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}>">
)

target_include_directories (graph_framework
+4 −2
Original line number Diff line number Diff line
@@ -972,11 +972,13 @@ namespace gpu {
                source_buffer << "input[i];" << std::endl;
            }
            source_buffer << "        for (size_t index = i + 1024; index < " << size <<"; index += 1024) {" << std::endl;
            source_buffer << "            sub_max = max(sub_max, ";
            if constexpr (jit::complex_scalar<T>) {
                source_buffer << "            sub_max = max(sub_max, abs(input[index]));" << std::endl;
                source_buffer << "abs(input[index]";
            } else {
                source_buffer << "            sub_max = max(sub_max, input[index]);" << std::endl;
                source_buffer << "input[index]";
            }
            source_buffer << ");" << std::endl;
            source_buffer << "        }" << std::endl;
            source_buffer << "        __shared__ " << jit::type_to_string<T> () << " thread_max[32];" << std::endl;
            source_buffer << "        for (int index = 16; index > 0; index /= 2) {" << std::endl;
+1 −1
Original line number Diff line number Diff line
@@ -64,7 +64,7 @@ namespace jit {
#ifdef USE_CUDA
                                                           gpu::cuda_context<T, SAFE_MATH>,
#elif defined(USE_METAL)
                                                           gpu::metal_context<SAFE_MATH>,
                                                           gpu::metal_context<T, SAFE_MATH>,
#else
                                                           gpu::cpu_context<T, SAFE_MATH>,
#endif
+74 −24
Original line number Diff line number Diff line
@@ -9,6 +9,7 @@
#define metal_context_h

#include <unordered_set>
#include <stdlib.h>

#import <Metal/Metal.h>

@@ -19,9 +20,10 @@ namespace gpu {
//------------------------------------------------------------------------------
///  @brief Class representing a metal gpu context.
///
///  @tparam T         Base type of the calculation.
///  @tparam SAFE_MATH Use @ref general_concepts_safe_math operations.
//------------------------------------------------------------------------------
    template<bool SAFE_MATH=false>
    template<jit::float_scalar T, bool SAFE_MATH=false>
    class metal_context {
    private:
///  The metal device.
@@ -82,6 +84,7 @@ namespace gpu {
                     std::vector<std::string> names,
                     const bool add_reduction=false) {
            NSError *error;
            setenv("MTL_HEADER_SEARCH_PATHS", HEADER_DIR, 1);
            library = [device newLibraryWithSource:[NSString stringWithCString:kernel_source.c_str()
                                                                      encoding:NSUTF8StringEncoding]
                                           options:compile_options()
@@ -141,9 +144,9 @@ namespace gpu {
            std::set<graph::leaf_node<float, SAFE_MATH> *> needed_buffers;

            const size_t buffer_element_size = sizeof(float);
            for (graph::shared_variable<float, SAFE_MATH> &input : inputs) {
            for (graph::shared_variable<T, SAFE_MATH> &input : inputs) {
                if (!kernel_arguments.contains(input.get())) {
                    backend::buffer<float> buffer = input->evaluate();
                    backend::buffer<T> buffer = input->evaluate();
                    kernel_arguments[input.get()] = [device newBufferWithBytes:buffer.data()
                                                                        length:buffer.size()*buffer_element_size
                                                                       options:MTLResourceStorageModeShared];
@@ -155,7 +158,7 @@ namespace gpu {
                    needed_buffers.insert(input.get());
                }
            }
            for (graph::shared_leaf<float, SAFE_MATH> &output : outputs) {
            for (graph::shared_leaf<T, SAFE_MATH> &output : outputs) {
                if (!kernel_arguments.contains(output.get())) {
                    kernel_arguments[output.get()] = [device newBufferWithLength:num_rays*sizeof(float)
                                                                         options:MTLResourceStorageModeShared];
@@ -296,7 +299,7 @@ namespace gpu {
///  @param[in] run      Function to run before reduction.
///  @returns A lambda function to run the kernel.
//------------------------------------------------------------------------------
        std::function<float(void)> create_max_call(graph::shared_leaf<float, SAFE_MATH> &argument,
        std::function<T(void)> create_max_call(graph::shared_leaf<T, SAFE_MATH> &argument,
                                               std::function<void(void)> run) {
            MTLComputePipelineDescriptor *compute = [MTLComputePipelineDescriptor new];
            compute.threadGroupSizeIsMultipleOfThreadExecutionWidth = YES;
@@ -375,7 +378,7 @@ namespace gpu {
///  @param[in] nodes Nodes to output.
//------------------------------------------------------------------------------
        void print_results(const size_t index,
                           const graph::output_nodes<float, SAFE_MATH> &nodes) {
                           const graph::output_nodes<T, SAFE_MATH> &nodes) {
            wait();
            for (auto &out : nodes) {
                std::cout << static_cast<float *> ([kernel_arguments[out.get()] contents])[index] << " ";
@@ -390,10 +393,10 @@ namespace gpu {
///  @param[in] node  Node to check the value for.
///  @returns The value at the index.
//------------------------------------------------------------------------------
        float check_value(const size_t index,
                          const graph::shared_leaf<float, SAFE_MATH> &node) {
        T check_value(const size_t index,
                      const graph::shared_leaf<T, SAFE_MATH> &node) {
            wait();
            return static_cast<float *> ([kernel_arguments[node.get()] contents])[index];
            return static_cast<T *> ([kernel_arguments[node.get()] contents])[index];
        }

//------------------------------------------------------------------------------
@@ -402,8 +405,8 @@ namespace gpu {
///  @param[in] node   Not to copy buffer to.
///  @param[in] source Host side buffer to copy from.
//------------------------------------------------------------------------------
        void copy_to_device(graph::shared_leaf<float, SAFE_MATH> node,
                            float *source) {
        void copy_to_device(graph::shared_leaf<T, SAFE_MATH> node,
                            T *source) {
            const size_t size = [kernel_arguments[node.get()] length];
            memcpy([kernel_arguments[node.get()] contents],
                   source, size);
@@ -415,8 +418,8 @@ namespace gpu {
///  @param[in]     node        Node to copy buffer from.
///  @param[in,out] destination Host side buffer to copy to.
//------------------------------------------------------------------------------
        void copy_to_host(graph::shared_leaf<float, SAFE_MATH> node,
                          float *destination) {
        void copy_to_host(graph::shared_leaf<T, SAFE_MATH> node,
                          T *destination) {
            command_buffer = [queue commandBuffer];

            [command_buffer commit];
@@ -436,6 +439,10 @@ namespace gpu {
            source_buffer << "#include <metal_stdlib>" << std::endl;
            source_buffer << "#include <metal_simdgroup>" << std::endl;
            source_buffer << "using namespace metal;" << std::endl;
            if constexpr (jit::complex_scalar<T>) {
                source_buffer << "#define METAL_DEVICE_CODE" << std::endl;
                source_buffer << "#include <special_functions.hpp>" << std::endl;
            }
        }

//------------------------------------------------------------------------------
@@ -595,8 +602,22 @@ namespace gpu {
                                  << jit::to_string('v',  in.get())
                                  << "[index] = ";
                    if constexpr (SAFE_MATH) {
                        if constexpr (jit::complex_scalar<T>) {
                            jit::add_type<T> (source_buffer);
                            source_buffer << " (";
                            source_buffer << "isnan(real(" << registers[a.get()]
                                          << ")) ? 0.0 : real(" << registers[a.get()]
                                          << "), ";
                            source_buffer << "isnan(imag(" << registers[a.get()]
                                          << ")) ? 0.0 : imag(" << registers[a.get()]
                                          << "));" << std::endl;
                        } else {
                            source_buffer << "isnan(" << registers[a.get()]
                                      << ") ? 0.0 : ";
                                          << ") ? 0.0 : " << registers[a.get()]
                                          << ";" << std::endl;
                        }
                    } else {
                        source_buffer << registers[a.get()] << ";" << std::endl;
                    }
                    source_buffer << registers[a.get()] << ";" << std::endl;
                    out_registers.insert(out.get());
@@ -613,8 +634,22 @@ namespace gpu {
                    source_buffer << "        " << jit::to_string('o',  out.get())
                                  << "[index] = ";
                    if constexpr (SAFE_MATH) {
                        if constexpr (jit::complex_scalar<T>) {
                            jit::add_type<T> (source_buffer);
                            source_buffer << " (";
                            source_buffer << "isnan(real(" << registers[a.get()]
                                          << ")) ? 0.0 : real(" << registers[a.get()]
                                          << "), ";
                            source_buffer << "isnan(imag(" << registers[a.get()]
                                          << ")) ? 0.0 : imag(" << registers[a.get()]
                                          << "));" << std::endl;
                        } else {
                            source_buffer << "isnan(" << registers[a.get()]
                                      << ") ? 0.0 : ";
                                          << ") ? 0.0 : " << registers[a.get()]
                                          << ";" << std::endl;
                        }
                    } else {
                        source_buffer << registers[a.get()] << ";" << std::endl;
                    }
                    source_buffer << registers[a.get()] << ";" << std::endl;
                    out_registers.insert(out.get());
@@ -634,15 +669,30 @@ namespace gpu {
                              const size_t size) {
            source_buffer << std::endl;
            source_buffer << "kernel void max_reduction(" << std::endl;
            source_buffer << "    constant float *input [[buffer(0)]]," << std::endl;
            source_buffer << "    device float *result [[buffer(1)]]," << std::endl;
            source_buffer << "    constant ";
            jit::add_type<T> (source_buffer);
            source_buffer << " *input [[buffer(0)]]," << std::endl;
            source_buffer << "    device ";
            jit::add_type<T> (source_buffer);
            source_buffer << " *result [[buffer(1)]]," << std::endl;
            source_buffer << "    uint i [[thread_position_in_grid]]," << std::endl;
            source_buffer << "    uint j [[simdgroup_index_in_threadgroup]]," << std::endl;
            source_buffer << "    uint k [[thread_index_in_simdgroup]]) {" << std::endl;
            source_buffer << "    if (i < " << size << ") {" << std::endl;
            source_buffer << "        float sub_max = input[i];" << std::endl;
            source_buffer << "        " << jit::type_to_string<T> () << " sub_max = ";
            if constexpr (jit::complex_scalar<T>) {
                source_buffer << "abs(input[i]);" << std::endl;
            } else {
                source_buffer << "input[i];" << std::endl;
            }
            source_buffer << "        for (size_t index = i + 1024; index < " << size <<"; index += 1024) {" << std::endl;
            source_buffer << "            sub_max = max(sub_max, input[index]);" << std::endl;
            source_buffer << "            sub_max = max(sub_max, ";
            if constexpr (jit::complex_scalar<T>) {
                source_buffer << "abs(input[index]";
            } else {
                source_buffer << "input[index]";
            }
            source_buffer << ");" << std::endl;
            source_buffer << "        }" << std::endl;
            source_buffer << "        threadgroup float thread_max[32];" << std::endl;
            source_buffer << "        thread_max[j] = simd_max(sub_max);" << std::endl;
@@ -659,8 +709,8 @@ namespace gpu {
///
///  @param[in] node Node to get the buffer for.
//------------------------------------------------------------------------------
        float *get_buffer(graph::shared_leaf<float, SAFE_MATH> &node) {
            return static_cast<float *> ([kernel_arguments[node.get()] contents]);
        T *get_buffer(graph::shared_leaf<T, SAFE_MATH> &node) {
            return static_cast<T *> ([kernel_arguments[node.get()] contents]);
        }
    };
}
+251 −57

File changed.

Preview size limit exceeded, changes collapsed.