Loading megatron/model/fused_softmax.py +1 −1 Original line number Diff line number Diff line Loading @@ -186,7 +186,7 @@ class FusedScaleMaskSoftmax(nn.Module): return probs @staticmethod def get_batch_per_block(b, np, sq, sk): def get_batch_per_block(sq, sk, b, np): import scaled_masked_softmax_cuda return scaled_masked_softmax_cuda.get_batch_per_block(sq, sk, b, np) Loading
megatron/model/fused_softmax.py +1 −1 Original line number Diff line number Diff line Loading @@ -186,7 +186,7 @@ class FusedScaleMaskSoftmax(nn.Module): return probs @staticmethod def get_batch_per_block(b, np, sq, sk): def get_batch_per_block(sq, sk, b, np): import scaled_masked_softmax_cuda return scaled_masked_softmax_cuda.get_batch_per_block(sq, sk, b, np)