Commit 4916bae6 authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

conditioning fused kernels

parent 872e38ea
Loading
Loading
Loading
Loading
+3 −1
Original line number Diff line number Diff line
@@ -119,11 +119,13 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
        data_size = input.size()
        query_seq_len = data_size[-2]
        key_seq_len = data_size[-1]
        attn_batch_size = data_size[0] * data_size[1]
        assert input.dim() == 4

        # invoke custom kernel
        if self.input_in_fp16 and key_seq_len <= 2048 and mask is not None and \
           query_seq_len % 4 == 0 and self.scaled_masked_softmax_fusion:
           query_seq_len % 4 == 0 and key_seq_len > 16 and \
           attn_batch_size % 4 == 0 and self.scaled_masked_softmax_fusion:

            scale = self.scale if self.scale is not None else 1.0