Loading megatron/model/transformer.py +8 −4 Original line number Diff line number Diff line Loading @@ -392,14 +392,18 @@ def get_bias_dropout_add(training): @torch.jit.script def bias_dropout_add_fused_train(x, bias, residual, prob): # type: (Tensor, Tensor, Tensor, float) -> Tensor def bias_dropout_add_fused_train(x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor, prob: float) -> torch.Tensor: return bias_dropout_add(x, bias, residual, prob, True) @torch.jit.script def bias_dropout_add_fused_inference(x, bias, residual, prob): # type: (Tensor, Tensor, Tensor, float) -> Tensor def bias_dropout_add_fused_inference(x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor, prob: float) -> torch.Tensor: return bias_dropout_add(x, bias, residual, prob, False) Loading Loading
megatron/model/transformer.py +8 −4 Original line number Diff line number Diff line Loading @@ -392,14 +392,18 @@ def get_bias_dropout_add(training): @torch.jit.script def bias_dropout_add_fused_train(x, bias, residual, prob): # type: (Tensor, Tensor, Tensor, float) -> Tensor def bias_dropout_add_fused_train(x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor, prob: float) -> torch.Tensor: return bias_dropout_add(x, bias, residual, prob, True) @torch.jit.script def bias_dropout_add_fused_inference(x, bias, residual, prob): # type: (Tensor, Tensor, Tensor, float) -> Tensor def bias_dropout_add_fused_inference(x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor, prob: float) -> torch.Tensor: return bias_dropout_add(x, bias, residual, prob, False) Loading