Loading megatron/arguments.py +37 −4 Original line number Diff line number Diff line Loading @@ -70,7 +70,7 @@ def parse_args(extra_args_provider=None, defaults={}, model_parallel_size = args.pipeline_model_parallel_size * \ args.tensor_model_parallel_size assert args.world_size % model_parallel_size == 0, 'world size is not'\ ' divisible by tensor parallel size ({}) times pipeline paralle ' \ ' divisible by tensor parallel size ({}) times pipeline parallel ' \ 'size ({})'.format(args.world_size, args.tensor_model_parallel_size, args.pipeline_model_parallel_size) args.data_parallel_size = args.world_size // model_parallel_size Loading Loading @@ -116,6 +116,18 @@ def parse_args(extra_args_provider=None, defaults={}, print('setting global batch size to {}'.format( args.global_batch_size), flush=True) assert args.global_batch_size > 0 if args.num_layers_per_virtual_pipeline_stage is not None: assert args.num_layers % args.num_layers_per_virtual_pipeline_stage == 0, \ 'number of layers is not divisible by number of layers per virtual ' \ 'pipeline stage' args.virtual_pipeline_model_parallel_size = \ (args.num_layers // args.pipeline_model_parallel_size) // \ args.num_layers_per_virtual_pipeline_stage assert args.global_batch_size % args.pipeline_model_parallel_size == 0, \ 'global batch size is not divisible by pipeline parallel size when ' \ 'using interleaved schedule' else: args.virtual_pipeline_model_parallel_size = None # Parameters dtype. args.params_dtype = torch.float Loading Loading @@ -203,6 +215,22 @@ def parse_args(extra_args_provider=None, defaults={}, 'for distribute-checkpointed-activations to work you '\ 'need to enable checkpoint-activations' # custom kernel constraints check seq_len = args.seq_length attn_batch_size = \ (args.num_attention_heads / args.tensor_model_parallel_size) * \ args.micro_batch_size # constraints on sequence length and attn_batch_size to enable warp based # optimization and upper triangular optimization (for causal mask) custom_kernel_constraint = seq_len > 16 and seq_len <=2048 and \ seq_len % 4 == 0 and attn_batch_size % 4 == 0 if args.fp16 and custom_kernel_constraint and args.masked_softmax_fusion: print('WARNING: constraints for invoking optimized' ' fused softmax kernel are not met. We default back to unfused' ' kernel invocations.') # Load scaled_masked_softmax_fusion_kernels if args.masked_softmax_fusion: fused_kernels.load_scaled_upper_triang_masked_softmax_fusion_kernel() Loading Loading @@ -478,9 +506,9 @@ def _add_checkpointing_args(parser): help='Output directory to save checkpoints to.') group.add_argument('--save-interval', type=int, default=None, help='Number of iterations between checkpoint saves.') group.add_argument('--no-save-optim', action='store_true', group.add_argument('--no-save-optim', action='store_true', default=None, help='Do not save current optimizer.') group.add_argument('--no-save-rng', action='store_true', group.add_argument('--no-save-rng', action='store_true', default=None, help='Do not save current rng state.') group.add_argument('--load', type=str, default=None, help='Directory containing a model checkpoint.') Loading Loading @@ -541,6 +569,8 @@ def _add_distributed_args(parser): group.add_argument('--model-parallel-size', type=int, default=None, help='Old model parallel argument, do not use. Use ' '--tensor-model-parallel-size instead.') group.add_argument('--num-layers-per-virtual-pipeline-stage', type=int, default=None, help='Number of layers per virtual pipeline stage') group.add_argument('--distributed-backend', default='nccl', choices=['nccl', 'gloo'], help='Which backend to use for distributed training.') Loading @@ -548,6 +578,9 @@ def _add_distributed_args(parser): choices=['local', 'torch'], help='which DistributedDataParallel implementation ' 'to use.') group.add_argument('--no-scatter-gather-tensors-in-pipeline', action='store_false', help='Use scatter/gather to optimize communication of tensors in pipeline', dest='scatter_gather_tensors_in_pipeline') group.add_argument('--local_rank', type=int, default=None, help='local rank passed from distributed launcher.') group.add_argument('--lazy-mpu-init', type=bool, required=False, Loading megatron/checkpointing.py +23 −12 Original line number Diff line number Diff line Loading @@ -21,12 +21,12 @@ import sys import numpy as np import torch from torch.nn.parallel import DistributedDataParallel as torchDDP from megatron import (get_args, mpu, print_rank_0, update_num_microbatches) update_num_microbatches, utils) _CHECKPOINT_VERSION = None Loading Loading @@ -111,8 +111,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler): args = get_args() # Only rank zero of the data parallel writes to the disk. if isinstance(model, torchDDP): model = model.module model = utils.unwrap_model(model) print_rank_0('saving checkpoint at iteration {:7d} to {}'.format( iteration, args.save)) Loading @@ -124,7 +123,12 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler): state_dict['args'] = args state_dict['checkpoint_version'] = 3.0 state_dict['iteration'] = iteration state_dict['model'] = model.state_dict_for_save_checkpoint() if len(model) == 1: state_dict['model'] = model[0].state_dict_for_save_checkpoint() else: for i in range(len(model)): mpu.set_virtual_pipeline_model_parallel_rank(i) state_dict['model%d' % i] = model[i].state_dict_for_save_checkpoint() # Optimizer stuff. if not args.no_save_optim: Loading Loading @@ -238,8 +242,8 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True args = get_args() load_dir = getattr(args, load_arg) if isinstance(model, torchDDP): model = model.module model = utils.unwrap_model(model) # Read the tracker file and set the iteration. tracker_filename = get_checkpoint_tracker_filename(load_dir) Loading Loading @@ -324,7 +328,12 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True print_rank_0('could not find arguments in the checkpoint ...') # Model. model.load_state_dict(state_dict['model'], strict=strict) if len(model) == 1: model[0].load_state_dict(state_dict['model'], strict=strict) else: for i in range(len(model)): mpu.set_virtual_pipeline_model_parallel_rank(i) model[i].load_state_dict(state_dict['model%d' % i], strict=strict) # Fix up query/key/value matrix ordering if needed checkpoint_version = get_checkpoint_version() Loading Loading @@ -352,12 +361,15 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True np.random.set_state(state_dict['np_rng_state']) torch.set_rng_state(state_dict['torch_rng_state']) torch.cuda.set_rng_state(state_dict['cuda_rng_state']) # Check for empty states array if not state_dict['rng_tracker_states']: raise KeyError mpu.get_cuda_rng_tracker().set_states( state_dict['rng_tracker_states']) except KeyError: print_rank_0('Unable to load optimizer from checkpoint {}. ' print_rank_0('Unable to load rng state from checkpoint {}. ' 'Specify --no-load-rng or --finetune to prevent ' 'attempting to load the optimizer state, ' 'attempting to load the rng state, ' 'exiting ...'.format(checkpoint_name)) sys.exit() Loading @@ -376,8 +388,7 @@ def load_ict_checkpoint(model, only_query_model=False, only_context_model=False, args = get_args() if isinstance(model, torchDDP): model = model.module model = utils.unwrap_model(model) load_path = args.load if from_realm_chkpt else args.ict_load Loading megatron/initialize.py +2 −1 Original line number Diff line number Diff line Loading @@ -133,7 +133,8 @@ def _initialize_distributed(): print('model parallel is already initialized') else: mpu.initialize_model_parallel(args.tensor_model_parallel_size, args.pipeline_model_parallel_size) args.pipeline_model_parallel_size, args.virtual_pipeline_model_parallel_size) def _init_autoresume(): Loading megatron/model/fused_softmax.py +10 −5 Original line number Diff line number Diff line Loading @@ -116,15 +116,20 @@ class FusedScaleMaskSoftmax(torch.nn.Module): def forward(self, input, mask): # [b, np, sq, sk] assert input.dim() == 4 data_size = input.size() query_seq_len = data_size[-2] key_seq_len = data_size[-1] assert input.dim() == 4 attn_batch_size = data_size[0] * data_size[1] # invoke custom kernel if self.input_in_fp16 and key_seq_len <= 2048 and mask is not None and \ query_seq_len % 4 == 0 and self.scaled_masked_softmax_fusion: # constraints on various tensor dimensions to enable warp based # optimization and upper triangular optimization (for causal mask) custom_kernel_constraint = key_seq_len > 16 and key_seq_len <= 2048 and \ query_seq_len % 4 == 0 and attn_batch_size % 4 == 0 # invoke custom kernel if self.input_in_fp16 and mask is not None and \ custom_kernel_constraint and self.scaled_masked_softmax_fusion: scale = self.scale if self.scale is not None else 1.0 if self.attn_mask_type == AttnMaskType.causal: Loading megatron/model/module.py +2 −2 Original line number Diff line number Diff line Loading @@ -50,9 +50,9 @@ class MegatronModule(torch.nn.Module): def word_embeddings_weight(self): if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage(ignore_virtual=True): return self.language_model.embedding.word_embeddings.weight if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage(ignore_virtual=True): if not self.share_word_embeddings: raise Exception('word_embeddings_weight() called for last ' 'stage, but share_word_embeddings is false') Loading Loading
megatron/arguments.py +37 −4 Original line number Diff line number Diff line Loading @@ -70,7 +70,7 @@ def parse_args(extra_args_provider=None, defaults={}, model_parallel_size = args.pipeline_model_parallel_size * \ args.tensor_model_parallel_size assert args.world_size % model_parallel_size == 0, 'world size is not'\ ' divisible by tensor parallel size ({}) times pipeline paralle ' \ ' divisible by tensor parallel size ({}) times pipeline parallel ' \ 'size ({})'.format(args.world_size, args.tensor_model_parallel_size, args.pipeline_model_parallel_size) args.data_parallel_size = args.world_size // model_parallel_size Loading Loading @@ -116,6 +116,18 @@ def parse_args(extra_args_provider=None, defaults={}, print('setting global batch size to {}'.format( args.global_batch_size), flush=True) assert args.global_batch_size > 0 if args.num_layers_per_virtual_pipeline_stage is not None: assert args.num_layers % args.num_layers_per_virtual_pipeline_stage == 0, \ 'number of layers is not divisible by number of layers per virtual ' \ 'pipeline stage' args.virtual_pipeline_model_parallel_size = \ (args.num_layers // args.pipeline_model_parallel_size) // \ args.num_layers_per_virtual_pipeline_stage assert args.global_batch_size % args.pipeline_model_parallel_size == 0, \ 'global batch size is not divisible by pipeline parallel size when ' \ 'using interleaved schedule' else: args.virtual_pipeline_model_parallel_size = None # Parameters dtype. args.params_dtype = torch.float Loading Loading @@ -203,6 +215,22 @@ def parse_args(extra_args_provider=None, defaults={}, 'for distribute-checkpointed-activations to work you '\ 'need to enable checkpoint-activations' # custom kernel constraints check seq_len = args.seq_length attn_batch_size = \ (args.num_attention_heads / args.tensor_model_parallel_size) * \ args.micro_batch_size # constraints on sequence length and attn_batch_size to enable warp based # optimization and upper triangular optimization (for causal mask) custom_kernel_constraint = seq_len > 16 and seq_len <=2048 and \ seq_len % 4 == 0 and attn_batch_size % 4 == 0 if args.fp16 and custom_kernel_constraint and args.masked_softmax_fusion: print('WARNING: constraints for invoking optimized' ' fused softmax kernel are not met. We default back to unfused' ' kernel invocations.') # Load scaled_masked_softmax_fusion_kernels if args.masked_softmax_fusion: fused_kernels.load_scaled_upper_triang_masked_softmax_fusion_kernel() Loading Loading @@ -478,9 +506,9 @@ def _add_checkpointing_args(parser): help='Output directory to save checkpoints to.') group.add_argument('--save-interval', type=int, default=None, help='Number of iterations between checkpoint saves.') group.add_argument('--no-save-optim', action='store_true', group.add_argument('--no-save-optim', action='store_true', default=None, help='Do not save current optimizer.') group.add_argument('--no-save-rng', action='store_true', group.add_argument('--no-save-rng', action='store_true', default=None, help='Do not save current rng state.') group.add_argument('--load', type=str, default=None, help='Directory containing a model checkpoint.') Loading Loading @@ -541,6 +569,8 @@ def _add_distributed_args(parser): group.add_argument('--model-parallel-size', type=int, default=None, help='Old model parallel argument, do not use. Use ' '--tensor-model-parallel-size instead.') group.add_argument('--num-layers-per-virtual-pipeline-stage', type=int, default=None, help='Number of layers per virtual pipeline stage') group.add_argument('--distributed-backend', default='nccl', choices=['nccl', 'gloo'], help='Which backend to use for distributed training.') Loading @@ -548,6 +578,9 @@ def _add_distributed_args(parser): choices=['local', 'torch'], help='which DistributedDataParallel implementation ' 'to use.') group.add_argument('--no-scatter-gather-tensors-in-pipeline', action='store_false', help='Use scatter/gather to optimize communication of tensors in pipeline', dest='scatter_gather_tensors_in_pipeline') group.add_argument('--local_rank', type=int, default=None, help='local rank passed from distributed launcher.') group.add_argument('--lazy-mpu-init', type=bool, required=False, Loading
megatron/checkpointing.py +23 −12 Original line number Diff line number Diff line Loading @@ -21,12 +21,12 @@ import sys import numpy as np import torch from torch.nn.parallel import DistributedDataParallel as torchDDP from megatron import (get_args, mpu, print_rank_0, update_num_microbatches) update_num_microbatches, utils) _CHECKPOINT_VERSION = None Loading Loading @@ -111,8 +111,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler): args = get_args() # Only rank zero of the data parallel writes to the disk. if isinstance(model, torchDDP): model = model.module model = utils.unwrap_model(model) print_rank_0('saving checkpoint at iteration {:7d} to {}'.format( iteration, args.save)) Loading @@ -124,7 +123,12 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler): state_dict['args'] = args state_dict['checkpoint_version'] = 3.0 state_dict['iteration'] = iteration state_dict['model'] = model.state_dict_for_save_checkpoint() if len(model) == 1: state_dict['model'] = model[0].state_dict_for_save_checkpoint() else: for i in range(len(model)): mpu.set_virtual_pipeline_model_parallel_rank(i) state_dict['model%d' % i] = model[i].state_dict_for_save_checkpoint() # Optimizer stuff. if not args.no_save_optim: Loading Loading @@ -238,8 +242,8 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True args = get_args() load_dir = getattr(args, load_arg) if isinstance(model, torchDDP): model = model.module model = utils.unwrap_model(model) # Read the tracker file and set the iteration. tracker_filename = get_checkpoint_tracker_filename(load_dir) Loading Loading @@ -324,7 +328,12 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True print_rank_0('could not find arguments in the checkpoint ...') # Model. model.load_state_dict(state_dict['model'], strict=strict) if len(model) == 1: model[0].load_state_dict(state_dict['model'], strict=strict) else: for i in range(len(model)): mpu.set_virtual_pipeline_model_parallel_rank(i) model[i].load_state_dict(state_dict['model%d' % i], strict=strict) # Fix up query/key/value matrix ordering if needed checkpoint_version = get_checkpoint_version() Loading Loading @@ -352,12 +361,15 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True np.random.set_state(state_dict['np_rng_state']) torch.set_rng_state(state_dict['torch_rng_state']) torch.cuda.set_rng_state(state_dict['cuda_rng_state']) # Check for empty states array if not state_dict['rng_tracker_states']: raise KeyError mpu.get_cuda_rng_tracker().set_states( state_dict['rng_tracker_states']) except KeyError: print_rank_0('Unable to load optimizer from checkpoint {}. ' print_rank_0('Unable to load rng state from checkpoint {}. ' 'Specify --no-load-rng or --finetune to prevent ' 'attempting to load the optimizer state, ' 'attempting to load the rng state, ' 'exiting ...'.format(checkpoint_name)) sys.exit() Loading @@ -376,8 +388,7 @@ def load_ict_checkpoint(model, only_query_model=False, only_context_model=False, args = get_args() if isinstance(model, torchDDP): model = model.module model = utils.unwrap_model(model) load_path = args.load if from_realm_chkpt else args.ict_load Loading
megatron/initialize.py +2 −1 Original line number Diff line number Diff line Loading @@ -133,7 +133,8 @@ def _initialize_distributed(): print('model parallel is already initialized') else: mpu.initialize_model_parallel(args.tensor_model_parallel_size, args.pipeline_model_parallel_size) args.pipeline_model_parallel_size, args.virtual_pipeline_model_parallel_size) def _init_autoresume(): Loading
megatron/model/fused_softmax.py +10 −5 Original line number Diff line number Diff line Loading @@ -116,15 +116,20 @@ class FusedScaleMaskSoftmax(torch.nn.Module): def forward(self, input, mask): # [b, np, sq, sk] assert input.dim() == 4 data_size = input.size() query_seq_len = data_size[-2] key_seq_len = data_size[-1] assert input.dim() == 4 attn_batch_size = data_size[0] * data_size[1] # invoke custom kernel if self.input_in_fp16 and key_seq_len <= 2048 and mask is not None and \ query_seq_len % 4 == 0 and self.scaled_masked_softmax_fusion: # constraints on various tensor dimensions to enable warp based # optimization and upper triangular optimization (for causal mask) custom_kernel_constraint = key_seq_len > 16 and key_seq_len <= 2048 and \ query_seq_len % 4 == 0 and attn_batch_size % 4 == 0 # invoke custom kernel if self.input_in_fp16 and mask is not None and \ custom_kernel_constraint and self.scaled_masked_softmax_fusion: scale = self.scale if self.scale is not None else 1.0 if self.attn_mask_type == AttnMaskType.causal: Loading
megatron/model/module.py +2 −2 Original line number Diff line number Diff line Loading @@ -50,9 +50,9 @@ class MegatronModule(torch.nn.Module): def word_embeddings_weight(self): if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage(ignore_virtual=True): return self.language_model.embedding.word_embeddings.weight if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage(ignore_virtual=True): if not self.share_word_embeddings: raise Exception('word_embeddings_weight() called for last ' 'stage, but share_word_embeddings is false') Loading