Commit 2f25c570 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

working: interleaving; free_output_tensor() now handles none/tensor/list

parent 86da10e9
Loading
Loading
Loading
Loading
+11 −4
Original line number Diff line number Diff line
@@ -42,8 +42,13 @@ def get_forward_backward_func():
        forward_backward_func = forward_backward_no_pipelining
    return forward_backward_func

def free_output_tensor(t):
    t.data = torch.FloatTensor([0]).to(t.data)
def free_output_tensor(output_tensors):
    if output_tensors is None:
        return
    if isinstance(output_tensors, torch.Tensor):
        output_tensors = [output_tensors]
    for output_tensor in output_tensors:
        output_tensor.data = torch.FloatTensor([0]).to(output_tensor.data)

def custom_backward(output, grad_output):

@@ -354,6 +359,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
                    output_tensor, recv_prev=recv_prev,
                    tensor_shape=tensor_shape,
                    timers=timers)
        free_output_tensor(output_tensor)
        input_tensors[next_forward_model_chunk_id].append(input_tensor)

    # Run 1F1B in steady state.
@@ -418,6 +424,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
                    output_tensor, input_tensor_grad,
                    recv_prev=recv_prev, recv_next=recv_next,
                    tensor_shape=tensor_shape, timers=timers)
        free_output_tensor(output_tensor)

        # Put input_tensor and output_tensor_grad in data structures in the
        # right location.
@@ -590,9 +597,9 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
        send_forward(output_tensor, send_tensor_shapes, timers=timers)

        if not forward_only:
            [ free_output_tensor(t) for t in output_tensor ]
            input_tensors.append(input_tensor)
            output_tensors.append(output_tensor)
            free_output_tensor(output_tensor)

    # Before running 1F1B, need to receive first forward tensor.
    # If all microbatches are run in warmup / cooldown phase, then no need to
@@ -619,9 +626,9 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
                                           timers=timers)

            # Add input_tensor and output_tensor to end of list.
            [ free_output_tensor(t) for t in output_tensor ]
            input_tensors.append(input_tensor)
            output_tensors.append(output_tensor)
            free_output_tensor(output_tensor)

            # Pop input_tensor and output_tensor from the start of the list for
            # the backward pass.