Commit b4b0d739 authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Merge branch 'fix_logging' into 'main'

fixed loss average when all but one value is skipped

See merge request ADLR/megatron-lm!164
parents 79888e16 664cd28b
Loading
Loading
Loading
Loading
+5 −3
Original line number Diff line number Diff line
@@ -315,7 +315,8 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
    got_nan = False
    for key in loss_dict:
        if not skipped_iter:
            total_loss_dict[key] = total_loss_dict.get(key, 0.) + loss_dict[key]
            total_loss_dict[key] = total_loss_dict.get(
                key, torch.cuda.FloatTensor([0.0])) + loss_dict[key]
        else:
            value = loss_dict[key].float().sum().item()
            is_nan = value == float('inf') or \
@@ -369,8 +370,9 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
        for key in total_loss_dict:
            if key not in [skipped_iters_key, got_nan_key]:
                avg = total_loss_dict[key].item() / float(num_iterations)
                if avg > 0.0:
                    log_string += ' {}: {:.6E} |'.format(key, avg)
                total_loss_dict[key] = 0.0
                total_loss_dict[key] = torch.cuda.FloatTensor([0.0])
        if args.fp16:
            log_string += ' loss scale: {:.1f} |'.format(loss_scale)
        log_string += ' number of skipped iterations: {:3d} |'.format(