Commit 21a11ca0 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files
Merge branch 'main' of ssh://gitlab-master.nvidia.com:12051/ADLR/megatron-lm into lmcafee/copygrad-fix-v2
parents 6f2bff5c 6a680986
Loading
Loading
Loading
Loading
+2 −0
Original line number Diff line number Diff line
@@ -300,6 +300,8 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./pretrain_<model>.py \
                --DDP-impl torch
</pre>

The interleaved pipelining schedule (more details in Section 2.2.2 of [our paper](https://arxiv.org/pdf/2104.04473.pdf)) can be enabled using the `--num-layers-per-virtual-pipeline-stage` argument, which controls the number of transformer layers in a virtual stage (by default with the non-interleaved schedule, each GPU will execute a single virtual stage with `NUM_LAYERS / PIPELINE_MP_SIZE` transformer layers). The total number of layers in the transformer model should be divisible by this argument value. Additionally, the number of microbatches in the pipeline (computed as `GLOBAL_BATCH_SIZE / (DATA_PARALLEL_SIZE * MICRO_BATCH_SIZE)`) should be divisible by the `PIPELINE_MP_SIZE` when using this schedule (this condition is checked in an assertion in the code). The interleaved schedule is not supported for pipelines with 2 stages (`PIPELINE_MP_SIZE=2`).

## GPT-3 Example

In `examples/pretrain_gpt3_175B.sh` we have provided an example of how to configure Megatron to run [GPT-3](https://arxiv.org/abs/2005.14165) with 175 billion parameters on 1024 GPUs. The script is designed for [slurm](https://slurm.schedmd.com/documentation.html) with [pyxis](https://github.com/NVIDIA/pyxis) plugin but can be easily adopted to any other scheduler. It uses 8-way and 16-way tensor and pipeline parallelism, respectively. With options `global-batch-size 1536` and `rampup-batch-size 16 16 5859375`, the training will start with global batch size 16 and linearly increase the global batch size to 1536 over 5,859,375 samples with incrmeental steps 16. The training dataset can be either a single set or a multiple datasets combined with a set of weights.
+3 −0
Original line number Diff line number Diff line
@@ -328,6 +328,9 @@ def _add_logging_args(parser):
                       action='store_true',
                       help='If set, write validation perplexity to '
                       'tensorboard.')
    group.add_argument('--log-memory-to-tensorboard',
                       action='store_true',
                       help='Enable memory logging to tensorboard.')

    return parser

+6 −1
Original line number Diff line number Diff line
@@ -286,9 +286,14 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):

    def zero_grad(self, set_to_none=True):
        """We only need to zero the model related parameters, i.e.,
                float16_groups & fp32_from_fp32_groups."""
        float16_groups & fp32_from_fp32_groups. We additionally zero
        fp32_from_float16_groups as a memory optimization to reduce
        fragmentation; in the case of set_to_none==True, the space
        used by this field can be safely deallocated at this point."""
        for group in self.float16_groups:
            _zero_grad_group_helper(group, set_to_none)
        for group in self.fp32_from_float16_groups:
            _zero_grad_group_helper(group, set_to_none)
        for group in self.fp32_from_fp32_groups:
            _zero_grad_group_helper(group, set_to_none)

+44 −12
Original line number Diff line number Diff line
@@ -22,7 +22,9 @@ from megatron import mpu


def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
                 use_ring_exchange=False):
                 use_ring_exchange=False, tensor_shape=None,
                 override_scatter_gather_tensors_in_pipeline=False,
                 dtype_=None):
    """Communicate tensors between stages. Used as helper method in other
    communication methods that are used in megatron/schedules.py.

@@ -37,7 +39,14 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
                   next rank.
        use_ring_exchange: boolean for whether torch.distributed.ring_exchange()
                           API should be used.

        tensor_shape: optional, use when the input sequence contains less
                      tokens than the default sequence length
        override_scatter_gather_tensors_in_pipeline: optional, this is used
                                                     when tensor_shape is
                                                     provided to overwide
                                                     scatter gather tensors
        dtype_: optional, this is used when tensor_shape is provied and what
                is the type of tensor_shape
    Returns:
        (tensor_recv_prev, tensor_recv_next)
    """
@@ -47,8 +56,10 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
    # if needed.
    tensor_recv_prev = None
    tensor_recv_next = None
    if tensor_shape is None:
        tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)
    if args.scatter_gather_tensors_in_pipeline:
    if not override_scatter_gather_tensors_in_pipeline and \
            args.scatter_gather_tensors_in_pipeline:
        tensor_chunk_shape = reduce(operator.mul, tensor_shape, 1) // \
            mpu.get_tensor_model_parallel_world_size()
    else:
@@ -56,19 +67,26 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
    dtype = args.params_dtype
    if args.fp32_residual_connection:
        dtype = torch.float

    requires_grad = True
    if dtype_ is not None:
        dtype = dtype_
        requires_grad = False

    if recv_prev:
        tensor_recv_prev = torch.empty(tensor_chunk_shape,
                                       requires_grad=True,
                                       requires_grad=requires_grad,
                                       device=torch.cuda.current_device(),
                                       dtype=dtype)
    if recv_next:
        tensor_recv_next = torch.empty(tensor_chunk_shape,
                                       requires_grad=True,
                                       requires_grad=requires_grad,
                                       device=torch.cuda.current_device(),
                                       dtype=dtype)

    # Split tensor into smaller chunks if using scatter-gather optimization.
    if args.scatter_gather_tensors_in_pipeline:
    if not override_scatter_gather_tensors_in_pipeline and \
            args.scatter_gather_tensors_in_pipeline:
        if tensor_send_next is not None:
            tensor_send_next = mpu.split_tensor_into_1d_equal_chunks(tensor_send_next)

@@ -112,7 +130,8 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
    torch.cuda.synchronize()

    # If using scatter-gather optimization, gather smaller chunks.
    if args.scatter_gather_tensors_in_pipeline:
    if not override_scatter_gather_tensors_in_pipeline and \
            args.scatter_gather_tensors_in_pipeline:
        if recv_prev:
            tensor_recv_prev = mpu.gather_split_1d_tensor(
                tensor_recv_prev).view(tensor_shape).requires_grad_()
@@ -124,8 +143,11 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
    return tensor_recv_prev, tensor_recv_next


def recv_forward(timers=None):
def recv_forward(tensor_shape=None,
                 override_scatter_gather_tensors_in_pipeline=False,
                 dtype_=None, timers=None):
    """Receive tensor from previous rank in pipeline (forward receive)."""

    if mpu.is_pipeline_first_stage():
        input_tensor = None
    else:
@@ -135,7 +157,11 @@ def recv_forward(timers=None):
            tensor_send_next=None,
            tensor_send_prev=None,
            recv_prev=True,
            recv_next=False)
            recv_next=False,
            tensor_shape=tensor_shape,
            override_scatter_gather_tensors_in_pipeline=\
                override_scatter_gather_tensors_in_pipeline,
            dtype_=dtype_)
        if timers is not None:
            timers('forward-recv').stop()
    return input_tensor
@@ -158,8 +184,11 @@ def recv_backward(timers=None):
    return output_tensor_grad


def send_forward(output_tensor, timers=None):
def send_forward(output_tensor, timers=None,
                 override_scatter_gather_tensors_in_pipeline=False,
                 dtype_=None):
    """Send tensor to next rank in pipeline (forward send)."""

    if not mpu.is_pipeline_last_stage():
        if timers is not None:
            timers('forward-send').start()
@@ -167,7 +196,10 @@ def send_forward(output_tensor, timers=None):
            tensor_send_next=output_tensor,
            tensor_send_prev=None,
            recv_prev=False,
            recv_next=False)
            recv_next=False,
            override_scatter_gather_tensors_in_pipeline=\
            override_scatter_gather_tensors_in_pipeline,
            dtype_=dtype_)
        if timers is not None:
            timers('forward-send').stop()

+17 −14
Original line number Diff line number Diff line
@@ -31,6 +31,9 @@ def get_forward_backward_func():
    if mpu.get_pipeline_model_parallel_world_size() > 1:
        if args.virtual_pipeline_model_parallel_size is not None:
            forward_backward_func = forward_backward_pipelining_with_interleaving
            assert get_num_microbatches() % args.pipeline_model_parallel_size == 0, \
                'number of microbatches is not divisible by pipeline-parallel ' \
                'size when using interleaved schedule'
        else:
            forward_backward_func = forward_backward_pipelining_without_interleaving
    else:
@@ -228,7 +231,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
    # Run warmup forward passes.
    mpu.set_virtual_pipeline_model_parallel_rank(0)
    input_tensors[0].append(
        p2p_communication.recv_forward(timers))
        p2p_communication.recv_forward(timers=timers))
    for k in range(num_warmup_microbatches):
        output_tensor = forward_step_helper(k)

@@ -262,7 +265,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
        else:
            input_tensor = \
                p2p_communication.send_forward_recv_forward(
                    output_tensor, recv_prev, timers)
                    output_tensor, recv_prev=recv_prev, timers=timers)
        input_tensors[next_forward_model_chunk_id].append(input_tensor)

    # Run 1F1B in steady state.
@@ -340,7 +343,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
    if not forward_only:
        if all_warmup_microbatches:
            output_tensor_grads[num_model_chunks-1].append(
                p2p_communication.recv_backward(timers))
                p2p_communication.recv_backward(timers=timers))
        for k in range(num_microbatches_remaining, num_microbatches):
            input_tensor_grad = backward_step_helper(k)
            next_backward_model_chunk_id = get_model_chunk_id(k+1, forward=False)
@@ -352,7 +355,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
                recv_next = False
            output_tensor_grads[next_backward_model_chunk_id].append(
                p2p_communication.send_backward_recv_backward(
                    input_tensor_grad, recv_next, timers))
                    input_tensor_grad, recv_next=recv_next, timers=timers))

    return losses_reduced

@@ -386,10 +389,10 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite

    # Run warmup forward passes.
    for i in range(num_warmup_microbatches):
        input_tensor = p2p_communication.recv_forward(timers)
        input_tensor = p2p_communication.recv_forward(timers=timers)
        output_tensor = forward_step(forward_step_func, data_iterator, model,
                                     input_tensor, losses_reduced)
        p2p_communication.send_forward(output_tensor, timers)
        p2p_communication.send_forward(output_tensor, timers=timers)

        input_tensors.append(input_tensor)
        output_tensors.append(output_tensor)
@@ -398,7 +401,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
    # If all microbatches are run in warmup / cooldown phase, then no need to
    # receive this tensor here.
    if num_microbatches_remaining > 0:
        input_tensor = p2p_communication.recv_forward(timers)
        input_tensor = p2p_communication.recv_forward(timers=timers)

    # Run 1F1B in steady state.
    for i in range(num_microbatches_remaining):
@@ -407,11 +410,11 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
        output_tensor = forward_step(forward_step_func, data_iterator, model,
                                     input_tensor, losses_reduced)
        if forward_only:
            p2p_communication.send_forward(output_tensor, timers)
            p2p_communication.send_forward(output_tensor, timers=timers)
        else:
            output_tensor_grad = \
                p2p_communication.send_forward_recv_backward(output_tensor,
                                                             timers)
                                                             timers=timers)

        # Add input_tensor and output_tensor to end of list, then pop from the
        # start of the list for backward pass.
@@ -420,7 +423,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite

        if forward_only:
            if not last_iteration:
                input_tensor = p2p_communication.recv_forward(timers)
                input_tensor = p2p_communication.recv_forward(timers=timers)
        else:
            input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0)

@@ -430,11 +433,11 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite

            if last_iteration:
                input_tensor = None
                p2p_communication.send_backward(input_tensor_grad, timers)
                p2p_communication.send_backward(input_tensor_grad, timers=timers)
            else:
                input_tensor = \
                    p2p_communication.send_backward_recv_forward(
                        input_tensor_grad, timers)
                        input_tensor_grad, timers=timers)

    # Run cooldown backward passes.
    if not forward_only:
@@ -442,12 +445,12 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
            input_tensor = input_tensors.pop(0)
            output_tensor = output_tensors.pop(0)

            output_tensor_grad = p2p_communication.recv_backward(timers)
            output_tensor_grad = p2p_communication.recv_backward(timers=timers)

            input_tensor_grad = \
                backward_step(optimizer, input_tensor, output_tensor,
                              output_tensor_grad)

            p2p_communication.send_backward(input_tensor_grad, timers)
            p2p_communication.send_backward(input_tensor_grad, timers=timers)

    return losses_reduced
Loading