Loading megatron/model/classification.py +19 −16 Original line number Diff line number Diff line Loading @@ -17,7 +17,7 @@ import torch from megatron import get_args, print_rank_0 from megatron import get_args, print_rank_last from megatron import mpu from megatron.model.bert_model import bert_attention_mask_func, bert_extended_attention_mask, bert_position_ids from megatron.model.language_model import get_language_model Loading Loading @@ -45,6 +45,7 @@ class ClassificationBase(PipelinedMegatronModule): args.num_layers)) # Multi-choice head. if mpu.is_pipeline_last_stage(): self.classification_dropout = torch.nn.Dropout(args.hidden_dropout) self.classification_head = get_linear_layer(args.hidden_size, self.num_classes, Loading Loading @@ -85,6 +86,7 @@ class ClassificationBase(PipelinedMegatronModule): state_dict_[self._language_model_key] \ = self.language_model.state_dict_for_save_checkpoint( destination, prefix, keep_vars) if mpu.is_pipeline_last_stage(): state_dict_[self._classification_head_key] \ = self.classification_head.state_dict( destination, prefix, keep_vars) Loading @@ -95,11 +97,12 @@ class ClassificationBase(PipelinedMegatronModule): self.language_model.load_state_dict( state_dict[self._language_model_key], strict=strict) if mpu.is_pipeline_last_stage(): if self._classification_head_key in state_dict: self.classification_head.load_state_dict( state_dict[self._classification_head_key], strict=strict) else: print_rank_0('***WARNING*** could not find {} in the checkpoint, ' print_rank_last('***WARNING*** could not find {} in the checkpoint, ' 'initializing to random'.format( self._classification_head_key)) Loading megatron/model/multiple_choice.py +18 −15 Original line number Diff line number Diff line Loading @@ -17,7 +17,7 @@ import torch from megatron import get_args, print_rank_0 from megatron import get_args, print_rank_last from megatron import mpu from megatron.model.bert_model import bert_attention_mask_func, bert_extended_attention_mask, bert_position_ids from megatron.model.language_model import get_language_model Loading @@ -44,6 +44,7 @@ class MultipleChoiceBase(PipelinedMegatronModule): args.num_layers)) # Multi-choice head. if mpu.is_pipeline_last_stage(): self.multichoice_dropout = torch.nn.Dropout(args.hidden_dropout) self.multichoice_head = get_linear_layer(args.hidden_size, 1, init_method) Loading Loading @@ -97,6 +98,7 @@ class MultipleChoiceBase(PipelinedMegatronModule): state_dict_[self._language_model_key] \ = self.language_model.state_dict_for_save_checkpoint( destination, prefix, keep_vars) if mpu.is_pipeline_last_stage(): state_dict_[self._multichoice_head_key] \ = self.multichoice_head.state_dict( destination, prefix, keep_vars) Loading @@ -107,11 +109,12 @@ class MultipleChoiceBase(PipelinedMegatronModule): self.language_model.load_state_dict( state_dict[self._language_model_key], strict=strict) if mpu.is_pipeline_last_stage(): if self._multichoice_head_key in state_dict: self.multichoice_head.load_state_dict( state_dict[self._multichoice_head_key], strict=strict) else: print_rank_0('***WARNING*** could not find {} in the checkpoint, ' print_rank_last('***WARNING*** could not find {} in the checkpoint, ' 'initializing to random'.format( self._multichoice_head_key)) Loading Loading
megatron/model/classification.py +19 −16 Original line number Diff line number Diff line Loading @@ -17,7 +17,7 @@ import torch from megatron import get_args, print_rank_0 from megatron import get_args, print_rank_last from megatron import mpu from megatron.model.bert_model import bert_attention_mask_func, bert_extended_attention_mask, bert_position_ids from megatron.model.language_model import get_language_model Loading Loading @@ -45,6 +45,7 @@ class ClassificationBase(PipelinedMegatronModule): args.num_layers)) # Multi-choice head. if mpu.is_pipeline_last_stage(): self.classification_dropout = torch.nn.Dropout(args.hidden_dropout) self.classification_head = get_linear_layer(args.hidden_size, self.num_classes, Loading Loading @@ -85,6 +86,7 @@ class ClassificationBase(PipelinedMegatronModule): state_dict_[self._language_model_key] \ = self.language_model.state_dict_for_save_checkpoint( destination, prefix, keep_vars) if mpu.is_pipeline_last_stage(): state_dict_[self._classification_head_key] \ = self.classification_head.state_dict( destination, prefix, keep_vars) Loading @@ -95,11 +97,12 @@ class ClassificationBase(PipelinedMegatronModule): self.language_model.load_state_dict( state_dict[self._language_model_key], strict=strict) if mpu.is_pipeline_last_stage(): if self._classification_head_key in state_dict: self.classification_head.load_state_dict( state_dict[self._classification_head_key], strict=strict) else: print_rank_0('***WARNING*** could not find {} in the checkpoint, ' print_rank_last('***WARNING*** could not find {} in the checkpoint, ' 'initializing to random'.format( self._classification_head_key)) Loading
megatron/model/multiple_choice.py +18 −15 Original line number Diff line number Diff line Loading @@ -17,7 +17,7 @@ import torch from megatron import get_args, print_rank_0 from megatron import get_args, print_rank_last from megatron import mpu from megatron.model.bert_model import bert_attention_mask_func, bert_extended_attention_mask, bert_position_ids from megatron.model.language_model import get_language_model Loading @@ -44,6 +44,7 @@ class MultipleChoiceBase(PipelinedMegatronModule): args.num_layers)) # Multi-choice head. if mpu.is_pipeline_last_stage(): self.multichoice_dropout = torch.nn.Dropout(args.hidden_dropout) self.multichoice_head = get_linear_layer(args.hidden_size, 1, init_method) Loading Loading @@ -97,6 +98,7 @@ class MultipleChoiceBase(PipelinedMegatronModule): state_dict_[self._language_model_key] \ = self.language_model.state_dict_for_save_checkpoint( destination, prefix, keep_vars) if mpu.is_pipeline_last_stage(): state_dict_[self._multichoice_head_key] \ = self.multichoice_head.state_dict( destination, prefix, keep_vars) Loading @@ -107,11 +109,12 @@ class MultipleChoiceBase(PipelinedMegatronModule): self.language_model.load_state_dict( state_dict[self._language_model_key], strict=strict) if mpu.is_pipeline_last_stage(): if self._multichoice_head_key in state_dict: self.multichoice_head.load_state_dict( state_dict[self._multichoice_head_key], strict=strict) else: print_rank_0('***WARNING*** could not find {} in the checkpoint, ' print_rank_last('***WARNING*** could not find {} in the checkpoint, ' 'initializing to random'.format( self._multichoice_head_key)) Loading