Loading pretrain_bert.py +23 −17 Original line number Diff line number Diff line Loading @@ -15,9 +15,6 @@ """Pretrain BERT""" # Flag to use Pytorch ddp which uses overlapping communication and computation. USE_TORCH_DDP = False from datetime import datetime import os import random Loading @@ -33,10 +30,7 @@ from learning_rates import AnnealingLR from model import BertModel from model import get_params_for_weight_decay_optimization from model import gpt2_get_params_for_weight_decay_optimization if USE_TORCH_DDP: from torch.nn.parallel.distributed import DistributedDataParallel as DDP else: from model import DistributedDataParallel as DDP from model import DistributedDataParallel as LocalDDP import mpu from apex.optimizers import FusedAdam as Adam from utils import Timers Loading Loading @@ -78,12 +72,18 @@ def get_model(args): _module.float() # Wrap model for distributed training. if USE_TORCH_DDP: if args.DDP_impl == 'torch': i = torch.cuda.current_device() model = DDP(model, device_ids=[i], output_device=i, args.DDP_type = torch.nn.parallel.distributed.DistributedDataParallel model = args.DDP_type(model, device_ids=[i], output_device=i, process_group=mpu.get_data_parallel_group()) elif args.DDP_impl == 'local': args.DDP_type = LocalDDP model = args.DDP_type(model) else: model = DDP(model) print_rank_0('Unknown DDP implementation specified: {}. ' 'Exiting.'.format(args.DDP_impl)) exit() return model Loading @@ -92,7 +92,7 @@ def get_optimizer(model, args): """Set up the optimizer.""" # Build parameter groups (weight decay and non-decay). while isinstance(model, (DDP, FP16_Module)): while isinstance(model, (args.DDP_type, FP16_Module)): model = model.module layers = model.model.bert.encoder.layer pooler = model.model.bert.pooler Loading Loading @@ -232,7 +232,7 @@ def forward_step(data_iterator, model, args, timers): return lm_loss, nsp_loss def backward_step(optimizer, model, lm_loss, nsp_loss, args): def backward_step(optimizer, model, lm_loss, nsp_loss, args, timers): """Backward step.""" # Total loss. Loading @@ -252,9 +252,11 @@ def backward_step(optimizer, model, lm_loss, nsp_loss, args): reduced_losses = torch.cat((lm_loss.view(1), nsp_loss.view(1))) torch.distributed.all_reduce(reduced_losses.data) reduced_losses.data = reduced_losses.data / args.world_size if not USE_TORCH_DDP: if args.DDP_impl == 'local': timers('allreduce').start() model.allreduce_params(reduce_after=False, fp32_allreduce=args.fp32_allreduce) timers('allreduce').stop() lm_loss_reduced = reduced_losses[0] nsp_loss_reduced = reduced_losses[1] Loading Loading @@ -285,7 +287,7 @@ def train_step(data_iterator, model, optimizer, lr_scheduler, # Calculate gradients, reduce across processes, and clip. timers('backward').start() lm_loss_reduced, nsp_loss_reduced = backward_step(optimizer, model, lm_loss, nsp_loss, args) nsp_loss, args, timers) timers('backward').stop() # Update parameters. Loading Loading @@ -338,8 +340,12 @@ def train(model, optimizer, lr_scheduler, # Logging. if args.DDP_impl == 'torch': timers_to_log = ['forward', 'backward', 'optimizer', 'batch generator', 'data loader'] else: timers_to_log = ['forward', 'backward', 'allreduce', 'optimizer', 'batch generator', 'data loader'] learning_rate = optimizer.param_groups[0]['lr'] Loading Loading @@ -425,7 +431,7 @@ def evaluate(data_iterator, model, args, timers, verbose = False): lm_loss, nsp_loss = forward_step(data_iterator, model, args, timers) # Reduce across processes. if isinstance(model, DDP): if isinstance(model, args.DDP_type): reduced_losses = torch.cat((lm_loss.view(1), nsp_loss.view(1))) torch.distributed.all_reduce(reduced_losses.data) reduced_losses.data = reduced_losses.data/args.world_size Loading Loading
pretrain_bert.py +23 −17 Original line number Diff line number Diff line Loading @@ -15,9 +15,6 @@ """Pretrain BERT""" # Flag to use Pytorch ddp which uses overlapping communication and computation. USE_TORCH_DDP = False from datetime import datetime import os import random Loading @@ -33,10 +30,7 @@ from learning_rates import AnnealingLR from model import BertModel from model import get_params_for_weight_decay_optimization from model import gpt2_get_params_for_weight_decay_optimization if USE_TORCH_DDP: from torch.nn.parallel.distributed import DistributedDataParallel as DDP else: from model import DistributedDataParallel as DDP from model import DistributedDataParallel as LocalDDP import mpu from apex.optimizers import FusedAdam as Adam from utils import Timers Loading Loading @@ -78,12 +72,18 @@ def get_model(args): _module.float() # Wrap model for distributed training. if USE_TORCH_DDP: if args.DDP_impl == 'torch': i = torch.cuda.current_device() model = DDP(model, device_ids=[i], output_device=i, args.DDP_type = torch.nn.parallel.distributed.DistributedDataParallel model = args.DDP_type(model, device_ids=[i], output_device=i, process_group=mpu.get_data_parallel_group()) elif args.DDP_impl == 'local': args.DDP_type = LocalDDP model = args.DDP_type(model) else: model = DDP(model) print_rank_0('Unknown DDP implementation specified: {}. ' 'Exiting.'.format(args.DDP_impl)) exit() return model Loading @@ -92,7 +92,7 @@ def get_optimizer(model, args): """Set up the optimizer.""" # Build parameter groups (weight decay and non-decay). while isinstance(model, (DDP, FP16_Module)): while isinstance(model, (args.DDP_type, FP16_Module)): model = model.module layers = model.model.bert.encoder.layer pooler = model.model.bert.pooler Loading Loading @@ -232,7 +232,7 @@ def forward_step(data_iterator, model, args, timers): return lm_loss, nsp_loss def backward_step(optimizer, model, lm_loss, nsp_loss, args): def backward_step(optimizer, model, lm_loss, nsp_loss, args, timers): """Backward step.""" # Total loss. Loading @@ -252,9 +252,11 @@ def backward_step(optimizer, model, lm_loss, nsp_loss, args): reduced_losses = torch.cat((lm_loss.view(1), nsp_loss.view(1))) torch.distributed.all_reduce(reduced_losses.data) reduced_losses.data = reduced_losses.data / args.world_size if not USE_TORCH_DDP: if args.DDP_impl == 'local': timers('allreduce').start() model.allreduce_params(reduce_after=False, fp32_allreduce=args.fp32_allreduce) timers('allreduce').stop() lm_loss_reduced = reduced_losses[0] nsp_loss_reduced = reduced_losses[1] Loading Loading @@ -285,7 +287,7 @@ def train_step(data_iterator, model, optimizer, lr_scheduler, # Calculate gradients, reduce across processes, and clip. timers('backward').start() lm_loss_reduced, nsp_loss_reduced = backward_step(optimizer, model, lm_loss, nsp_loss, args) nsp_loss, args, timers) timers('backward').stop() # Update parameters. Loading Loading @@ -338,8 +340,12 @@ def train(model, optimizer, lr_scheduler, # Logging. if args.DDP_impl == 'torch': timers_to_log = ['forward', 'backward', 'optimizer', 'batch generator', 'data loader'] else: timers_to_log = ['forward', 'backward', 'allreduce', 'optimizer', 'batch generator', 'data loader'] learning_rate = optimizer.param_groups[0]['lr'] Loading Loading @@ -425,7 +431,7 @@ def evaluate(data_iterator, model, args, timers, verbose = False): lm_loss, nsp_loss = forward_step(data_iterator, model, args, timers) # Reduce across processes. if isinstance(model, DDP): if isinstance(model, args.DDP_type): reduced_losses = torch.cat((lm_loss.view(1), nsp_loss.view(1))) torch.distributed.all_reduce(reduced_losses.data) reduced_losses.data = reduced_losses.data/args.world_size Loading