Commit a1fe4805 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

added comments

parent 806422e5
Loading
Loading
Loading
Loading
+0 −3
Original line number Diff line number Diff line
@@ -696,9 +696,6 @@ class ParallelTransformer(MegatronModule):
            hidden_states = self.input_tensor

        # Viewless tensor
        # >>>
        assert hidden_states is not None, "rank == %d." % torch.distributed.get_rank()
        # <<<
        hidden_states = mpu.make_viewless_tensor(
            hidden_states,
            requires_grad = True,
+39 −18
Original line number Diff line number Diff line
@@ -98,15 +98,15 @@ def gather_split_1d_tensor(tensor):
                                 group=get_tensor_model_parallel_group())
    return gathered

class MakeViewlessTensor(torch.autograd.Function):
    @staticmethod
    def forward(ctx, inp, requires_grad):
        return _kernel_make_viewless_tensor(inp, requires_grad)
    @staticmethod
    def backward(ctx, grad_output):
        return grad_output, None

def _kernel_make_viewless_tensor(inp, requires_grad):
    '''Make a viewless tensor.

    View tensors have the undesirable side-affect of retaining a reference
    to the originally-viewed tensor, even after manually setting the '.data'
    field. This method creates a new tensor that links to the old tensor's
    data, without linking the viewed tensor, referenced via the '._base'
    field.
    '''
    out = torch.empty(
        (1,),
        dtype = inp.dtype,
@@ -116,7 +116,31 @@ def _kernel_make_viewless_tensor(inp, requires_grad):
    out.data = inp.data
    return out

class MakeViewlessTensor(torch.autograd.Function):
    '''
    Autograd function to make a viewless tensor.

    This function should be used in cases where the computation graph needs
    to be propagated, but we only want a viewless tensor (e.g.,
    ParallelTransformer's hidden_states). Call this function by passing
    'keep_graph = True' to 'make_viewless_tensor()'.
    '''
    @staticmethod
    def forward(ctx, inp, requires_grad):
        return _kernel_make_viewless_tensor(inp, requires_grad)
    @staticmethod
    def backward(ctx, grad_output):
        return grad_output, None

def make_viewless_tensor(inp, requires_grad, keep_graph):
    '''
    Entry-point for creating viewless tensors.

    This method should be used, rather than calling 'MakeViewlessTensor'
    or '_kernel_make_viewless_tensor' directly. This method acts as a
    switch for determining if an autograd function or a regular method
    should be used to create the tensor.
    '''

    # return tensor as-is, if not a 'view'
    if inp._base is None:
@@ -129,6 +153,8 @@ def make_viewless_tensor(inp, requires_grad, keep_graph):
        return _kernel_make_viewless_tensor(inp, requires_grad)

def assert_viewless_tensor(tensor, extra_msg = None):
    '''Assert that a tensor is not a view (i.e., its '._base' field is
    not set).'''
    if isinstance(tensor, list):
        [ assert_viewless_tensor(t) for t in tensor ]
        return tensor
@@ -142,6 +168,11 @@ def assert_viewless_tensor(tensor, extra_msg = None):
    return tensor

def safely_set_viewless_tensor_data(tensor, new_data_tensor):
    '''Safely set tensor's '.data' field.

    Check first that the tensor is viewless (i.e., '._base' not set). If not,
    raise an exception.
    '''
    assert_viewless_tensor(tensor, extra_msg = "FYI, tensor._base has shape %s, and new_data_tensor has shape %s." % ("--" if tensor._base is None else tensor._base.shape, new_data_tensor.shape))
    tensor.data = new_data_tensor

@@ -305,19 +336,9 @@ class CheckpointFunction(torch.autograd.Function):
                               "please use .backward() if possible")
        inputs = ctx.saved_tensors
        if ctx.distribute_checkpointed_activations:
            # >>>
            # inputs[0].data = gather_split_1d_tensor(inputs[0].data)
            # inputs[0].data = inputs[0].data.view(ctx.input_0_shape)
            # safely_set_tensor_data_attr(
            #     inputs[0],
            #     gather_split_1d_tensor(inputs[0].data))
            # safely_set_tensor_data_attr(
            #     inputs[0],
            #     inputs[0].data.view(ctx.input_0_shape))
            safely_set_viewless_tensor_data(
                inputs[0],
                gather_split_1d_tensor(inputs[0].data).view(ctx.input_0_shape))
            # <<<

        # Store the current states.
        bwd_cpu_rng_state = torch.get_rng_state()
+0 −3
Original line number Diff line number Diff line
@@ -28,9 +28,6 @@ from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module
from megatron.model import ModelType

# >>>
# from megatron.mpu.random import assert_viewless_tensor
# <<<

def get_forward_backward_func():
    args = get_args()