Commit c6b77639 authored by hyunwoongko's avatar hyunwoongko Committed by mshoeybi
Browse files

chagne PR by reviews

parent 0d350c8d
Loading
Loading
Loading
Loading
+3 −0
Original line number Diff line number Diff line
@@ -339,6 +339,7 @@ void dispatch_scaled_masked_softmax_forward(
    int attn_heads,
    int pad_batches)
{
    TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 2048 );
    if (key_seq_len == 0) {
        return;
    } else {
@@ -357,6 +358,7 @@ void dispatch_scaled_masked_softmax_forward(

        int warps_per_block = (threads_per_block / warp_size);
        int batches_per_block = warps_per_block * batches_per_warp;
        TORCH_INTERNAL_ASSERT(query_seq_len%batches_per_block == 0);
        dim3 blocks(query_seq_len/batches_per_block, attn_heads, batches);
        dim3 threads(warp_size, warps_per_block, 1);
        // Launch code would be more elegant if C++ supported FOR CONSTEXPR
@@ -426,6 +428,7 @@ void dispatch_scaled_masked_softmax_backward(
    int batches,
    int attn_heads)
{
    TORCH_INTERNAL_ASSERT( key_seq_len >= 0 && key_seq_len <= 2048 );
    if (key_seq_len == 0) {
       return;
    } else {
+6 −0
Original line number Diff line number Diff line
@@ -340,6 +340,7 @@ void dispatch_scaled_upper_triang_masked_softmax_forward(
    int softmax_elements_stride, 
    int attn_batches)
{
    TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048 );
    if (softmax_elements == 0) {
        return;
    } else {
@@ -359,6 +360,8 @@ void dispatch_scaled_upper_triang_masked_softmax_forward(

        int warps_per_block = (threads_per_block / warp_size);
        int batches_per_block = warps_per_block * batches_per_warp;
        TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0);

        int blocks_per_seq = attn_batches / batches_per_block;
        dim3 blocks(seq_len, blocks_per_seq, 1);
        dim3 threads(warp_size, warps_per_block, 1);
@@ -428,6 +431,7 @@ void dispatch_scaled_upper_triang_masked_softmax_backward(
    int softmax_elements_stride, 
    int attn_batches)
{
    TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 2048 );
    if (softmax_elements == 0) {
       return;
    } else {
@@ -447,6 +451,8 @@ void dispatch_scaled_upper_triang_masked_softmax_backward(

        int warps_per_block = (threads_per_block / warp_size);
        int batches_per_block = warps_per_block * batches_per_warp;
        TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0);

        int blocks_per_seq = attn_batches / batches_per_block;
        dim3 blocks(seq_len, blocks_per_seq, 1);
        dim3 threads(warp_size, warps_per_block, 1);
+2 −2
Original line number Diff line number Diff line
@@ -138,8 +138,8 @@ class FusedScaleMaskSoftmax(nn.Module):
            self.scaled_masked_softmax_fusion  # user want to fuse
            and self.input_in_float16  # input must be fp16
            and mask is not None  # mask tensor must not be None
            and 16 < sq <= 2048  # sq must be 16 ~ 2048
            and sk % 4 == 0  # sk must be divisor of 4
            and 16 < sk <= 2048  # sq must be 16 ~ 2048
            and sq % 4 == 0  # sk must be divisor of 4
            and attn_batches % 4 == 0  # np * b must be divisor of 4
        ):
            if 0 <= sk <= 2048: