Loading megatron/arguments.py +29 −22 Original line number Diff line number Diff line Loading @@ -183,13 +183,15 @@ def parse_args(extra_args_provider=None, defaults={}, 'for distribute-checkpointed-activations to work you '\ 'need to enable checkpoint-activations' # load scaled_upper_triang_masked_softmax_fusion kernel if args.scaled_masked_softmax_fusion: if args.scaled_upper_triang_masked_softmax_fusion: fused_kernels.load_scaled_upper_triang_masked_softmax_fusion_kernel() # load scaled_masked_softmax_fusion kernel if args.scaled_masked_softmax_fusion: else: fused_kernels.load_scaled_masked_softmax_fusion_kernel() else: # This argument will eventually go away, for now make sure it is off # if scaled_masked_softmax_fusion is off. args.scaled_upper_triang_masked_softmax_fusion = False # Load mixed precision fused layer norm. if args.fp32_residual_connection: Loading Loading @@ -328,18 +330,22 @@ def _add_training_args(parser): help='Exit the program after this many minutes.') group.add_argument('--tensorboard-dir', type=str, default=None, help='Write TensorBoard logs to this directory.') group.add_argument('--no-scaled-masked-softmax-fusion', action='store_false', help='Disable fusion of query_key_value scaling, ' 'masking, and softmax.', dest='scaled_masked_softmax_fusion') group.add_argument('--scaled-upper-triang-masked-softmax-fusion', action='store_true', help='Enable fusion of query_key_value_scaling ' 'time (upper diagonal) masking and softmax.') group.add_argument('--scaled-masked-softmax-fusion', action='store_true', help='Enable fusion of query_key_value_scaling ' 'general masking and softmax.') group.add_argument('--bias-gelu-fusion', action='store_true', help='Enable bias and gelu fusion.') group.add_argument('--bias-dropout-fusion', action='store_true', help='Enable bias and dropout fusion.') type=bool, help='Use upper triangular version of fused ' 'scale, mask, softmax fusion kernel (default for GPT). ' '- DEPRECATED') group.add_argument('--no-bias-gelu-fusion', action='store_false', help='Disable bias and gelu fusion.', dest='bias_gelu_fusion') group.add_argument('--no-bias-dropout-fusion', action='store_false', help='Disable bias and dropout fusion.', dest='bias_dropout_fusion') return parser Loading Loading @@ -447,12 +453,13 @@ def _add_mixed_precision_args(parser): help='hysteresis for dynamic loss scaling') group.add_argument('--fp32-residual-connection', action='store_true', help='Move residual connections to fp32.') group.add_argument('--apply-query-key-layer-scaling', action='store_true', help='Scale Q * K^T by 1 / layer-number. If this flag ' 'is set, then it will automatically set ' 'attention-softmax-in-fp32 to true') group.add_argument('--no-query-key-layer-scaling', action='store_false', help='Do not scale Q * K^T by 1 / layer-number.', dest='apply_query_key_layer_scaling') group.add_argument('--attention-softmax-in-fp32', action='store_true', help='Run attention masking and softmax in fp32.') help='Run attention masking and softmax in fp32. ' 'This flag is ignored unless ' '--no-query-key-layer-scaling is specified.') group.add_argument('--fp32-allreduce', action='store_true', help='All-reduce in fp32') group.add_argument('--fp16-lm-cross-entropy', action='store_true', Loading pretrain_gpt2.py +2 −1 Original line number Diff line number Diff line Loading @@ -141,4 +141,5 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): if __name__ == "__main__": pretrain(train_valid_test_datasets_provider, model_provider, forward_step, args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}) args_defaults={'tokenizer_type': 'GPT2BPETokenizer', 'scaled_upper_triang_masked_softmax_fusion': True}) Loading
megatron/arguments.py +29 −22 Original line number Diff line number Diff line Loading @@ -183,13 +183,15 @@ def parse_args(extra_args_provider=None, defaults={}, 'for distribute-checkpointed-activations to work you '\ 'need to enable checkpoint-activations' # load scaled_upper_triang_masked_softmax_fusion kernel if args.scaled_masked_softmax_fusion: if args.scaled_upper_triang_masked_softmax_fusion: fused_kernels.load_scaled_upper_triang_masked_softmax_fusion_kernel() # load scaled_masked_softmax_fusion kernel if args.scaled_masked_softmax_fusion: else: fused_kernels.load_scaled_masked_softmax_fusion_kernel() else: # This argument will eventually go away, for now make sure it is off # if scaled_masked_softmax_fusion is off. args.scaled_upper_triang_masked_softmax_fusion = False # Load mixed precision fused layer norm. if args.fp32_residual_connection: Loading Loading @@ -328,18 +330,22 @@ def _add_training_args(parser): help='Exit the program after this many minutes.') group.add_argument('--tensorboard-dir', type=str, default=None, help='Write TensorBoard logs to this directory.') group.add_argument('--no-scaled-masked-softmax-fusion', action='store_false', help='Disable fusion of query_key_value scaling, ' 'masking, and softmax.', dest='scaled_masked_softmax_fusion') group.add_argument('--scaled-upper-triang-masked-softmax-fusion', action='store_true', help='Enable fusion of query_key_value_scaling ' 'time (upper diagonal) masking and softmax.') group.add_argument('--scaled-masked-softmax-fusion', action='store_true', help='Enable fusion of query_key_value_scaling ' 'general masking and softmax.') group.add_argument('--bias-gelu-fusion', action='store_true', help='Enable bias and gelu fusion.') group.add_argument('--bias-dropout-fusion', action='store_true', help='Enable bias and dropout fusion.') type=bool, help='Use upper triangular version of fused ' 'scale, mask, softmax fusion kernel (default for GPT). ' '- DEPRECATED') group.add_argument('--no-bias-gelu-fusion', action='store_false', help='Disable bias and gelu fusion.', dest='bias_gelu_fusion') group.add_argument('--no-bias-dropout-fusion', action='store_false', help='Disable bias and dropout fusion.', dest='bias_dropout_fusion') return parser Loading Loading @@ -447,12 +453,13 @@ def _add_mixed_precision_args(parser): help='hysteresis for dynamic loss scaling') group.add_argument('--fp32-residual-connection', action='store_true', help='Move residual connections to fp32.') group.add_argument('--apply-query-key-layer-scaling', action='store_true', help='Scale Q * K^T by 1 / layer-number. If this flag ' 'is set, then it will automatically set ' 'attention-softmax-in-fp32 to true') group.add_argument('--no-query-key-layer-scaling', action='store_false', help='Do not scale Q * K^T by 1 / layer-number.', dest='apply_query_key_layer_scaling') group.add_argument('--attention-softmax-in-fp32', action='store_true', help='Run attention masking and softmax in fp32.') help='Run attention masking and softmax in fp32. ' 'This flag is ignored unless ' '--no-query-key-layer-scaling is specified.') group.add_argument('--fp32-allreduce', action='store_true', help='All-reduce in fp32') group.add_argument('--fp16-lm-cross-entropy', action='store_true', Loading
pretrain_gpt2.py +2 −1 Original line number Diff line number Diff line Loading @@ -141,4 +141,5 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): if __name__ == "__main__": pretrain(train_valid_test_datasets_provider, model_provider, forward_step, args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}) args_defaults={'tokenizer_type': 'GPT2BPETokenizer', 'scaled_upper_triang_masked_softmax_fusion': True})