Commit 0be40526 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'github_fused_softmax' into 'main'

Fused softmax checks and additions from Github (#133)

See merge request ADLR/megatron-lm!312
parents 23266c57 bc7b3539
Loading
Loading
Loading
Loading
+21 −1
Original line number Diff line number Diff line
@@ -32,6 +32,12 @@ torch::Tensor bwd_cuda(
    torch::Tensor const& softmax_results,
    float scale_factor);

int get_batch_per_block_cuda(
    int query_seq_len,
    int key_seq_len,
    int batches,
    int attn_heads);

torch::Tensor fwd(
    torch::Tensor const& input,
    torch::Tensor const& mask,
@@ -63,6 +69,14 @@ torch::Tensor bwd(
  return bwd_cuda(output_grads, softmax_results, scale_factor);
}

int get_batch_per_block(
    int query_seq_len,
    int key_seq_len,
    int batches,
    int attn_heads) {
    return get_batch_per_block_cuda(query_seq_len, key_seq_len, batches, attn_heads);
}

} // end namespace scaled_masked_softmax
} // end namespace fused_softmax
} // end namespace multihead_attn
@@ -71,7 +85,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("forward", 
        &multihead_attn::fused_softmax::scaled_masked_softmax::fwd, 
	"Self Multihead Attention scaled, time masked softmax -- Forward.");

  m.def("backward",
        &multihead_attn::fused_softmax::scaled_masked_softmax::bwd,
	"Self Multihead Attention scaled, time masked softmax -- Backward.");

  m.def("get_batch_per_block",
        &multihead_attn::fused_softmax::scaled_masked_softmax::get_batch_per_block,
        "Return Batch per block size."
  );
}
+14 −1
Original line number Diff line number Diff line
@@ -310,9 +310,22 @@ __global__ void scaled_masked_softmax_warp_backward(
        }
    }
}

} // end of anonymous namespace

int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, int attn_heads){
    int log2_elements = log2_ceil(key_seq_len);
    const int next_power_of_two = 1 << log2_elements;

    int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
    int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;

    constexpr int threads_per_block = 128;
    int warps_per_block = (threads_per_block / warp_size);
    int batches_per_block = warps_per_block * batches_per_warp;

    return batches_per_block;
}

template<typename input_t, typename output_t, typename acc_t>
void dispatch_scaled_masked_softmax_forward(
    output_t *dst, 
+5 −0
Original line number Diff line number Diff line
@@ -28,6 +28,11 @@ namespace multihead_attn {
namespace fused_softmax {
namespace scaled_masked_softmax {

int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches, int attn_heads){
    return get_batch_per_block(query_seq_len, key_seq_len, batches, attn_heads);
}


torch::Tensor fwd_cuda(
    torch::Tensor const& input,
    torch::Tensor const& mask,
+2 −0
Original line number Diff line number Diff line
@@ -361,6 +361,7 @@ void dispatch_scaled_upper_triang_masked_softmax_forward(
        int warps_per_block = (threads_per_block / warp_size);
        int batches_per_block = warps_per_block * batches_per_warp;
        TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0);

        int blocks_per_seq = attn_batches / batches_per_block;
        dim3 blocks(seq_len, blocks_per_seq, 1);
        dim3 threads(warp_size, warps_per_block, 1);
@@ -451,6 +452,7 @@ void dispatch_scaled_upper_triang_masked_softmax_backward(
        int warps_per_block = (threads_per_block / warp_size);
        int batches_per_block = warps_per_block * batches_per_warp;
        TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0);

        int blocks_per_seq = attn_batches / batches_per_block;
        dim3 blocks(seq_len, blocks_per_seq, 1);
        dim3 threads(warp_size, warps_per_block, 1);
+0 −0

Empty file added.

Loading