Loading megatron/model/fused_softmax.py +3 −1 Original line number Diff line number Diff line Loading @@ -119,11 +119,13 @@ class FusedScaleMaskSoftmax(torch.nn.Module): data_size = input.size() query_seq_len = data_size[-2] key_seq_len = data_size[-1] attn_batch_size = data_size[0] * data_size[1] assert input.dim() == 4 # invoke custom kernel if self.input_in_fp16 and key_seq_len <= 2048 and mask is not None and \ query_seq_len % 4 == 0 and self.scaled_masked_softmax_fusion: query_seq_len % 4 == 0 and key_seq_len > 16 and \ attn_batch_size % 4 == 0 and self.scaled_masked_softmax_fusion: scale = self.scale if self.scale is not None else 1.0 Loading Loading
megatron/model/fused_softmax.py +3 −1 Original line number Diff line number Diff line Loading @@ -119,11 +119,13 @@ class FusedScaleMaskSoftmax(torch.nn.Module): data_size = input.size() query_seq_len = data_size[-2] key_seq_len = data_size[-1] attn_batch_size = data_size[0] * data_size[1] assert input.dim() == 4 # invoke custom kernel if self.input_in_fp16 and key_seq_len <= 2048 and mask is not None and \ query_seq_len % 4 == 0 and self.scaled_masked_softmax_fusion: query_seq_len % 4 == 0 and key_seq_len > 16 and \ attn_batch_size % 4 == 0 and self.scaled_masked_softmax_fusion: scale = self.scale if self.scale is not None else 1.0 Loading