Commit fc5d4c2b authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'lmcafee/t5-free-pipe-output-true-fix' into 'main'

T5 fix for deallocating pipeline output tensors

See merge request ADLR/megatron-lm!380
parents 539e39b6 7f1c5917
Loading
Loading
Loading
Loading
+0 −3
Original line number Diff line number Diff line
@@ -684,9 +684,6 @@ def _add_distributed_args(parser):
                       help='Call torch.cuda.empty_cache() each iteration '
                       '(training and eval), to reduce fragmentation.'
                       '0=off, 1=moderate, 2=aggressive.')
    group.add_argument('--deallocate-pipeline-outputs', action='store_true',
                       default=False, help='If set, pipeline output tensors '
                       'are deallocated during the forward pass.')
    return parser


+19 −17
Original line number Diff line number Diff line
@@ -42,24 +42,29 @@ def get_forward_backward_func():
        forward_backward_func = forward_backward_no_pipelining
    return forward_backward_func

def free_output_tensor(output_tensors, deallocate_pipeline_outputs):
    '''Pseudo-free (i.e., set to scalar) the output tensor's '.data' field.
def deallocate_output_tensor(out):
    '''Pseudo-deallocate (i.e., set to scalar) the output tensor's '.data' field.

    This method should be called right after the output tensor has been
    sent to the next pipeline stage. At this point, the output tensor is
    only useful for its '.grad_fn' field, and not its '.data'.
    '''
    if not deallocate_pipeline_outputs or output_tensors is None:
    if out is None:
        return
    if isinstance(output_tensors, torch.Tensor):
        output_tensors = [output_tensors]
    for output_tensor in output_tensors:
        output_tensor.data = torch.cuda.FloatTensor([0])
    assert isinstance(out, torch.Tensor), \
        "expected Tensor, found %s." % type(out).__name__
    assert out._base is None, \
        "counter-productive to free a view of another tensor."
    out.data = torch.empty(
        (1,),
        device = out.device,
        dtype = out.dtype,
    )
        
def custom_backward(output, grad_output):
    '''Directly call C++ autograd engine.

    To make the 'free_output_tensor' (above) optimization work, the C++
    To make the 'deallocate_output_tensor' (above) optimization work, the C++
    autograd engine must be called directly, bypassing Pytorch's
    torch.autograd.backward. Pytorch's 'backward' checks that the output and
    grad have the same shape, while C++'s 'backward' does not.
@@ -91,6 +96,7 @@ def custom_backward(output, grad_output):
        accumulate_grad=True,
    )
        

def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_reduced):
    """Forward step for passed-in model.

@@ -164,11 +170,7 @@ def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad):
    # Backward pass.
    if output_tensor_grad[0] is None:
        output_tensor = optimizer.scale_loss(output_tensor[0])
    if args.deallocate_pipeline_outputs:
    custom_backward(output_tensor[0], output_tensor_grad[0])
    else:
        torch.autograd.backward(output_tensor[0],
                                grad_tensors=output_tensor_grad[0])

    # Collect the grad of the input_tensor.
    input_tensor_grad = [None]
@@ -376,8 +378,8 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
                    output_tensor, recv_prev=recv_prev,
                    tensor_shape=tensor_shape,
                    timers=timers)
        free_output_tensor(output_tensor, args.deallocate_pipeline_outputs)
        input_tensors[next_forward_model_chunk_id].append(input_tensor)
        deallocate_output_tensor(output_tensor)

    # Run 1F1B in steady state.
    for k in range(num_microbatches_remaining):
@@ -441,7 +443,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
                    output_tensor, input_tensor_grad,
                    recv_prev=recv_prev, recv_next=recv_next,
                    tensor_shape=tensor_shape, timers=timers)
        free_output_tensor(output_tensor, args.deallocate_pipeline_outputs)
        deallocate_output_tensor(output_tensor)

        # Put input_tensor and output_tensor_grad in data structures in the
        # right location.
@@ -617,7 +619,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
        if not forward_only:
            input_tensors.append(input_tensor)
            output_tensors.append(output_tensor)
            free_output_tensor(output_tensor, args.deallocate_pipeline_outputs)
            deallocate_output_tensor(output_tensor[0])

    # Before running 1F1B, need to receive first forward tensor.
    # If all microbatches are run in warmup / cooldown phase, then no need to
@@ -646,7 +648,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
            # Add input_tensor and output_tensor to end of list.
            input_tensors.append(input_tensor)
            output_tensors.append(output_tensor)
            free_output_tensor(output_tensor, args.deallocate_pipeline_outputs)
            deallocate_output_tensor(output_tensor[0])

            # Pop input_tensor and output_tensor from the start of the list for
            # the backward pass.