Commit 72c5f666 authored by Jared Casper's avatar Jared Casper
Browse files

Use DDP command line argument instead of source flag in pretrain_bert.py.

Note that there is currently an issue with bert using Torch DDP.
parent a54978bb
Loading
Loading
Loading
Loading
+23 −17
Original line number Diff line number Diff line
@@ -15,9 +15,6 @@

"""Pretrain BERT"""

# Flag to use Pytorch ddp which uses overlapping communication and computation.
USE_TORCH_DDP = False

from datetime import datetime
import os
import random
@@ -33,10 +30,7 @@ from learning_rates import AnnealingLR
from model import BertModel
from model import get_params_for_weight_decay_optimization
from model import gpt2_get_params_for_weight_decay_optimization
if USE_TORCH_DDP:
    from torch.nn.parallel.distributed import DistributedDataParallel as DDP
else:
    from model import DistributedDataParallel as DDP
from model import DistributedDataParallel as LocalDDP
import mpu
from apex.optimizers import FusedAdam as Adam
from utils import Timers
@@ -78,12 +72,18 @@ def get_model(args):
                    _module.float()

    # Wrap model for distributed training.
    if USE_TORCH_DDP:
    if args.DDP_impl == 'torch':
        i = torch.cuda.current_device()
        model = DDP(model, device_ids=[i], output_device=i,
        args.DDP_type = torch.nn.parallel.distributed.DistributedDataParallel
        model = args.DDP_type(model, device_ids=[i], output_device=i,
                              process_group=mpu.get_data_parallel_group())
    elif args.DDP_impl == 'local':
        args.DDP_type = LocalDDP
        model = args.DDP_type(model)
    else:
        model = DDP(model)
        print_rank_0('Unknown DDP implementation specified: {}. '
                     'Exiting.'.format(args.DDP_impl))
        exit()

    return model

@@ -92,7 +92,7 @@ def get_optimizer(model, args):
    """Set up the optimizer."""

    # Build parameter groups (weight decay and non-decay).
    while isinstance(model, (DDP, FP16_Module)):
    while isinstance(model, (args.DDP_type, FP16_Module)):
        model = model.module
    layers = model.model.bert.encoder.layer
    pooler = model.model.bert.pooler
@@ -232,7 +232,7 @@ def forward_step(data_iterator, model, args, timers):
    return lm_loss, nsp_loss


def backward_step(optimizer, model, lm_loss, nsp_loss, args):
def backward_step(optimizer, model, lm_loss, nsp_loss, args, timers):
    """Backward step."""

    # Total loss.
@@ -252,9 +252,11 @@ def backward_step(optimizer, model, lm_loss, nsp_loss, args):
    reduced_losses = torch.cat((lm_loss.view(1), nsp_loss.view(1)))
    torch.distributed.all_reduce(reduced_losses.data)
    reduced_losses.data = reduced_losses.data / args.world_size
    if not USE_TORCH_DDP:
    if args.DDP_impl == 'local':
        timers('allreduce').start()
        model.allreduce_params(reduce_after=False,
                               fp32_allreduce=args.fp32_allreduce)
        timers('allreduce').stop()
    lm_loss_reduced = reduced_losses[0]
    nsp_loss_reduced = reduced_losses[1]

@@ -285,7 +287,7 @@ def train_step(data_iterator, model, optimizer, lr_scheduler,
    # Calculate gradients, reduce across processes, and clip.
    timers('backward').start()
    lm_loss_reduced, nsp_loss_reduced = backward_step(optimizer, model, lm_loss,
                                                      nsp_loss, args)
                                                      nsp_loss, args, timers)
    timers('backward').stop()

    # Update parameters.
@@ -338,8 +340,12 @@ def train(model, optimizer, lr_scheduler,

        # Logging.

        if args.DDP_impl == 'torch':
            timers_to_log = ['forward', 'backward', 'optimizer',
                            'batch generator', 'data loader']
        else:
            timers_to_log = ['forward', 'backward', 'allreduce', 'optimizer',
                             'batch generator', 'data loader']

        learning_rate = optimizer.param_groups[0]['lr']

@@ -425,7 +431,7 @@ def evaluate(data_iterator, model, args, timers, verbose = False):
            lm_loss, nsp_loss = forward_step(data_iterator, model,
                                             args, timers)
            # Reduce across processes.
            if isinstance(model, DDP):
            if isinstance(model, args.DDP_type):
                reduced_losses = torch.cat((lm_loss.view(1), nsp_loss.view(1)))
                torch.distributed.all_reduce(reduced_losses.data)
                reduced_losses.data = reduced_losses.data/args.world_size