Commit fd5469aa authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'lmcafee/distrib-chkpt-fix-v2' into 'main'

Distributed checkpointing memory fix

See merge request ADLR/megatron-lm!379
parents fc5d4c2b d16e2a24
Loading
Loading
Loading
Loading
+24 −2
Original line number Diff line number Diff line
@@ -27,7 +27,6 @@ from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl
from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu


""" We use the following notation throughout this file:
     h: hidden size
     n: number of attention heads
@@ -696,9 +695,32 @@ class ParallelTransformer(MegatronModule):
            # See set_input_tensor()
            hidden_states = self.input_tensor

        # Viewless tensor.
        # - We only need to create a viewless tensor in the case of micro batch
        #   size (mbs) == 1, since in this case, 'hidden_states.transpose()'
        #   above creates a view tensor, and '.contiguous()' is a pass-through.
        #   For mbs >= 2, '.contiguous()' creates a new tensor, eliminating
        #   the need to make it viewless.
        #
        #   However, we don't explicitly check mbs == 1 here because
        #   make_viewless_tensor() has negligible overhead when its input
        #   is already viewless.
        # 
        # - For the 'else' case above, calling make_viewless_tensor() here is
        #   likely redundant, since p2p_communication.py (likely originator)
        #   already creates viewless tensors. That said, make_viewless_tensor()
        #   is called here to be future-proof and corner-case-proof.
        hidden_states = mpu.make_viewless_tensor(
            hidden_states,
            requires_grad = True,
            keep_graph = True,
        )

        # Transpose encoder output.
        if encoder_output is not None:
            encoder_output = encoder_output.transpose(0, 1).contiguous()

        # Forward pass.
        if self.activations_checkpoint_method is not None:
            hidden_states = self._checkpointed_forward(hidden_states,
                                                       attention_mask,
+3 −0
Original line number Diff line number Diff line
@@ -66,6 +66,9 @@ from .random import get_cuda_rng_tracker
from .random import model_parallel_cuda_manual_seed
from .random import gather_split_1d_tensor
from .random import split_tensor_into_1d_equal_chunks
from .random import make_viewless_tensor
from .random import assert_viewless_tensor
from .random import safely_set_viewless_tensor_data

from .utils import divide
from .utils import split_tensor_along_last_dim
+84 −4
Original line number Diff line number Diff line
@@ -98,6 +98,84 @@ def gather_split_1d_tensor(tensor):
                                 group=get_tensor_model_parallel_group())
    return gathered

def _kernel_make_viewless_tensor(inp, requires_grad):
    '''Make a viewless tensor.

    View tensors have the undesirable side-affect of retaining a reference
    to the originally-viewed tensor, even after manually setting the '.data'
    field. This method creates a new tensor that links to the old tensor's
    data, without linking the viewed tensor, referenced via the '._base'
    field.
    '''
    out = torch.empty(
        (1,),
        dtype = inp.dtype,
        device = inp.device,
        requires_grad = requires_grad,
    )
    out.data = inp.data
    return out

class MakeViewlessTensor(torch.autograd.Function):
    '''
    Autograd function to make a viewless tensor.

    This function should be used in cases where the computation graph needs
    to be propagated, but we only want a viewless tensor (e.g.,
    ParallelTransformer's hidden_states). Call this function by passing
    'keep_graph = True' to 'make_viewless_tensor()'.
    '''
    @staticmethod
    def forward(ctx, inp, requires_grad):
        return _kernel_make_viewless_tensor(inp, requires_grad)
    @staticmethod
    def backward(ctx, grad_output):
        return grad_output, None

def make_viewless_tensor(inp, requires_grad, keep_graph):
    '''
    Entry-point for creating viewless tensors.

    This method should be used, rather than calling 'MakeViewlessTensor'
    or '_kernel_make_viewless_tensor' directly. This method acts as a
    switch for determining if an autograd function or a regular method
    should be used to create the tensor.
    '''

    # return tensor as-is, if not a 'view'
    if inp._base is None:
        return inp

    # create viewless tensor
    if keep_graph:
        return MakeViewlessTensor.apply(inp, requires_grad)
    else:
        return _kernel_make_viewless_tensor(inp, requires_grad)

def assert_viewless_tensor(tensor, extra_msg = None):
    '''Assert that a tensor is not a view (i.e., its '._base' field is
    not set).'''
    if isinstance(tensor, list):
        [ assert_viewless_tensor(t) for t in tensor ]
        return tensor
    if not isinstance(tensor, torch.Tensor):
        return tensor
    assert tensor._base is None, (
        "Ensure tensor._base is None before setting tensor.data or storing "
        "tensor to memory buffer. Otherwise, a memory leak will occur (and "
        "likely accumulate over iterations). %s"
    ) % extra_msg
    return tensor

def safely_set_viewless_tensor_data(tensor, new_data_tensor):
    '''Safely set tensor's '.data' field.

    Check first that the tensor is viewless (i.e., '._base' not set). If not,
    raise an exception.
    '''
    assert_viewless_tensor(tensor, extra_msg = "FYI, tensor._base has shape %s, and new_data_tensor has shape %s." % ("--" if tensor._base is None else tensor._base.shape, new_data_tensor.shape))
    tensor.data = new_data_tensor


class CudaRNGStatesTracker:
    """Tracker for the cuda RNG states.
@@ -242,8 +320,9 @@ class CheckpointFunction(torch.autograd.Function):
        # the chunk corresponding to the current rank.
        if distribute_checkpointed_activations:
            ctx.input_0_shape = args[0].data.shape
            args[0].data = split_tensor_into_1d_equal_chunks(args[0].data,
                                                             new_buffer=True)
            safely_set_viewless_tensor_data(
                args[0],
                split_tensor_into_1d_equal_chunks(args[0].data, new_buffer=True))

        # Store everything.
        ctx.save_for_backward(*args)
@@ -257,8 +336,9 @@ class CheckpointFunction(torch.autograd.Function):
                               "please use .backward() if possible")
        inputs = ctx.saved_tensors
        if ctx.distribute_checkpointed_activations:
            inputs[0].data = gather_split_1d_tensor(inputs[0].data)
            inputs[0].data = inputs[0].data.view(ctx.input_0_shape)
            safely_set_viewless_tensor_data(
                inputs[0],
                gather_split_1d_tensor(inputs[0].data).view(ctx.input_0_shape))

        # Store the current states.
        bwd_cpu_rng_state = torch.get_rng_state()
+6 −0
Original line number Diff line number Diff line
@@ -142,10 +142,16 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
        if recv_prev:
            tensor_recv_prev = mpu.gather_split_1d_tensor(
                tensor_recv_prev).view(tensor_shape).requires_grad_()
            tensor_recv_prev = mpu.make_viewless_tensor(tensor_recv_prev,
                                                        requires_grad = True,
                                                        keep_graph = False)

        if recv_next:
            tensor_recv_next = mpu.gather_split_1d_tensor(
                tensor_recv_next).view(tensor_shape).requires_grad_()
            tensor_recv_next = mpu.make_viewless_tensor(tensor_recv_next,
                                                        requires_grad = True,
                                                        keep_graph = False)

    return tensor_recv_prev, tensor_recv_next

+1 −0
Original line number Diff line number Diff line
@@ -28,6 +28,7 @@ from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module
from megatron.model import ModelType


def get_forward_backward_func():
    args = get_args()
    if mpu.get_pipeline_model_parallel_world_size() > 1: