Commit 5422d23a authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

debugging make_standalone_tensor(), safely_set_tensor_data_attr()

parent 9a8b89ac
Loading
Loading
Loading
Loading
+15 −0
Original line number Diff line number Diff line
@@ -696,6 +696,21 @@ class ParallelTransformer(MegatronModule):
            # See set_input_tensor()
            hidden_states = self.input_tensor

        # >>>
        def make_standalone_tensor(a):
            assert a._base is not None
            b = torch.empty((1,), dtype = a.dtype, device = a.device)
            b.data = a.data
            return b
        # <<<

        # hidden_states = make_standalone_tensor(hidden_states)
        hidden_states = hidden_states.clone()
        # >>>
        # from lutil import pax
        # pax({"hidden_states": hidden_states})
        # <<<

        if encoder_output is not None:
             encoder_output = encoder_output.transpose(0, 1).contiguous()

+16 −2
Original line number Diff line number Diff line
@@ -98,6 +98,13 @@ def gather_split_1d_tensor(tensor):
                                 group=get_tensor_model_parallel_group())
    return gathered

def safely_set_tensor_data_attr(tensor, new_data_tensor):
    assert tensor._base is None, (
        "Ensure tensor._base is None before setting tensor.data. Otherwise, "
        "a memory leak will occur (and likely accumulate over iterations). "
        "FYI, tensor._base has shape %s, and new_data_tensor has shape %s."
    ) % (tensor._base.shape, new_data_tensor.shape)
    tensor.data = new_data_tensor

class CudaRNGStatesTracker:
    """Tracker for the cuda RNG states.
@@ -241,9 +248,16 @@ class CheckpointFunction(torch.autograd.Function):
        # Divide hidden states across model parallel group and only keep
        # the chunk corresponding to the current rank.
        if distribute_checkpointed_activations:
            # >>>
            # from lutil import data_leak_ctx
            # with data_leak_ctx(args[0]):
            # <<<
            ctx.input_0_shape = args[0].data.shape
            args[0].data = split_tensor_into_1d_equal_chunks(args[0].data,
                                                             new_buffer=True)
            # args[0].data = split_tensor_into_1d_equal_chunks(args[0].data,
            #                                                  new_buffer=True)
            safely_set_tensor_data_attr(
                args[0],
                split_tensor_into_1d_equal_chunks(args[0].data, new_buffer=True))

        # Store everything.
        ctx.save_for_backward(*args)