Loading megatron/model/__init__.py +2 −10 Original line number Diff line number Diff line Loading @@ -16,15 +16,7 @@ from .fused_layer_norm import MixedFusedLayerNorm as LayerNorm from .distributed import * from .bert_model import (BertModel, BertModelFirstStage, BertModelIntermediateStage, BertModelLastStage) from .gpt_model import (GPTModel, GPTModelFirstStage, GPTModelIntermediateStage, GPTModelLastStage) from .bert_model import BertModel from .gpt_model import GPTModel from .language_model import get_language_model from .module import Float16Module megatron/model/bert_model.py +34 −84 Original line number Diff line number Diff line Loading @@ -121,17 +121,23 @@ def post_language_model_processing(lm_output, pooled_output, return lm_loss, binary_logits class BertModelBase(MegatronModule): class BertModel(MegatronModule): """Bert Language model.""" def __init__(self, num_tokentypes=2, add_binary_head=True, parallel_output=True): super(BertModelBase, self).__init__() def __init__(self, num_tokentypes=2, add_binary_head=True, parallel_output=True, pre_process=True, post_process=True): super(BertModel, self).__init__() args = get_args() self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy self.add_binary_head = add_binary_head self.parallel_output = parallel_output self.pre_process = pre_process self.post_process = post_process init_method = init_method_normal(args.init_method_std) scaled_init_method = scaled_init_method_normal(args.init_method_std, Loading @@ -142,10 +148,12 @@ class BertModelBase(MegatronModule): add_pooler=self.add_binary_head, encoder_attn_mask_type=AttnMaskType.padding, init_method=init_method, scaled_init_method=scaled_init_method) scaled_init_method=scaled_init_method, pre_process=self.pre_process, post_process=self.post_process) self.initialize_word_embeddings(init_method_normal) if mpu.is_pipeline_last_stage(): if self.post_process: self.lm_head = BertLMHead( self.word_embeddings_weight().size(0), args.hidden_size, init_method, args.layernorm_epsilon, parallel_output) Loading @@ -156,26 +164,29 @@ class BertModelBase(MegatronModule): init_method) self._binary_head_key = 'binary_head' def set_input_tensor(self, input_tensor): self.language_model.set_input_tensor(input_tensor) def forward(self, bert_model_input, attention_mask, tokentype_ids=None, lm_labels=None): extended_attention_mask = bert_extended_attention_mask(attention_mask) kwargs = {} if mpu.is_pipeline_first_stage(): input_ids = bert_model_input position_ids = bert_position_ids(input_ids) args = [input_ids, position_ids, extended_attention_mask] kwargs['tokentype_ids'] = tokentype_ids else: args = [bert_model_input, extended_attention_mask] lm_output = self.language_model(*args, **kwargs) if mpu.is_pipeline_last_stage() and self.add_binary_head: lm_output = self.language_model( input_ids, position_ids, extended_attention_mask, tokentype_ids=tokentype_ids ) if self.post_process and self.add_binary_head: lm_output, pooled_output = lm_output else: pooled_output = None if mpu.is_pipeline_last_stage(): if self.post_process: return post_language_model_processing(lm_output, pooled_output, self.lm_head, self.binary_head, lm_labels, Loading @@ -194,15 +205,15 @@ class BertModelBase(MegatronModule): state_dict_[self._language_model_key] \ = self.language_model.state_dict_for_save_checkpoint( destination, prefix, keep_vars) if mpu.is_pipeline_last_stage(): if self.post_process: state_dict_[self._lm_head_key] \ = self.lm_head.state_dict_for_save_checkpoint( destination, prefix, keep_vars) if mpu.is_pipeline_last_stage() and self.add_binary_head: if self.post_process and self.add_binary_head: state_dict_[self._binary_head_key] \ = self.binary_head.state_dict(destination, prefix, keep_vars) # Save word_embeddings. if mpu.is_pipeline_last_stage() and not mpu.is_pipeline_first_stage(): if self.post_process and not self.pre_process: state_dict_[self._word_embeddings_for_head_key] \ = self.word_embeddings.state_dict(destination, prefix, keep_vars) return state_dict_ Loading @@ -212,74 +223,13 @@ class BertModelBase(MegatronModule): self.language_model.load_state_dict( state_dict[self._language_model_key], strict=strict) if mpu.is_pipeline_last_stage(): if self.post_process: self.lm_head.load_state_dict( state_dict[self._lm_head_key], strict=strict) if mpu.is_pipeline_last_stage() and self.add_binary_head: if self.post_process and self.add_binary_head: self.binary_head.load_state_dict( state_dict[self._binary_head_key], strict=strict) # Load word_embeddings. if mpu.is_pipeline_last_stage() and not mpu.is_pipeline_first_stage(): if self.post_process and not self.pre_process: self.word_embeddings.load_state_dict( state_dict[self._word_embeddings_for_head_key], strict=strict) class BertModel(BertModelBase): def __init__(self, num_tokentypes=2, add_binary_head=True, parallel_output=True): super(BertModel, self).__init__( num_tokentypes=num_tokentypes, add_binary_head=add_binary_head, parallel_output=parallel_output) def forward(self, input_ids, attention_mask, tokentype_ids=None, lm_labels=None): return super(BertModel, self).forward( input_ids, attention_mask, tokentype_ids=tokentype_ids, lm_labels=lm_labels) class BertModelFirstStage(BertModelBase): def __init__(self, num_tokentypes=2): super(BertModelFirstStage, self).__init__( num_tokentypes=num_tokentypes) def forward(self, input_ids, attention_mask, tokentype_ids=None): return super(BertModelFirstStage, self).forward( input_ids, attention_mask, tokentype_ids=tokentype_ids) class BertModelIntermediateStage(BertModelBase): def __init__(self, num_tokentypes=2): super(BertModelIntermediateStage, self).__init__( num_tokentypes=num_tokentypes) def forward(self, hidden_state, attention_mask): return super(BertModelIntermediateStage, self).forward( hidden_state, attention_mask) class BertModelLastStage(BertModelBase): def __init__(self, num_tokentypes=2, add_binary_head=True, parallel_output=True): super(BertModelLastStage, self).__init__( num_tokentypes=num_tokentypes, add_binary_head=add_binary_head, parallel_output=parallel_output) def forward(self, hidden_state, attention_mask, lm_labels=None): return super(BertModelLastStage, self).forward( hidden_state, attention_mask, lm_labels=lm_labels) megatron/model/classification.py +29 −71 Original line number Diff line number Diff line Loading @@ -28,13 +28,19 @@ from megatron.model.utils import scaled_init_method_normal from .module import MegatronModule class ClassificationBase(MegatronModule): def __init__(self, num_classes, num_tokentypes=2): super(ClassificationBase, self).__init__(share_word_embeddings=False) class Classification(MegatronModule): def __init__(self, num_classes, num_tokentypes=2, pre_process=True, post_process=True): super(Classification, self).__init__(share_word_embeddings=False) args = get_args() self.num_classes = num_classes self.pre_process = pre_process self.post_process = post_process init_method = init_method_normal(args.init_method_std) self.language_model, self._language_model_key = get_language_model( Loading @@ -43,31 +49,35 @@ class ClassificationBase(MegatronModule): encoder_attn_mask_type=AttnMaskType.padding, init_method=init_method, scaled_init_method=scaled_init_method_normal(args.init_method_std, args.num_layers)) args.num_layers), pre_process=self.pre_process, post_process=self.post_process) # Multi-choice head. if mpu.is_pipeline_last_stage(): if self.post_process: self.classification_dropout = torch.nn.Dropout(args.hidden_dropout) self.classification_head = get_linear_layer(args.hidden_size, self.num_classes, init_method) self._classification_head_key = 'classification_head' def set_input_tensor(self, input_tensor): self.language_model.set_input_tensor(input_tensor) def forward(self, model_input, attention_mask, tokentype_ids=None): extended_attention_mask = bert_extended_attention_mask(attention_mask) kwargs = {} if mpu.is_pipeline_first_stage(): input_ids = model_input position_ids = bert_position_ids(input_ids) args = [input_ids, position_ids, extended_attention_mask] kwargs['tokentype_ids'] = tokentype_ids else: args = [model_input, extended_attention_mask] lm_output = self.language_model(*args, **kwargs) if mpu.is_pipeline_last_stage(): lm_output = self.language_model( input_ids, position_ids, extended_attention_mask, tokentype_ids=tokentype_ids ) if self.post_process: _, pooled_output = lm_output classification_output = self.classification_dropout(pooled_output) classification_logits = self.classification_head(classification_output) Loading @@ -87,7 +97,7 @@ class ClassificationBase(MegatronModule): state_dict_[self._language_model_key] \ = self.language_model.state_dict_for_save_checkpoint( destination, prefix, keep_vars) if mpu.is_pipeline_last_stage(): if self.post_process: state_dict_[self._classification_head_key] \ = self.classification_head.state_dict( destination, prefix, keep_vars) Loading @@ -98,7 +108,7 @@ class ClassificationBase(MegatronModule): self.language_model.load_state_dict( state_dict[self._language_model_key], strict=strict) if mpu.is_pipeline_last_stage(): if self.post_process: if self._classification_head_key in state_dict: self.classification_head.load_state_dict( state_dict[self._classification_head_key], strict=strict) Loading @@ -106,55 +116,3 @@ class ClassificationBase(MegatronModule): print_rank_last('***WARNING*** could not find {} in the checkpoint, ' 'initializing to random'.format( self._classification_head_key)) class Classification(ClassificationBase): def __init__(self, num_classes, num_tokentypes=2): super(Classification, self).__init__( num_classes, num_tokentypes=num_tokentypes) def forward(self, input_ids, attention_mask, tokentype_ids=None): return super(Classification, self).forward( input_ids, attention_mask, tokentype_ids=tokentype_ids) class ClassificationFirstStage(ClassificationBase): def __init__(self, num_classes, num_tokentypes=2): super(ClassificationFirstStage, self).__init__( num_classes, num_tokentypes=num_tokentypes) def forward(self, input_ids, attention_mask, tokentype_ids=None): return super(ClassificationFirstStage, self).forward( input_ids, attention_mask, tokentype_ids=tokentype_ids) class ClassificationIntermediateStage(ClassificationBase): def __init__(self, num_classes, num_tokentypes=2): super(ClassificationIntermediateStage, self).__init__( num_classes, num_tokentypes=num_tokentypes) def forward(self, hidden_state, attention_mask): return super(ClassificationIntermediateStage, self).forward( hidden_state, attention_mask) class ClassificationLastStage(ClassificationBase): def __init__(self, num_classes, num_tokentypes=2): super(ClassificationLastStage, self).__init__( num_classes, num_tokentypes=num_tokentypes) def forward(self, hidden_state, attention_mask): return super(ClassificationLastStage, self).forward( hidden_state, attention_mask) megatron/model/gpt_model.py +25 −86 Original line number Diff line number Diff line Loading @@ -57,14 +57,20 @@ def post_language_model_processing(lm_output, labels, logit_weights, return loss class GPTModelBase(MegatronModule): class GPTModel(MegatronModule): """GPT-2 Language model.""" def __init__(self, num_tokentypes=0, parallel_output=True): super(GPTModelBase, self).__init__() def __init__(self, num_tokentypes=0, parallel_output=True, pre_process=True, post_process=True): super(GPTModel, self).__init__() args = get_args() self.parallel_output = parallel_output self.pre_process = pre_process self.post_process = post_process self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy self.language_model, self._language_model_key = get_language_model( Loading @@ -73,24 +79,27 @@ class GPTModelBase(MegatronModule): encoder_attn_mask_type=AttnMaskType.causal, init_method=init_method_normal(args.init_method_std), scaled_init_method=scaled_init_method_normal(args.init_method_std, args.num_layers)) args.num_layers), pre_process=self.pre_process, post_process=self.post_process) self.initialize_word_embeddings(init_method_normal) def forward(self, gpt_model_input, attention_mask, labels=None, def set_input_tensor(self, input_tensor): self.language_model.set_input_tensor(input_tensor) def forward(self, input_ids, position_ids, attention_mask, labels=None, tokentype_ids=None, layer_past=None, get_key_value=False, forward_method_parallel_output=None): kwargs = {'layer_past': layer_past, 'get_key_value': get_key_value} if mpu.is_pipeline_first_stage(): (input_ids, position_ids) = gpt_model_input args = [input_ids, position_ids, attention_mask] kwargs['tokentype_ids'] = tokentype_ids else: args = [gpt_model_input, attention_mask] lm_output = self.language_model(*args, **kwargs) lm_output = self.language_model( input_ids, position_ids, attention_mask, layer_past=layer_past, get_key_value=get_key_value) if mpu.is_pipeline_last_stage(): if self.post_process: return post_language_model_processing( lm_output, labels, self.word_embeddings_weight(), Loading @@ -109,7 +118,7 @@ class GPTModelBase(MegatronModule): = self.language_model.state_dict_for_save_checkpoint( destination, prefix, keep_vars) # Save word_embeddings. if mpu.is_pipeline_last_stage() and not mpu.is_pipeline_first_stage(): if self.post_process and not self.pre_process: state_dict_[self._word_embeddings_for_head_key] \ = self.word_embeddings.state_dict(destination, prefix, keep_vars) return state_dict_ Loading @@ -118,79 +127,9 @@ class GPTModelBase(MegatronModule): """Customized load.""" # Load word_embeddings. if mpu.is_pipeline_last_stage() and not mpu.is_pipeline_first_stage(): if self.post_process and not self.pre_process: self.word_embeddings.load_state_dict( state_dict[self._word_embeddings_for_head_key], strict=strict) if self._language_model_key in state_dict: state_dict = state_dict[self._language_model_key] self.language_model.load_state_dict(state_dict, strict=strict) class GPTModel(GPTModelBase): def __init__(self, num_tokentypes=0, parallel_output=True): super(GPTModel, self).__init__( num_tokentypes=num_tokentypes, parallel_output=parallel_output) def forward(self, input_ids, position_ids, attention_mask, labels=None, tokentype_ids=None, layer_past=None, get_key_value=False, forward_method_parallel_output=None): return super(GPTModel, self).forward( (input_ids, position_ids), attention_mask, labels=labels, tokentype_ids=tokentype_ids, layer_past=layer_past, get_key_value=get_key_value, forward_method_parallel_output=forward_method_parallel_output) class GPTModelFirstStage(GPTModelBase): def __init__(self, num_tokentypes=0): super(GPTModelFirstStage, self).__init__( num_tokentypes=num_tokentypes) def forward(self, input_ids, position_ids, attention_mask, tokentype_ids=None, layer_past=None, get_key_value=False): return super(GPTModelFirstStage, self).forward( (input_ids, position_ids), attention_mask, tokentype_ids=tokentype_ids, layer_past=layer_past, get_key_value=get_key_value) class GPTModelIntermediateStage(GPTModelBase): def __init__(self, num_tokentypes=0): super(GPTModelIntermediateStage, self).__init__( num_tokentypes=num_tokentypes) def forward(self, hidden_state, attention_mask, layer_past=None, get_key_value=False): return super(GPTModelIntermediateStage, self).forward( hidden_state, attention_mask, layer_past=layer_past, get_key_value=get_key_value) class GPTModelLastStage(GPTModelBase): def __init__(self, num_tokentypes=0, parallel_output=True): super(GPTModelLastStage, self).__init__( num_tokentypes=num_tokentypes, parallel_output=parallel_output) def forward(self, hidden_state, attention_mask, labels=None, layer_past=None, get_key_value=False, forward_method_parallel_output=None): return super(GPTModelLastStage, self).forward( hidden_state, attention_mask, labels=labels, layer_past=layer_past, get_key_value=get_key_value, forward_method_parallel_output=forward_method_parallel_output) megatron/model/language_model.py +41 −162 File changed.Preview size limit exceeded, changes collapsed. Show changes Loading
megatron/model/__init__.py +2 −10 Original line number Diff line number Diff line Loading @@ -16,15 +16,7 @@ from .fused_layer_norm import MixedFusedLayerNorm as LayerNorm from .distributed import * from .bert_model import (BertModel, BertModelFirstStage, BertModelIntermediateStage, BertModelLastStage) from .gpt_model import (GPTModel, GPTModelFirstStage, GPTModelIntermediateStage, GPTModelLastStage) from .bert_model import BertModel from .gpt_model import GPTModel from .language_model import get_language_model from .module import Float16Module
megatron/model/bert_model.py +34 −84 Original line number Diff line number Diff line Loading @@ -121,17 +121,23 @@ def post_language_model_processing(lm_output, pooled_output, return lm_loss, binary_logits class BertModelBase(MegatronModule): class BertModel(MegatronModule): """Bert Language model.""" def __init__(self, num_tokentypes=2, add_binary_head=True, parallel_output=True): super(BertModelBase, self).__init__() def __init__(self, num_tokentypes=2, add_binary_head=True, parallel_output=True, pre_process=True, post_process=True): super(BertModel, self).__init__() args = get_args() self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy self.add_binary_head = add_binary_head self.parallel_output = parallel_output self.pre_process = pre_process self.post_process = post_process init_method = init_method_normal(args.init_method_std) scaled_init_method = scaled_init_method_normal(args.init_method_std, Loading @@ -142,10 +148,12 @@ class BertModelBase(MegatronModule): add_pooler=self.add_binary_head, encoder_attn_mask_type=AttnMaskType.padding, init_method=init_method, scaled_init_method=scaled_init_method) scaled_init_method=scaled_init_method, pre_process=self.pre_process, post_process=self.post_process) self.initialize_word_embeddings(init_method_normal) if mpu.is_pipeline_last_stage(): if self.post_process: self.lm_head = BertLMHead( self.word_embeddings_weight().size(0), args.hidden_size, init_method, args.layernorm_epsilon, parallel_output) Loading @@ -156,26 +164,29 @@ class BertModelBase(MegatronModule): init_method) self._binary_head_key = 'binary_head' def set_input_tensor(self, input_tensor): self.language_model.set_input_tensor(input_tensor) def forward(self, bert_model_input, attention_mask, tokentype_ids=None, lm_labels=None): extended_attention_mask = bert_extended_attention_mask(attention_mask) kwargs = {} if mpu.is_pipeline_first_stage(): input_ids = bert_model_input position_ids = bert_position_ids(input_ids) args = [input_ids, position_ids, extended_attention_mask] kwargs['tokentype_ids'] = tokentype_ids else: args = [bert_model_input, extended_attention_mask] lm_output = self.language_model(*args, **kwargs) if mpu.is_pipeline_last_stage() and self.add_binary_head: lm_output = self.language_model( input_ids, position_ids, extended_attention_mask, tokentype_ids=tokentype_ids ) if self.post_process and self.add_binary_head: lm_output, pooled_output = lm_output else: pooled_output = None if mpu.is_pipeline_last_stage(): if self.post_process: return post_language_model_processing(lm_output, pooled_output, self.lm_head, self.binary_head, lm_labels, Loading @@ -194,15 +205,15 @@ class BertModelBase(MegatronModule): state_dict_[self._language_model_key] \ = self.language_model.state_dict_for_save_checkpoint( destination, prefix, keep_vars) if mpu.is_pipeline_last_stage(): if self.post_process: state_dict_[self._lm_head_key] \ = self.lm_head.state_dict_for_save_checkpoint( destination, prefix, keep_vars) if mpu.is_pipeline_last_stage() and self.add_binary_head: if self.post_process and self.add_binary_head: state_dict_[self._binary_head_key] \ = self.binary_head.state_dict(destination, prefix, keep_vars) # Save word_embeddings. if mpu.is_pipeline_last_stage() and not mpu.is_pipeline_first_stage(): if self.post_process and not self.pre_process: state_dict_[self._word_embeddings_for_head_key] \ = self.word_embeddings.state_dict(destination, prefix, keep_vars) return state_dict_ Loading @@ -212,74 +223,13 @@ class BertModelBase(MegatronModule): self.language_model.load_state_dict( state_dict[self._language_model_key], strict=strict) if mpu.is_pipeline_last_stage(): if self.post_process: self.lm_head.load_state_dict( state_dict[self._lm_head_key], strict=strict) if mpu.is_pipeline_last_stage() and self.add_binary_head: if self.post_process and self.add_binary_head: self.binary_head.load_state_dict( state_dict[self._binary_head_key], strict=strict) # Load word_embeddings. if mpu.is_pipeline_last_stage() and not mpu.is_pipeline_first_stage(): if self.post_process and not self.pre_process: self.word_embeddings.load_state_dict( state_dict[self._word_embeddings_for_head_key], strict=strict) class BertModel(BertModelBase): def __init__(self, num_tokentypes=2, add_binary_head=True, parallel_output=True): super(BertModel, self).__init__( num_tokentypes=num_tokentypes, add_binary_head=add_binary_head, parallel_output=parallel_output) def forward(self, input_ids, attention_mask, tokentype_ids=None, lm_labels=None): return super(BertModel, self).forward( input_ids, attention_mask, tokentype_ids=tokentype_ids, lm_labels=lm_labels) class BertModelFirstStage(BertModelBase): def __init__(self, num_tokentypes=2): super(BertModelFirstStage, self).__init__( num_tokentypes=num_tokentypes) def forward(self, input_ids, attention_mask, tokentype_ids=None): return super(BertModelFirstStage, self).forward( input_ids, attention_mask, tokentype_ids=tokentype_ids) class BertModelIntermediateStage(BertModelBase): def __init__(self, num_tokentypes=2): super(BertModelIntermediateStage, self).__init__( num_tokentypes=num_tokentypes) def forward(self, hidden_state, attention_mask): return super(BertModelIntermediateStage, self).forward( hidden_state, attention_mask) class BertModelLastStage(BertModelBase): def __init__(self, num_tokentypes=2, add_binary_head=True, parallel_output=True): super(BertModelLastStage, self).__init__( num_tokentypes=num_tokentypes, add_binary_head=add_binary_head, parallel_output=parallel_output) def forward(self, hidden_state, attention_mask, lm_labels=None): return super(BertModelLastStage, self).forward( hidden_state, attention_mask, lm_labels=lm_labels)
megatron/model/classification.py +29 −71 Original line number Diff line number Diff line Loading @@ -28,13 +28,19 @@ from megatron.model.utils import scaled_init_method_normal from .module import MegatronModule class ClassificationBase(MegatronModule): def __init__(self, num_classes, num_tokentypes=2): super(ClassificationBase, self).__init__(share_word_embeddings=False) class Classification(MegatronModule): def __init__(self, num_classes, num_tokentypes=2, pre_process=True, post_process=True): super(Classification, self).__init__(share_word_embeddings=False) args = get_args() self.num_classes = num_classes self.pre_process = pre_process self.post_process = post_process init_method = init_method_normal(args.init_method_std) self.language_model, self._language_model_key = get_language_model( Loading @@ -43,31 +49,35 @@ class ClassificationBase(MegatronModule): encoder_attn_mask_type=AttnMaskType.padding, init_method=init_method, scaled_init_method=scaled_init_method_normal(args.init_method_std, args.num_layers)) args.num_layers), pre_process=self.pre_process, post_process=self.post_process) # Multi-choice head. if mpu.is_pipeline_last_stage(): if self.post_process: self.classification_dropout = torch.nn.Dropout(args.hidden_dropout) self.classification_head = get_linear_layer(args.hidden_size, self.num_classes, init_method) self._classification_head_key = 'classification_head' def set_input_tensor(self, input_tensor): self.language_model.set_input_tensor(input_tensor) def forward(self, model_input, attention_mask, tokentype_ids=None): extended_attention_mask = bert_extended_attention_mask(attention_mask) kwargs = {} if mpu.is_pipeline_first_stage(): input_ids = model_input position_ids = bert_position_ids(input_ids) args = [input_ids, position_ids, extended_attention_mask] kwargs['tokentype_ids'] = tokentype_ids else: args = [model_input, extended_attention_mask] lm_output = self.language_model(*args, **kwargs) if mpu.is_pipeline_last_stage(): lm_output = self.language_model( input_ids, position_ids, extended_attention_mask, tokentype_ids=tokentype_ids ) if self.post_process: _, pooled_output = lm_output classification_output = self.classification_dropout(pooled_output) classification_logits = self.classification_head(classification_output) Loading @@ -87,7 +97,7 @@ class ClassificationBase(MegatronModule): state_dict_[self._language_model_key] \ = self.language_model.state_dict_for_save_checkpoint( destination, prefix, keep_vars) if mpu.is_pipeline_last_stage(): if self.post_process: state_dict_[self._classification_head_key] \ = self.classification_head.state_dict( destination, prefix, keep_vars) Loading @@ -98,7 +108,7 @@ class ClassificationBase(MegatronModule): self.language_model.load_state_dict( state_dict[self._language_model_key], strict=strict) if mpu.is_pipeline_last_stage(): if self.post_process: if self._classification_head_key in state_dict: self.classification_head.load_state_dict( state_dict[self._classification_head_key], strict=strict) Loading @@ -106,55 +116,3 @@ class ClassificationBase(MegatronModule): print_rank_last('***WARNING*** could not find {} in the checkpoint, ' 'initializing to random'.format( self._classification_head_key)) class Classification(ClassificationBase): def __init__(self, num_classes, num_tokentypes=2): super(Classification, self).__init__( num_classes, num_tokentypes=num_tokentypes) def forward(self, input_ids, attention_mask, tokentype_ids=None): return super(Classification, self).forward( input_ids, attention_mask, tokentype_ids=tokentype_ids) class ClassificationFirstStage(ClassificationBase): def __init__(self, num_classes, num_tokentypes=2): super(ClassificationFirstStage, self).__init__( num_classes, num_tokentypes=num_tokentypes) def forward(self, input_ids, attention_mask, tokentype_ids=None): return super(ClassificationFirstStage, self).forward( input_ids, attention_mask, tokentype_ids=tokentype_ids) class ClassificationIntermediateStage(ClassificationBase): def __init__(self, num_classes, num_tokentypes=2): super(ClassificationIntermediateStage, self).__init__( num_classes, num_tokentypes=num_tokentypes) def forward(self, hidden_state, attention_mask): return super(ClassificationIntermediateStage, self).forward( hidden_state, attention_mask) class ClassificationLastStage(ClassificationBase): def __init__(self, num_classes, num_tokentypes=2): super(ClassificationLastStage, self).__init__( num_classes, num_tokentypes=num_tokentypes) def forward(self, hidden_state, attention_mask): return super(ClassificationLastStage, self).forward( hidden_state, attention_mask)
megatron/model/gpt_model.py +25 −86 Original line number Diff line number Diff line Loading @@ -57,14 +57,20 @@ def post_language_model_processing(lm_output, labels, logit_weights, return loss class GPTModelBase(MegatronModule): class GPTModel(MegatronModule): """GPT-2 Language model.""" def __init__(self, num_tokentypes=0, parallel_output=True): super(GPTModelBase, self).__init__() def __init__(self, num_tokentypes=0, parallel_output=True, pre_process=True, post_process=True): super(GPTModel, self).__init__() args = get_args() self.parallel_output = parallel_output self.pre_process = pre_process self.post_process = post_process self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy self.language_model, self._language_model_key = get_language_model( Loading @@ -73,24 +79,27 @@ class GPTModelBase(MegatronModule): encoder_attn_mask_type=AttnMaskType.causal, init_method=init_method_normal(args.init_method_std), scaled_init_method=scaled_init_method_normal(args.init_method_std, args.num_layers)) args.num_layers), pre_process=self.pre_process, post_process=self.post_process) self.initialize_word_embeddings(init_method_normal) def forward(self, gpt_model_input, attention_mask, labels=None, def set_input_tensor(self, input_tensor): self.language_model.set_input_tensor(input_tensor) def forward(self, input_ids, position_ids, attention_mask, labels=None, tokentype_ids=None, layer_past=None, get_key_value=False, forward_method_parallel_output=None): kwargs = {'layer_past': layer_past, 'get_key_value': get_key_value} if mpu.is_pipeline_first_stage(): (input_ids, position_ids) = gpt_model_input args = [input_ids, position_ids, attention_mask] kwargs['tokentype_ids'] = tokentype_ids else: args = [gpt_model_input, attention_mask] lm_output = self.language_model(*args, **kwargs) lm_output = self.language_model( input_ids, position_ids, attention_mask, layer_past=layer_past, get_key_value=get_key_value) if mpu.is_pipeline_last_stage(): if self.post_process: return post_language_model_processing( lm_output, labels, self.word_embeddings_weight(), Loading @@ -109,7 +118,7 @@ class GPTModelBase(MegatronModule): = self.language_model.state_dict_for_save_checkpoint( destination, prefix, keep_vars) # Save word_embeddings. if mpu.is_pipeline_last_stage() and not mpu.is_pipeline_first_stage(): if self.post_process and not self.pre_process: state_dict_[self._word_embeddings_for_head_key] \ = self.word_embeddings.state_dict(destination, prefix, keep_vars) return state_dict_ Loading @@ -118,79 +127,9 @@ class GPTModelBase(MegatronModule): """Customized load.""" # Load word_embeddings. if mpu.is_pipeline_last_stage() and not mpu.is_pipeline_first_stage(): if self.post_process and not self.pre_process: self.word_embeddings.load_state_dict( state_dict[self._word_embeddings_for_head_key], strict=strict) if self._language_model_key in state_dict: state_dict = state_dict[self._language_model_key] self.language_model.load_state_dict(state_dict, strict=strict) class GPTModel(GPTModelBase): def __init__(self, num_tokentypes=0, parallel_output=True): super(GPTModel, self).__init__( num_tokentypes=num_tokentypes, parallel_output=parallel_output) def forward(self, input_ids, position_ids, attention_mask, labels=None, tokentype_ids=None, layer_past=None, get_key_value=False, forward_method_parallel_output=None): return super(GPTModel, self).forward( (input_ids, position_ids), attention_mask, labels=labels, tokentype_ids=tokentype_ids, layer_past=layer_past, get_key_value=get_key_value, forward_method_parallel_output=forward_method_parallel_output) class GPTModelFirstStage(GPTModelBase): def __init__(self, num_tokentypes=0): super(GPTModelFirstStage, self).__init__( num_tokentypes=num_tokentypes) def forward(self, input_ids, position_ids, attention_mask, tokentype_ids=None, layer_past=None, get_key_value=False): return super(GPTModelFirstStage, self).forward( (input_ids, position_ids), attention_mask, tokentype_ids=tokentype_ids, layer_past=layer_past, get_key_value=get_key_value) class GPTModelIntermediateStage(GPTModelBase): def __init__(self, num_tokentypes=0): super(GPTModelIntermediateStage, self).__init__( num_tokentypes=num_tokentypes) def forward(self, hidden_state, attention_mask, layer_past=None, get_key_value=False): return super(GPTModelIntermediateStage, self).forward( hidden_state, attention_mask, layer_past=layer_past, get_key_value=get_key_value) class GPTModelLastStage(GPTModelBase): def __init__(self, num_tokentypes=0, parallel_output=True): super(GPTModelLastStage, self).__init__( num_tokentypes=num_tokentypes, parallel_output=parallel_output) def forward(self, hidden_state, attention_mask, labels=None, layer_past=None, get_key_value=False, forward_method_parallel_output=None): return super(GPTModelLastStage, self).forward( hidden_state, attention_mask, labels=labels, layer_past=layer_past, get_key_value=get_key_value, forward_method_parallel_output=forward_method_parallel_output)
megatron/model/language_model.py +41 −162 File changed.Preview size limit exceeded, changes collapsed. Show changes