Loading megatron/fused_kernels/scaled_masked_softmax.h +80 −42 Original line number Diff line number Diff line Loading @@ -26,6 +26,21 @@ namespace { template <typename Datatype, int ELEMENTS_PER_LDG> __device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src); template <> __device__ __inline__ void copy_vector<__half, 1>(__half *dst, const __half *src) { *dst = *src; } template <> __device__ __inline__ void copy_vector<__half, 4>(__half *dst, const __half *src) { *((float2*) dst) = *((float2*) src); } template <> __device__ __inline__ void copy_vector<uint8_t, 1>(uint8_t *dst, const uint8_t *src) { *dst = *src; } template <> __device__ __inline__ void copy_vector<uint8_t, 4>(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); } int log2_ceil(int value) { int log2_value = 0; while ((1 << log2_value) < value) ++log2_value; Loading Loading @@ -90,6 +105,7 @@ __global__ void scaled_masked_softmax_warp_forward( constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; constexpr int ELEMENTS_PER_LDG_STG = 4; // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) // gridDim/blockIdx = (seq_len, attn_heads, batches) Loading @@ -110,29 +126,40 @@ __global__ void scaled_masked_softmax_warp_forward( // there might be multiple batches per warp. compute the index within the batch int local_idx = threadIdx.x; src += first_batch * element_count + local_idx; dst += first_batch * element_count + local_idx; mask += pad_first_batch * element_count + local_idx; src += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; dst += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; mask += pad_first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; // load data from global memory acc_t elements[WARP_BATCH][WARP_ITERATIONS]; input_t temp_data[ELEMENTS_PER_LDG_STG]; uint8_t temp_mask[ELEMENTS_PER_LDG_STG]; #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { int batch_element_count = (i >= local_batches) ? 0 : element_count; #pragma unroll for (int it = 0; it < WARP_ITERATIONS; ++it) { int element_index = local_idx + it * WARP_SIZE; int itr_idx = i*element_count+it*WARP_SIZE; for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; if (element_index < batch_element_count) { if (mask[itr_idx] != 1) { elements[i][it] = (acc_t)src[itr_idx] * scale; int itr_idx = i*element_count+it*WARP_SIZE; copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_data, src + itr_idx); copy_vector<uint8_t, ELEMENTS_PER_LDG_STG>(temp_mask, mask + itr_idx); #pragma unroll for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { if (temp_mask[element] != 1) { elements[i][it + element] = (acc_t)temp_data[element] * scale; } else { elements[i][it] = -10000.0; elements[i][it + element] = -10000.0; } } } else { elements[i][it] = -std::numeric_limits<acc_t>::infinity(); #pragma unroll for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { elements[i][it + element] = -std::numeric_limits<acc_t>::infinity(); } } } } Loading Loading @@ -161,15 +188,20 @@ __global__ void scaled_masked_softmax_warp_forward( warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum); // store result output_t out[ELEMENTS_PER_LDG_STG]; #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { if (i >= local_batches) break; #pragma unroll for (int it = 0; it < WARP_ITERATIONS; ++it) { int element_index = local_idx + it * WARP_SIZE; for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; if (element_index < element_count) { dst[i*element_count+it*WARP_SIZE] = (output_t)(elements[i][it] / sum[i]); #pragma unroll for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { out[element] = elements[i][it + element] / sum[i]; } copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count + it * WARP_SIZE, out); } else { break; } Loading @@ -192,6 +224,7 @@ __global__ void scaled_masked_softmax_warp_backward( constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; constexpr int ELEMENTS_PER_LDG_STG = 4; // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) // gridDim/blockIdx = (seq_len, attn_heads, batches) Loading @@ -207,35 +240,35 @@ __global__ void scaled_masked_softmax_warp_backward( int local_idx = threadIdx.x; // the first element to process by the current thread int thread_offset = first_batch * element_count + local_idx; int thread_offset = first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; grad += thread_offset; output += thread_offset; gradInput += thread_offset; // load data from global memory acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; acc_t output_reg[WARP_BATCH][WARP_ITERATIONS]; acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; input_t temp_grad[ELEMENTS_PER_LDG_STG]; input_t temp_output[ELEMENTS_PER_LDG_STG]; #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { int batch_element_count = (i >= local_batches) ? 0 : element_count; #pragma unroll for (int it = 0; it < WARP_ITERATIONS; ++it) { int element_index = local_idx + it * WARP_SIZE; for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; if (element_index < batch_element_count) { output_reg[i][it] = output[i*element_count+it*WARP_SIZE]; } else { output_reg[i][it] = acc_t(0); } } copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_grad, grad + i * element_count + it * WARP_SIZE); copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_output, output + i * element_count + it * WARP_SIZE); #pragma unroll for (int it = 0; it < WARP_ITERATIONS; ++it) { int element_index = local_idx + it * WARP_SIZE; if (element_index < batch_element_count) { grad_reg[i][it] = (acc_t)grad[i*element_count+it*WARP_SIZE] * output_reg[i][it]; } else { grad_reg[i][it] = acc_t(0); for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { output_reg[i][it + element] = (acc_t)temp_output[element]; } #pragma unroll for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element]; } } } } Loading @@ -257,11 +290,16 @@ __global__ void scaled_masked_softmax_warp_backward( if (i >= local_batches) break; #pragma unroll for (int it = 0; it < WARP_ITERATIONS; ++it) { int element_index = local_idx + it * WARP_SIZE; for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; if (element_index < element_count) { // compute gradients gradInput[i*element_count+it*WARP_SIZE] = (output_t)(scale * (grad_reg[i][it] - output_reg[i][it] * sum[i])); output_t out[ELEMENTS_PER_LDG_STG]; #pragma unroll for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { out[element] = (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i])); } copy_vector<output_t, ELEMENTS_PER_LDG_STG>(gradInput + i * element_count + it * WARP_SIZE, out); } } } Loading megatron/fused_kernels/scaled_upper_triang_masked_softmax.h +97 −36 Original line number Diff line number Diff line Loading @@ -26,6 +26,31 @@ namespace { template <typename Datatype, int ELEMENTS_PER_LDG> __device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src); template <> __device__ __inline__ void copy_vector<__half, 1>(__half *dst, const __half *src) { *dst = *src; } template <> __device__ __inline__ void copy_vector<__half, 4>(__half *dst, const __half *src) { *((float2*) dst) = *((float2*) src); } template <> __device__ __inline__ void copy_vector<uint8_t, 1>(uint8_t *dst, const uint8_t *src) { *dst = *src; } template <> __device__ __inline__ void copy_vector<uint8_t, 4>(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); } template <typename Datatype, int ELEMENTS_PER_LDG> __device__ __inline__ void copy_zero_vector(Datatype *dst); template <> __device__ __inline__ void copy_zero_vector<__half, 1>(__half *dst) { *dst = 0.0; } template <> __device__ __inline__ void copy_zero_vector<__half, 4>(__half *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); } int log2_ceil(int value) { int log2_value = 0; while ((1 << log2_value) < value) ++log2_value; Loading Loading @@ -89,10 +114,11 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward( constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; constexpr int ELEMENTS_PER_LDG_STG = 4; int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x; int local_seq = blockIdx.x + 1; int warp_iteration_limit = (local_seq + WARP_SIZE - 1)/WARP_SIZE; int warp_iteration_limit = (local_seq + ELEMENTS_PER_LDG_STG * WARP_SIZE - 1)/ WARP_SIZE; // micro_batch_size might not be a multiple of WARP_BATCH. Check how // many batches have to computed within this WARP. Loading @@ -103,22 +129,36 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward( // there might be multiple batches per warp. compute the index within the batch int local_idx = threadIdx.x; src += first_batch * stride + local_idx; dst += first_batch * stride + local_idx; src += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; dst += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; // load data from global memory acc_t elements[WARP_BATCH][WARP_ITERATIONS]; input_t temp_data[ELEMENTS_PER_LDG_STG]; #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { int batch_element_count = (i >= local_batches) ? 0 : local_seq; #pragma unroll for (int it = 0; it < WARP_ITERATIONS; ++it) { int element_index = local_idx + it * WARP_SIZE; for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; if (element_index < batch_element_count) { elements[i][it] = (acc_t)src[i*element_count*stride+it*WARP_SIZE] * scale; copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_data, src + i*element_count*stride + it*WARP_SIZE); #pragma unroll for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { if ((element_index + element) < batch_element_count) { elements[i][it+element] = (acc_t)temp_data[element] * scale; } else { elements[i][it] = -std::numeric_limits<acc_t>::infinity(); elements[i][it + element] = -std::numeric_limits<acc_t>::infinity(); } } } else { #pragma unroll for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { elements[i][it + element] = -std::numeric_limits<acc_t>::infinity(); } } } } Loading Loading @@ -149,17 +189,28 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward( warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum); // store result output_t out[ELEMENTS_PER_LDG_STG]; #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { if (i >= local_batches) break; #pragma unroll for (int it = 0; it < WARP_ITERATIONS; ++it) { int element_index = local_idx + it * WARP_SIZE; for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; if (element_index < local_seq) { dst[i*element_count*stride+it*WARP_SIZE] = (output_t)(elements[i][it] / sum[i]); #pragma unroll for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { if (element_index + element < local_seq) { out[element] = elements[i][it + element] / sum[i]; } else { out[element] = 0; } } copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count * stride + it * WARP_SIZE, out); } else if (element_index < element_count) { dst[i*element_count*stride+it*WARP_SIZE] = 0; copy_zero_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count * stride + it * WARP_SIZE); } else { break; } Loading @@ -183,6 +234,7 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward( constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; constexpr int ELEMENTS_PER_LDG_STG = 4; int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x; int local_seq = blockIdx.x + 1; Loading @@ -197,35 +249,39 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward( int local_idx = threadIdx.x; // the first element to process by the current thread int thread_offset = first_batch * stride + local_idx; int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; grad += thread_offset; output += thread_offset; gradInput += thread_offset; // load data from global memory acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; acc_t output_reg[WARP_BATCH][WARP_ITERATIONS]; acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; input_t temp_grad[ELEMENTS_PER_LDG_STG]; input_t temp_output[ELEMENTS_PER_LDG_STG]; #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { int batch_element_count = (i >= local_batches) ? 0 : local_seq; #pragma unroll for (int it = 0; it < WARP_ITERATIONS; ++it) { int element_index = local_idx + it * WARP_SIZE; for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; if (element_index < batch_element_count) { output_reg[i][it] = output[i*element_count*stride+it*WARP_SIZE]; } else { output_reg[i][it] = acc_t(0); copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_grad, grad + i * element_count * stride + it * WARP_SIZE); copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_output, output + i * element_count * stride + it * WARP_SIZE); #pragma unroll for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { if (element_index + element < batch_element_count) { output_reg[i][it + element] = (acc_t)temp_output[element]; } } #pragma unroll for (int it = 0; it < WARP_ITERATIONS; ++it) { int element_index = local_idx + it * WARP_SIZE; if (element_index < batch_element_count) { grad_reg[i][it] = (acc_t)grad[i*element_count*stride+it*WARP_SIZE] * output_reg[i][it]; } else { grad_reg[i][it] = acc_t(0); for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { if (element_index + element < batch_element_count) { grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element]; } } } } } Loading @@ -247,11 +303,16 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward( if (i >= local_batches) break; #pragma unroll for (int it = 0; it < WARP_ITERATIONS; ++it) { int element_index = local_idx + it * WARP_SIZE; for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; if (element_index < element_count) { // compute gradients gradInput[i*element_count*stride+it*WARP_SIZE] = (output_t)(scale * (grad_reg[i][it] - output_reg[i][it] * sum[i])); output_t out[ELEMENTS_PER_LDG_STG]; #pragma unroll for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { out[element] = (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i])); } copy_vector<output_t, ELEMENTS_PER_LDG_STG>(gradInput + i * element_count * stride + it * WARP_SIZE, out); } } } Loading Loading
megatron/fused_kernels/scaled_masked_softmax.h +80 −42 Original line number Diff line number Diff line Loading @@ -26,6 +26,21 @@ namespace { template <typename Datatype, int ELEMENTS_PER_LDG> __device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src); template <> __device__ __inline__ void copy_vector<__half, 1>(__half *dst, const __half *src) { *dst = *src; } template <> __device__ __inline__ void copy_vector<__half, 4>(__half *dst, const __half *src) { *((float2*) dst) = *((float2*) src); } template <> __device__ __inline__ void copy_vector<uint8_t, 1>(uint8_t *dst, const uint8_t *src) { *dst = *src; } template <> __device__ __inline__ void copy_vector<uint8_t, 4>(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); } int log2_ceil(int value) { int log2_value = 0; while ((1 << log2_value) < value) ++log2_value; Loading Loading @@ -90,6 +105,7 @@ __global__ void scaled_masked_softmax_warp_forward( constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; constexpr int ELEMENTS_PER_LDG_STG = 4; // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) // gridDim/blockIdx = (seq_len, attn_heads, batches) Loading @@ -110,29 +126,40 @@ __global__ void scaled_masked_softmax_warp_forward( // there might be multiple batches per warp. compute the index within the batch int local_idx = threadIdx.x; src += first_batch * element_count + local_idx; dst += first_batch * element_count + local_idx; mask += pad_first_batch * element_count + local_idx; src += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; dst += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; mask += pad_first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; // load data from global memory acc_t elements[WARP_BATCH][WARP_ITERATIONS]; input_t temp_data[ELEMENTS_PER_LDG_STG]; uint8_t temp_mask[ELEMENTS_PER_LDG_STG]; #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { int batch_element_count = (i >= local_batches) ? 0 : element_count; #pragma unroll for (int it = 0; it < WARP_ITERATIONS; ++it) { int element_index = local_idx + it * WARP_SIZE; int itr_idx = i*element_count+it*WARP_SIZE; for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; if (element_index < batch_element_count) { if (mask[itr_idx] != 1) { elements[i][it] = (acc_t)src[itr_idx] * scale; int itr_idx = i*element_count+it*WARP_SIZE; copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_data, src + itr_idx); copy_vector<uint8_t, ELEMENTS_PER_LDG_STG>(temp_mask, mask + itr_idx); #pragma unroll for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { if (temp_mask[element] != 1) { elements[i][it + element] = (acc_t)temp_data[element] * scale; } else { elements[i][it] = -10000.0; elements[i][it + element] = -10000.0; } } } else { elements[i][it] = -std::numeric_limits<acc_t>::infinity(); #pragma unroll for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { elements[i][it + element] = -std::numeric_limits<acc_t>::infinity(); } } } } Loading Loading @@ -161,15 +188,20 @@ __global__ void scaled_masked_softmax_warp_forward( warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum); // store result output_t out[ELEMENTS_PER_LDG_STG]; #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { if (i >= local_batches) break; #pragma unroll for (int it = 0; it < WARP_ITERATIONS; ++it) { int element_index = local_idx + it * WARP_SIZE; for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; if (element_index < element_count) { dst[i*element_count+it*WARP_SIZE] = (output_t)(elements[i][it] / sum[i]); #pragma unroll for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { out[element] = elements[i][it + element] / sum[i]; } copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count + it * WARP_SIZE, out); } else { break; } Loading @@ -192,6 +224,7 @@ __global__ void scaled_masked_softmax_warp_backward( constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; constexpr int ELEMENTS_PER_LDG_STG = 4; // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) // gridDim/blockIdx = (seq_len, attn_heads, batches) Loading @@ -207,35 +240,35 @@ __global__ void scaled_masked_softmax_warp_backward( int local_idx = threadIdx.x; // the first element to process by the current thread int thread_offset = first_batch * element_count + local_idx; int thread_offset = first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; grad += thread_offset; output += thread_offset; gradInput += thread_offset; // load data from global memory acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; acc_t output_reg[WARP_BATCH][WARP_ITERATIONS]; acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; input_t temp_grad[ELEMENTS_PER_LDG_STG]; input_t temp_output[ELEMENTS_PER_LDG_STG]; #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { int batch_element_count = (i >= local_batches) ? 0 : element_count; #pragma unroll for (int it = 0; it < WARP_ITERATIONS; ++it) { int element_index = local_idx + it * WARP_SIZE; for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; if (element_index < batch_element_count) { output_reg[i][it] = output[i*element_count+it*WARP_SIZE]; } else { output_reg[i][it] = acc_t(0); } } copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_grad, grad + i * element_count + it * WARP_SIZE); copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_output, output + i * element_count + it * WARP_SIZE); #pragma unroll for (int it = 0; it < WARP_ITERATIONS; ++it) { int element_index = local_idx + it * WARP_SIZE; if (element_index < batch_element_count) { grad_reg[i][it] = (acc_t)grad[i*element_count+it*WARP_SIZE] * output_reg[i][it]; } else { grad_reg[i][it] = acc_t(0); for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { output_reg[i][it + element] = (acc_t)temp_output[element]; } #pragma unroll for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element]; } } } } Loading @@ -257,11 +290,16 @@ __global__ void scaled_masked_softmax_warp_backward( if (i >= local_batches) break; #pragma unroll for (int it = 0; it < WARP_ITERATIONS; ++it) { int element_index = local_idx + it * WARP_SIZE; for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; if (element_index < element_count) { // compute gradients gradInput[i*element_count+it*WARP_SIZE] = (output_t)(scale * (grad_reg[i][it] - output_reg[i][it] * sum[i])); output_t out[ELEMENTS_PER_LDG_STG]; #pragma unroll for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { out[element] = (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i])); } copy_vector<output_t, ELEMENTS_PER_LDG_STG>(gradInput + i * element_count + it * WARP_SIZE, out); } } } Loading
megatron/fused_kernels/scaled_upper_triang_masked_softmax.h +97 −36 Original line number Diff line number Diff line Loading @@ -26,6 +26,31 @@ namespace { template <typename Datatype, int ELEMENTS_PER_LDG> __device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src); template <> __device__ __inline__ void copy_vector<__half, 1>(__half *dst, const __half *src) { *dst = *src; } template <> __device__ __inline__ void copy_vector<__half, 4>(__half *dst, const __half *src) { *((float2*) dst) = *((float2*) src); } template <> __device__ __inline__ void copy_vector<uint8_t, 1>(uint8_t *dst, const uint8_t *src) { *dst = *src; } template <> __device__ __inline__ void copy_vector<uint8_t, 4>(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); } template <typename Datatype, int ELEMENTS_PER_LDG> __device__ __inline__ void copy_zero_vector(Datatype *dst); template <> __device__ __inline__ void copy_zero_vector<__half, 1>(__half *dst) { *dst = 0.0; } template <> __device__ __inline__ void copy_zero_vector<__half, 4>(__half *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); } int log2_ceil(int value) { int log2_value = 0; while ((1 << log2_value) < value) ++log2_value; Loading Loading @@ -89,10 +114,11 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward( constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; constexpr int ELEMENTS_PER_LDG_STG = 4; int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x; int local_seq = blockIdx.x + 1; int warp_iteration_limit = (local_seq + WARP_SIZE - 1)/WARP_SIZE; int warp_iteration_limit = (local_seq + ELEMENTS_PER_LDG_STG * WARP_SIZE - 1)/ WARP_SIZE; // micro_batch_size might not be a multiple of WARP_BATCH. Check how // many batches have to computed within this WARP. Loading @@ -103,22 +129,36 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward( // there might be multiple batches per warp. compute the index within the batch int local_idx = threadIdx.x; src += first_batch * stride + local_idx; dst += first_batch * stride + local_idx; src += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; dst += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; // load data from global memory acc_t elements[WARP_BATCH][WARP_ITERATIONS]; input_t temp_data[ELEMENTS_PER_LDG_STG]; #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { int batch_element_count = (i >= local_batches) ? 0 : local_seq; #pragma unroll for (int it = 0; it < WARP_ITERATIONS; ++it) { int element_index = local_idx + it * WARP_SIZE; for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; if (element_index < batch_element_count) { elements[i][it] = (acc_t)src[i*element_count*stride+it*WARP_SIZE] * scale; copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_data, src + i*element_count*stride + it*WARP_SIZE); #pragma unroll for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { if ((element_index + element) < batch_element_count) { elements[i][it+element] = (acc_t)temp_data[element] * scale; } else { elements[i][it] = -std::numeric_limits<acc_t>::infinity(); elements[i][it + element] = -std::numeric_limits<acc_t>::infinity(); } } } else { #pragma unroll for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { elements[i][it + element] = -std::numeric_limits<acc_t>::infinity(); } } } } Loading Loading @@ -149,17 +189,28 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward( warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum); // store result output_t out[ELEMENTS_PER_LDG_STG]; #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { if (i >= local_batches) break; #pragma unroll for (int it = 0; it < WARP_ITERATIONS; ++it) { int element_index = local_idx + it * WARP_SIZE; for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; if (element_index < local_seq) { dst[i*element_count*stride+it*WARP_SIZE] = (output_t)(elements[i][it] / sum[i]); #pragma unroll for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { if (element_index + element < local_seq) { out[element] = elements[i][it + element] / sum[i]; } else { out[element] = 0; } } copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count * stride + it * WARP_SIZE, out); } else if (element_index < element_count) { dst[i*element_count*stride+it*WARP_SIZE] = 0; copy_zero_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count * stride + it * WARP_SIZE); } else { break; } Loading @@ -183,6 +234,7 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward( constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; constexpr int ELEMENTS_PER_LDG_STG = 4; int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x; int local_seq = blockIdx.x + 1; Loading @@ -197,35 +249,39 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward( int local_idx = threadIdx.x; // the first element to process by the current thread int thread_offset = first_batch * stride + local_idx; int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; grad += thread_offset; output += thread_offset; gradInput += thread_offset; // load data from global memory acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; acc_t output_reg[WARP_BATCH][WARP_ITERATIONS]; acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; input_t temp_grad[ELEMENTS_PER_LDG_STG]; input_t temp_output[ELEMENTS_PER_LDG_STG]; #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { int batch_element_count = (i >= local_batches) ? 0 : local_seq; #pragma unroll for (int it = 0; it < WARP_ITERATIONS; ++it) { int element_index = local_idx + it * WARP_SIZE; for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; if (element_index < batch_element_count) { output_reg[i][it] = output[i*element_count*stride+it*WARP_SIZE]; } else { output_reg[i][it] = acc_t(0); copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_grad, grad + i * element_count * stride + it * WARP_SIZE); copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_output, output + i * element_count * stride + it * WARP_SIZE); #pragma unroll for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { if (element_index + element < batch_element_count) { output_reg[i][it + element] = (acc_t)temp_output[element]; } } #pragma unroll for (int it = 0; it < WARP_ITERATIONS; ++it) { int element_index = local_idx + it * WARP_SIZE; if (element_index < batch_element_count) { grad_reg[i][it] = (acc_t)grad[i*element_count*stride+it*WARP_SIZE] * output_reg[i][it]; } else { grad_reg[i][it] = acc_t(0); for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { if (element_index + element < batch_element_count) { grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element]; } } } } } Loading @@ -247,11 +303,16 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward( if (i >= local_batches) break; #pragma unroll for (int it = 0; it < WARP_ITERATIONS; ++it) { int element_index = local_idx + it * WARP_SIZE; for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; if (element_index < element_count) { // compute gradients gradInput[i*element_count*stride+it*WARP_SIZE] = (output_t)(scale * (grad_reg[i][it] - output_reg[i][it] * sum[i])); output_t out[ELEMENTS_PER_LDG_STG]; #pragma unroll for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { out[element] = (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i])); } copy_vector<output_t, ELEMENTS_PER_LDG_STG>(gradInput + i * element_count * stride + it * WARP_SIZE, out); } } } Loading