Commit 60750922 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'lmcafee/dst-chk-act' into 'main'

Fixed distributed checkpoint activations to work on all layers

See merge request ADLR/megatron-lm!365
parents 2ad00f4e 1b28a51b
Loading
Loading
Loading
Loading
+13 −11
Original line number Diff line number Diff line
@@ -245,17 +245,6 @@ def parse_args(extra_args_provider=None, defaults={},
    if args.fp32_residual_connection:
        assert args.fp16 or args.bf16, \
            'residual connection in fp32 only supported when using fp16 or bf16.'
    # Activation checkpointing.
    if args.distribute_checkpointed_activations:
        assert args.tensor_model_parallel_size > 1, 'can distribute ' \
            'checkpointed activations only across tensor model ' \
            'parallel groups'
        assert args.activations_checkpoint_method is not None, \
            'for distribute-checkpointed-activations to work you '\
            'need to use a activation-checkpoint method '
        assert args.num_layers_per_virtual_pipeline_stage is None, \
            'currently distrobuted checkpoint activations only supported for ' \
            'nointerleaved pipeline parallelism'

    TORCH_MAJOR = int(torch.__version__.split('.')[0])
    TORCH_MINOR = int(torch.__version__.split('.')[1])
@@ -267,6 +256,19 @@ def parse_args(extra_args_provider=None, defaults={},
                  'pytorch v1.11 (nvidia pytorch container paired with v1.11). '
                  'Defaulting to no_persist_layer_norm=True')

    # Activation checkpointing.
    if args.distribute_checkpointed_activations:
        assert args.tensor_model_parallel_size > 1, 'can distribute ' \
            'checkpointed activations only across tensor model ' \
            'parallel groups'
        assert args.activations_checkpoint_method is not None, \
            'for distributed checkpoint activations to work you '\
            'need to use a activation-checkpoint method '
        assert TORCH_MAJOR >= 1 and TORCH_MINOR >= 10, \
            'distributed checkpoint activations are supported for pytorch ' \
            'v1.10 and above (Nvidia Pytorch container >= 21.07). Current ' \
            'pytorch version is v%s.%s.' % (TORCH_MAJOR, TORCH_MINOR)

    _print_args(args)
    return args

+2 −19
Original line number Diff line number Diff line
@@ -626,23 +626,6 @@ class ParallelTransformer(MegatronModule):
                return x_
            return custom_forward

        def distribute_checkpointed_activations_helper(layer_number):
            """Distribute checkpointed activations across the tensor model
               Parallel ranks if the `distribute-checkpointed-activations
               is on and either of the following conditions is met:
                 - it is not the first layer in the in the pipeline stage.
                   The first layer is used in the pipeline parallelism 
                   and changing its shape throws error in the backward pass.
                 - we are at the first pipline stage so the input tensor is
                   not used in pipeline parallelism. Note that no pipeline
                   parallelism is a special case of this.
            """
            not_first_layer_in_pipeline_stage = (layer_number > 0)
            is_first_pipeline_stage = (
                mpu.get_pipeline_model_parallel_rank() == 0)
            return self.distribute_checkpointed_activations and \
                (not_first_layer_in_pipeline_stage or is_first_pipeline_stage)

        if self.activations_checkpoint_method == 'uniform':
            # Uniformly divide the total number of Transformer layers and checkpoint
            # the input activation of each divided chunk.
@@ -651,7 +634,7 @@ class ParallelTransformer(MegatronModule):
            while l < self.num_layers:
                hidden_states = mpu.checkpoint(
                    custom(l, l + self.activations_checkpoint_num_layers),
                    distribute_checkpointed_activations_helper(l),
                    self.distribute_checkpointed_activations,
                    hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
                l += self.activations_checkpoint_num_layers
        elif self.activations_checkpoint_method == 'block':
@@ -662,7 +645,7 @@ class ParallelTransformer(MegatronModule):
                if l < self.activations_checkpoint_num_layers:
                    hidden_states = mpu.checkpoint(
                        custom(l, l + 1),
                        distribute_checkpointed_activations_helper(l),
                        self.distribute_checkpointed_activations,
                        hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
                else:
                    hidden_states = custom(l, l + 1)(
+0 −1
Original line number Diff line number Diff line
@@ -24,7 +24,6 @@ from torch import _C
from torch.cuda import _lazy_call, device as device_ctx_manager
from torch.utils.checkpoint import detach_variable

from megatron import get_args
from megatron.memory import allocate_mem_buff

from .initialize import get_data_parallel_rank