Commit cfe35fd3 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'cross_entropy_mem_opt' into 'master'

Option to perform language model cross entropy calculations in fp16

See merge request ADLR/megatron-lm!86
parents 5897a790 197c132e
Loading
Loading
Loading
Loading
+7 −0
Original line number Diff line number Diff line
@@ -97,6 +97,9 @@ def parse_args(extra_args_provider=None, defaults={},
        if args.num_unique_layers < args.num_layers:
            assert args.DDP_impl == 'local', \
                'torch-DDP does not work with parameters sharing.'
    # Mixed precision checks.
    if args.fp16_lm_cross_entropy:
        assert args.fp16, 'lm cross entropy in fp16 only support in fp16 mode.'

    _print_args(args)
    return args
@@ -294,6 +297,10 @@ def _add_mixed_precision_args(parser):
                       help='Window over which to raise/lower dynamic scale.')
    group.add_argument('--min-scale', type=float, default=1,
                       help='Minimum loss scale for dynamic loss scale.')
    group.add_argument('--fp16-lm-cross-entropy', action='store_true',
                       help='Move the cross entropy unreduced loss calculation'
                       'for lm head to fp16.')


    return parser

+15 −2
Original line number Diff line number Diff line
@@ -18,6 +18,7 @@
import torch

from megatron import get_args
from megatron import mpu
from megatron.model.language_model import parallel_lm_logits
from megatron.model.language_model import get_language_model
from megatron.model.transformer import LayerNorm
@@ -114,6 +115,7 @@ class BertModel(MegatronModule):
        super(BertModel, self).__init__()
        args = get_args()

        self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
        self.add_binary_head = add_binary_head
        self.parallel_output = parallel_output
        init_method = init_method_normal(args.init_method_std)
@@ -138,7 +140,8 @@ class BertModel(MegatronModule):
                                                init_method)
            self._binary_head_key = 'binary_head'

    def forward(self, input_ids, attention_mask, tokentype_ids=None):
    def forward(self, input_ids, attention_mask,
                tokentype_ids=None, lm_labels=None):

        extended_attention_mask = bert_extended_attention_mask(
            attention_mask, next(self.language_model.parameters()).dtype)
@@ -161,11 +164,21 @@ class BertModel(MegatronModule):
        lm_logits = self.lm_head(
            lm_output, self.language_model.embedding.word_embeddings.weight)

        binary_logits = None
        if self.add_binary_head:
            binary_logits = self.binary_head(pooled_output)

        if lm_labels is None:
            return lm_logits, binary_logits
        else:
            if self.fp16_lm_cross_entropy:
                assert lm_logits.dtype == torch.half
                lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits, lm_labels)
            else:
                lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits.float(),
                                                           lm_labels)
            return lm_loss, binary_logits

        return lm_logits, None

    def state_dict_for_save_checkpoint(self, destination=None, prefix='',
                                       keep_vars=False):
+13 −2
Original line number Diff line number Diff line
@@ -18,6 +18,7 @@
import torch

from megatron import get_args
from megatron import mpu
from megatron.module import MegatronModule

from .language_model import parallel_lm_logits
@@ -39,6 +40,7 @@ class GPT2Model(MegatronModule):
        args = get_args()

        self.parallel_output = parallel_output
        self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy

        self.language_model, self._language_model_key = get_language_model(
            attention_mask_func=gpt2_attention_mask_func,
@@ -48,7 +50,7 @@ class GPT2Model(MegatronModule):
            scaled_init_method=scaled_init_method_normal(args.init_method_std,
                                                         args.num_layers))

    def forward(self, input_ids, position_ids, attention_mask,
    def forward(self, input_ids, position_ids, attention_mask, labels=None,
                tokentype_ids=None, layer_past=None, get_key_value=False,
                forward_method_parallel_output=None):

@@ -75,7 +77,16 @@ class GPT2Model(MegatronModule):
        if get_key_value:
            output = [output, presents]

        if labels is None:
            return output
        else:
            if self.fp16_lm_cross_entropy:
                assert output.dtype == torch.half
                loss = mpu.vocab_parallel_cross_entropy(output, labels)
            else:
                loss = mpu.vocab_parallel_cross_entropy(output.float(), labels)
            return loss


    def state_dict_for_save_checkpoint(self, destination=None, prefix='',
                                       keep_vars=False):
+7 −6
Original line number Diff line number Diff line
@@ -68,6 +68,7 @@ def get_batch(data_iterator):

def forward_step(data_iterator, model):
    """Forward step."""
    args = get_args()
    timers = get_timers()

    # Get the batch.
@@ -76,15 +77,15 @@ def forward_step(data_iterator, model):
        = get_batch(data_iterator)
    timers('batch generator').stop()

    # Forward model.
    lm_logits, sop_logits = model(tokens, padding_mask, tokentype_ids=types)
    # Forward model. lm_labels
    lm_loss_, sop_logits = model(tokens, padding_mask,
                                 tokentype_ids=types,
                                 lm_labels=lm_labels)

    sop_loss = F.cross_entropy(sop_logits.view(-1, 2).contiguous().float(),
                               sentence_order.view(-1).contiguous(),
    sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(),
                               sentence_order.view(-1),
                               ignore_index=-1)

    lm_loss_ = mpu.vocab_parallel_cross_entropy(lm_logits.contiguous().float(),
                                                lm_labels.contiguous())
    lm_loss = torch.sum(
        lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()

+3 −3
Original line number Diff line number Diff line
@@ -72,6 +72,7 @@ def get_batch(data_iterator):

def forward_step(data_iterator, model):
    """Forward step."""
    args = get_args()
    timers = get_timers()

    # Get the batch.
@@ -81,9 +82,8 @@ def forward_step(data_iterator, model):
    timers('batch generator').stop()

    # Forward model.
    output = model(tokens, position_ids, attention_mask)
    losses = mpu.vocab_parallel_cross_entropy(output.contiguous().float(),
                                              labels)
    losses = model(tokens, position_ids, attention_mask, labels=labels)
    
    loss_mask = loss_mask.view(-1)
    loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()