Commit 65eeb427 authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Support Torch DDP for single-stage, num_microbatches() > 1

parent 8cd16667
Loading
Loading
Loading
Loading
+29 −7
Original line number Diff line number Diff line
@@ -13,7 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from contextlib import contextmanager
import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP

from megatron import get_args
from megatron import get_num_microbatches
@@ -74,6 +76,14 @@ def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad):
    return input_tensor_grad


@contextmanager
def dummy_handler():
    try:
        yield
    finally:
        pass


def forward_backward_no_pipelining(forward_step_func, data_iterator, model,
                                   optimizer, timers, forward_only):
    """Run forward and backward passes with no pipeline parallelism
@@ -83,15 +93,27 @@ def forward_backward_no_pipelining(forward_step_func, data_iterator, model,
    assert len(model) == 1
    model = model[0]

    context_handler = dummy_handler
    if isinstance(model, torchDDP):
        context_handler = model.no_sync

    losses_reduced = []
    for i in range(get_num_microbatches()):
    input_tensor, output_tensor_grad = None, None
    with context_handler():
        for i in range(get_num_microbatches() - 1):
            output_tensor = forward_step(forward_step_func, data_iterator, model,
                                         input_tensor, losses_reduced)
            if not forward_only:
                backward_step(optimizer, input_tensor, output_tensor,
                              output_tensor_grad)

    # Run computation for last microbatch out of context handler (want to
    # synchronize gradients).
    output_tensor = forward_step(forward_step_func, data_iterator, model,
                                 input_tensor, losses_reduced)
    if not forward_only:
        backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad)

    return losses_reduced


+0 −2
Original line number Diff line number Diff line
@@ -309,8 +309,6 @@ def setup_model_and_optimizer(model_provider_func):
        args.iteration = 0

    # We only support local DDP with multiple micro-batches.
    if get_num_microbatches() > 1:
        assert args.DDP_impl == 'local'
    if len(model) > 1:
        assert args.DDP_impl == 'local'
    if mpu.get_pipeline_model_parallel_world_size() > 1: