Commit 600d1957 authored by Cianciosa, Mark's avatar Cianciosa, Mark
Browse files

After the inital reduction treat everything as float or double when using complex types.

parent 7bca5f2c
Loading
Loading
Loading
Loading
+2 −7
Original line number Diff line number Diff line
@@ -163,8 +163,7 @@ namespace jit {
#elif defined(USE_CUDA)
            source_buffer << "        __shared__ ";
#endif
            add_type<BACKEND> (source_buffer);
            source_buffer << " thread_max[32];" << std::endl;
            source_buffer << jit::type_to_string<typename BACKEND::base> () << " thread_max[32];" << std::endl;
#ifdef USE_METAL
            source_buffer << "        thread_max[j] = simd_max(sub_max);" << std::endl;

@@ -186,11 +185,7 @@ namespace jit {
            source_buffer << "            *result = simd_max(thread_max[k]);"  << std::endl;
#elif defined(USE_CUDA)
            source_buffer << "            for (int index = 16; index > 0; index /= 2) {" << std::endl;
            if constexpr (jit::is_complex<typename BACKEND::base> ()) {
                source_buffer << "                thread_max[k] = max(abs(thread_max[k]), abs(__shfl_down_sync(__activemask(), thread_max[k], index)));" << std::endl;
            } else {
            source_buffer << "                thread_max[k] = max(thread_max[k], __shfl_down_sync(__activemask(), thread_max[k], index));" << std::endl;
            }
            source_buffer << "            }" << std::endl;
            source_buffer << "            *result = thread_max[0];" << std::endl;
#endif