Commit 78022005 authored by mohammad's avatar mohammad
Browse files

added fp16 lm cross entropy to bert

parent 22e3c7e6
Loading
Loading
Loading
Loading
+9 −2
Original line number Diff line number Diff line
@@ -18,6 +18,7 @@
import torch

from megatron import get_args
from megatron import mpu
from megatron.model.language_model import parallel_lm_logits
from megatron.model.language_model import get_language_model
from megatron.model.transformer import LayerNorm
@@ -138,7 +139,8 @@ class BertModel(MegatronModule):
                                                init_method)
            self._binary_head_key = 'binary_head'

    def forward(self, input_ids, attention_mask, tokentype_ids=None):
    def forward(self, input_ids, attention_mask,
                tokentype_ids=None, lm_labels=None):

        extended_attention_mask = bert_extended_attention_mask(
            attention_mask, next(self.language_model.parameters()).dtype)
@@ -161,11 +163,16 @@ class BertModel(MegatronModule):
        lm_logits = self.lm_head(
            lm_output, self.language_model.embedding.word_embeddings.weight)

        binary_logits = None
        if self.add_binary_head:
            binary_logits = self.binary_head(pooled_output)

        if lm_labels is None:
            return lm_logits, binary_logits
        else:
            lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits, lm_labels)
            return lm_loss, binary_logits

        return lm_logits, None

    def state_dict_for_save_checkpoint(self, destination=None, prefix='',
                                       keep_vars=False):
+9 −4
Original line number Diff line number Diff line
@@ -68,6 +68,7 @@ def get_batch(data_iterator):

def forward_step(data_iterator, model):
    """Forward step."""
    args = get_args()
    timers = get_timers()

    # Get the batch.
@@ -76,15 +77,19 @@ def forward_step(data_iterator, model):
        = get_batch(data_iterator)
    timers('batch generator').stop()

    # Forward model.
    # Forward model. lm_labels
    if args.fp16_lm_cross_entropy:
        lm_loss_, sop_logits = model(tokens, padding_mask, tokentype_ids=types,
                                     lm_labels=lm_labels)
    else:
        lm_logits, sop_logits = model(tokens, padding_mask, tokentype_ids=types)
        lm_loss_ = mpu.vocab_parallel_cross_entropy(
            lm_logits.contiguous().float(), lm_labels.contiguous())

    sop_loss = F.cross_entropy(sop_logits.view(-1, 2).contiguous().float(),
                               sentence_order.view(-1).contiguous(),
                               ignore_index=-1)

    lm_loss_ = mpu.vocab_parallel_cross_entropy(lm_logits.contiguous().float(),
                                                lm_labels.contiguous())
    lm_loss = torch.sum(
        lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()