Commit 806422e5 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

partially cleaned

parent 05042081
Loading
Loading
Loading
Loading
+4 −5
Original line number Diff line number Diff line
@@ -27,10 +27,6 @@ from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl
from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu

# >>>
from megatron.mpu.random import make_viewless_tensor
# <<<

""" We use the following notation throughout this file:
     h: hidden size
     n: number of attention heads
@@ -700,7 +696,10 @@ class ParallelTransformer(MegatronModule):
            hidden_states = self.input_tensor

        # Viewless tensor
        hidden_states = make_viewless_tensor(
        # >>>
        assert hidden_states is not None, "rank == %d." % torch.distributed.get_rank()
        # <<<
        hidden_states = mpu.make_viewless_tensor(
            hidden_states,
            requires_grad = True,
            keep_graph = True,
+3 −0
Original line number Diff line number Diff line
@@ -65,6 +65,9 @@ from .random import get_cuda_rng_tracker
from .random import model_parallel_cuda_manual_seed
from .random import gather_split_1d_tensor
from .random import split_tensor_into_1d_equal_chunks
from .random import make_viewless_tensor
from .random import assert_viewless_tensor
from .random import safely_set_viewless_tensor_data

from .utils import divide
from .utils import split_tensor_along_last_dim
+14 −60
Original line number Diff line number Diff line
@@ -98,34 +98,12 @@ def gather_split_1d_tensor(tensor):
                                 group=get_tensor_model_parallel_group())
    return gathered

# >>>
from lutil import pax # ****************

# def make_standalone_tensor(a):
#     assert a._base is not None
#     b = torch.empty((1,), dtype = a.dtype, device = a.device)
#     b.data = a.data
#     return b
# class MakeStandaloneTensor(torch.autograd.Function):
# class MakeViewlessTensor_(torch.autograd.Function):
class MakeViewlessTensor(torch.autograd.Function):
    # @staticmethod
    # def forward(ctx, inp):
    #     assert inp._base is not None
    #     out = torch.empty((1,), dtype = inp.dtype, device = inp.device)
    #     out.data = inp.data
    #     # pax(0, {"inp": inp, "out": out})
    #     return out
    @staticmethod
    def forward(ctx, inp, requires_grad):
        return _kernel_make_viewless_tensor(inp, requires_grad)
    # @staticmethod
    # def forward(ctx, args):
    #     return [_kernel_make_viewless_tensor(*args)]
    @staticmethod
    def backward(ctx, grad_output):
        # pax(0, {"grad_output": grad_output})
        # return grad_output
        return grad_output, None

def _kernel_make_viewless_tensor(inp, requires_grad):
@@ -136,17 +114,8 @@ def _kernel_make_viewless_tensor(inp, requires_grad):
        requires_grad = requires_grad,
    )
    out.data = inp.data
    # >>>
    # pax(0, {"inp": inp, "out": out})
    # assert out.requires_grad
    # <<<
    return out

# def make_viewless_tensor(tensor):
#     if tensor._base is None:
#         return tensor
#     else:
#         return MakeViewlessTensor_.apply(tensor)
def make_viewless_tensor(inp, requires_grad, keep_graph):

    # return tensor as-is, if not a 'view'
@@ -155,36 +124,27 @@ def make_viewless_tensor(inp, requires_grad, keep_graph):

    # create viewless tensor
    if keep_graph:
        # return MakeViewlessTensor.apply((inp, requires_grad))[0]
        return MakeViewlessTensor.apply(inp, requires_grad)
    else:
        return _kernel_make_viewless_tensor(inp, requires_grad)
    # return MakeViewlessTensor.apply((inp, requires_grad))[0]
    # return MakeViewlessTensor.apply(inp, requires_grad)
    # return MakeViewlessTensor.apply(inp)
    # return MakeViewlessTensor.apply(inp, 7)
    # return MakeViewlessTensor.apply(inp, 7)[0]


def assert_viewless_tensor(tensor, extra_msg = None):
    if isinstance(tensor, list):
        [ assert_viewless_tensor(t) for t in tensor ]
        return
    # assert isinstance(tensor, torch.Tensor), \
    #     "expected Tensor; found %s." % type(tensor).__name__
        return tensor
    if not isinstance(tensor, torch.Tensor):
        return
        return tensor
    assert tensor._base is None, (
        "Ensure tensor._base is None before setting tensor.data or storing "
        "tensor to memory buffer. Otherwise, a memory leak will occur (and "
        "likely accumulate over iterations). %s"
    ) % extra_msg
    return tensor

# def set_viewless_tensor_data_attr(tensor, new_data_tensor):
def safely_set_tensor_data_attr(tensor, new_data_tensor):
def safely_set_viewless_tensor_data(tensor, new_data_tensor):
    assert_viewless_tensor(tensor, extra_msg = "FYI, tensor._base has shape %s, and new_data_tensor has shape %s." % ("--" if tensor._base is None else tensor._base.shape, new_data_tensor.shape))
    tensor.data = new_data_tensor
# <<<


class CudaRNGStatesTracker:
    """Tracker for the cuda RNG states.
