Commit d10f81c5 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

removed uses are args.deallocate_pipeline_output

parent 18846a0a
Loading
Loading
Loading
Loading
+9 −23
Original line number Diff line number Diff line
@@ -76,7 +76,7 @@ def get_forward_backward_func():
#         )
#         # <<<
# <<<
def free_output_tensor(out, deallocate_pipeline_outputs):
def free_output_tensor(out):
    '''Pseudo-free (i.e., set to scalar) the output tensor's '.data' field.

    This method should be called right after the output tensor has been
@@ -216,14 +216,7 @@ def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad):
    # Backward pass.
    if output_tensor_grad[0] is None:
        output_tensor = optimizer.scale_loss(output_tensor[0])
    if args.deallocate_pipeline_outputs:
        # >>>
        # pax(4, {"output_tensor": output_tensor})
        # <<<
    custom_backward(output_tensor[0], output_tensor_grad[0])
    else:
        torch.autograd.backward(output_tensor[0],
                                grad_tensors=output_tensor_grad[0])

    # Collect the grad of the input_tensor.
    input_tensor_grad = [None]
@@ -431,8 +424,11 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
                    output_tensor, recv_prev=recv_prev,
                    tensor_shape=tensor_shape,
                    timers=timers)
        free_output_tensor(output_tensor, args.deallocate_pipeline_outputs)
        input_tensors[next_forward_model_chunk_id].append(input_tensor)
        # >>>
        pax({"output_tensor": output_tensor})
        # <<<
        free_output_tensor(output_tensor)

    # Run 1F1B in steady state.
    for k in range(num_microbatches_remaining):
@@ -496,7 +492,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
                    output_tensor, input_tensor_grad,
                    recv_prev=recv_prev, recv_next=recv_next,
                    tensor_shape=tensor_shape, timers=timers)
        free_output_tensor(output_tensor, args.deallocate_pipeline_outputs)
        free_output_tensor(output_tensor)

        # Put input_tensor and output_tensor_grad in data structures in the
        # right location.
@@ -672,10 +668,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
        if not forward_only:
            input_tensors.append(input_tensor)
            output_tensors.append(output_tensor)
            # >>>
            # pax(2, {"output_tensor": output_tensor})
            # <<<
            free_output_tensor(output_tensor[0], args.deallocate_pipeline_outputs)
            free_output_tensor(output_tensor[0])

    # Before running 1F1B, need to receive first forward tensor.
    # If all microbatches are run in warmup / cooldown phase, then no need to
@@ -704,14 +697,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
            # Add input_tensor and output_tensor to end of list.
            input_tensors.append(input_tensor)
            output_tensors.append(output_tensor)
            # >>>
            # if torch.distributed.get_rank() == 3:
            #     pax({"output_tensor": output_tensor})
            # <<<
            # >>>
            # free_output_tensor(output_tensor, args.deallocate_pipeline_outputs)
            free_output_tensor(output_tensor[0], args.deallocate_pipeline_outputs)
            # <<<
            free_output_tensor(output_tensor[0])

            # Pop input_tensor and output_tensor from the start of the list for
            # the backward pass.