Commit b68b510c authored by Cianciosa, Mark's avatar Cianciosa, Mark
Browse files

Add update kernel prefix for match cuda and metal backends.

parent 661e6c1c
Loading
Loading
Loading
Loading
+40 −13
Original line number Diff line number Diff line
@@ -242,7 +242,7 @@ namespace gpu {
                        "hipModuleGetFunction");

            std::vector<void *> buffers;
            std::set<graph::leaf_node<float, SAFE_MATH> *> needed_buffers;
            std::set<graph::leaf_node<T, SAFE_MATH> *> needed_buffers;
            
            for (auto &input : inputs) {
                if (!kernel_arguments.contains(input.get())) {
@@ -495,27 +495,54 @@ namespace gpu {
                                  const std::string name,
                                  graph::input_nodes<T, SAFE_MATH> &inputs,
                                  graph::output_nodes<T, SAFE_MATH> &outputs,
                                  graph::shared_random_state<T, SAFE_MATH> state,
                                  const size_t size,
                                  jit::register_map &registers) {
                                  const std::vector<bool> &is_constant,
                                  jit::register_map &registers,
                                  jit::texture1d_list &textures1d,
                                  jit::texture2d_list &textures2d) {
            source_buffer << std::endl;
            source_buffer << "extern \"C\" __global__ void " << name << "("
                          << std::endl;

            std::unordered_set<void *> used_args;
            if (inputs.size()) {
                source_buffer << "    ";
                if (is_constant[0]) {
                    source_buffer << "const ";
                }
                jit::add_type<T> (source_buffer);
                source_buffer << " *" << jit::to_string('v', inputs[0].get());
                used_args.insert(inputs[0].get());
            }
            for (size_t i = 1, ie = inputs.size(); i < ie; i++) {
                if (!used_args.contains(inputs[i].get())) {
                    source_buffer << "," << std::endl;
                    source_buffer << "    ";
                    if (is_constant[i]) {
                        source_buffer << "const ";
                    }
                    jit::add_type<T> (source_buffer);
                    source_buffer << " *" << jit::to_string('v', inputs[i].get());
                    used_args.insert(inputs[i].get());
                }
            }

            for (size_t i = 0, ie = outputs.size(); i < ie; i++) {
                if (!used_args.contains(outputs[i].get())) {
                    source_buffer << "," << std::endl;
                    source_buffer << "    ";
                    jit::add_type<T> (source_buffer);
                    source_buffer << " *" << jit::to_string('o', outputs[i].get());
                    used_args.insert(outputs[i].get());
                }
            }
            if (state.get()) {
                source_buffer << "," << std::endl
                              << "    mt_state * __restrict__ "
                              << jit::to_string('s', state.get())
                              << "," << std::endl
                              << "    const uint32_t *offset"
                              << std::endl;
            }
            source_buffer << ") {" << std::endl;