Loading megatron/model/bert_model.py +3 −2 Original line number Diff line number Diff line Loading @@ -165,6 +165,7 @@ class BertModel(MegatronModule): self._binary_head_key = 'binary_head' def set_input_tensor(self, input_tensor): """See megatron.model.transformer.set_input_tensor()""" self.language_model.set_input_tensor(input_tensor) def forward(self, bert_model_input, attention_mask, Loading megatron/model/classification.py +3 −2 Original line number Diff line number Diff line Loading @@ -62,6 +62,7 @@ class Classification(MegatronModule): self._classification_head_key = 'classification_head' def set_input_tensor(self, input_tensor): """See megatron.model.transformer.set_input_tensor()""" self.language_model.set_input_tensor(input_tensor) def forward(self, model_input, attention_mask, tokentype_ids=None): Loading megatron/model/gpt_model.py +1 −0 Original line number Diff line number Diff line Loading @@ -86,6 +86,7 @@ class GPTModel(MegatronModule): self.initialize_word_embeddings(init_method_normal) def set_input_tensor(self, input_tensor): """See megatron.model.transformer.set_input_tensor()""" self.language_model.set_input_tensor(input_tensor) def forward(self, input_ids, position_ids, attention_mask, labels=None, Loading megatron/model/language_model.py +1 −0 Original line number Diff line number Diff line Loading @@ -329,6 +329,7 @@ class TransformerLanguageModel(MegatronModule): self._pooler_key = 'pooler' def set_input_tensor(self, input_tensor): """ See megatron.model.transformer.set_input_tensor()""" self.encoder.set_input_tensor(input_tensor) def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask, Loading megatron/model/multiple_choice.py +1 −0 Original line number Diff line number Diff line Loading @@ -59,6 +59,7 @@ class MultipleChoice(MegatronModule): self._multichoice_head_key = 'multichoice_head' def set_input_tensor(self, input_tensor): """See megatron.model.transformer.set_input_tensor()""" self.language_model.set_input_tensor(input_tensor) def forward(self, model_input, attention_mask, tokentype_ids=None): Loading Loading
megatron/model/bert_model.py +3 −2 Original line number Diff line number Diff line Loading @@ -165,6 +165,7 @@ class BertModel(MegatronModule): self._binary_head_key = 'binary_head' def set_input_tensor(self, input_tensor): """See megatron.model.transformer.set_input_tensor()""" self.language_model.set_input_tensor(input_tensor) def forward(self, bert_model_input, attention_mask, Loading
megatron/model/classification.py +3 −2 Original line number Diff line number Diff line Loading @@ -62,6 +62,7 @@ class Classification(MegatronModule): self._classification_head_key = 'classification_head' def set_input_tensor(self, input_tensor): """See megatron.model.transformer.set_input_tensor()""" self.language_model.set_input_tensor(input_tensor) def forward(self, model_input, attention_mask, tokentype_ids=None): Loading
megatron/model/gpt_model.py +1 −0 Original line number Diff line number Diff line Loading @@ -86,6 +86,7 @@ class GPTModel(MegatronModule): self.initialize_word_embeddings(init_method_normal) def set_input_tensor(self, input_tensor): """See megatron.model.transformer.set_input_tensor()""" self.language_model.set_input_tensor(input_tensor) def forward(self, input_ids, position_ids, attention_mask, labels=None, Loading
megatron/model/language_model.py +1 −0 Original line number Diff line number Diff line Loading @@ -329,6 +329,7 @@ class TransformerLanguageModel(MegatronModule): self._pooler_key = 'pooler' def set_input_tensor(self, input_tensor): """ See megatron.model.transformer.set_input_tensor()""" self.encoder.set_input_tensor(input_tensor) def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask, Loading
megatron/model/multiple_choice.py +1 −0 Original line number Diff line number Diff line Loading @@ -59,6 +59,7 @@ class MultipleChoice(MegatronModule): self._multichoice_head_key = 'multichoice_head' def set_input_tensor(self, input_tensor): """See megatron.model.transformer.set_input_tensor()""" self.language_model.set_input_tensor(input_tensor) def forward(self, model_input, attention_mask, tokentype_ids=None): Loading