Commit e270f68a authored by Jared Casper's avatar Jared Casper
Browse files

Addressed MR comments, mostly adding comments to code.

parent 64a83fb5
Loading
Loading
Loading
Loading
+3 −2
Original line number Diff line number Diff line
@@ -165,6 +165,7 @@ class BertModel(MegatronModule):
                self._binary_head_key = 'binary_head'

    def set_input_tensor(self, input_tensor):
        """See megatron.model.transformer.set_input_tensor()"""
        self.language_model.set_input_tensor(input_tensor)

    def forward(self, bert_model_input, attention_mask,
+3 −2
Original line number Diff line number Diff line
@@ -62,6 +62,7 @@ class Classification(MegatronModule):
            self._classification_head_key = 'classification_head'

    def set_input_tensor(self, input_tensor):
        """See megatron.model.transformer.set_input_tensor()"""
        self.language_model.set_input_tensor(input_tensor)

    def forward(self, model_input, attention_mask, tokentype_ids=None):
+1 −0
Original line number Diff line number Diff line
@@ -86,6 +86,7 @@ class GPTModel(MegatronModule):
        self.initialize_word_embeddings(init_method_normal)

    def set_input_tensor(self, input_tensor):
        """See megatron.model.transformer.set_input_tensor()"""
        self.language_model.set_input_tensor(input_tensor)

    def forward(self, input_ids, position_ids, attention_mask, labels=None,
+1 −0
Original line number Diff line number Diff line
@@ -329,6 +329,7 @@ class TransformerLanguageModel(MegatronModule):
                self._pooler_key = 'pooler'

    def set_input_tensor(self, input_tensor):
        """ See megatron.model.transformer.set_input_tensor()"""
        self.encoder.set_input_tensor(input_tensor)

    def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask,
+1 −0
Original line number Diff line number Diff line
@@ -59,6 +59,7 @@ class MultipleChoice(MegatronModule):
            self._multichoice_head_key = 'multichoice_head'

    def set_input_tensor(self, input_tensor):
        """See megatron.model.transformer.set_input_tensor()"""
        self.language_model.set_input_tensor(input_tensor)

    def forward(self, model_input, attention_mask, tokentype_ids=None):
Loading