@@ -328,19 +288,10 @@ class CheckpointFunction(torch.autograd.Function):
        # Divide hidden states across model parallel group and only keep
        # the chunk corresponding to the current rank.
        if distribute_checkpointed_activations:
            # >>>
            # raise Exception("distrib.")
            # from lutil import data_leak_ctx
            # with data_leak_ctx(args[0]):
            # <<<
            ctx.input_0_shape = args[0].data.shape
            # >>>
            # args[0].data = split_tensor_into_1d_equal_chunks(args[0].data,
            #                                                  new_buffer=True)
            safely_set_tensor_data_attr(
            safely_set_viewless_tensor_data(
                args[0],
                split_tensor_into_1d_equal_chunks(args[0].data, new_buffer=True))
            # <<<

        # Store everything.
        ctx.save_for_backward(*args)
@@ -357,12 +308,15 @@ class CheckpointFunction(torch.autograd.Function):
            # >>>
            # inputs[0].data = gather_split_1d_tensor(inputs[0].data)
            # inputs[0].data = inputs[0].data.view(ctx.input_0_shape)
            safely_set_tensor_data_attr(
                inputs[0],
                gather_split_1d_tensor(inputs[0].data))
            safely_set_tensor_data_attr(
            # safely_set_tensor_data_attr(
            #     inputs[0],
            #     gather_split_1d_tensor(inputs[0].data))
            # safely_set_tensor_data_attr(
            #     inputs[0],
            #     inputs[0].data.view(ctx.input_0_shape))
            safely_set_viewless_tensor_data(
                inputs[0],
                inputs[0].data.view(ctx.input_0_shape))
                gather_split_1d_tensor(inputs[0].data).view(ctx.input_0_shape))
            # <<<

        # Store the current states.
+6 −9
Original line number Diff line number Diff line
@@ -20,9 +20,6 @@ import torch
from megatron import get_args
from megatron import mpu

# >>>
from megatron.mpu.random import make_viewless_tensor
# <<<

def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
                 tensor_shape,
@@ -145,14 +142,14 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
        if recv_prev:
            tensor_recv_prev = mpu.gather_split_1d_tensor(
                tensor_recv_prev).view(tensor_shape).requires_grad_()
            tensor_recv_prev = make_viewless_tensor(tensor_recv_prev,
            tensor_recv_prev = mpu.make_viewless_tensor(tensor_recv_prev,
                                                        requires_grad = True,
                                                        keep_graph = False)

        if recv_next:
            tensor_recv_next = mpu.gather_split_1d_tensor(
                tensor_recv_next).view(tensor_shape).requires_grad_()
            tensor_recv_next = make_viewless_tensor(tensor_recv_next,
            tensor_recv_next = mpu.make_viewless_tensor(tensor_recv_next,
                                                        requires_grad = True,
                                                        keep_graph = False)

+5 −41
Original line number Diff line number Diff line
@@ -29,7 +29,7 @@ from megatron.model import Float16Module
from megatron.model import ModelType

# >>>
from megatron.mpu.random import assert_viewless_tensor
# from megatron.mpu.random import assert_viewless_tensor
# <<<

def get_forward_backward_func():
@@ -115,17 +115,7 @@ def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_r
        unwrap_output_tensor = True

    unwrapped_model.set_input_tensor(input_tensor)
    # >>>
    # if input_tensor[0] is not None:
    #     from lutil import pax, tp
    #     pax({"input_tensor": tp(input_tensor)})
    # <<<
    output_tensor, loss_func = forward_step_func(data_iterator, model)
    # >>>
    # if input_tensor[0] is not None:
    #     from lutil import pax, tp
    #     pax({"input_tensor": tp(input_tensor)})
    # <<<
    if mpu.is_pipeline_last_stage():
        output_tensor = loss_func(output_tensor)
        loss, loss_reduced = output_tensor
@@ -530,7 +520,6 @@ def recv_forward(tensor_shapes, timers):
        else:
            input_tensors.append(p2p_communication.recv_forward(tensor_shape,
                                                                timers=timers))
            assert_viewless_tensor(input_tensors[-1])
    return input_tensors


@@ -542,7 +531,6 @@ def recv_backward(tensor_shapes, timers):
        else:
            output_tensor_grads.append(p2p_communication.recv_backward(tensor_shape,
                                                                       timers=timers))
            assert_viewless_tensor(output_tensor_grads[-1])
    return output_tensor_grads


@@ -575,7 +563,6 @@ def send_forward_recv_backward(output_tensors, tensor_shapes, timers):
        output_tensor_grad = p2p_communication.send_forward_recv_backward(
                output_tensor, tensor_shape, timers=timers)
        output_tensor_grads.append(output_tensor_grad)
        assert_viewless_tensor(output_tensor_grad)
    return output_tensor_grads


@@ -590,7 +577,6 @@ def send_backward_recv_forward(input_tensor_grads, tensor_shapes, timers):
        input_tensor = p2p_communication.send_backward_recv_forward(
                input_tensor_grad, tensor_shape, timers=timers)
        input_tensors.append(input_tensor)
        assert_viewless_tensor(input_tensor)
    return input_tensors


@@ -636,33 +622,13 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
    # Run warmup forward passes.
    for i in range(num_warmup_microbatches):
        input_tensor = recv_forward(recv_tensor_shapes, timers=timers)
        # >>>
        # if input_tensor[0] is not None:
        #     from lutil import pax
        #     pax({"input_tensor": input_tensor})
        # <<<
        output_tensor = forward_step(forward_step_func, data_iterator, model,
                                     input_tensor, losses_reduced)
        # >>>
        # if True or input_tensor[0] is not None:
        #     from lutil import pax
        #     pax({"input_tensor": input_tensor})
        # <<<
        send_forward(output_tensor, send_tensor_shapes, timers=timers)

        if not forward_only:
            # >>>
            # if input_tensor[0] is not None:
            #     from lutil import pax
            #     pax({"input_tensor": input_tensor})
            # if output_tensor[0] is not None:
            #     from lutil import pax
            #     pax(0, {"output_tensor / 0": output_tensor[0]})
            # <<<
            assert_viewless_tensor(input_tensor)
            assert_viewless_tensor(output_tensor)
            input_tensors.append(input_tensor)
            output_tensors.append(output_tensor)
            input_tensors.append(mpu.assert_viewless_tensor(input_tensor))
            output_tensors.append(mpu.assert_viewless_tensor(output_tensor))
            free_output_tensor(output_tensor, args.deallocate_pipeline_outputs)

    # Before running 1F1B, need to receive first forward tensor.
@@ -690,10 +656,8 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
                                           timers=timers)

            # Add input_tensor and output_tensor to end of list.
            assert_viewless_tensor(input_tensor)
            assert_viewless_tensor(output_tensor)
            input_tensors.append(input_tensor)
            output_tensors.append(output_tensor)
            input_tensors.append(mpu.assert_viewless_tensor(input_tensor))
            output_tensors.append(mpu.assert_viewless_tensor(output_tensor))
            free_output_tensor(output_tensor, args.deallocate_pipeline_outputs)

            # Pop input_tensor and output_tensor from the start of the list for