Commit 0d350c8d authored by hyunwoongko's avatar hyunwoongko Committed by mshoeybi
Browse files

fix bugs in fused softmax

parent 116820a5
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -186,7 +186,7 @@ class FusedScaleMaskSoftmax(nn.Module):
        return probs

    @staticmethod
    def get_batch_per_block(b, np, sq, sk):
    def get_batch_per_block(sq, sk, b, np):
        import scaled_masked_softmax_cuda

        return scaled_masked_softmax_cuda.get_batch_per_block(sq, sk, b, np)