Commit 197c132e authored by mohammad's avatar mohammad
Browse files

addressed jareds comments

parent 78022005
Loading
Loading
Loading
Loading
+3 −0
Original line number Diff line number Diff line
@@ -97,6 +97,9 @@ def parse_args(extra_args_provider=None, defaults={},
        if args.num_unique_layers < args.num_layers:
            assert args.DDP_impl == 'local', \
                'torch-DDP does not work with parameters sharing.'
    # Mixed precision checks.
    if args.fp16_lm_cross_entropy:
        assert args.fp16, 'lm cross entropy in fp16 only support in fp16 mode.'

    _print_args(args)
    return args
+7 −1
Original line number Diff line number Diff line
@@ -115,6 +115,7 @@ class BertModel(MegatronModule):
        super(BertModel, self).__init__()
        args = get_args()

        self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
        self.add_binary_head = add_binary_head
        self.parallel_output = parallel_output
        init_method = init_method_normal(args.init_method_std)
@@ -170,7 +171,12 @@ class BertModel(MegatronModule):
        if lm_labels is None:
            return lm_logits, binary_logits
        else:
            if self.fp16_lm_cross_entropy:
                assert lm_logits.dtype == torch.half
                lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits, lm_labels)
            else:
                lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits.float(),
                                                           lm_labels)
            return lm_loss, binary_logits


+6 −1
Original line number Diff line number Diff line
@@ -40,6 +40,7 @@ class GPT2Model(MegatronModule):
        args = get_args()

        self.parallel_output = parallel_output
        self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy

        self.language_model, self._language_model_key = get_language_model(
            attention_mask_func=gpt2_attention_mask_func,
@@ -79,7 +80,11 @@ class GPT2Model(MegatronModule):
        if labels is None:
            return output
        else:
            if self.fp16_lm_cross_entropy:
                assert output.dtype == torch.half
                loss = mpu.vocab_parallel_cross_entropy(output, labels)
            else:
                loss = mpu.vocab_parallel_cross_entropy(output.float(), labels)
            return loss


+5 −9
Original line number Diff line number Diff line
@@ -78,16 +78,12 @@ def forward_step(data_iterator, model):
    timers('batch generator').stop()

    # Forward model. lm_labels
    if args.fp16_lm_cross_entropy:
        lm_loss_, sop_logits = model(tokens, padding_mask, tokentype_ids=types,
    lm_loss_, sop_logits = model(tokens, padding_mask,
                                 tokentype_ids=types,
                                 lm_labels=lm_labels)
    else:
        lm_logits, sop_logits = model(tokens, padding_mask, tokentype_ids=types)
        lm_loss_ = mpu.vocab_parallel_cross_entropy(
            lm_logits.contiguous().float(), lm_labels.contiguous())

    sop_loss = F.cross_entropy(sop_logits.view(-1, 2).contiguous().float(),
                               sentence_order.view(-1).contiguous(),
    sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(),
                               sentence_order.view(-1),
                               ignore_index=-1)

    lm_loss = torch.sum(
+2 −7
Original line number Diff line number Diff line
@@ -82,12 +82,7 @@ def forward_step(data_iterator, model):
    timers('batch generator').stop()

    # Forward model.
    if args.fp16_lm_cross_entropy:
    losses = model(tokens, position_ids, attention_mask, labels=labels)
    else:
        output = model(tokens, position_ids, attention_mask)
        losses = mpu.vocab_parallel_cross_entropy(output.contiguous().float(),
                                                  labels)
    
    loss_mask = loss_mask.view(-1)
    loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()