Commit 7b585440 authored by mshoeybi's avatar mshoeybi
Browse files

onlly support pp=1

parent e923ec52
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -236,6 +236,7 @@ def parse_args(extra_args_provider=None, defaults={},
        assert args.checkpoint_activations, \
            'for distribute-checkpointed-activations to work you '\
            'need to enable checkpoint-activations'
        assert args.pipeline_model_parallel_size == 1

    _print_args(args)
    return args
+3 −2
Original line number Diff line number Diff line
@@ -289,7 +289,8 @@ class CheckpointFunction(torch.autograd.Function):
        return (None, None) + grads


def checkpoint(function, *args):
def checkpoint(function, distribute_checkpointed_activations, *args):
    """Checkpoint a model or part of the model.
    This has been directly copied from torch.utils.checkpoint."""
    return CheckpointFunction.apply(function, *args)
    return CheckpointFunction.apply(function,
                                    distribute_checkpointed_activations, *args)