Commit b0b017ff authored by hyunwoongko's avatar hyunwoongko Committed by mshoeybi
Browse files

Modify comments

parent c6b77639
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -138,8 +138,8 @@ class FusedScaleMaskSoftmax(nn.Module):
            self.scaled_masked_softmax_fusion  # user want to fuse
            and self.input_in_float16  # input must be fp16
            and mask is not None  # mask tensor must not be None
            and 16 < sk <= 2048  # sq must be 16 ~ 2048
            and sq % 4 == 0  # sk must be divisor of 4
            and 16 < sk <= 2048  # sk must be 16 ~ 2048
            and sq % 4 == 0  # sq must be divisor of 4
            and attn_batches % 4 == 0  # np * b must be divisor of 4
        ):
            if 0 <= sk <= 2048: