Loading megatron/fused_kernels/scaled_masked_softmax.h +3 −3 Original line number Diff line number Diff line Loading @@ -32,11 +32,9 @@ __device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src); template <> __device__ __inline__ void copy_vector<__half, 1>(__half *dst, const __half *src) { *dst = *src; } template <> __device__ __inline__ void copy_vector<float, 1>(float *dst, const float *src) { *dst = *src; } template <> __device__ __inline__ void copy_vector<__half, 4>(__half *dst, const __half *src) { *((float2*) dst) = *((float2*) src); } template <> __device__ __inline__ void copy_vector<uint8_t, 1>(uint8_t *dst, const uint8_t *src) { *dst = *src; } Loading Loading @@ -250,6 +248,8 @@ __global__ void scaled_masked_softmax_warp_backward( // load data from global memory acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; input_t temp_grad[ELEMENTS_PER_LDG_STG]; input_t temp_output[ELEMENTS_PER_LDG_STG]; #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { int batch_element_count = (i >= local_batches) ? 0 : element_count; Loading megatron/fused_kernels/scaled_upper_triang_masked_softmax.h +8 −7 Original line number Diff line number Diff line Loading @@ -32,21 +32,22 @@ __device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src); template <> __device__ __inline__ void copy_vector<__half, 1>(__half *dst, const __half *src) { *dst = *src; } template <> __device__ __inline__ void copy_vector<float, 1>(float *dst, const float *src) { *dst = *src; } template <> __device__ __inline__ void copy_vector<__half, 4>(__half *dst, const __half *src) { *((float2*) dst) = *((float2*) src); } template <> __device__ __inline__ void copy_zero_vector<__half, 4>(__half *dst) { *((float2*) dst) = 0; } template <> __device__ __inline__ void copy_vector<uint8_t, 1>(uint8_t *dst, const uint8_t *src) { *dst = *src; } template <> __device__ __inline__ void copy_vector<uint8_t, 4>(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); } template <typename Datatype, int ELEMENTS_PER_LDG> __device__ __inline__ void copy_zero_vector(Datatype *dst); template <> __device__ __inline__ void copy_zero_vector<__half, 4>(__half *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); } int log2_ceil(int value) { int log2_value = 0; while ((1 << log2_value) < value) ++log2_value; Loading Loading @@ -199,7 +200,7 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward( } copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count * stride + it * WARP_SIZE, out); } else if (element_index < element_count) { copy_zero_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count * stride + it * WARP_SIZE) copy_zero_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count * stride + it * WARP_SIZE); } else { break; } Loading Loading
megatron/fused_kernels/scaled_masked_softmax.h +3 −3 Original line number Diff line number Diff line Loading @@ -32,11 +32,9 @@ __device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src); template <> __device__ __inline__ void copy_vector<__half, 1>(__half *dst, const __half *src) { *dst = *src; } template <> __device__ __inline__ void copy_vector<float, 1>(float *dst, const float *src) { *dst = *src; } template <> __device__ __inline__ void copy_vector<__half, 4>(__half *dst, const __half *src) { *((float2*) dst) = *((float2*) src); } template <> __device__ __inline__ void copy_vector<uint8_t, 1>(uint8_t *dst, const uint8_t *src) { *dst = *src; } Loading Loading @@ -250,6 +248,8 @@ __global__ void scaled_masked_softmax_warp_backward( // load data from global memory acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; input_t temp_grad[ELEMENTS_PER_LDG_STG]; input_t temp_output[ELEMENTS_PER_LDG_STG]; #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { int batch_element_count = (i >= local_batches) ? 0 : element_count; Loading
megatron/fused_kernels/scaled_upper_triang_masked_softmax.h +8 −7 Original line number Diff line number Diff line Loading @@ -32,21 +32,22 @@ __device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src); template <> __device__ __inline__ void copy_vector<__half, 1>(__half *dst, const __half *src) { *dst = *src; } template <> __device__ __inline__ void copy_vector<float, 1>(float *dst, const float *src) { *dst = *src; } template <> __device__ __inline__ void copy_vector<__half, 4>(__half *dst, const __half *src) { *((float2*) dst) = *((float2*) src); } template <> __device__ __inline__ void copy_zero_vector<__half, 4>(__half *dst) { *((float2*) dst) = 0; } template <> __device__ __inline__ void copy_vector<uint8_t, 1>(uint8_t *dst, const uint8_t *src) { *dst = *src; } template <> __device__ __inline__ void copy_vector<uint8_t, 4>(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); } template <typename Datatype, int ELEMENTS_PER_LDG> __device__ __inline__ void copy_zero_vector(Datatype *dst); template <> __device__ __inline__ void copy_zero_vector<__half, 4>(__half *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); } int log2_ceil(int value) { int log2_value = 0; while ((1 << log2_value) < value) ++log2_value; Loading Loading @@ -199,7 +200,7 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward( } copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count * stride + it * WARP_SIZE, out); } else if (element_index < element_count) { copy_zero_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count * stride + it * WARP_SIZE) copy_zero_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count * stride + it * WARP_SIZE); } else { break; } Loading