Loading megatron/fused_kernels/scaled_masked_softmax.cpp +21 −1 Original line number Diff line number Diff line Loading @@ -32,6 +32,12 @@ torch::Tensor bwd_cuda( torch::Tensor const& softmax_results, float scale_factor); int get_batch_per_block_cuda( int query_seq_len, int key_seq_len, int batches, int attn_heads); torch::Tensor fwd( torch::Tensor const& input, torch::Tensor const& mask, Loading Loading @@ -63,6 +69,14 @@ torch::Tensor bwd( return bwd_cuda(output_grads, softmax_results, scale_factor); } int get_batch_per_block( int query_seq_len, int key_seq_len, int batches, int attn_heads) { return get_batch_per_block_cuda(query_seq_len, key_seq_len, batches, attn_heads); } } // end namespace scaled_masked_softmax } // end namespace fused_softmax } // end namespace multihead_attn Loading @@ -71,7 +85,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("forward", &multihead_attn::fused_softmax::scaled_masked_softmax::fwd, "Self Multihead Attention scaled, time masked softmax -- Forward."); m.def("backward", &multihead_attn::fused_softmax::scaled_masked_softmax::bwd, "Self Multihead Attention scaled, time masked softmax -- Backward."); m.def("get_batch_per_block", &multihead_attn::fused_softmax::scaled_masked_softmax::get_batch_per_block, "Return Batch per block size." ); } megatron/fused_kernels/scaled_masked_softmax.h +16 −4 Original line number Diff line number Diff line Loading @@ -16,6 +16,7 @@ #pragma once #include <stdio.h> #include <assert.h> #include <cuda_fp16.h> #include <cfloat> Loading Loading @@ -310,9 +311,23 @@ __global__ void scaled_masked_softmax_warp_backward( } } } } // end of anonymous namespace int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, int attn_heads){ int log2_elements = log2_ceil(key_seq_len); const int next_power_of_two = 1 << log2_elements; int batch_count = batches * attn_heads * query_seq_len; int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; constexpr int threads_per_block = 128; int warps_per_block = (threads_per_block / warp_size); int batches_per_block = warps_per_block * batches_per_warp; return batches_per_block; } template<typename input_t, typename output_t, typename acc_t> void dispatch_scaled_masked_softmax_forward( output_t *dst, Loading @@ -325,7 +340,6 @@ void dispatch_scaled_masked_softmax_forward( int attn_heads, int pad_batches) { TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 2048 ); if (key_seq_len == 0) { return; } else { Loading @@ -344,7 +358,6 @@ void dispatch_scaled_masked_softmax_forward( int warps_per_block = (threads_per_block / warp_size); int batches_per_block = warps_per_block * batches_per_warp; TORCH_INTERNAL_ASSERT(query_seq_len%batches_per_block == 0); dim3 blocks(query_seq_len/batches_per_block, attn_heads, batches); dim3 threads(warp_size, warps_per_block, 1); // Launch code would be more elegant if C++ supported FOR CONSTEXPR Loading Loading @@ -414,7 +427,6 @@ void dispatch_scaled_masked_softmax_backward( int batches, int attn_heads) { TORCH_INTERNAL_ASSERT( key_seq_len >= 0 && key_seq_len <= 2048 ); if (key_seq_len == 0) { return; } else { Loading megatron/fused_kernels/scaled_masked_softmax_cuda.cu +5 −0 Original line number Diff line number Diff line Loading @@ -28,6 +28,11 @@ namespace multihead_attn { namespace fused_softmax { namespace scaled_masked_softmax { int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches, int attn_heads){ return get_batch_per_block(query_seq_len, key_seq_len, batches, attn_heads); } torch::Tensor fwd_cuda( torch::Tensor const& input, torch::Tensor const& mask, Loading megatron/fused_kernels/scaled_upper_triang_masked_softmax.h +0 −4 Original line number Diff line number Diff line Loading @@ -340,7 +340,6 @@ void dispatch_scaled_upper_triang_masked_softmax_forward( int softmax_elements_stride, int attn_batches) { TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048 ); if (softmax_elements == 0) { return; } else { Loading @@ -360,7 +359,6 @@ void dispatch_scaled_upper_triang_masked_softmax_forward( int warps_per_block = (threads_per_block / warp_size); int batches_per_block = warps_per_block * batches_per_warp; TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0); int blocks_per_seq = attn_batches / batches_per_block; dim3 blocks(seq_len, blocks_per_seq, 1); dim3 threads(warp_size, warps_per_block, 1); Loading Loading @@ -430,7 +428,6 @@ void dispatch_scaled_upper_triang_masked_softmax_backward( int softmax_elements_stride, int attn_batches) { TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 2048 ); if (softmax_elements == 0) { return; } else { Loading @@ -450,7 +447,6 @@ void dispatch_scaled_upper_triang_masked_softmax_backward( int warps_per_block = (threads_per_block / warp_size); int batches_per_block = warps_per_block * batches_per_warp; TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0); int blocks_per_seq = attn_batches / batches_per_block; dim3 blocks(seq_len, blocks_per_seq, 1); dim3 threads(warp_size, warps_per_block, 1); Loading megatron/fused_kernels/tests/__init__.py 0 → 100644 +0 −0 Empty file added. Loading
megatron/fused_kernels/scaled_masked_softmax.cpp +21 −1 Original line number Diff line number Diff line Loading @@ -32,6 +32,12 @@ torch::Tensor bwd_cuda( torch::Tensor const& softmax_results, float scale_factor); int get_batch_per_block_cuda( int query_seq_len, int key_seq_len, int batches, int attn_heads); torch::Tensor fwd( torch::Tensor const& input, torch::Tensor const& mask, Loading Loading @@ -63,6 +69,14 @@ torch::Tensor bwd( return bwd_cuda(output_grads, softmax_results, scale_factor); } int get_batch_per_block( int query_seq_len, int key_seq_len, int batches, int attn_heads) { return get_batch_per_block_cuda(query_seq_len, key_seq_len, batches, attn_heads); } } // end namespace scaled_masked_softmax } // end namespace fused_softmax } // end namespace multihead_attn Loading @@ -71,7 +85,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("forward", &multihead_attn::fused_softmax::scaled_masked_softmax::fwd, "Self Multihead Attention scaled, time masked softmax -- Forward."); m.def("backward", &multihead_attn::fused_softmax::scaled_masked_softmax::bwd, "Self Multihead Attention scaled, time masked softmax -- Backward."); m.def("get_batch_per_block", &multihead_attn::fused_softmax::scaled_masked_softmax::get_batch_per_block, "Return Batch per block size." ); }
megatron/fused_kernels/scaled_masked_softmax.h +16 −4 Original line number Diff line number Diff line Loading @@ -16,6 +16,7 @@ #pragma once #include <stdio.h> #include <assert.h> #include <cuda_fp16.h> #include <cfloat> Loading Loading @@ -310,9 +311,23 @@ __global__ void scaled_masked_softmax_warp_backward( } } } } // end of anonymous namespace int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, int attn_heads){ int log2_elements = log2_ceil(key_seq_len); const int next_power_of_two = 1 << log2_elements; int batch_count = batches * attn_heads * query_seq_len; int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; constexpr int threads_per_block = 128; int warps_per_block = (threads_per_block / warp_size); int batches_per_block = warps_per_block * batches_per_warp; return batches_per_block; } template<typename input_t, typename output_t, typename acc_t> void dispatch_scaled_masked_softmax_forward( output_t *dst, Loading @@ -325,7 +340,6 @@ void dispatch_scaled_masked_softmax_forward( int attn_heads, int pad_batches) { TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 2048 ); if (key_seq_len == 0) { return; } else { Loading @@ -344,7 +358,6 @@ void dispatch_scaled_masked_softmax_forward( int warps_per_block = (threads_per_block / warp_size); int batches_per_block = warps_per_block * batches_per_warp; TORCH_INTERNAL_ASSERT(query_seq_len%batches_per_block == 0); dim3 blocks(query_seq_len/batches_per_block, attn_heads, batches); dim3 threads(warp_size, warps_per_block, 1); // Launch code would be more elegant if C++ supported FOR CONSTEXPR Loading Loading @@ -414,7 +427,6 @@ void dispatch_scaled_masked_softmax_backward( int batches, int attn_heads) { TORCH_INTERNAL_ASSERT( key_seq_len >= 0 && key_seq_len <= 2048 ); if (key_seq_len == 0) { return; } else { Loading
megatron/fused_kernels/scaled_masked_softmax_cuda.cu +5 −0 Original line number Diff line number Diff line Loading @@ -28,6 +28,11 @@ namespace multihead_attn { namespace fused_softmax { namespace scaled_masked_softmax { int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches, int attn_heads){ return get_batch_per_block(query_seq_len, key_seq_len, batches, attn_heads); } torch::Tensor fwd_cuda( torch::Tensor const& input, torch::Tensor const& mask, Loading
megatron/fused_kernels/scaled_upper_triang_masked_softmax.h +0 −4 Original line number Diff line number Diff line Loading @@ -340,7 +340,6 @@ void dispatch_scaled_upper_triang_masked_softmax_forward( int softmax_elements_stride, int attn_batches) { TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048 ); if (softmax_elements == 0) { return; } else { Loading @@ -360,7 +359,6 @@ void dispatch_scaled_upper_triang_masked_softmax_forward( int warps_per_block = (threads_per_block / warp_size); int batches_per_block = warps_per_block * batches_per_warp; TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0); int blocks_per_seq = attn_batches / batches_per_block; dim3 blocks(seq_len, blocks_per_seq, 1); dim3 threads(warp_size, warps_per_block, 1); Loading Loading @@ -430,7 +428,6 @@ void dispatch_scaled_upper_triang_masked_softmax_backward( int softmax_elements_stride, int attn_batches) { TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 2048 ); if (softmax_elements == 0) { return; } else { Loading @@ -450,7 +447,6 @@ void dispatch_scaled_upper_triang_masked_softmax_backward( int warps_per_block = (threads_per_block / warp_size); int batches_per_block = warps_per_block * batches_per_warp; TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0); int blocks_per_seq = attn_batches / batches_per_block; dim3 blocks(seq_len, blocks_per_seq, 1); dim3 threads(warp_size, warps_per_block, 1); Loading