Commit 0cb36de2 authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

address review comments

parent 4916bae6
Loading
Loading
Loading
Loading
+17 −1
Original line number Diff line number Diff line
@@ -203,6 +203,22 @@ def parse_args(extra_args_provider=None, defaults={},
            'for distribute-checkpointed-activations to work you '\
            'need to enable checkpoint-activations'

    # custom kernel constraints check
    seq_len = args.seq_length
    attn_batch_size = \
        (args.num_attention_heads / args.tensor_model_parallel_size) * \
        args.micro_batch_size

    # constraints on sequence length and attn_batch_size to enable warp based
    # optimization and upper triangular optimization (for causal mask)
    custom_kernel_constraint = seq_len > 16 and seq_len <=2048 and \
        seq_len % 4 == 0 and attn_batch_size % 4 == 0

    if args.fp16 and custom_kernel_constraint and args.masked_softmax_fusion:
        print('WARNING: constraints for invoking optimized'
            ' fused softmax kernel are not met. We default back to unfused'
            ' kernel invocations.')

    # Load scaled_masked_softmax_fusion_kernels
    if args.masked_softmax_fusion:
        fused_kernels.load_scaled_upper_triang_masked_softmax_fusion_kernel()
+9 −6
Original line number Diff line number Diff line
@@ -116,17 +116,20 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
 
    def forward(self, input, mask):
        # [b, np, sq, sk]
        assert input.dim() == 4
        data_size = input.size()
        query_seq_len = data_size[-2]
        key_seq_len = data_size[-1]
        attn_batch_size = data_size[0] * data_size[1]
        assert input.dim() == 4

        # invoke custom kernel
        if self.input_in_fp16 and key_seq_len <= 2048 and mask is not None and \
           query_seq_len % 4 == 0 and key_seq_len > 16 and \
           attn_batch_size % 4 == 0 and self.scaled_masked_softmax_fusion:
        # constraints on various tensor dimensions to enable warp based
        # optimization and upper triangular optimization (for causal mask)
        custom_kernel_constraint = key_seq_len > 16 and key_seq_len <= 2048 and \
            query_seq_len % 4 == 0 and attn_batch_size % 4 == 0

        # invoke custom kernel
        if self.input_in_fp16 and mask is not None and \
           custom_kernel_constraint and self.scaled_masked_softmax_fusion:
            scale = self.scale if self.scale is not None else 1.0

            if self.attn_mask_type == AttnMaskType.causal: