Loading megatron/fused_kernels/scaled_masked_softmax.h +3 −0 Original line number Diff line number Diff line Loading @@ -339,6 +339,7 @@ void dispatch_scaled_masked_softmax_forward( int attn_heads, int pad_batches) { TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 2048 ); if (key_seq_len == 0) { return; } else { Loading @@ -357,6 +358,7 @@ void dispatch_scaled_masked_softmax_forward( int warps_per_block = (threads_per_block / warp_size); int batches_per_block = warps_per_block * batches_per_warp; TORCH_INTERNAL_ASSERT(query_seq_len%batches_per_block == 0); dim3 blocks(query_seq_len/batches_per_block, attn_heads, batches); dim3 threads(warp_size, warps_per_block, 1); // Launch code would be more elegant if C++ supported FOR CONSTEXPR Loading Loading @@ -426,6 +428,7 @@ void dispatch_scaled_masked_softmax_backward( int batches, int attn_heads) { TORCH_INTERNAL_ASSERT( key_seq_len >= 0 && key_seq_len <= 2048 ); if (key_seq_len == 0) { return; } else { Loading megatron/fused_kernels/scaled_upper_triang_masked_softmax.h +6 −0 Original line number Diff line number Diff line Loading @@ -340,6 +340,7 @@ void dispatch_scaled_upper_triang_masked_softmax_forward( int softmax_elements_stride, int attn_batches) { TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048 ); if (softmax_elements == 0) { return; } else { Loading @@ -359,6 +360,8 @@ void dispatch_scaled_upper_triang_masked_softmax_forward( int warps_per_block = (threads_per_block / warp_size); int batches_per_block = warps_per_block * batches_per_warp; TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0); int blocks_per_seq = attn_batches / batches_per_block; dim3 blocks(seq_len, blocks_per_seq, 1); dim3 threads(warp_size, warps_per_block, 1); Loading Loading @@ -428,6 +431,7 @@ void dispatch_scaled_upper_triang_masked_softmax_backward( int softmax_elements_stride, int attn_batches) { TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 2048 ); if (softmax_elements == 0) { return; } else { Loading @@ -447,6 +451,8 @@ void dispatch_scaled_upper_triang_masked_softmax_backward( int warps_per_block = (threads_per_block / warp_size); int batches_per_block = warps_per_block * batches_per_warp; TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0); int blocks_per_seq = attn_batches / batches_per_block; dim3 blocks(seq_len, blocks_per_seq, 1); dim3 threads(warp_size, warps_per_block, 1); Loading megatron/model/fused_softmax.py +2 −2 Original line number Diff line number Diff line Loading @@ -138,8 +138,8 @@ class FusedScaleMaskSoftmax(nn.Module): self.scaled_masked_softmax_fusion # user want to fuse and self.input_in_float16 # input must be fp16 and mask is not None # mask tensor must not be None and 16 < sq <= 2048 # sq must be 16 ~ 2048 and sk % 4 == 0 # sk must be divisor of 4 and 16 < sk <= 2048 # sq must be 16 ~ 2048 and sq % 4 == 0 # sk must be divisor of 4 and attn_batches % 4 == 0 # np * b must be divisor of 4 ): if 0 <= sk <= 2048: Loading Loading
megatron/fused_kernels/scaled_masked_softmax.h +3 −0 Original line number Diff line number Diff line Loading @@ -339,6 +339,7 @@ void dispatch_scaled_masked_softmax_forward( int attn_heads, int pad_batches) { TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 2048 ); if (key_seq_len == 0) { return; } else { Loading @@ -357,6 +358,7 @@ void dispatch_scaled_masked_softmax_forward( int warps_per_block = (threads_per_block / warp_size); int batches_per_block = warps_per_block * batches_per_warp; TORCH_INTERNAL_ASSERT(query_seq_len%batches_per_block == 0); dim3 blocks(query_seq_len/batches_per_block, attn_heads, batches); dim3 threads(warp_size, warps_per_block, 1); // Launch code would be more elegant if C++ supported FOR CONSTEXPR Loading Loading @@ -426,6 +428,7 @@ void dispatch_scaled_masked_softmax_backward( int batches, int attn_heads) { TORCH_INTERNAL_ASSERT( key_seq_len >= 0 && key_seq_len <= 2048 ); if (key_seq_len == 0) { return; } else { Loading
megatron/fused_kernels/scaled_upper_triang_masked_softmax.h +6 −0 Original line number Diff line number Diff line Loading @@ -340,6 +340,7 @@ void dispatch_scaled_upper_triang_masked_softmax_forward( int softmax_elements_stride, int attn_batches) { TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048 ); if (softmax_elements == 0) { return; } else { Loading @@ -359,6 +360,8 @@ void dispatch_scaled_upper_triang_masked_softmax_forward( int warps_per_block = (threads_per_block / warp_size); int batches_per_block = warps_per_block * batches_per_warp; TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0); int blocks_per_seq = attn_batches / batches_per_block; dim3 blocks(seq_len, blocks_per_seq, 1); dim3 threads(warp_size, warps_per_block, 1); Loading Loading @@ -428,6 +431,7 @@ void dispatch_scaled_upper_triang_masked_softmax_backward( int softmax_elements_stride, int attn_batches) { TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 2048 ); if (softmax_elements == 0) { return; } else { Loading @@ -447,6 +451,8 @@ void dispatch_scaled_upper_triang_masked_softmax_backward( int warps_per_block = (threads_per_block / warp_size); int batches_per_block = warps_per_block * batches_per_warp; TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0); int blocks_per_seq = attn_batches / batches_per_block; dim3 blocks(seq_len, blocks_per_seq, 1); dim3 threads(warp_size, warps_per_block, 1); Loading
megatron/model/fused_softmax.py +2 −2 Original line number Diff line number Diff line Loading @@ -138,8 +138,8 @@ class FusedScaleMaskSoftmax(nn.Module): self.scaled_masked_softmax_fusion # user want to fuse and self.input_in_float16 # input must be fp16 and mask is not None # mask tensor must not be None and 16 < sq <= 2048 # sq must be 16 ~ 2048 and sk % 4 == 0 # sk must be divisor of 4 and 16 < sk <= 2048 # sq must be 16 ~ 2048 and sq % 4 == 0 # sk must be divisor of 4 and attn_batches % 4 == 0 # np * b must be divisor of 4 ): if 0 <= sk <= 2048: Loading