Loading megatron/checkpointing.py +5 −16 Original line number Diff line number Diff line Loading @@ -21,12 +21,12 @@ import sys import numpy as np import torch from torch.nn.parallel import DistributedDataParallel as torchDDP from megatron import (get_args, mpu, print_rank_0, update_num_microbatches) update_num_microbatches, utils) _CHECKPOINT_VERSION = None Loading Loading @@ -111,12 +111,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler): args = get_args() # Only rank zero of the data parallel writes to the disk. unwrapped_model = [] for model_module in model: if isinstance(model_module, torchDDP): model_module = model_module.module unwrapped_model.append(model_module) model = unwrapped_model model = utils.unwrap_model(model) print_rank_0('saving checkpoint at iteration {:7d} to {}'.format( iteration, args.save)) Loading Loading @@ -220,12 +215,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True args = get_args() load_dir = getattr(args, load_arg) unwrapped_model = [] for model_module in model: if isinstance(model_module, torchDDP): model_module = model_module.module unwrapped_model.append(model_module) model = unwrapped_model model = utils.unwrap_model(model) # Read the tracker file and set the iteration. tracker_filename = get_checkpoint_tracker_filename(load_dir) Loading Loading @@ -389,8 +379,7 @@ def load_ict_checkpoint(model, only_query_model=False, only_block_model=False, f args = get_args() if isinstance(model, torchDDP): model = model.module model = utils.unwrap_model(model) load_path = args.load if from_realm_chkpt else args.ict_load Loading megatron/schedules.py +22 −22 Original line number Diff line number Diff line Loading @@ -16,14 +16,10 @@ import torch from megatron import get_args from megatron import get_num_microbatches from megatron import get_timers from megatron import mpu from megatron import get_num_microbatches from megatron.p2p_communication import recv_forward, recv_backward from megatron.p2p_communication import send_forward, send_backward from megatron.p2p_communication import send_forward_recv_backward, send_backward_recv_forward from megatron.p2p_communication import send_forward_recv_forward, send_backward_recv_backward from megatron.p2p_communication import send_forward_backward_recv_forward_backward from megatron import p2p_communication def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_reduced): Loading Loading @@ -154,7 +150,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat # Run warmup forward passes. mpu.set_virtual_pipeline_model_parallel_rank(0) input_tensors[0].append(recv_forward(timers, use_ring_exchange=True)) input_tensors[0].append(p2p_communication.recv_forward(timers, use_ring_exchange=True)) for k in range(num_warmup_microbatches): output_tensor = forward_step_helper(k) next_forward_model_chunk_id = get_model_chunk_id(k+1, forward=True) Loading @@ -173,13 +169,14 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat if mpu.is_pipeline_last_stage(ignore_virtual=True): recv_next = False input_tensor, output_tensor_grad = \ send_forward_backward_recv_forward_backward( p2p_communication.send_forward_backward_recv_forward_backward( output_tensor, input_tensor_grad, recv_prev=recv_prev, recv_next=recv_next, timers=timers) output_tensor_grads[num_model_chunks-1].append(output_tensor_grad) else: input_tensor = send_forward_recv_forward(output_tensor, recv_prev, timers) input_tensor = \ p2p_communication.send_forward_recv_forward(output_tensor, recv_prev, timers) input_tensors[next_forward_model_chunk_id].append(input_tensor) # Run 1F1B in steady state. Loading Loading @@ -238,7 +235,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat # Communicate tensors. input_tensor, output_tensor_grad = \ send_forward_backward_recv_forward_backward( p2p_communication.send_forward_backward_recv_forward_backward( output_tensor, input_tensor_grad, recv_prev=recv_prev, recv_next=recv_next, timers=timers) Loading @@ -253,7 +250,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat if not forward_only: if all_warmup_microbatches: output_tensor_grads[num_model_chunks-1].append( recv_backward(timers, use_ring_exchange=True)) p2p_communication.recv_backward(timers, use_ring_exchange=True)) for k in range(num_microbatches_remaining, num_microbatches): input_tensor_grad = backward_step_helper(k) next_backward_model_chunk_id = get_model_chunk_id(k+1, forward=False) Loading @@ -264,7 +261,8 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat if k == (num_microbatches - 1): recv_next = False output_tensor_grads[next_backward_model_chunk_id].append( send_backward_recv_backward(input_tensor_grad, recv_next, timers)) p2p_communication.send_backward_recv_backward( input_tensor_grad, recv_next, timers)) return losses_reduced Loading Loading @@ -294,7 +292,7 @@ def forward_backward_pipelining(forward_step_func, data_iterator, model, # Run warmup forward passes. for i in range(num_warmup_microbatches): input_tensor = recv_forward(timers) input_tensor = p2p_communication.recv_forward(timers) output_tensor = forward_step(forward_step_func, data_iterator, model, input_tensor, losses_reduced) # Barrier before first receive to measure forward stall. Loading @@ -302,7 +300,7 @@ def forward_backward_pipelining(forward_step_func, data_iterator, model, timers('forward-pipeline-stall').start() torch.distributed.barrier(group=mpu.get_pipeline_model_parallel_group()) timers('forward-pipeline-stall').stop() send_forward(output_tensor, timers) p2p_communication.send_forward(output_tensor, timers) input_tensors.append(input_tensor) output_tensors.append(output_tensor) Loading @@ -317,7 +315,7 @@ def forward_backward_pipelining(forward_step_func, data_iterator, model, # If all microbatches are run in warmup / cooldown phase, then no need to # receive this tensor here. if num_microbatches_remaining > 0: input_tensor = recv_forward(timers) input_tensor = p2p_communication.recv_forward(timers) # Run 1F1B in steady state. for i in range(num_microbatches_remaining): Loading @@ -326,9 +324,10 @@ def forward_backward_pipelining(forward_step_func, data_iterator, model, output_tensor = forward_step(forward_step_func, data_iterator, model, input_tensor, losses_reduced) if forward_only: send_forward(output_tensor, timers) p2p_communication.send_forward(output_tensor, timers) else: output_tensor_grad = send_forward_recv_backward(output_tensor, timers) output_tensor_grad = \ p2p_communication.send_forward_recv_backward(output_tensor, timers) # Add input_tensor and output_tensor to end of list, then pop from the # start of the list for backward pass. Loading @@ -337,7 +336,7 @@ def forward_backward_pipelining(forward_step_func, data_iterator, model, if forward_only: if not last_iteration: input_tensor = recv_forward(timers) input_tensor = p2p_communication.recv_forward(timers) else: input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0) Loading @@ -347,9 +346,10 @@ def forward_backward_pipelining(forward_step_func, data_iterator, model, if last_iteration: input_tensor = None send_backward(input_tensor_grad, timers) p2p_communication.send_backward(input_tensor_grad, timers) else: input_tensor = send_backward_recv_forward(input_tensor_grad, timers) input_tensor = \ p2p_communication.send_backward_recv_forward(input_tensor_grad, timers) # Run cooldown backward passes. if not forward_only: Loading @@ -357,12 +357,12 @@ def forward_backward_pipelining(forward_step_func, data_iterator, model, input_tensor = input_tensors.pop(0) output_tensor = output_tensors.pop(0) output_tensor_grad = recv_backward(timers) output_tensor_grad = p2p_communication.recv_backward(timers) input_tensor_grad = \ backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad) send_backward(input_tensor_grad, timers) p2p_communication.send_backward(input_tensor_grad, timers) return losses_reduced megatron/training.py +5 −5 Original line number Diff line number Diff line Loading @@ -46,6 +46,7 @@ from megatron.learning_rates import AnnealingLR from megatron.model import DistributedDataParallel as LocalDDP from megatron.model.realm_model import ICTBertModel from megatron.utils import check_adlr_autoresume_termination from megatron.utils import unwrap_model from megatron.data.data_samplers import build_pretraining_data_loader from megatron.utils import calc_params_l2_norm from megatron.schedules import forward_backward_no_pipelining Loading Loading @@ -288,9 +289,8 @@ def setup_model_and_optimizer(model_provider_func): model = get_model(model_provider_func) unwrapped_model = model while isinstance(unwrapped_model, (torchDDP, LocalDDP, FP16Module)): unwrapped_model = unwrapped_model.module unwrapped_model = unwrap_model(model, (torchDDP, LocalDDP, FP16Module)) optimizer = get_megatron_optimizer(unwrapped_model) lr_scheduler = get_learning_rate_scheduler(optimizer) Loading Loading @@ -370,8 +370,8 @@ def train_step(forward_step_func, data_iterator, unwrapped_model = model[0] elif mpu.is_pipeline_last_stage(ignore_virtual=True): unwrapped_model = model[-1] while isinstance(unwrapped_model, (torchDDP, LocalDDP, FP16Module)): unwrapped_model = unwrapped_model.module unwrapped_model = unwrap_model( unwrapped_model, (torchDDP, LocalDDP, FP16Module)) if unwrapped_model.share_word_embeddings: word_embeddings_weight = unwrapped_model.word_embeddings_weight() Loading megatron/utils.py +18 −1 Original line number Diff line number Diff line Loading @@ -18,6 +18,7 @@ import sys import torch from torch.nn.parallel import DistributedDataParallel as torchDDP from apex.multi_tensor_apply import multi_tensor_applier import amp_C Loading @@ -26,11 +27,25 @@ from megatron import get_args from megatron import print_rank_0 from megatron import get_adlr_autoresume from megatron import mpu from megatron.checkpointing import save_checkpoint from megatron.model.module import param_is_not_shared from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate def unwrap_model(model, module_instances=(torchDDP)): return_list = True if not isinstance(model, list): model = [model] return_list = False unwrapped_model = [] for model_module in model: while isinstance(model_module, module_instances): model_module = model_module.module unwrapped_model.append(model_module) if not return_list: return unwrapped_model[0] return unwrapped_model def calc_params_l2_norm(model): """Calculate l2 norm of parameters """ # Remove duplicate params. Loading Loading @@ -106,6 +121,8 @@ def print_params_min_max_norm(optimizer, iteration): def check_adlr_autoresume_termination(iteration, model, optimizer, lr_scheduler): """Check for autoresume signal and exit if it is received.""" from megatron.checkpointing import save_checkpoint args = get_args() autoresume = get_adlr_autoresume() # Add barrier to ensure consistnecy. Loading Loading
megatron/checkpointing.py +5 −16 Original line number Diff line number Diff line Loading @@ -21,12 +21,12 @@ import sys import numpy as np import torch from torch.nn.parallel import DistributedDataParallel as torchDDP from megatron import (get_args, mpu, print_rank_0, update_num_microbatches) update_num_microbatches, utils) _CHECKPOINT_VERSION = None Loading Loading @@ -111,12 +111,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler): args = get_args() # Only rank zero of the data parallel writes to the disk. unwrapped_model = [] for model_module in model: if isinstance(model_module, torchDDP): model_module = model_module.module unwrapped_model.append(model_module) model = unwrapped_model model = utils.unwrap_model(model) print_rank_0('saving checkpoint at iteration {:7d} to {}'.format( iteration, args.save)) Loading Loading @@ -220,12 +215,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True args = get_args() load_dir = getattr(args, load_arg) unwrapped_model = [] for model_module in model: if isinstance(model_module, torchDDP): model_module = model_module.module unwrapped_model.append(model_module) model = unwrapped_model model = utils.unwrap_model(model) # Read the tracker file and set the iteration. tracker_filename = get_checkpoint_tracker_filename(load_dir) Loading Loading @@ -389,8 +379,7 @@ def load_ict_checkpoint(model, only_query_model=False, only_block_model=False, f args = get_args() if isinstance(model, torchDDP): model = model.module model = utils.unwrap_model(model) load_path = args.load if from_realm_chkpt else args.ict_load Loading
megatron/schedules.py +22 −22 Original line number Diff line number Diff line Loading @@ -16,14 +16,10 @@ import torch from megatron import get_args from megatron import get_num_microbatches from megatron import get_timers from megatron import mpu from megatron import get_num_microbatches from megatron.p2p_communication import recv_forward, recv_backward from megatron.p2p_communication import send_forward, send_backward from megatron.p2p_communication import send_forward_recv_backward, send_backward_recv_forward from megatron.p2p_communication import send_forward_recv_forward, send_backward_recv_backward from megatron.p2p_communication import send_forward_backward_recv_forward_backward from megatron import p2p_communication def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_reduced): Loading Loading @@ -154,7 +150,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat # Run warmup forward passes. mpu.set_virtual_pipeline_model_parallel_rank(0) input_tensors[0].append(recv_forward(timers, use_ring_exchange=True)) input_tensors[0].append(p2p_communication.recv_forward(timers, use_ring_exchange=True)) for k in range(num_warmup_microbatches): output_tensor = forward_step_helper(k) next_forward_model_chunk_id = get_model_chunk_id(k+1, forward=True) Loading @@ -173,13 +169,14 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat if mpu.is_pipeline_last_stage(ignore_virtual=True): recv_next = False input_tensor, output_tensor_grad = \ send_forward_backward_recv_forward_backward( p2p_communication.send_forward_backward_recv_forward_backward( output_tensor, input_tensor_grad, recv_prev=recv_prev, recv_next=recv_next, timers=timers) output_tensor_grads[num_model_chunks-1].append(output_tensor_grad) else: input_tensor = send_forward_recv_forward(output_tensor, recv_prev, timers) input_tensor = \ p2p_communication.send_forward_recv_forward(output_tensor, recv_prev, timers) input_tensors[next_forward_model_chunk_id].append(input_tensor) # Run 1F1B in steady state. Loading Loading @@ -238,7 +235,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat # Communicate tensors. input_tensor, output_tensor_grad = \ send_forward_backward_recv_forward_backward( p2p_communication.send_forward_backward_recv_forward_backward( output_tensor, input_tensor_grad, recv_prev=recv_prev, recv_next=recv_next, timers=timers) Loading @@ -253,7 +250,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat if not forward_only: if all_warmup_microbatches: output_tensor_grads[num_model_chunks-1].append( recv_backward(timers, use_ring_exchange=True)) p2p_communication.recv_backward(timers, use_ring_exchange=True)) for k in range(num_microbatches_remaining, num_microbatches): input_tensor_grad = backward_step_helper(k) next_backward_model_chunk_id = get_model_chunk_id(k+1, forward=False) Loading @@ -264,7 +261,8 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat if k == (num_microbatches - 1): recv_next = False output_tensor_grads[next_backward_model_chunk_id].append( send_backward_recv_backward(input_tensor_grad, recv_next, timers)) p2p_communication.send_backward_recv_backward( input_tensor_grad, recv_next, timers)) return losses_reduced Loading Loading @@ -294,7 +292,7 @@ def forward_backward_pipelining(forward_step_func, data_iterator, model, # Run warmup forward passes. for i in range(num_warmup_microbatches): input_tensor = recv_forward(timers) input_tensor = p2p_communication.recv_forward(timers) output_tensor = forward_step(forward_step_func, data_iterator, model, input_tensor, losses_reduced) # Barrier before first receive to measure forward stall. Loading @@ -302,7 +300,7 @@ def forward_backward_pipelining(forward_step_func, data_iterator, model, timers('forward-pipeline-stall').start() torch.distributed.barrier(group=mpu.get_pipeline_model_parallel_group()) timers('forward-pipeline-stall').stop() send_forward(output_tensor, timers) p2p_communication.send_forward(output_tensor, timers) input_tensors.append(input_tensor) output_tensors.append(output_tensor) Loading @@ -317,7 +315,7 @@ def forward_backward_pipelining(forward_step_func, data_iterator, model, # If all microbatches are run in warmup / cooldown phase, then no need to # receive this tensor here. if num_microbatches_remaining > 0: input_tensor = recv_forward(timers) input_tensor = p2p_communication.recv_forward(timers) # Run 1F1B in steady state. for i in range(num_microbatches_remaining): Loading @@ -326,9 +324,10 @@ def forward_backward_pipelining(forward_step_func, data_iterator, model, output_tensor = forward_step(forward_step_func, data_iterator, model, input_tensor, losses_reduced) if forward_only: send_forward(output_tensor, timers) p2p_communication.send_forward(output_tensor, timers) else: output_tensor_grad = send_forward_recv_backward(output_tensor, timers) output_tensor_grad = \ p2p_communication.send_forward_recv_backward(output_tensor, timers) # Add input_tensor and output_tensor to end of list, then pop from the # start of the list for backward pass. Loading @@ -337,7 +336,7 @@ def forward_backward_pipelining(forward_step_func, data_iterator, model, if forward_only: if not last_iteration: input_tensor = recv_forward(timers) input_tensor = p2p_communication.recv_forward(timers) else: input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0) Loading @@ -347,9 +346,10 @@ def forward_backward_pipelining(forward_step_func, data_iterator, model, if last_iteration: input_tensor = None send_backward(input_tensor_grad, timers) p2p_communication.send_backward(input_tensor_grad, timers) else: input_tensor = send_backward_recv_forward(input_tensor_grad, timers) input_tensor = \ p2p_communication.send_backward_recv_forward(input_tensor_grad, timers) # Run cooldown backward passes. if not forward_only: Loading @@ -357,12 +357,12 @@ def forward_backward_pipelining(forward_step_func, data_iterator, model, input_tensor = input_tensors.pop(0) output_tensor = output_tensors.pop(0) output_tensor_grad = recv_backward(timers) output_tensor_grad = p2p_communication.recv_backward(timers) input_tensor_grad = \ backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad) send_backward(input_tensor_grad, timers) p2p_communication.send_backward(input_tensor_grad, timers) return losses_reduced
megatron/training.py +5 −5 Original line number Diff line number Diff line Loading @@ -46,6 +46,7 @@ from megatron.learning_rates import AnnealingLR from megatron.model import DistributedDataParallel as LocalDDP from megatron.model.realm_model import ICTBertModel from megatron.utils import check_adlr_autoresume_termination from megatron.utils import unwrap_model from megatron.data.data_samplers import build_pretraining_data_loader from megatron.utils import calc_params_l2_norm from megatron.schedules import forward_backward_no_pipelining Loading Loading @@ -288,9 +289,8 @@ def setup_model_and_optimizer(model_provider_func): model = get_model(model_provider_func) unwrapped_model = model while isinstance(unwrapped_model, (torchDDP, LocalDDP, FP16Module)): unwrapped_model = unwrapped_model.module unwrapped_model = unwrap_model(model, (torchDDP, LocalDDP, FP16Module)) optimizer = get_megatron_optimizer(unwrapped_model) lr_scheduler = get_learning_rate_scheduler(optimizer) Loading Loading @@ -370,8 +370,8 @@ def train_step(forward_step_func, data_iterator, unwrapped_model = model[0] elif mpu.is_pipeline_last_stage(ignore_virtual=True): unwrapped_model = model[-1] while isinstance(unwrapped_model, (torchDDP, LocalDDP, FP16Module)): unwrapped_model = unwrapped_model.module unwrapped_model = unwrap_model( unwrapped_model, (torchDDP, LocalDDP, FP16Module)) if unwrapped_model.share_word_embeddings: word_embeddings_weight = unwrapped_model.word_embeddings_weight() Loading
megatron/utils.py +18 −1 Original line number Diff line number Diff line Loading @@ -18,6 +18,7 @@ import sys import torch from torch.nn.parallel import DistributedDataParallel as torchDDP from apex.multi_tensor_apply import multi_tensor_applier import amp_C Loading @@ -26,11 +27,25 @@ from megatron import get_args from megatron import print_rank_0 from megatron import get_adlr_autoresume from megatron import mpu from megatron.checkpointing import save_checkpoint from megatron.model.module import param_is_not_shared from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate def unwrap_model(model, module_instances=(torchDDP)): return_list = True if not isinstance(model, list): model = [model] return_list = False unwrapped_model = [] for model_module in model: while isinstance(model_module, module_instances): model_module = model_module.module unwrapped_model.append(model_module) if not return_list: return unwrapped_model[0] return unwrapped_model def calc_params_l2_norm(model): """Calculate l2 norm of parameters """ # Remove duplicate params. Loading Loading @@ -106,6 +121,8 @@ def print_params_min_max_norm(optimizer, iteration): def check_adlr_autoresume_termination(iteration, model, optimizer, lr_scheduler): """Check for autoresume signal and exit if it is received.""" from megatron.checkpointing import save_checkpoint args = get_args() autoresume = get_adlr_autoresume() # Add barrier to ensure consistnecy. Loading