Commit 057b086c authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'lmcafee/free-pipe-output' into 'main'

Deallocate pipeline stage output tensors after sending

See merge request ADLR/megatron-lm!370
parents 5833553a 85876092
Loading
Loading
Loading
Loading
+54 −2
Original line number Diff line number Diff line
@@ -15,6 +15,7 @@

from contextlib import contextmanager
import torch
from torch.autograd.variable import Variable
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP

from megatron import get_args
@@ -27,7 +28,6 @@ from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module
from megatron.model import ModelType


def get_forward_backward_func():
    args = get_args()
    if mpu.get_pipeline_model_parallel_world_size() > 1:
@@ -42,6 +42,54 @@ def get_forward_backward_func():
        forward_backward_func = forward_backward_no_pipelining
    return forward_backward_func

def free_output_tensor(output_tensors):
    '''Pseudo-free (i.e., set to scalar) the output tensor's '.data' field.

    This method should be called right after the output tensor has been
    sent to the next pipeline stage. At this point, the output tensor is
    only useful for its '.grad_fn' field, and not its '.data'.
    '''
    if output_tensors is None:
        return
    if isinstance(output_tensors, torch.Tensor):
        output_tensors = [output_tensors]
    for output_tensor in output_tensors:
        output_tensor.data = torch.cuda.FloatTensor([0])
        
def custom_backward(output, grad_output):
    '''Directly call C++ autograd engine.

    To make the 'free_output_tensor' (above) optimization work, the C++
    autograd engine must be called directly, bypassing Pytorch's
    torch.autograd.backward. Pytorch's 'backward' checks that the output and
    grad have the same shape, while C++'s 'backward' does not.
    '''

    assert output.numel() == 1, \
        "output should be pseudo-'freed' in schedule, to optimize memory"
    assert isinstance(output, torch.Tensor), \
        "output == '%s'." % type(output).__name__
    assert isinstance(grad_output, (torch.Tensor, type(None))), \
        "grad_output == '%s'." % type(grad_output).__name__

    # Handle scalar output
    if grad_output is None:
        assert output.numel() == 1, "implicit grad requires scalar output."
        grad_output = torch.ones_like(
            output,
            memory_format = torch.preserve_format,
        )

    # Call c++ engine [ see torch/csrc/autograd/python_engine.cpp ]
    Variable._execution_engine.run_backward(
        tensors = (output,),
        grad_tensors = (grad_output,),
        keep_graph = False,
        create_graph = False,
        inputs = tuple(),
        allow_unreachable=True,
        accumulate_grad=True,
    )

def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_reduced):
    """Forward step for passed-in model.
@@ -116,7 +164,7 @@ def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad):
    # Backward pass.
    if output_tensor_grad[0] is None:
        output_tensor = optimizer.scale_loss(output_tensor[0])
    torch.autograd.backward(output_tensor[0], grad_tensors=output_tensor_grad[0])
    custom_backward(output_tensor[0], output_tensor_grad[0])

    # Collect the grad of the input_tensor.
    input_tensor_grad = [None]
@@ -324,6 +372,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
                    output_tensor, recv_prev=recv_prev,
                    tensor_shape=tensor_shape,
                    timers=timers)
        free_output_tensor(output_tensor)
        input_tensors[next_forward_model_chunk_id].append(input_tensor)

    # Run 1F1B in steady state.
@@ -388,6 +437,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
                    output_tensor, input_tensor_grad,
                    recv_prev=recv_prev, recv_next=recv_next,
                    tensor_shape=tensor_shape, timers=timers)
        free_output_tensor(output_tensor)

        # Put input_tensor and output_tensor_grad in data structures in the
        # right location.
@@ -562,6 +612,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
        if not forward_only:
            input_tensors.append(input_tensor)
            output_tensors.append(output_tensor)
            free_output_tensor(output_tensor)

    # Before running 1F1B, need to receive first forward tensor.
    # If all microbatches are run in warmup / cooldown phase, then no need to
@@ -590,6 +641,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
            # Add input_tensor and output_tensor to end of list.
            input_tensors.append(input_tensor)
            output_tensors.append(output_tensor)
            free_output_tensor(output_tensor)

            # Pop input_tensor and output_tensor from the start of the list for
            # the backward pass.