Loading megatron/arguments.py +1 −0 Original line number Diff line number Diff line Loading @@ -236,6 +236,7 @@ def parse_args(extra_args_provider=None, defaults={}, assert args.checkpoint_activations, \ 'for distribute-checkpointed-activations to work you '\ 'need to enable checkpoint-activations' assert args.pipeline_model_parallel_size == 1 _print_args(args) return args Loading megatron/mpu/random.py +3 −2 Original line number Diff line number Diff line Loading @@ -289,7 +289,8 @@ class CheckpointFunction(torch.autograd.Function): return (None, None) + grads def checkpoint(function, *args): def checkpoint(function, distribute_checkpointed_activations, *args): """Checkpoint a model or part of the model. This has been directly copied from torch.utils.checkpoint.""" return CheckpointFunction.apply(function, *args) return CheckpointFunction.apply(function, distribute_checkpointed_activations, *args) Loading
megatron/arguments.py +1 −0 Original line number Diff line number Diff line Loading @@ -236,6 +236,7 @@ def parse_args(extra_args_provider=None, defaults={}, assert args.checkpoint_activations, \ 'for distribute-checkpointed-activations to work you '\ 'need to enable checkpoint-activations' assert args.pipeline_model_parallel_size == 1 _print_args(args) return args Loading
megatron/mpu/random.py +3 −2 Original line number Diff line number Diff line Loading @@ -289,7 +289,8 @@ class CheckpointFunction(torch.autograd.Function): return (None, None) + grads def checkpoint(function, *args): def checkpoint(function, distribute_checkpointed_activations, *args): """Checkpoint a model or part of the model. This has been directly copied from torch.utils.checkpoint.""" return CheckpointFunction.apply(function, *args) return CheckpointFunction.apply(function, distribute_checkpointed_activations, *args)