Commit 0865c4dc authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

removed saving of input/output tensors for forward-only passes of pipeline schedules

parent 6a680986
Loading
Loading
Loading
Loading
+26 −13
Original line number Diff line number Diff line
@@ -194,6 +194,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
        model_chunk_id = get_model_chunk_id(microbatch_id, forward=True)
        mpu.set_virtual_pipeline_model_parallel_rank(model_chunk_id)

        # forward step
        if mpu.is_pipeline_first_stage():
            if len(input_tensors[model_chunk_id]) == \
                    len(output_tensors[model_chunk_id]):
@@ -205,6 +206,11 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
                                     input_tensor, losses_reduced)
        output_tensors[model_chunk_id].append(output_tensor)

        # if forward-only, no need to save tensors for a backward pass
        if forward_only:
            input_tensors[model_chunk_id].pop()
            output_tensors[model_chunk_id].pop()

        return output_tensor

    def backward_step_helper(microbatch_id):
@@ -383,6 +389,10 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
    num_microbatches_remaining = \
        num_microbatches - num_warmup_microbatches

    # Input, output tensors only need to be saved when doing backward passes
    input_tensors = None
    output_tensors = None
    if not forward_only:
        input_tensors = []
        output_tensors = []
    losses_reduced = []
@@ -394,6 +404,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
                                     input_tensor, losses_reduced)
        p2p_communication.send_forward(output_tensor, timers=timers)

        if not forward_only:
            input_tensors.append(input_tensor)
            output_tensors.append(output_tensor)

@@ -411,21 +422,23 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
                                     input_tensor, losses_reduced)
        if forward_only:
            p2p_communication.send_forward(output_tensor, timers=timers)

            if not last_iteration:
                input_tensor = p2p_communication.recv_forward(timers=timers)

        else:
            output_tensor_grad = \
                p2p_communication.send_forward_recv_backward(output_tensor,
                                                             timers=timers)

        # Add input_tensor and output_tensor to end of list, then pop from the
        # start of the list for backward pass.
            # Add input_tensor and output_tensor to end of list.
            input_tensors.append(input_tensor)
            output_tensors.append(output_tensor)

        if forward_only:
            if not last_iteration:
                input_tensor = p2p_communication.recv_forward(timers=timers)
        else:
            input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0)
            # Pop input_tensor and output_tensor from the start of the list for
            # the backward pass.
            input_tensor = input_tensors.pop(0)
            output_tensor = output_tensors.pop(0)

            input_tensor_grad = \
                backward_step(optimizer, input_tensor, output_tensor,