Commit ed0c8714 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

more iterating on 'viewless tensor' methods

parent 5422d23a
Loading
Loading
Loading
Loading
+8 −10
Original line number Diff line number Diff line
@@ -27,6 +27,9 @@ from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl
from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu

# >>>
from megatron.mpu.random import make_viewless_tensor
# <<<

""" We use the following notation throughout this file:
     h: hidden size
@@ -696,19 +699,14 @@ class ParallelTransformer(MegatronModule):
            # See set_input_tensor()
            hidden_states = self.input_tensor

        # >>>
        def make_standalone_tensor(a):
            assert a._base is not None
            b = torch.empty((1,), dtype = a.dtype, device = a.device)
            b.data = a.data
            return b
        # <<<

        # hidden_states = make_standalone_tensor(hidden_states)
        hidden_states = hidden_states.clone()
        # hidden_states = MakeStandaloneTensor.apply(hidden_states)
        # hidden_states = MakeViewlessTensor.apply(hidden_states)
        hidden_states = make_viewless_tensor(hidden_states)
        # hidden_states = hidden_states.clone()
        # >>>
        # from lutil import pax
        # pax({"hidden_states": hidden_states})
        # pax(0, {"hidden_states": hidden_states})
        # <<<

        if encoder_output is not None:
+57 −6
Original line number Diff line number Diff line
@@ -98,13 +98,54 @@ def gather_split_1d_tensor(tensor):
                                 group=get_tensor_model_parallel_group())
    return gathered

def safely_set_tensor_data_attr(tensor, new_data_tensor):
# >>>
# from lutil import pax

# def make_standalone_tensor(a):
#     assert a._base is not None
#     b = torch.empty((1,), dtype = a.dtype, device = a.device)
#     b.data = a.data
#     return b
# class MakeStandaloneTensor(torch.autograd.Function):
class MakeViewlessTensor_(torch.autograd.Function):
    @staticmethod
    def forward(ctx, inp):
        assert inp._base is not None
        out = torch.empty((1,), dtype = inp.dtype, device = inp.device)
        out.data = inp.data
        # pax(0, {"inp": inp, "out": out})
        return out
    @staticmethod
    def backward(ctx, grad_output):
        # pax(0, {"grad_output": grad_output})
        return grad_output

def make_viewless_tensor(tensor):
    if tensor._base is None:
        return tensor
    else:
        return MakeViewlessTensor_.apply(tensor)

def assert_viewless_tensor(tensor):
    if isinstance(tensor, list):
        [ assert_viewless_tensor(t) for t in tensor ]
        return
    # assert isinstance(tensor, torch.Tensor), \
    #     "expected Tensor; found %s." % type(tensor).__name__
    if not isinstance(tensor, torch.Tensor):
        return
    assert tensor._base is None, (
        "Ensure tensor._base is None before setting tensor.data. Otherwise, "
        "a memory leak will occur (and likely accumulate over iterations). "
        "FYI, tensor._base has shape %s, and new_data_tensor has shape %s."
        "Ensure tensor._base is None before setting tensor.data or storing "
        "tensor to memory buffer. Otherwise, a memory leak will occur (and "
        "likely accumulate over iterations). FYI, tensor._base has shape "
        "%s, and new_data_tensor has shape %s."
    ) % (tensor._base.shape, new_data_tensor.shape)

# def set_viewless_tensor_data_attr(tensor, new_data_tensor):
def safely_set_tensor_data_attr(tensor, new_data_tensor):
    assert_viewless_tensor(tensor)
    tensor.data = new_data_tensor
# <<<

class CudaRNGStatesTracker:
    """Tracker for the cuda RNG states.
@@ -253,11 +294,13 @@ class CheckpointFunction(torch.autograd.Function):
            # with data_leak_ctx(args[0]):
            # <<<
            ctx.input_0_shape = args[0].data.shape
            # >>>
            # args[0].data = split_tensor_into_1d_equal_chunks(args[0].data,
            #                                                  new_buffer=True)
            safely_set_tensor_data_attr(
                args[0],
                split_tensor_into_1d_equal_chunks(args[0].data, new_buffer=True))
            # <<<

        # Store everything.
        ctx.save_for_backward(*args)
@@ -271,8 +314,16 @@ class CheckpointFunction(torch.autograd.Function):
                               "please use .backward() if possible")
        inputs = ctx.saved_tensors
        if ctx.distribute_checkpointed_activations:
            inputs[0].data = gather_split_1d_tensor(inputs[0].data)
            inputs[0].data = inputs[0].data.view(ctx.input_0_shape)
            # >>>
            # inputs[0].data = gather_split_1d_tensor(inputs[0].data)
            # inputs[0].data = inputs[0].data.view(ctx.input_0_shape)
            safely_set_tensor_data_attr(
                inputs[0],
                gather_split_1d_tensor(inputs[0].data))
            safely_set_tensor_data_attr(
                inputs[0],
                inputs[0].data.view(ctx.input_0_shape))
            # <<<

        # Store the current states.
        bwd_cpu_rng_state = torch.get_rng_state()
+5 −0
Original line number Diff line number Diff line
@@ -20,6 +20,9 @@ import torch
from megatron import get_args
from megatron import mpu

# >>>
from megatron.mpu.random import make_viewless_tensor
# <<<

def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
                 tensor_shape,
@@ -142,10 +145,12 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
        if recv_prev:
            tensor_recv_prev = mpu.gather_split_1d_tensor(
                tensor_recv_prev).view(tensor_shape).requires_grad_()
            tensor_recv_prev = make_viewless_tensor(tensor_recv_prev)

        if recv_next:
            tensor_recv_next = mpu.gather_split_1d_tensor(
                tensor_recv_next).view(tensor_shape).requires_grad_()
            tensor_recv_next = make_viewless_tensor(tensor_recv_next)

    return tensor_recv_prev, tensor_recv_next

