Loading graph_framework/jit.hpp +2 −7 Original line number Diff line number Diff line Loading @@ -163,8 +163,7 @@ namespace jit { #elif defined(USE_CUDA) source_buffer << " __shared__ "; #endif add_type<BACKEND> (source_buffer); source_buffer << " thread_max[32];" << std::endl; source_buffer << jit::type_to_string<typename BACKEND::base> () << " thread_max[32];" << std::endl; #ifdef USE_METAL source_buffer << " thread_max[j] = simd_max(sub_max);" << std::endl; Loading @@ -186,11 +185,7 @@ namespace jit { source_buffer << " *result = simd_max(thread_max[k]);" << std::endl; #elif defined(USE_CUDA) source_buffer << " for (int index = 16; index > 0; index /= 2) {" << std::endl; if constexpr (jit::is_complex<typename BACKEND::base> ()) { source_buffer << " thread_max[k] = max(abs(thread_max[k]), abs(__shfl_down_sync(__activemask(), thread_max[k], index)));" << std::endl; } else { source_buffer << " thread_max[k] = max(thread_max[k], __shfl_down_sync(__activemask(), thread_max[k], index));" << std::endl; } source_buffer << " }" << std::endl; source_buffer << " *result = thread_max[0];" << std::endl; #endif Loading Loading
graph_framework/jit.hpp +2 −7 Original line number Diff line number Diff line Loading @@ -163,8 +163,7 @@ namespace jit { #elif defined(USE_CUDA) source_buffer << " __shared__ "; #endif add_type<BACKEND> (source_buffer); source_buffer << " thread_max[32];" << std::endl; source_buffer << jit::type_to_string<typename BACKEND::base> () << " thread_max[32];" << std::endl; #ifdef USE_METAL source_buffer << " thread_max[j] = simd_max(sub_max);" << std::endl; Loading @@ -186,11 +185,7 @@ namespace jit { source_buffer << " *result = simd_max(thread_max[k]);" << std::endl; #elif defined(USE_CUDA) source_buffer << " for (int index = 16; index > 0; index /= 2) {" << std::endl; if constexpr (jit::is_complex<typename BACKEND::base> ()) { source_buffer << " thread_max[k] = max(abs(thread_max[k]), abs(__shfl_down_sync(__activemask(), thread_max[k], index)));" << std::endl; } else { source_buffer << " thread_max[k] = max(thread_max[k], __shfl_down_sync(__activemask(), thread_max[k], index));" << std::endl; } source_buffer << " }" << std::endl; source_buffer << " *result = thread_max[0];" << std::endl; #endif Loading