Loading megatron/arguments.py +9 −0 Original line number Diff line number Diff line Loading @@ -80,6 +80,12 @@ def parse_args(extra_args_provider=None, defaults={}, args.world_size, args.data_parallel_size, args.tensor_model_parallel_size, args.pipeline_model_parallel_size), flush=True) if args.pipeline_model_parallel_size > 1: if args.pipeline_model_parallel_split_rank is not None: assert args.pipeline_model_parallel_split_rank < \ args.pipeline_model_parallel_size, 'split rank needs'\ ' to be less than pipeline model parallel size ({})'.format( args.pipeline_model_parallel_size) # Deprecated arguments assert args.batch_size is None, '--batch-size argument is no longer ' \ Loading Loading @@ -601,6 +607,9 @@ def _add_distributed_args(parser): help='Degree of tensor model parallelism.') group.add_argument('--pipeline-model-parallel-size', type=int, default=1, help='Degree of pipeline model parallelism.') group.add_argument('--pipeline-model-parallel-split-rank', type=int, default=None, help='Rank where encoder and decoder should be split.') group.add_argument('--model-parallel-size', type=int, default=None, help='Old model parallel argument, do not use. Use ' '--tensor-model-parallel-size instead.') Loading megatron/initialize.py +2 −1 Original line number Diff line number Diff line Loading @@ -190,7 +190,8 @@ def _initialize_distributed(): else: mpu.initialize_model_parallel(args.tensor_model_parallel_size, args.pipeline_model_parallel_size, args.virtual_pipeline_model_parallel_size) args.virtual_pipeline_model_parallel_size, args.pipeline_model_parallel_split_rank) def _init_autoresume(): Loading megatron/model/__init__.py +1 −0 Original line number Diff line number Diff line Loading @@ -21,3 +21,4 @@ from .gpt_model import GPTModel from .t5_model import T5Model from .language_model import get_language_model from .module import Float16Module from .enums import ModelType megatron/model/enums.py +4 −0 Original line number Diff line number Diff line Loading @@ -15,6 +15,10 @@ import enum class ModelType(enum.Enum): encoder_or_decoder = 1 encoder_and_decoder = 2 class LayerType(enum.Enum): encoder = 1 decoder = 2 Loading megatron/model/language_model.py +109 −56 Original line number Diff line number Diff line Loading @@ -45,7 +45,8 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, def get_language_model(num_tokentypes, add_pooler, encoder_attn_mask_type, init_method=None, scaled_init_method=None, add_decoder=False, scaled_init_method=None, add_encoder=True, add_decoder=False, decoder_attn_mask_type=AttnMaskType.causal, pre_process=True, post_process=True): """Build language model and return along with the key to save.""" Loading @@ -64,6 +65,7 @@ def get_language_model(num_tokentypes, add_pooler, scaled_init_method, encoder_attn_mask_type, num_tokentypes=num_tokentypes, add_encoder=add_encoder, add_decoder=add_decoder, decoder_attn_mask_type=decoder_attn_mask_type, add_pooler=add_pooler, Loading Loading @@ -159,6 +161,16 @@ class Embedding(MegatronModule): # Embeddings dropout self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob) def zero_parameters(self): """Zero out all parameters in embedding.""" self.word_embeddings.weight.data.fill_(0) self.word_embeddings.weight.shared = True self.position_embeddings.weight.data.fill_(0) self.position_embeddings.weight.shared = True if self.num_tokentypes > 0: self.tokentype_embeddings.weight.data.fill_(0) self.tokentype_embeddings.weight.shared = True def add_tokentype_embeddings(self, num_tokentypes): """Add token-type embedding. This function is provided so we can add token-type embeddings in case the pretrained model does not have it. Loading Loading @@ -273,6 +285,7 @@ class TransformerLanguageModel(MegatronModule): output_layer_init_method, encoder_attn_mask_type, num_tokentypes=0, add_encoder=True, add_decoder=False, decoder_attn_mask_type=AttnMaskType.causal, add_pooler=False, Loading @@ -286,10 +299,12 @@ class TransformerLanguageModel(MegatronModule): self.hidden_size = args.hidden_size self.num_tokentypes = num_tokentypes self.init_method = init_method self.add_encoder = add_encoder self.encoder_attn_mask_type = encoder_attn_mask_type self.add_decoder = add_decoder self.decoder_attn_mask_type = decoder_attn_mask_type self.add_pooler = add_pooler self.encoder_hidden_state = None # Embeddings. if self.pre_process: Loading @@ -302,6 +317,9 @@ class TransformerLanguageModel(MegatronModule): self._embedding_key = 'embedding' # Transformer. # Encoder (usually set to True, False if part of an encoder-decoder # architecture and in encoder-only stage). if self.add_encoder: self.encoder = ParallelTransformer( self.init_method, output_layer_init_method, Loading @@ -310,17 +328,26 @@ class TransformerLanguageModel(MegatronModule): post_process=self.post_process ) self._encoder_key = 'encoder' else: self.encoder = None # Decoder # Decoder (usually set to False, True if part of an encoder-decoder # architecture and in decoder-only stage). if self.add_decoder: # Temporary assertion until we verify correctness of pipeline parallelism # implementation of T5. assert args.pipeline_model_parallel_size == 1, \ 'pipeline parallelism is not supported in the presence of decoder' self.decoder = ParallelTransformer( self.init_method, output_layer_init_method, layer_type=LayerType.decoder, self_attn_mask_type=self.decoder_attn_mask_type) self_attn_mask_type=self.decoder_attn_mask_type, pre_process=self.pre_process, post_process=self.post_process) self._decoder_key = 'decoder' else: self.decoder = None if self.post_process: # Pooler. Loading @@ -330,7 +357,25 @@ class TransformerLanguageModel(MegatronModule): def set_input_tensor(self, input_tensor): """ See megatron.model.transformer.set_input_tensor()""" self.encoder.set_input_tensor(input_tensor) if self.add_encoder and self.add_decoder: assert len(input_tensor) == 1, \ 'input_tensor should only be length 1 for stage with both encoder and decoder' self.encoder.set_input_tensor(input_tensor[0]) elif self.add_encoder: assert len(input_tensor) == 1, \ 'input_tensor should only be length 1 for stage with only encoder' self.encoder.set_input_tensor(input_tensor[0]) elif self.add_decoder: if len(input_tensor) == 2: self.decoder.set_input_tensor(input_tensor[0]) self.encoder_hidden_state = input_tensor[1] elif len(input_tensor) == 1: self.decoder.set_input_tensor(None) self.encoder_hidden_state = input_tensor[0] else: raise Exception('input_tensor must have either length 1 or 2') else: raise Exception('Stage must have at least either encoder or decoder') def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask, dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None, Loading @@ -340,21 +385,23 @@ class TransformerLanguageModel(MegatronModule): pooling_sequence_index=0, enc_hidden_states=None, output_enc_hidden=False): # Embeddings. # Encoder embedding. if self.pre_process: embedding_output = self.embedding(enc_input_ids, enc_position_ids, encoder_input = self.embedding(enc_input_ids, enc_position_ids, tokentype_ids=tokentype_ids) encoder_input = embedding_output else: encoder_input = None # encoder. # Run encoder. if enc_hidden_states is None: if self.encoder is not None: encoder_output = self.encoder( encoder_input, enc_attn_mask, set_inference_key_value_memory=set_inference_key_value_memory, inference_max_sequence_len=inference_max_sequence_len) else: encoder_output = self.encoder_hidden_state else: encoder_output = enc_hidden_states.to(encoder_input.dtype) Loading @@ -372,12 +419,16 @@ class TransformerLanguageModel(MegatronModule): else: return encoder_output # Decoder Embedding dec_embedding_output = self.embedding(dec_input_ids, # Decoder embedding. if self.pre_process: decoder_input = self.embedding(dec_input_ids, dec_position_ids) # decoder else: decoder_input = None # Run decoder. decoder_output = self.decoder( dec_embedding_output, decoder_input, dec_attn_mask, encoder_output=encoder_output, enc_dec_attn_mask=enc_dec_attn_mask, Loading @@ -398,6 +449,7 @@ class TransformerLanguageModel(MegatronModule): state_dict_[self._embedding_key] \ = self.embedding.state_dict_for_save_checkpoint( destination, prefix, keep_vars) if self.add_encoder: state_dict_[self._encoder_key] \ = self.encoder.state_dict_for_save_checkpoint( destination, prefix, keep_vars) Loading Loading @@ -429,19 +481,20 @@ class TransformerLanguageModel(MegatronModule): self.embedding.load_state_dict(state_dict_, strict=strict) # Encoder. if self.add_encoder: if self._encoder_key in state_dict: state_dict_ = state_dict[self._encoder_key] # for backward compatibility. # For backward compatibility. elif 'transformer' in state_dict: state_dict_ = state_dict['transformer'] else: # for backward compatibility. # For backward compatibility. state_dict_ = {} for key in state_dict.keys(): if 'transformer.' in key: state_dict_[key.split('transformer.')[1]] = state_dict[key] # for backward compatibility. # For backward compatibility. state_dict_self_attention = {} for key in state_dict_.keys(): if '.attention.' in key: Loading @@ -453,14 +506,14 @@ class TransformerLanguageModel(MegatronModule): self.encoder.load_state_dict(state_dict_, strict=strict) # Pooler. if self.post_process: # pooler if self.add_pooler: assert 'pooler' in state_dict, \ 'could not find data for pooler in the checkpoint' self.pooler.load_state_dict(state_dict[self._pooler_key], strict=strict) # decoder # Decoder. if self.add_decoder: assert 'decoder' in state_dict, \ 'could not find data for pooler in the checkpoint' Loading Loading
megatron/arguments.py +9 −0 Original line number Diff line number Diff line Loading @@ -80,6 +80,12 @@ def parse_args(extra_args_provider=None, defaults={}, args.world_size, args.data_parallel_size, args.tensor_model_parallel_size, args.pipeline_model_parallel_size), flush=True) if args.pipeline_model_parallel_size > 1: if args.pipeline_model_parallel_split_rank is not None: assert args.pipeline_model_parallel_split_rank < \ args.pipeline_model_parallel_size, 'split rank needs'\ ' to be less than pipeline model parallel size ({})'.format( args.pipeline_model_parallel_size) # Deprecated arguments assert args.batch_size is None, '--batch-size argument is no longer ' \ Loading Loading @@ -601,6 +607,9 @@ def _add_distributed_args(parser): help='Degree of tensor model parallelism.') group.add_argument('--pipeline-model-parallel-size', type=int, default=1, help='Degree of pipeline model parallelism.') group.add_argument('--pipeline-model-parallel-split-rank', type=int, default=None, help='Rank where encoder and decoder should be split.') group.add_argument('--model-parallel-size', type=int, default=None, help='Old model parallel argument, do not use. Use ' '--tensor-model-parallel-size instead.') Loading
megatron/initialize.py +2 −1 Original line number Diff line number Diff line Loading @@ -190,7 +190,8 @@ def _initialize_distributed(): else: mpu.initialize_model_parallel(args.tensor_model_parallel_size, args.pipeline_model_parallel_size, args.virtual_pipeline_model_parallel_size) args.virtual_pipeline_model_parallel_size, args.pipeline_model_parallel_split_rank) def _init_autoresume(): Loading
megatron/model/__init__.py +1 −0 Original line number Diff line number Diff line Loading @@ -21,3 +21,4 @@ from .gpt_model import GPTModel from .t5_model import T5Model from .language_model import get_language_model from .module import Float16Module from .enums import ModelType
megatron/model/enums.py +4 −0 Original line number Diff line number Diff line Loading @@ -15,6 +15,10 @@ import enum class ModelType(enum.Enum): encoder_or_decoder = 1 encoder_and_decoder = 2 class LayerType(enum.Enum): encoder = 1 decoder = 2 Loading
megatron/model/language_model.py +109 −56 Original line number Diff line number Diff line Loading @@ -45,7 +45,8 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, def get_language_model(num_tokentypes, add_pooler, encoder_attn_mask_type, init_method=None, scaled_init_method=None, add_decoder=False, scaled_init_method=None, add_encoder=True, add_decoder=False, decoder_attn_mask_type=AttnMaskType.causal, pre_process=True, post_process=True): """Build language model and return along with the key to save.""" Loading @@ -64,6 +65,7 @@ def get_language_model(num_tokentypes, add_pooler, scaled_init_method, encoder_attn_mask_type, num_tokentypes=num_tokentypes, add_encoder=add_encoder, add_decoder=add_decoder, decoder_attn_mask_type=decoder_attn_mask_type, add_pooler=add_pooler, Loading Loading @@ -159,6 +161,16 @@ class Embedding(MegatronModule): # Embeddings dropout self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob) def zero_parameters(self): """Zero out all parameters in embedding.""" self.word_embeddings.weight.data.fill_(0) self.word_embeddings.weight.shared = True self.position_embeddings.weight.data.fill_(0) self.position_embeddings.weight.shared = True if self.num_tokentypes > 0: self.tokentype_embeddings.weight.data.fill_(0) self.tokentype_embeddings.weight.shared = True def add_tokentype_embeddings(self, num_tokentypes): """Add token-type embedding. This function is provided so we can add token-type embeddings in case the pretrained model does not have it. Loading Loading @@ -273,6 +285,7 @@ class TransformerLanguageModel(MegatronModule): output_layer_init_method, encoder_attn_mask_type, num_tokentypes=0, add_encoder=True, add_decoder=False, decoder_attn_mask_type=AttnMaskType.causal, add_pooler=False, Loading @@ -286,10 +299,12 @@ class TransformerLanguageModel(MegatronModule): self.hidden_size = args.hidden_size self.num_tokentypes = num_tokentypes self.init_method = init_method self.add_encoder = add_encoder self.encoder_attn_mask_type = encoder_attn_mask_type self.add_decoder = add_decoder self.decoder_attn_mask_type = decoder_attn_mask_type self.add_pooler = add_pooler self.encoder_hidden_state = None # Embeddings. if self.pre_process: Loading @@ -302,6 +317,9 @@ class TransformerLanguageModel(MegatronModule): self._embedding_key = 'embedding' # Transformer. # Encoder (usually set to True, False if part of an encoder-decoder # architecture and in encoder-only stage). if self.add_encoder: self.encoder = ParallelTransformer( self.init_method, output_layer_init_method, Loading @@ -310,17 +328,26 @@ class TransformerLanguageModel(MegatronModule): post_process=self.post_process ) self._encoder_key = 'encoder' else: self.encoder = None # Decoder # Decoder (usually set to False, True if part of an encoder-decoder # architecture and in decoder-only stage). if self.add_decoder: # Temporary assertion until we verify correctness of pipeline parallelism # implementation of T5. assert args.pipeline_model_parallel_size == 1, \ 'pipeline parallelism is not supported in the presence of decoder' self.decoder = ParallelTransformer( self.init_method, output_layer_init_method, layer_type=LayerType.decoder, self_attn_mask_type=self.decoder_attn_mask_type) self_attn_mask_type=self.decoder_attn_mask_type, pre_process=self.pre_process, post_process=self.post_process) self._decoder_key = 'decoder' else: self.decoder = None if self.post_process: # Pooler. Loading @@ -330,7 +357,25 @@ class TransformerLanguageModel(MegatronModule): def set_input_tensor(self, input_tensor): """ See megatron.model.transformer.set_input_tensor()""" self.encoder.set_input_tensor(input_tensor) if self.add_encoder and self.add_decoder: assert len(input_tensor) == 1, \ 'input_tensor should only be length 1 for stage with both encoder and decoder' self.encoder.set_input_tensor(input_tensor[0]) elif self.add_encoder: assert len(input_tensor) == 1, \ 'input_tensor should only be length 1 for stage with only encoder' self.encoder.set_input_tensor(input_tensor[0]) elif self.add_decoder: if len(input_tensor) == 2: self.decoder.set_input_tensor(input_tensor[0]) self.encoder_hidden_state = input_tensor[1] elif len(input_tensor) == 1: self.decoder.set_input_tensor(None) self.encoder_hidden_state = input_tensor[0] else: raise Exception('input_tensor must have either length 1 or 2') else: raise Exception('Stage must have at least either encoder or decoder') def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask, dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None, Loading @@ -340,21 +385,23 @@ class TransformerLanguageModel(MegatronModule): pooling_sequence_index=0, enc_hidden_states=None, output_enc_hidden=False): # Embeddings. # Encoder embedding. if self.pre_process: embedding_output = self.embedding(enc_input_ids, enc_position_ids, encoder_input = self.embedding(enc_input_ids, enc_position_ids, tokentype_ids=tokentype_ids) encoder_input = embedding_output else: encoder_input = None # encoder. # Run encoder. if enc_hidden_states is None: if self.encoder is not None: encoder_output = self.encoder( encoder_input, enc_attn_mask, set_inference_key_value_memory=set_inference_key_value_memory, inference_max_sequence_len=inference_max_sequence_len) else: encoder_output = self.encoder_hidden_state else: encoder_output = enc_hidden_states.to(encoder_input.dtype) Loading @@ -372,12 +419,16 @@ class TransformerLanguageModel(MegatronModule): else: return encoder_output # Decoder Embedding dec_embedding_output = self.embedding(dec_input_ids, # Decoder embedding. if self.pre_process: decoder_input = self.embedding(dec_input_ids, dec_position_ids) # decoder else: decoder_input = None # Run decoder. decoder_output = self.decoder( dec_embedding_output, decoder_input, dec_attn_mask, encoder_output=encoder_output, enc_dec_attn_mask=enc_dec_attn_mask, Loading @@ -398,6 +449,7 @@ class TransformerLanguageModel(MegatronModule): state_dict_[self._embedding_key] \ = self.embedding.state_dict_for_save_checkpoint( destination, prefix, keep_vars) if self.add_encoder: state_dict_[self._encoder_key] \ = self.encoder.state_dict_for_save_checkpoint( destination, prefix, keep_vars) Loading Loading @@ -429,19 +481,20 @@ class TransformerLanguageModel(MegatronModule): self.embedding.load_state_dict(state_dict_, strict=strict) # Encoder. if self.add_encoder: if self._encoder_key in state_dict: state_dict_ = state_dict[self._encoder_key] # for backward compatibility. # For backward compatibility. elif 'transformer' in state_dict: state_dict_ = state_dict['transformer'] else: # for backward compatibility. # For backward compatibility. state_dict_ = {} for key in state_dict.keys(): if 'transformer.' in key: state_dict_[key.split('transformer.')[1]] = state_dict[key] # for backward compatibility. # For backward compatibility. state_dict_self_attention = {} for key in state_dict_.keys(): if '.attention.' in key: Loading @@ -453,14 +506,14 @@ class TransformerLanguageModel(MegatronModule): self.encoder.load_state_dict(state_dict_, strict=strict) # Pooler. if self.post_process: # pooler if self.add_pooler: assert 'pooler' in state_dict, \ 'could not find data for pooler in the checkpoint' self.pooler.load_state_dict(state_dict[self._pooler_key], strict=strict) # decoder # Decoder. if self.add_decoder: assert 'decoder' in state_dict, \ 'could not find data for pooler in the checkpoint' Loading