Commit 1dccefd8 authored by Mostofa Patwary's avatar Mostofa Patwary Committed by Deepak Narayanan
Browse files

Make it possible to pass in tensor shapes to communication methods in p2p_communication.py

parent 3db6517a
Loading
Loading
Loading
Loading
+44 −12
Original line number Diff line number Diff line
@@ -22,7 +22,9 @@ from megatron import mpu


def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
                 use_ring_exchange=False):
                 use_ring_exchange=False, tensor_shape=None,
                 override_scatter_gather_tensors_in_pipeline=False,
                 dtype_=None):
    """Communicate tensors between stages. Used as helper method in other
    communication methods that are used in megatron/schedules.py.

@@ -37,7 +39,14 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
                   next rank.
        use_ring_exchange: boolean for whether torch.distributed.ring_exchange()
                           API should be used.

        tensor_shape: optional, use when the input sequence contains less
                      tokens than the default sequence length
        override_scatter_gather_tensors_in_pipeline: optional, this is used
                                                     when tensor_shape is
                                                     provided to overwide
                                                     scatter gather tensors
        dtype_: optional, this is used when tensor_shape is provied and what
                is the type of tensor_shape
    Returns:
        (tensor_recv_prev, tensor_recv_next)
    """
@@ -47,8 +56,10 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
    # if needed.
    tensor_recv_prev = None
    tensor_recv_next = None
    if tensor_shape is None:
        tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)
    if args.scatter_gather_tensors_in_pipeline:
    if not override_scatter_gather_tensors_in_pipeline and \
            args.scatter_gather_tensors_in_pipeline:
        tensor_chunk_shape = reduce(operator.mul, tensor_shape, 1) // \
            mpu.get_tensor_model_parallel_world_size()
    else:
@@ -56,19 +67,26 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
    dtype = args.params_dtype
    if args.fp32_residual_connection:
        dtype = torch.float

    requires_grad = True
    if dtype_ is not None:
        dtype = dtype_
        requires_grad = False

    if recv_prev:
        tensor_recv_prev = torch.empty(tensor_chunk_shape,
                                       requires_grad=True,
                                       requires_grad=requires_grad,
                                       device=torch.cuda.current_device(),
                                       dtype=dtype)
    if recv_next:
        tensor_recv_next = torch.empty(tensor_chunk_shape,
                                       requires_grad=True,
                                       requires_grad=requires_grad,
                                       device=torch.cuda.current_device(),
                                       dtype=dtype)

    # Split tensor into smaller chunks if using scatter-gather optimization.
    if args.scatter_gather_tensors_in_pipeline:
    if not override_scatter_gather_tensors_in_pipeline and \
            args.scatter_gather_tensors_in_pipeline:
        if tensor_send_next is not None:
            tensor_send_next = mpu.split_tensor_into_1d_equal_chunks(tensor_send_next)

@@ -112,7 +130,8 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
    torch.cuda.synchronize()

    # If using scatter-gather optimization, gather smaller chunks.
    if args.scatter_gather_tensors_in_pipeline:
    if not override_scatter_gather_tensors_in_pipeline and \
            args.scatter_gather_tensors_in_pipeline:
        if recv_prev:
            tensor_recv_prev = mpu.gather_split_1d_tensor(
                tensor_recv_prev).view(tensor_shape).requires_grad_()
@@ -124,8 +143,11 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
    return tensor_recv_prev, tensor_recv_next


def recv_forward(timers=None):
def recv_forward(tensor_shape=None,
                 override_scatter_gather_tensors_in_pipeline=False,
                 dtype_=None, timers=None):
    """Receive tensor from previous rank in pipeline (forward receive)."""

    if mpu.is_pipeline_first_stage():
        input_tensor = None
    else:
@@ -135,7 +157,11 @@ def recv_forward(timers=None):
            tensor_send_next=None,
            tensor_send_prev=None,
            recv_prev=True,
            recv_next=False)
            recv_next=False,
            tensor_shape=tensor_shape,
            override_scatter_gather_tensors_in_pipeline=\
                override_scatter_gather_tensors_in_pipeline,
            dtype_=dtype_)
        if timers is not None:
            timers('forward-recv').stop()
    return input_tensor
@@ -158,8 +184,11 @@ def recv_backward(timers=None):
    return output_tensor_grad


def send_forward(output_tensor, timers=None):
def send_forward(output_tensor, timers=None,
                 override_scatter_gather_tensors_in_pipeline=False,
                 dtype_=None):
    """Send tensor to next rank in pipeline (forward send)."""

    if not mpu.is_pipeline_last_stage():
        if timers is not None:
            timers('forward-send').start()
@@ -167,7 +196,10 @@ def send_forward(output_tensor, timers=None):
            tensor_send_next=output_tensor,
            tensor_send_prev=None,
            recv_prev=False,
            recv_next=False)
            recv_next=False,
            override_scatter_gather_tensors_in_pipeline=\
            override_scatter_gather_tensors_in_pipeline,
            dtype_=dtype_)
        if timers is not None:
            timers('forward-send').stop()