Commit 40becfc9 authored by hyunwoongko's avatar hyunwoongko Committed by mshoeybi
Browse files

Improve and fix bugs about fused softmax layer

parent 23266c57
Loading
Loading
Loading
Loading
+21 −1
Original line number Diff line number Diff line
@@ -32,6 +32,12 @@ torch::Tensor bwd_cuda(
    torch::Tensor const& softmax_results,
    float scale_factor);

int get_batch_per_block_cuda(
    int query_seq_len,
    int key_seq_len,
    int batches,
    int attn_heads);

torch::Tensor fwd(
    torch::Tensor const& input,
    torch::Tensor const& mask,
@@ -63,6 +69,14 @@ torch::Tensor bwd(
  return bwd_cuda(output_grads, softmax_results, scale_factor);
}

int get_batch_per_block(
    int query_seq_len,
    int key_seq_len,
    int batches,
    int attn_heads) {
    return get_batch_per_block_cuda(query_seq_len, key_seq_len, batches, attn_heads);
}

} // end namespace scaled_masked_softmax
} // end namespace fused_softmax
} // end namespace multihead_attn
@@ -71,7 +85,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("forward", 
        &multihead_attn::fused_softmax::scaled_masked_softmax::fwd, 
	"Self Multihead Attention scaled, time masked softmax -- Forward.");

  m.def("backward",
        &multihead_attn::fused_softmax::scaled_masked_softmax::bwd,
	"Self Multihead Attention scaled, time masked softmax -- Backward.");

  m.def("get_batch_per_block",
        &multihead_attn::fused_softmax::scaled_masked_softmax::get_batch_per_block,
        "Return Batch per block size."
  );
}
+16 −4
Original line number Diff line number Diff line
@@ -16,6 +16,7 @@

#pragma once

#include <stdio.h>
#include <assert.h>
#include <cuda_fp16.h>
#include <cfloat>
@@ -310,9 +311,23 @@ __global__ void scaled_masked_softmax_warp_backward(
        }
    }
}

} // end of anonymous namespace

int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, int attn_heads){
    int log2_elements = log2_ceil(key_seq_len);
    const int next_power_of_two = 1 << log2_elements;

    int batch_count = batches * attn_heads * query_seq_len;
    int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
    int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;

    constexpr int threads_per_block = 128;
    int warps_per_block = (threads_per_block / warp_size);
    int batches_per_block = warps_per_block * batches_per_warp;

    return batches_per_block;
}

template<typename input_t, typename output_t, typename acc_t>
void dispatch_scaled_masked_softmax_forward(
    output_t *dst, 
@@ -325,7 +340,6 @@ void dispatch_scaled_masked_softmax_forward(
    int attn_heads,
    int pad_batches)
{
    TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 2048 );
    if (key_seq_len == 0) {
        return;
    } else {
@@ -344,7 +358,6 @@ void dispatch_scaled_masked_softmax_forward(

        int warps_per_block = (threads_per_block / warp_size);
        int batches_per_block = warps_per_block * batches_per_warp;
        TORCH_INTERNAL_ASSERT(query_seq_len%batches_per_block == 0);
        dim3 blocks(query_seq_len/batches_per_block, attn_heads, batches);
        dim3 threads(warp_size, warps_per_block, 1);
        // Launch code would be more elegant if C++ supported FOR CONSTEXPR
@@ -414,7 +427,6 @@ void dispatch_scaled_masked_softmax_backward(
    int batches,
    int attn_heads)
{
    TORCH_INTERNAL_ASSERT( key_seq_len >= 0 && key_seq_len <= 2048 );
    if (key_seq_len == 0) {
       return;
    } else {
+5 −0
Original line number Diff line number Diff line
@@ -28,6 +28,11 @@ namespace multihead_attn {
namespace fused_softmax {
namespace scaled_masked_softmax {

int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches, int attn_heads){
    return get_batch_per_block(query_seq_len, key_seq_len, batches, attn_heads);
}


torch::Tensor fwd_cuda(
    torch::Tensor const& input,
    torch::Tensor const& mask,
+0 −4
Original line number Diff line number Diff line
@@ -340,7 +340,6 @@ void dispatch_scaled_upper_triang_masked_softmax_forward(
    int softmax_elements_stride, 
    int attn_batches)
{
    TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048 );
    if (softmax_elements == 0) {
        return;
    } else {
@@ -360,7 +359,6 @@ void dispatch_scaled_upper_triang_masked_softmax_forward(

        int warps_per_block = (threads_per_block / warp_size);
        int batches_per_block = warps_per_block * batches_per_warp;
        TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0);
        int blocks_per_seq = attn_batches / batches_per_block;
        dim3 blocks(seq_len, blocks_per_seq, 1);
        dim3 threads(warp_size, warps_per_block, 1);
@@ -430,7 +428,6 @@ void dispatch_scaled_upper_triang_masked_softmax_backward(
    int softmax_elements_stride, 
    int attn_batches)
{
    TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 2048 );
    if (softmax_elements == 0) {
       return;
    } else {
@@ -450,7 +447,6 @@ void dispatch_scaled_upper_triang_masked_softmax_backward(

        int warps_per_block = (threads_per_block / warp_size);
        int batches_per_block = warps_per_block * batches_per_warp;
        TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0);
        int blocks_per_seq = attn_batches / batches_per_block;
        dim3 blocks(seq_len, blocks_per_seq, 1);
        dim3 threads(warp_size, warps_per_block, 1);
+0 −0

Empty file added.

Loading