Commit 270d6412 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

loss matches; memory savings for multi-node (tested n3, n16)

parent b6d4369b
Loading
Loading
Loading
Loading
+6 −9
Original line number Diff line number Diff line
@@ -699,15 +699,12 @@ class ParallelTransformer(MegatronModule):
            # See set_input_tensor()
            hidden_states = self.input_tensor

        # hidden_states = make_standalone_tensor(hidden_states)
        # hidden_states = MakeStandaloneTensor.apply(hidden_states)
        # hidden_states = MakeViewlessTensor.apply(hidden_states)
        hidden_states = make_viewless_tensor(hidden_states)
        # hidden_states = hidden_states.clone()
        # >>>
        # from lutil import pax
        # pax(0, {"hidden_states": hidden_states})
        # <<<
        # Viewless tensor
        hidden_states = make_viewless_tensor(
            hidden_states,
            requires_grad = True,
            keep_graph = True,
        )

        if encoder_output is not None:
             encoder_output = encoder_output.transpose(0, 1).contiguous()
+58 −19
Original line number Diff line number Diff line
@@ -99,7 +99,7 @@ def gather_split_1d_tensor(tensor):
    return gathered

# >>>
# from lutil import pax
from lutil import pax # ****************

# def make_standalone_tensor(a):
#     assert a._base is not None
@@ -107,26 +107,66 @@ def gather_split_1d_tensor(tensor):
#     b.data = a.data
#     return b
# class MakeStandaloneTensor(torch.autograd.Function):
class MakeViewlessTensor_(torch.autograd.Function):
# class MakeViewlessTensor_(torch.autograd.Function):
class MakeViewlessTensor(torch.autograd.Function):
    # @staticmethod
    # def forward(ctx, inp):
    #     assert inp._base is not None
    #     out = torch.empty((1,), dtype = inp.dtype, device = inp.device)
    #     out.data = inp.data
    #     # pax(0, {"inp": inp, "out": out})
    #     return out
    @staticmethod
    def forward(ctx, inp):
        assert inp._base is not None
        out = torch.empty((1,), dtype = inp.dtype, device = inp.device)
        out.data = inp.data
        # pax(0, {"inp": inp, "out": out})
        return out
    def forward(ctx, inp, requires_grad):
        return _kernel_make_viewless_tensor(inp, requires_grad)
    # @staticmethod
    # def forward(ctx, args):
    #     return [_kernel_make_viewless_tensor(*args)]
    @staticmethod
    def backward(ctx, grad_output):
        # pax(0, {"grad_output": grad_output})
        return grad_output
        # return grad_output
        return grad_output, None

def _kernel_make_viewless_tensor(inp, requires_grad):
    out = torch.empty(
        (1,),
        dtype = inp.dtype,
        device = inp.device,
        requires_grad = requires_grad,
    )
    out.data = inp.data
    # >>>
    # pax(0, {"inp": inp, "out": out})
    # assert out.requires_grad
    # <<<
    return out

def make_viewless_tensor(tensor):
    if tensor._base is None:
        return tensor
# def make_viewless_tensor(tensor):
#     if tensor._base is None:
#         return tensor
#     else:
#         return MakeViewlessTensor_.apply(tensor)
def make_viewless_tensor(inp, requires_grad, keep_graph):

    # return tensor as-is, if not a 'view'
    if inp._base is None:
        return inp

    # create viewless tensor
    if keep_graph:
        # return MakeViewlessTensor.apply((inp, requires_grad))[0]
        return MakeViewlessTensor.apply(inp, requires_grad)
    else:
        return MakeViewlessTensor_.apply(tensor)
        return _kernel_make_viewless_tensor(inp, requires_grad)
    # return MakeViewlessTensor.apply((inp, requires_grad))[0]
    # return MakeViewlessTensor.apply(inp, requires_grad)
    # return MakeViewlessTensor.apply(inp)
    # return MakeViewlessTensor.apply(inp, 7)
    # return MakeViewlessTensor.apply(inp, 7)[0]


def assert_viewless_tensor(tensor):
def assert_viewless_tensor(tensor, extra_msg = None):
    if isinstance(tensor, list):
        [ assert_viewless_tensor(t) for t in tensor ]
        return
@@ -137,13 +177,12 @@ def assert_viewless_tensor(tensor):
    assert tensor._base is None, (
        "Ensure tensor._base is None before setting tensor.data or storing "
        "tensor to memory buffer. Otherwise, a memory leak will occur (and "
        "likely accumulate over iterations). FYI, tensor._base has shape "
        "%s, and new_data_tensor has shape %s."
    ) % (tensor._base.shape, new_data_tensor.shape)
        "likely accumulate over iterations). %s"
    ) % extra_msg

# def set_viewless_tensor_data_attr(tensor, new_data_tensor):
def safely_set_tensor_data_attr(tensor, new_data_tensor):
    assert_viewless_tensor(tensor)
    assert_viewless_tensor(tensor, extra_msg = "FYI, tensor._base has shape %s, and new_data_tensor has shape %s." % ("--" if tensor._base is None else tensor._base.shape, new_data_tensor.shape))
    tensor.data = new_data_tensor
# <<<

+6 −2
Original line number Diff line number Diff line
@@ -145,12 +145,16 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
        if recv_prev:
            tensor_recv_prev = mpu.gather_split_1d_tensor(
                tensor_recv_prev).view(tensor_shape).requires_grad_()
            tensor_recv_prev = make_viewless_tensor(tensor_recv_prev)
            tensor_recv_prev = make_viewless_tensor(tensor_recv_prev,
                                                    requires_grad = True,
                                                    keep_graph = False)

        if recv_next:
            tensor_recv_next = mpu.gather_split_1d_tensor(
                tensor_recv_next).view(tensor_shape).requires_grad_()
            tensor_recv_next = make_viewless_tensor(tensor_recv_next)
            tensor_recv_next = make_viewless_tensor(tensor_recv_next,
                                                    requires_grad = True,
                                                    keep_graph = False)

    return tensor_recv_prev, tensor_recv_next

+0 −16
Original line number Diff line number Diff line
@@ -631,13 +631,6 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
        send_forward(output_tensor, send_tensor_shapes, timers=timers)

        if not forward_only:
            # >>>
            # if input_tensor[0] is not None:
            #     from lutil import pax
            #     pax({
            #         "input_tensor" : input_tensor,
            #     })
            # <<<
            assert_viewless_tensor(input_tensor)
            assert_viewless_tensor(output_tensor)
            input_tensors.append(input_tensor)
@@ -669,15 +662,6 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
                                           timers=timers)

            # Add input_tensor and output_tensor to end of list.
            # >>>
            # assert input_tensor[0]._base is None, \
            #     "rank %s; uh oh." % torch.distributed.get_rank()
            # if input_tensor[0] is not None:
            #     from lutil import pax
            #     pax(4, {
            #         "input_tensor[0]" : input_tensor[0],
            #     })
            # <<<
            assert_viewless_tensor(input_tensor)
            assert_viewless_tensor(output_tensor)
            input_tensors.append(input_tensor)