Commit acfe848e authored by mohammad's avatar mohammad
Browse files

added fp16 cross entropy loss option for gpt2

parent 2ede8235
Loading
Loading
Loading
Loading
+4 −0
Original line number Diff line number Diff line
@@ -294,6 +294,10 @@ def _add_mixed_precision_args(parser):
                       help='Window over which to raise/lower dynamic scale.')
    group.add_argument('--min-scale', type=float, default=1,
                       help='Minimum loss scale for dynamic loss scale.')
    group.add_argument('--fp16-lm-cross-entropy', action='store_true',
                       help='Move the cross entropy unreduced loss calculation'
                       'for lm head to fp16.')


    return parser

+7 −11
Original line number Diff line number Diff line
@@ -18,6 +18,7 @@
import torch

from megatron import get_args
from megatron import mpu
from megatron.module import MegatronModule

from .language_model import parallel_lm_logits
@@ -25,9 +26,6 @@ from .language_model import get_language_model
from .utils import init_method_normal
from .utils import scaled_init_method_normal

from megatron.utils import report_memory
from megatron import mpu


def gpt2_attention_mask_func(attention_scores, ltor_mask):
    attention_scores.masked_fill_(ltor_mask, -10000.0)
@@ -51,7 +49,7 @@ class GPT2Model(MegatronModule):
            scaled_init_method=scaled_init_method_normal(args.init_method_std,
                                                         args.num_layers))

    def forward(self, input_ids, position_ids, attention_mask, labels,
    def forward(self, input_ids, position_ids, attention_mask, labels=None,
                tokentype_ids=None, layer_past=None, get_key_value=False,
                forward_method_parallel_output=None):

@@ -78,14 +76,12 @@ class GPT2Model(MegatronModule):
        if get_key_value:
            output = [output, presents]

        #report_memory('AAA')

        losses = mpu.vocab_parallel_cross_entropy(output, labels)

        #report_memory('BBB')
        if labels is not None:
            return output
        else:
            loss = mpu.vocab_parallel_cross_entropy(output, labels)
            return loss

        #return output
        return losses

    def state_dict_for_save_checkpoint(self, destination=None, prefix='',
                                       keep_vars=False):
+0 −1
Original line number Diff line number Diff line
@@ -379,7 +379,6 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
                                          optimizer.param_groups[0]['lr'],
                                          iteration, loss_scale,
                                          report_memory_flag)
        #report_memory_flag = True

        # Autoresume
        if args.adlr_autoresume and \
+9 −7
Original line number Diff line number Diff line
@@ -27,7 +27,7 @@ from megatron.model import GPT2Model
from megatron.training import pretrain
from megatron.utils import get_ltor_masks_and_position_ids
from megatron.utils import reduce_losses
from megatron.utils import report_memory


def model_provider():
    """Build the model."""
@@ -72,6 +72,7 @@ def get_batch(data_iterator):

def forward_step(data_iterator, model):
    """Forward step."""
    args = get_args()
    timers = get_timers()

    # Get the batch.
@@ -81,12 +82,13 @@ def forward_step(data_iterator, model):
    timers('batch generator').stop()

    # Forward model.
    losses = model(tokens, position_ids, attention_mask, labels)
    #report_memory('CCC')
    #exit()
    #losses = mpu.vocab_parallel_cross_entropy(output.contiguous().float(),
    #                                          labels)
    #report_memory('DDD')
    if args.fp16_lm_cross_entropy:
        losses = model(tokens, position_ids, attention_mask, labels=labels)
    else:
        output = model(tokens, position_ids, attention_mask)
        losses = mpu.vocab_parallel_cross_entropy(output.contiguous().float(),
                                                  labels)

    loss_mask = loss_mask.view(-1)
    loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()