Loading megatron/fused_kernels/scaled_masked_softmax.cpp +21 −1 Original line number Diff line number Diff line Loading @@ -32,6 +32,12 @@ torch::Tensor bwd_cuda( torch::Tensor const& softmax_results, float scale_factor); int get_batch_per_block_cuda( int query_seq_len, int key_seq_len, int batches, int attn_heads); torch::Tensor fwd( torch::Tensor const& input, torch::Tensor const& mask, Loading Loading @@ -63,6 +69,14 @@ torch::Tensor bwd( return bwd_cuda(output_grads, softmax_results, scale_factor); } int get_batch_per_block( int query_seq_len, int key_seq_len, int batches, int attn_heads) { return get_batch_per_block_cuda(query_seq_len, key_seq_len, batches, attn_heads); } } // end namespace scaled_masked_softmax } // end namespace fused_softmax } // end namespace multihead_attn Loading @@ -71,7 +85,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("forward", &multihead_attn::fused_softmax::scaled_masked_softmax::fwd, "Self Multihead Attention scaled, time masked softmax -- Forward."); m.def("backward", &multihead_attn::fused_softmax::scaled_masked_softmax::bwd, "Self Multihead Attention scaled, time masked softmax -- Backward."); m.def("get_batch_per_block", &multihead_attn::fused_softmax::scaled_masked_softmax::get_batch_per_block, "Return Batch per block size." ); } megatron/fused_kernels/scaled_masked_softmax.h +14 −1 Original line number Diff line number Diff line Loading @@ -310,9 +310,22 @@ __global__ void scaled_masked_softmax_warp_backward( } } } } // end of anonymous namespace int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, int attn_heads){ int log2_elements = log2_ceil(key_seq_len); const int next_power_of_two = 1 << log2_elements; int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; constexpr int threads_per_block = 128; int warps_per_block = (threads_per_block / warp_size); int batches_per_block = warps_per_block * batches_per_warp; return batches_per_block; } template<typename input_t, typename output_t, typename acc_t> void dispatch_scaled_masked_softmax_forward( output_t *dst, Loading megatron/fused_kernels/scaled_masked_softmax_cuda.cu +5 −0 Original line number Diff line number Diff line Loading @@ -28,6 +28,11 @@ namespace multihead_attn { namespace fused_softmax { namespace scaled_masked_softmax { int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches, int attn_heads){ return get_batch_per_block(query_seq_len, key_seq_len, batches, attn_heads); } torch::Tensor fwd_cuda( torch::Tensor const& input, torch::Tensor const& mask, Loading megatron/fused_kernels/scaled_upper_triang_masked_softmax.h +2 −0 Original line number Diff line number Diff line Loading @@ -361,6 +361,7 @@ void dispatch_scaled_upper_triang_masked_softmax_forward( int warps_per_block = (threads_per_block / warp_size); int batches_per_block = warps_per_block * batches_per_warp; TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0); int blocks_per_seq = attn_batches / batches_per_block; dim3 blocks(seq_len, blocks_per_seq, 1); dim3 threads(warp_size, warps_per_block, 1); Loading Loading @@ -451,6 +452,7 @@ void dispatch_scaled_upper_triang_masked_softmax_backward( int warps_per_block = (threads_per_block / warp_size); int batches_per_block = warps_per_block * batches_per_warp; TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0); int blocks_per_seq = attn_batches / batches_per_block; dim3 blocks(seq_len, blocks_per_seq, 1); dim3 threads(warp_size, warps_per_block, 1); Loading megatron/fused_kernels/tests/__init__.py 0 → 100644 +0 −0 Empty file added. Loading
megatron/fused_kernels/scaled_masked_softmax.cpp +21 −1 Original line number Diff line number Diff line Loading @@ -32,6 +32,12 @@ torch::Tensor bwd_cuda( torch::Tensor const& softmax_results, float scale_factor); int get_batch_per_block_cuda( int query_seq_len, int key_seq_len, int batches, int attn_heads); torch::Tensor fwd( torch::Tensor const& input, torch::Tensor const& mask, Loading Loading @@ -63,6 +69,14 @@ torch::Tensor bwd( return bwd_cuda(output_grads, softmax_results, scale_factor); } int get_batch_per_block( int query_seq_len, int key_seq_len, int batches, int attn_heads) { return get_batch_per_block_cuda(query_seq_len, key_seq_len, batches, attn_heads); } } // end namespace scaled_masked_softmax } // end namespace fused_softmax } // end namespace multihead_attn Loading @@ -71,7 +85,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("forward", &multihead_attn::fused_softmax::scaled_masked_softmax::fwd, "Self Multihead Attention scaled, time masked softmax -- Forward."); m.def("backward", &multihead_attn::fused_softmax::scaled_masked_softmax::bwd, "Self Multihead Attention scaled, time masked softmax -- Backward."); m.def("get_batch_per_block", &multihead_attn::fused_softmax::scaled_masked_softmax::get_batch_per_block, "Return Batch per block size." ); }
megatron/fused_kernels/scaled_masked_softmax.h +14 −1 Original line number Diff line number Diff line Loading @@ -310,9 +310,22 @@ __global__ void scaled_masked_softmax_warp_backward( } } } } // end of anonymous namespace int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, int attn_heads){ int log2_elements = log2_ceil(key_seq_len); const int next_power_of_two = 1 << log2_elements; int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; constexpr int threads_per_block = 128; int warps_per_block = (threads_per_block / warp_size); int batches_per_block = warps_per_block * batches_per_warp; return batches_per_block; } template<typename input_t, typename output_t, typename acc_t> void dispatch_scaled_masked_softmax_forward( output_t *dst, Loading
megatron/fused_kernels/scaled_masked_softmax_cuda.cu +5 −0 Original line number Diff line number Diff line Loading @@ -28,6 +28,11 @@ namespace multihead_attn { namespace fused_softmax { namespace scaled_masked_softmax { int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches, int attn_heads){ return get_batch_per_block(query_seq_len, key_seq_len, batches, attn_heads); } torch::Tensor fwd_cuda( torch::Tensor const& input, torch::Tensor const& mask, Loading
megatron/fused_kernels/scaled_upper_triang_masked_softmax.h +2 −0 Original line number Diff line number Diff line Loading @@ -361,6 +361,7 @@ void dispatch_scaled_upper_triang_masked_softmax_forward( int warps_per_block = (threads_per_block / warp_size); int batches_per_block = warps_per_block * batches_per_warp; TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0); int blocks_per_seq = attn_batches / batches_per_block; dim3 blocks(seq_len, blocks_per_seq, 1); dim3 threads(warp_size, warps_per_block, 1); Loading Loading @@ -451,6 +452,7 @@ void dispatch_scaled_upper_triang_masked_softmax_backward( int warps_per_block = (threads_per_block / warp_size); int batches_per_block = warps_per_block * batches_per_warp; TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0); int blocks_per_seq = attn_batches / batches_per_block; dim3 blocks(seq_len, blocks_per_seq, 1); dim3 threads(warp_size, warps_per_block, 1); Loading