Loading megatron/model/bert_model.py +9 −2 Original line number Diff line number Diff line Loading @@ -18,6 +18,7 @@ import torch from megatron import get_args from megatron import mpu from megatron.model.language_model import parallel_lm_logits from megatron.model.language_model import get_language_model from megatron.model.transformer import LayerNorm Loading Loading @@ -138,7 +139,8 @@ class BertModel(MegatronModule): init_method) self._binary_head_key = 'binary_head' def forward(self, input_ids, attention_mask, tokentype_ids=None): def forward(self, input_ids, attention_mask, tokentype_ids=None, lm_labels=None): extended_attention_mask = bert_extended_attention_mask( attention_mask, next(self.language_model.parameters()).dtype) Loading @@ -161,11 +163,16 @@ class BertModel(MegatronModule): lm_logits = self.lm_head( lm_output, self.language_model.embedding.word_embeddings.weight) binary_logits = None if self.add_binary_head: binary_logits = self.binary_head(pooled_output) if lm_labels is None: return lm_logits, binary_logits else: lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits, lm_labels) return lm_loss, binary_logits return lm_logits, None def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False): Loading pretrain_bert.py +9 −4 Original line number Diff line number Diff line Loading @@ -68,6 +68,7 @@ def get_batch(data_iterator): def forward_step(data_iterator, model): """Forward step.""" args = get_args() timers = get_timers() # Get the batch. Loading @@ -76,15 +77,19 @@ def forward_step(data_iterator, model): = get_batch(data_iterator) timers('batch generator').stop() # Forward model. # Forward model. lm_labels if args.fp16_lm_cross_entropy: lm_loss_, sop_logits = model(tokens, padding_mask, tokentype_ids=types, lm_labels=lm_labels) else: lm_logits, sop_logits = model(tokens, padding_mask, tokentype_ids=types) lm_loss_ = mpu.vocab_parallel_cross_entropy( lm_logits.contiguous().float(), lm_labels.contiguous()) sop_loss = F.cross_entropy(sop_logits.view(-1, 2).contiguous().float(), sentence_order.view(-1).contiguous(), ignore_index=-1) lm_loss_ = mpu.vocab_parallel_cross_entropy(lm_logits.contiguous().float(), lm_labels.contiguous()) lm_loss = torch.sum( lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum() Loading Loading
megatron/model/bert_model.py +9 −2 Original line number Diff line number Diff line Loading @@ -18,6 +18,7 @@ import torch from megatron import get_args from megatron import mpu from megatron.model.language_model import parallel_lm_logits from megatron.model.language_model import get_language_model from megatron.model.transformer import LayerNorm Loading Loading @@ -138,7 +139,8 @@ class BertModel(MegatronModule): init_method) self._binary_head_key = 'binary_head' def forward(self, input_ids, attention_mask, tokentype_ids=None): def forward(self, input_ids, attention_mask, tokentype_ids=None, lm_labels=None): extended_attention_mask = bert_extended_attention_mask( attention_mask, next(self.language_model.parameters()).dtype) Loading @@ -161,11 +163,16 @@ class BertModel(MegatronModule): lm_logits = self.lm_head( lm_output, self.language_model.embedding.word_embeddings.weight) binary_logits = None if self.add_binary_head: binary_logits = self.binary_head(pooled_output) if lm_labels is None: return lm_logits, binary_logits else: lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits, lm_labels) return lm_loss, binary_logits return lm_logits, None def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False): Loading
pretrain_bert.py +9 −4 Original line number Diff line number Diff line Loading @@ -68,6 +68,7 @@ def get_batch(data_iterator): def forward_step(data_iterator, model): """Forward step.""" args = get_args() timers = get_timers() # Get the batch. Loading @@ -76,15 +77,19 @@ def forward_step(data_iterator, model): = get_batch(data_iterator) timers('batch generator').stop() # Forward model. # Forward model. lm_labels if args.fp16_lm_cross_entropy: lm_loss_, sop_logits = model(tokens, padding_mask, tokentype_ids=types, lm_labels=lm_labels) else: lm_logits, sop_logits = model(tokens, padding_mask, tokentype_ids=types) lm_loss_ = mpu.vocab_parallel_cross_entropy( lm_logits.contiguous().float(), lm_labels.contiguous()) sop_loss = F.cross_entropy(sop_logits.view(-1, 2).contiguous().float(), sentence_order.view(-1).contiguous(), ignore_index=-1) lm_loss_ = mpu.vocab_parallel_cross_entropy(lm_logits.contiguous().float(), lm_labels.contiguous()) lm_loss = torch.sum( lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum() Loading