Commit 664cd28b authored by mohammad's avatar mohammad
Browse files

fixed loss average when all but one value is skipped

parent 79888e16
Loading
Loading
Loading
Loading
+5 −3
Original line number Diff line number Diff line
@@ -315,7 +315,8 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
    got_nan = False
    for key in loss_dict:
        if not skipped_iter:
            total_loss_dict[key] = total_loss_dict.get(key, 0.) + loss_dict[key]
            total_loss_dict[key] = total_loss_dict.get(
                key, torch.cuda.FloatTensor([0.0])) + loss_dict[key]
        else:
            value = loss_dict[key].float().sum().item()
            is_nan = value == float('inf') or \
@@ -369,8 +370,9 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
        for key in total_loss_dict:
            if key not in [skipped_iters_key, got_nan_key]:
                avg = total_loss_dict[key].item() / float(num_iterations)
                if avg > 0.0:
                    log_string += ' {}: {:.6E} |'.format(key, avg)
                total_loss_dict[key] = 0.0
                total_loss_dict[key] = torch.cuda.FloatTensor([0.0])
        if args.fp16:
            log_string += ' loss scale: {:.1f} |'.format(loss_scale)
        log_string += ' number of skipped iterations: {:3d} |'.format(