Commit 788b59e7 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'remove_ring_exchange' into 'main'

Get PyTorch batched communication API working for interleaved schedule

See merge request ADLR/megatron-lm!242
parents 1acac4e3 cbad126a
Loading
Loading
Loading
Loading
+19 −27
Original line number Diff line number Diff line
@@ -104,6 +104,7 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
                torch.distributed.irecv, tensor_recv_next,
                mpu.get_pipeline_model_parallel_next_rank())
            ops.append(recv_next_op)
        if len(ops) > 0:
            reqs = torch.distributed.batch_isend_irecv(ops)
            for req in reqs:
                req.wait()
@@ -123,7 +124,7 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
    return tensor_recv_prev, tensor_recv_next


def recv_forward(timers=None, use_ring_exchange=False):
def recv_forward(timers=None):
    """Receive tensor from previous rank in pipeline (forward receive)."""
    if mpu.is_pipeline_first_stage():
        input_tensor = None
@@ -134,14 +135,13 @@ def recv_forward(timers=None, use_ring_exchange=False):
            tensor_send_next=None,
            tensor_send_prev=None,
            recv_prev=True,
            recv_next=False,
            use_ring_exchange=use_ring_exchange)
            recv_next=False)
        if timers is not None:
            timers('forward-recv').stop()
    return input_tensor


def recv_backward(timers=None, use_ring_exchange=False):
def recv_backward(timers=None):
    """Receive tensor from next rank in pipeline (backward receive)."""
    if mpu.is_pipeline_last_stage():
        output_tensor_grad = None
@@ -152,14 +152,13 @@ def recv_backward(timers=None, use_ring_exchange=False):
            tensor_send_next=None,
            tensor_send_prev=None,
            recv_prev=False,
            recv_next=True,
            use_ring_exchange=use_ring_exchange)
            recv_next=True)
        if timers is not None:
            timers('backward-recv').stop()
    return output_tensor_grad


def send_forward(output_tensor, timers=None, use_ring_exchange=False):
def send_forward(output_tensor, timers=None):
    """Send tensor to next rank in pipeline (forward send)."""
    if not mpu.is_pipeline_last_stage():
        if timers is not None:
@@ -168,13 +167,12 @@ def send_forward(output_tensor, timers=None, use_ring_exchange=False):
            tensor_send_next=output_tensor,
            tensor_send_prev=None,
            recv_prev=False,
            recv_next=False,
            use_ring_exchange=use_ring_exchange)
            recv_next=False)
        if timers is not None:
            timers('forward-send').stop()


def send_backward(input_tensor_grad, timers=None, use_ring_exchange=False):
def send_backward(input_tensor_grad, timers=None):
    """Send tensor to previous rank in pipeline (backward send)."""
    if not mpu.is_pipeline_first_stage():
        if timers is not None:
@@ -183,13 +181,12 @@ def send_backward(input_tensor_grad, timers=None, use_ring_exchange=False):
            tensor_send_next=None,
            tensor_send_prev=input_tensor_grad,
            recv_prev=False,
            recv_next=False,
            use_ring_exchange=use_ring_exchange)
            recv_next=False)
        if timers is not None:
            timers('backward-send').stop()


def send_forward_recv_backward(output_tensor, timers=None, use_ring_exchange=False):
def send_forward_recv_backward(output_tensor, timers=None):
    """Batched send and recv with next rank in pipeline."""
    if mpu.is_pipeline_last_stage():
        output_tensor_grad = None
@@ -200,14 +197,13 @@ def send_forward_recv_backward(output_tensor, timers=None, use_ring_exchange=Fal
            tensor_send_next=output_tensor,
            tensor_send_prev=None,
            recv_prev=False,
            recv_next=True,
            use_ring_exchange=use_ring_exchange)
            recv_next=True)
        if timers is not None:
            timers('forward-send-backward-recv').stop()
    return output_tensor_grad


def send_backward_recv_forward(input_tensor_grad, timers=None, use_ring_exchange=False):
def send_backward_recv_forward(input_tensor_grad, timers=None):
    """Batched send and recv with previous rank in pipeline."""
    if mpu.is_pipeline_first_stage():
        input_tensor = None
@@ -218,8 +214,7 @@ def send_backward_recv_forward(input_tensor_grad, timers=None, use_ring_exchange
            tensor_send_next=None,
            tensor_send_prev=input_tensor_grad,
            recv_prev=True,
            recv_next=False,
            use_ring_exchange=use_ring_exchange)
            recv_next=False)
        if timers is not None:
            timers('backward-send-forward-recv').stop()
    return input_tensor
@@ -233,8 +228,7 @@ def send_forward_recv_forward(output_tensor, recv_prev, timers=None):
        tensor_send_next=output_tensor,
        tensor_send_prev=None,
        recv_prev=recv_prev,
        recv_next=False,
        use_ring_exchange=True)
        recv_next=False)
    if timers is not None:
        timers('forward-send-forward-recv').stop()
    return input_tensor
@@ -248,8 +242,7 @@ def send_backward_recv_backward(input_tensor_grad, recv_next, timers=None):
        tensor_send_next=None,
        tensor_send_prev=input_tensor_grad,
        recv_prev=False,
        recv_next=recv_next,
        use_ring_exchange=True)
        recv_next=recv_next)
    if timers is not None:
        timers('backward-send-backward-recv').stop()
    return output_tensor_grad
@@ -265,8 +258,7 @@ def send_forward_backward_recv_forward_backward(
        tensor_send_next=output_tensor,
        tensor_send_prev=input_tensor_grad,
        recv_prev=recv_prev,
        recv_next=recv_next,
        use_ring_exchange=True)
        recv_next=recv_next)
    if timers is not None:
        timers('forward-backward-send-forward-backward-recv').stop()
    return input_tensor, output_tensor_grad
+2 −2
Original line number Diff line number Diff line
@@ -210,7 +210,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
    # Run warmup forward passes.
    mpu.set_virtual_pipeline_model_parallel_rank(0)
    input_tensors[0].append(
        p2p_communication.recv_forward(timers, use_ring_exchange=True))
        p2p_communication.recv_forward(timers))
    for k in range(num_warmup_microbatches):
        output_tensor = forward_step_helper(k)

@@ -322,7 +322,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
    if not forward_only:
        if all_warmup_microbatches:
            output_tensor_grads[num_model_chunks-1].append(
                p2p_communication.recv_backward(timers, use_ring_exchange=True))
                p2p_communication.recv_backward(timers))
        for k in range(num_microbatches_remaining, num_microbatches):
            input_tensor_grad = backward_step_helper(k)
            next_backward_model_chunk_id = get_model_chunk_id(k+1, forward=False)