Commit e727de99 authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Use timers kwargs correctly to prevent bug with new p2p_communication API

parent a676bc2d
Loading
Loading
Loading
Loading
+14 −14
Original line number Diff line number Diff line
@@ -231,7 +231,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
    # Run warmup forward passes.
    mpu.set_virtual_pipeline_model_parallel_rank(0)
    input_tensors[0].append(
        p2p_communication.recv_forward(timers))
        p2p_communication.recv_forward(timers=timers))
    for k in range(num_warmup_microbatches):
        output_tensor = forward_step_helper(k)

@@ -265,7 +265,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
        else:
            input_tensor = \
                p2p_communication.send_forward_recv_forward(
                    output_tensor, recv_prev, timers)
                    output_tensor, recv_prev=recv_prev, timers=timers)
        input_tensors[next_forward_model_chunk_id].append(input_tensor)

    # Run 1F1B in steady state.
@@ -343,7 +343,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
    if not forward_only:
        if all_warmup_microbatches:
            output_tensor_grads[num_model_chunks-1].append(
                p2p_communication.recv_backward(timers))
                p2p_communication.recv_backward(timers=timers))
        for k in range(num_microbatches_remaining, num_microbatches):
            input_tensor_grad = backward_step_helper(k)
            next_backward_model_chunk_id = get_model_chunk_id(k+1, forward=False)
@@ -355,7 +355,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
                recv_next = False
            output_tensor_grads[next_backward_model_chunk_id].append(
                p2p_communication.send_backward_recv_backward(
                    input_tensor_grad, recv_next, timers))
                    input_tensor_grad, recv_next=recv_next, timers=timers))

    return losses_reduced

@@ -389,10 +389,10 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite

    # Run warmup forward passes.
    for i in range(num_warmup_microbatches):
        input_tensor = p2p_communication.recv_forward(timers)
        input_tensor = p2p_communication.recv_forward(timers=timers)
        output_tensor = forward_step(forward_step_func, data_iterator, model,
                                     input_tensor, losses_reduced)
        p2p_communication.send_forward(output_tensor, timers)
        p2p_communication.send_forward(output_tensor, timers=timers)

        input_tensors.append(input_tensor)
        output_tensors.append(output_tensor)
@@ -401,7 +401,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
    # If all microbatches are run in warmup / cooldown phase, then no need to
    # receive this tensor here.
    if num_microbatches_remaining > 0:
        input_tensor = p2p_communication.recv_forward(timers)
        input_tensor = p2p_communication.recv_forward(timers=timers)

    # Run 1F1B in steady state.
    for i in range(num_microbatches_remaining):
@@ -410,11 +410,11 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
        output_tensor = forward_step(forward_step_func, data_iterator, model,
                                     input_tensor, losses_reduced)
        if forward_only:
            p2p_communication.send_forward(output_tensor, timers)
            p2p_communication.send_forward(output_tensor, timers=timers)
        else:
            output_tensor_grad = \
                p2p_communication.send_forward_recv_backward(output_tensor,
                                                             timers)
                                                             timers=timers)

        # Add input_tensor and output_tensor to end of list, then pop from the
        # start of the list for backward pass.
@@ -423,7 +423,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite

        if forward_only:
            if not last_iteration:
                input_tensor = p2p_communication.recv_forward(timers)
                input_tensor = p2p_communication.recv_forward(timers=timers)
        else:
            input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0)

@@ -433,11 +433,11 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite

            if last_iteration:
                input_tensor = None
                p2p_communication.send_backward(input_tensor_grad, timers)
                p2p_communication.send_backward(input_tensor_grad, timers=timers)
            else:
                input_tensor = \
                    p2p_communication.send_backward_recv_forward(
                        input_tensor_grad, timers)
                        input_tensor_grad, timers=timers)

    # Run cooldown backward passes.
    if not forward_only:
@@ -445,12 +445,12 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
            input_tensor = input_tensors.pop(0)
            output_tensor = output_tensors.pop(0)

            output_tensor_grad = p2p_communication.recv_backward(timers)
            output_tensor_grad = p2p_communication.recv_backward(timers=timers)

            input_tensor_grad = \
                backward_step(optimizer, input_tensor, output_tensor,
                              output_tensor_grad)

            p2p_communication.send_backward(input_tensor_grad, timers)
            p2p_communication.send_backward(input_tensor_grad, timers=timers)

    return losses_reduced