+36 −0
Original line number Diff line number Diff line
@@ -28,6 +28,10 @@ from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module
from megatron.model import ModelType

# >>>
from megatron.mpu.random import assert_viewless_tensor
# <<<

def get_forward_backward_func():
    args = get_args()
    if mpu.get_pipeline_model_parallel_world_size() > 1:
@@ -306,6 +310,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
                                     model[model_chunk_id],
                                     input_tensor, losses_reduced)
        output_tensors[model_chunk_id].append(output_tensor)
        assert_viewless_tensor(output_tensor)

        # if forward-only, no need to save tensors for a backward pass
        if forward_only:
@@ -339,6 +344,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
    mpu.set_virtual_pipeline_model_parallel_rank(0)
    input_tensors[0].append(
        p2p_communication.recv_forward(tensor_shape, timers=timers))
    assert_viewless_tensor(input_tensors[0][-1])
    for k in range(num_warmup_microbatches):
        output_tensor = forward_step_helper(k)

@@ -370,6 +376,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
                        tensor_shape=tensor_shape,
                        timers=timers)
            output_tensor_grads[num_model_chunks-1].append(output_tensor_grad)
            assert_viewless_tensor(output_tensor_grad)
        else:
            input_tensor = \
                p2p_communication.send_forward_recv_forward(
@@ -378,6 +385,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
                    timers=timers)
        free_output_tensor(output_tensor, args.deallocate_pipeline_outputs)
        input_tensors[next_forward_model_chunk_id].append(input_tensor)
        assert_viewless_tensor(input_tensor)

    # Run 1F1B in steady state.
    for k in range(num_microbatches_remaining):
@@ -447,15 +455,18 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
        # right location.
        if recv_prev:
            input_tensors[next_forward_model_chunk_id].append(input_tensor)
            assert_viewless_tensor(input_tensor)
        if recv_next:
            output_tensor_grads[next_backward_model_chunk_id].append(
                output_tensor_grad)
            assert_viewless_tensor(output_tensor_grad)

    # Run cooldown backward passes (flush out pipeline).
    if not forward_only:
        if all_warmup_microbatches:
            output_tensor_grads[num_model_chunks-1].append(
                p2p_communication.recv_backward(tensor_shape, timers=timers))
            assert_viewless_tensor(output_tensor_grads[num_model_chunks-1][-1])
        for k in range(num_microbatches_remaining, num_microbatches):
            input_tensor_grad = backward_step_helper(k)
            next_backward_model_chunk_id = get_model_chunk_id(k+1, forward=False)
@@ -470,6 +481,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
                    input_tensor_grad, recv_next=recv_next,
                    tensor_shape=tensor_shape,
                    timers=timers))
            assert_viewless_tensor(output_tensor_grads[next_backward_model_chunk_id][-1])

    return losses_reduced

@@ -508,6 +520,7 @@ def recv_forward(tensor_shapes, timers):
        else:
            input_tensors.append(p2p_communication.recv_forward(tensor_shape,
                                                                timers=timers))
            assert_viewless_tensor(input_tensors[-1])
    return input_tensors


@@ -519,6 +532,7 @@ def recv_backward(tensor_shapes, timers):
        else:
            output_tensor_grads.append(p2p_communication.recv_backward(tensor_shape,
                                                                       timers=timers))
            assert_viewless_tensor(output_tensor_grads[-1])
    return output_tensor_grads


@@ -551,6 +565,7 @@ def send_forward_recv_backward(output_tensors, tensor_shapes, timers):
        output_tensor_grad = p2p_communication.send_forward_recv_backward(
                output_tensor, tensor_shape, timers=timers)
        output_tensor_grads.append(output_tensor_grad)
        assert_viewless_tensor(output_tensor_grad)
    return output_tensor_grads


@@ -565,6 +580,7 @@ def send_backward_recv_forward(input_tensor_grads, tensor_shapes, timers):
        input_tensor = p2p_communication.send_backward_recv_forward(
                input_tensor_grad, tensor_shape, timers=timers)
        input_tensors.append(input_tensor)
        assert_viewless_tensor(input_tensor)
    return input_tensors


@@ -615,6 +631,15 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
        send_forward(output_tensor, send_tensor_shapes, timers=timers)

        if not forward_only:
            # >>>
            if input_tensor[0] is not None:
                from lutil import pax
                pax({
                    "input_tensor" : input_tensor,
                })
            # <<<
            assert_viewless_tensor(input_tensor)
            assert_viewless_tensor(output_tensor)
            input_tensors.append(input_tensor)
            output_tensors.append(output_tensor)
            free_output_tensor(output_tensor, args.deallocate_pipeline_outputs)
@@ -644,6 +669,17 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
                                           timers=timers)

            # Add input_tensor and output_tensor to end of list.
            # >>>
            # assert input_tensor[0]._base is None, \
            #     "rank %s; uh oh." % torch.distributed.get_rank()
            # if input_tensor[0] is not None:
            #     from lutil import pax
            #     pax(4, {
            #         "input_tensor[0]" : input_tensor[0],
            #     })
            # <<<
            assert_viewless_tensor(input_tensor)
            assert_viewless_tensor(output_tensor)
            input_tensors.append(input_tensor)
            output_tensors.append(output_tensor)
            free_output_tensor(output_tensor, args.deallocate_pipeline_outputs)