Loading megatron/fused_kernels/scaled_upper_triang_masked_softmax.h +21 −7 Original line number Diff line number Diff line Loading @@ -44,6 +44,9 @@ __device__ __inline__ void copy_vector<uint8_t, 4>(uint8_t *dst, const uint8_t * template <typename Datatype, int ELEMENTS_PER_LDG> __device__ __inline__ void copy_zero_vector(Datatype *dst); template <> __device__ __inline__ void copy_zero_vector<__half, 1>(__half *dst) { *dst = 0.0; } template <> __device__ __inline__ void copy_zero_vector<__half, 4>(__half *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); } Loading Loading @@ -115,7 +118,7 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward( int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x; int local_seq = blockIdx.x + 1; int warp_iteration_limit = (local_seq + WARP_SIZE - 1)/WARP_SIZE; int warp_iteration_limit = (local_seq + ELEMENTS_PER_LDG_STG * WARP_SIZE - 1)/ WARP_SIZE; // micro_batch_size might not be a multiple of WARP_BATCH. Check how // many batches have to computed within this WARP. Loading @@ -141,12 +144,15 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward( int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; if (element_index < batch_element_count) { int itr_idx = i*element_count*stride+it*WARP_SIZE; copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_data, src + itr_idx); copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_data, src + i*element_count*stride + it*WARP_SIZE); #pragma unroll for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { if ((element_index + element) < batch_element_count) { elements[i][it+element] = (acc_t)temp_data[element] * scale; } else { elements[i][it + element] = -std::numeric_limits<acc_t>::infinity(); } } } else { #pragma unroll Loading Loading @@ -196,7 +202,11 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward( #pragma unroll for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { if (element_index + element < local_seq) { out[element] = elements[i][it + element] / sum[i]; } else { out[element] = 0; } } copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count * stride + it * WARP_SIZE, out); } else if (element_index < element_count) { Loading Loading @@ -262,15 +272,19 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward( #pragma unroll for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { if (element_index + element < batch_element_count) { output_reg[i][it + element] = (acc_t)temp_output[element]; } } #pragma unroll for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { if (element_index + element < batch_element_count) { grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element]; } } } } } acc_t sum[WARP_BATCH]; #pragma unroll Loading Loading
megatron/fused_kernels/scaled_upper_triang_masked_softmax.h +21 −7 Original line number Diff line number Diff line Loading @@ -44,6 +44,9 @@ __device__ __inline__ void copy_vector<uint8_t, 4>(uint8_t *dst, const uint8_t * template <typename Datatype, int ELEMENTS_PER_LDG> __device__ __inline__ void copy_zero_vector(Datatype *dst); template <> __device__ __inline__ void copy_zero_vector<__half, 1>(__half *dst) { *dst = 0.0; } template <> __device__ __inline__ void copy_zero_vector<__half, 4>(__half *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); } Loading Loading @@ -115,7 +118,7 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward( int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x; int local_seq = blockIdx.x + 1; int warp_iteration_limit = (local_seq + WARP_SIZE - 1)/WARP_SIZE; int warp_iteration_limit = (local_seq + ELEMENTS_PER_LDG_STG * WARP_SIZE - 1)/ WARP_SIZE; // micro_batch_size might not be a multiple of WARP_BATCH. Check how // many batches have to computed within this WARP. Loading @@ -141,12 +144,15 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward( int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; if (element_index < batch_element_count) { int itr_idx = i*element_count*stride+it*WARP_SIZE; copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_data, src + itr_idx); copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_data, src + i*element_count*stride + it*WARP_SIZE); #pragma unroll for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { if ((element_index + element) < batch_element_count) { elements[i][it+element] = (acc_t)temp_data[element] * scale; } else { elements[i][it + element] = -std::numeric_limits<acc_t>::infinity(); } } } else { #pragma unroll Loading Loading @@ -196,7 +202,11 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward( #pragma unroll for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { if (element_index + element < local_seq) { out[element] = elements[i][it + element] / sum[i]; } else { out[element] = 0; } } copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count * stride + it * WARP_SIZE, out); } else if (element_index < element_count) { Loading Loading @@ -262,15 +272,19 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward( #pragma unroll for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { if (element_index + element < batch_element_count) { output_reg[i][it + element] = (acc_t)temp_output[element]; } } #pragma unroll for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { if (element_index + element < batch_element_count) { grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element]; } } } } } acc_t sum[WARP_BATCH]; #pragma unroll Loading