Commit 86da10e9 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

working for pure pipeline parallelism, w/ no interleaving

parent d4169684
Loading
Loading
Loading
Loading
+34 −2
Original line number Diff line number Diff line
@@ -15,6 +15,7 @@

from contextlib import contextmanager
import torch
from torch.autograd.variable import Variable
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP

from megatron import get_args
@@ -27,7 +28,6 @@ from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module
from megatron.model import ModelType


def get_forward_backward_func():
    args = get_args()
    if mpu.get_pipeline_model_parallel_world_size() > 1:
@@ -42,6 +42,36 @@ def get_forward_backward_func():
        forward_backward_func = forward_backward_no_pipelining
    return forward_backward_func

def free_output_tensor(t):
    t.data = torch.FloatTensor([0]).to(t.data)

def custom_backward(output, grad_output):

    assert output.numel() == 1, \
        "output should be pseudo-'freed' in schedule, to optimize memory"
    assert isinstance(output, torch.Tensor), \
        "output == '%s'." % type(output).__name__
    assert isinstance(grad_output, (torch.Tensor, type(None))), \
        "grad_output == '%s'." % type(grad_output).__name__

    # Handle scalar output
    if grad_output is None:
        assert output.numel() == 1, "implicit grad requires scalar output."
        grad_output = torch.ones_like(
            output,
            memory_format = torch.preserve_format,
        )

    # Call c++ engine [ see torch/csrc/autograd/python_engine.cpp ]
    Variable._execution_engine.run_backward(
        tensors = (output,),
        grad_tensors = (grad_output,),
        keep_graph = False,
        create_graph = False,
        inputs = tuple(),
        allow_unreachable=True,
        accumulate_grad=True,
    )

def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_reduced):
    """Forward step for passed-in model.
@@ -116,7 +146,7 @@ def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad):
    # Backward pass.
    if output_tensor_grad[0] is None:
        output_tensor = optimizer.scale_loss(output_tensor[0])
    torch.autograd.backward(output_tensor[0], grad_tensors=output_tensor_grad[0])
    custom_backward(output_tensor[0], output_tensor_grad[0])

    # Collect the grad of the input_tensor.
    input_tensor_grad = [None]
@@ -560,6 +590,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
        send_forward(output_tensor, send_tensor_shapes, timers=timers)

        if not forward_only:
            [ free_output_tensor(t) for t in output_tensor ]
            input_tensors.append(input_tensor)
            output_tensors.append(output_tensor)

@@ -588,6 +619,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
                                           timers=timers)

            # Add input_tensor and output_tensor to end of list.
            [ free_output_tensor(t) for t in output_tensor ]
            input_tensors.append(input_tensor)
            output_tensors.append(output_tensor)