Commit cbad126a authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Bring back call to ring_exchange() in _communicate()

parent 78cf869f
Loading
Loading
Loading
Loading
+36 −26
Original line number Diff line number Diff line
@@ -21,7 +21,8 @@ from megatron import get_args
from megatron import mpu


def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
                 use_ring_exchange=False):
    """Communicate tensors between stages. Used as helper method in other
    communication methods that are used in megatron/schedules.py.

@@ -34,6 +35,8 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
                   previous rank.
        recv_next: boolean for whether tensor should be received from
                   next rank.
        use_ring_exchange: boolean for whether torch.distributed.ring_exchange()
                           API should be used.

    Returns:
        (tensor_recv_prev, tensor_recv_next)
@@ -73,6 +76,13 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
            tensor_send_prev = mpu.split_tensor_into_1d_equal_chunks(tensor_send_prev)

    # Send tensors in both the forward and backward directions as appropriate.
    if use_ring_exchange:
        torch.distributed.ring_exchange(tensor_send_prev=tensor_send_prev,
                                        tensor_recv_prev=tensor_recv_prev,
                                        tensor_send_next=tensor_send_next,
                                        tensor_recv_next=tensor_recv_next,
                                        group=mpu.get_pipeline_model_parallel_group())
    else:
        ops = []
        if tensor_send_prev is not None:
            send_prev_op = torch.distributed.P2POp(