Commit 136d63cb authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'dist_chkpt_act' into 'main'

Revisited distributing checkpointed activations along the tensor parallel ranks

See merge request ADLR/megatron-lm!311
parents 0be40526 cb5e611d
Loading
Loading
Loading
Loading
+7 −1
Original line number Diff line number Diff line
@@ -240,9 +240,15 @@ def parse_args(extra_args_provider=None, defaults={},
            'residual connection in fp32 only supported when using fp16 or bf16.'
    # Activation checkpointing.
    if args.distribute_checkpointed_activations:
        assert args.tensor_model_parallel_size > 1, 'can distribute ' \
            'checkpointed activations only across tensor model ' \
            'parallel groups'
        assert args.activations_checkpoint_method is not None, \
            'for distribute-checkpointed-activations to work you '\
            'need to use a valid checkpoint-activation method (\'uniform\' or \'block\')'
            'need to use a activation-checkpoint method '
        assert args.num_layers_per_virtual_pipeline_stage is None, \
            'currently distrobuted checkpoint activations only supported for ' \
            'nointerleaved pipeline parallelism'

    _print_args(args)
    return args
+0 −11
Original line number Diff line number Diff line
@@ -77,9 +77,6 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
        # Megatron's MPU is the master. Complete initialization right away.
        finish_mpu_init()

        # Initialize memory buffers.
        _initialize_mem_buffs()
        
        # Autoresume.
        _init_autoresume()

@@ -224,11 +221,3 @@ def write_args_to_tensorboard():
            writer.add_text(arg, str(getattr(args, arg)),
                            global_step=args.iteration)

def _initialize_mem_buffs():
    """Initialize manually allocated static memory."""
    args = get_args()

    # Initialize memory for checkpointed activations.
    if args.distribute_checkpointed_activations:
        mpu.init_checkpointed_activations_memory_buffer()
+19 −2
Original line number Diff line number Diff line
@@ -544,6 +544,7 @@ class ParallelTransformer(MegatronModule):
        # Store activation checkpoiting flag.
        self.activations_checkpoint_method = args.activations_checkpoint_method
        self.activations_checkpoint_num_layers = args.activations_checkpoint_num_layers
        self.distribute_checkpointed_activations = args.distribute_checkpointed_activations

        # Number of layers.
        assert args.num_layers % mpu.get_pipeline_model_parallel_world_size() == 0, \
@@ -607,8 +608,22 @@ class ParallelTransformer(MegatronModule):
                return x_
            return custom_forward

        # Make sure memory is freed.
        mpu.reset_checkpointed_activations_memory_buffer()
        def distribute_checkpointed_activations_helper(layer_number):
            """Distribute checkpointed activations across the tensor model
               Parallel ranks if the `distribute-checkpointed-activations
               is on and either of the following conditions is met:
                 - it is not the first layer in the in the pipeline stage.
                   The first layer is used in the pipeline parallelism 
                   and changing its shape throws error in the backward pass.
                 - we are at the first pipline stage so the input tensor is
                   not used in pipeline parallelism. Note that no pipeline
                   parallelism is a special case of this.
            """
            not_first_layer_in_pipeline_stage = (layer_number > 0)
            is_first_pipeline_stage = (
                mpu.get_pipeline_model_parallel_rank() == 0)
            return self.distribute_checkpointed_activations and \
                (not_first_layer_in_pipeline_stage or is_first_pipeline_stage)

        if self.activations_checkpoint_method == 'uniform':
            # Uniformly divide the total number of Transformer layers and checkpoint
@@ -618,6 +633,7 @@ class ParallelTransformer(MegatronModule):
            while l < self.num_layers:
                hidden_states = mpu.checkpoint(
                    custom(l, l + self.activations_checkpoint_num_layers),
                    distribute_checkpointed_activations_helper(l),
                    hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
                l += self.activations_checkpoint_num_layers
        elif self.activations_checkpoint_method == 'block':
@@ -628,6 +644,7 @@ class ParallelTransformer(MegatronModule):
                if l < self.activations_checkpoint_num_layers:
                    hidden_states = mpu.checkpoint(
                        custom(l, l + 1),
                        distribute_checkpointed_activations_helper(l),
                        hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
                else:
                    hidden_states = custom(l, l + 1)(
+0 −2
Original line number Diff line number Diff line
@@ -56,9 +56,7 @@ from .mappings import scatter_to_tensor_model_parallel_region

from .random import checkpoint
from .random import get_cuda_rng_tracker
from .random import init_checkpointed_activations_memory_buffer
from .random import model_parallel_cuda_manual_seed
from .random import reset_checkpointed_activations_memory_buffer
from .random import gather_split_1d_tensor
from .random import split_tensor_into_1d_equal_chunks

+23 −55
Original line number Diff line number Diff line
@@ -37,46 +37,6 @@ from .initialize import get_tensor_model_parallel_world_size
_MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng'


# Whether apply model parallelsim to checkpointed hidden states.
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER = None


def init_checkpointed_activations_memory_buffer():
    """Initializ the memory buffer for the checkpointed activations."""
    args = get_args()

    per_layer = args.micro_batch_size * args.max_position_embeddings * \
                args.hidden_size // args.tensor_model_parallel_size
    num_layers = args.num_layers // mpu.get_pipeline_model_parallel_world_size()
    if args.virtual_pipeline_model_parallel_size is not None:
        num_layers = num_layers // args.virtual_pipeline_model_parallel_size

    if args.activations_checkpoint_method == 'uniform':
        assert num_layers % args.activations_checkpoint_num_layers == 0, \
            'total number of layers is not divisible by checkpoint-chunk_size'
        num_checkpointer_layers = args.num_layers // args.activations_checkpoint_num_layers
    elif args.activations_checkpoint_method == 'block':
        assert args.activations_checkpoint_num_layers <= num_layers, \
            'total number of layers is fewer than the number of layers to checkpoint'
        num_checkpointer_layers = args.activations_checkpoint_num_layers
    numel = per_layer * num_checkpointer_layers
    dtype = torch.half
    if not args.fp16:
        dtype = torch.float

    global _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
    assert _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is None, \
        'checkpointed activations memory buffer is already allocated.'
    _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER = allocate_mem_buff(
        'checkpointed activations', numel, dtype, track_usage=False)


def reset_checkpointed_activations_memory_buffer():
    """Reset the memory used for checkpointing."""
    if _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is not None:
        _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER.reset()


def _set_cuda_rng_state(new_state, device=-1):
    """Sets the random number generator state of the current GPU.

@@ -110,13 +70,20 @@ def _set_cuda_rng_state(new_state, device=-1):
    _lazy_call(cb)


def split_tensor_into_1d_equal_chunks(tensor):
def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False):
    """Break a tensor into equal 1D chunks."""
    data = tensor.view(-1)
    partition_size = torch.numel(data) // get_tensor_model_parallel_world_size()
    partition_size = torch.numel(tensor) // \
        get_tensor_model_parallel_world_size()
    start_index = partition_size * get_tensor_model_parallel_rank()
    end_index = start_index + partition_size
    return data[start_index:end_index]
    if new_buffer:
        data = torch.empty(partition_size, dtype=tensor.dtype,
                           device=torch.cuda.current_device(),
                           requires_grad=False)
        data.copy_(tensor.view(-1)[start_index:end_index])
    else:
        data = tensor.view(-1)[start_index:end_index]
    return data
    

def gather_split_1d_tensor(tensor):
@@ -259,8 +226,10 @@ class CheckpointFunction(torch.autograd.Function):
              tracked/set/reset.
    """
    @staticmethod
    def forward(ctx, run_function, *args):
    def forward(ctx, run_function, distribute_checkpointed_activations, *args):
        ctx.run_function = run_function
        ctx.distribute_checkpointed_activations \
            = distribute_checkpointed_activations

        # Copy the rng states.
        ctx.fwd_cpu_rng_state = torch.get_rng_state()
@@ -272,16 +241,14 @@ class CheckpointFunction(torch.autograd.Function):

        # Divide hidden states across model parallel group and only keep
        # the chunk corresponding to the current rank.
        if _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is not None:
        if distribute_checkpointed_activations:
            ctx.input_0_shape = args[0].data.shape
            args[0].data = split_tensor_into_1d_equal_chunks(args[0].data)
            args[0].data = _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER.add(
                args[0].data)
            args[0].data = split_tensor_into_1d_equal_chunks(args[0].data,
                                                             new_buffer=True)

        # Store everything.
        ctx.save_for_backward(*args)


        return outputs

    @staticmethod
@@ -290,7 +257,7 @@ class CheckpointFunction(torch.autograd.Function):
            raise RuntimeError("Checkpointing is not compatible with .grad(), "
                               "please use .backward() if possible")
        inputs = ctx.saved_tensors
        if _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is not None:
        if ctx.distribute_checkpointed_activations:
            inputs[0].data = gather_split_1d_tensor(inputs[0].data)
            inputs[0].data = inputs[0].data.view(ctx.input_0_shape)

@@ -319,10 +286,11 @@ class CheckpointFunction(torch.autograd.Function):
        torch.autograd.backward(outputs, args)
        grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp
                      for inp in detached_inputs)
        return (None,) + grads
        return (None, None) + grads


def checkpoint(function, *args):
def checkpoint(function, distribute_checkpointed_activations, *args):
    """Checkpoint a model or part of the model.
    This has been directly copied from torch.utils.checkpoint."""
    return CheckpointFunction.apply(function, *args)
    return CheckpointFunction.apply(function,
                                    distribute_checkpointed_activations, *args)