Commit 60078cf3 authored by mshoeybi's avatar mshoeybi
Browse files

minor changes from github issues

parent f69f699f
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -674,7 +674,7 @@ def get_samples_mapping(indexed_dataset,
        # Build samples mapping
        verbose = torch.distributed.get_rank() == 0
        start_time = time.time()
        print_rank_0(' > building sapmles index mapping for {} ...'.format(
        print_rank_0(' > building samples index mapping for {} ...'.format(
            name))
        # First compile and then import.
        from megatron.data import helpers
@@ -688,7 +688,7 @@ def get_samples_mapping(indexed_dataset,
            seed,
            verbose,
            2 if binary_head else 1)
        print_rank_0(' > done building sapmles index maping')
        print_rank_0(' > done building samples index maping')
        np.save(indexmap_filename, samples_mapping, allow_pickle=True)
        print_rank_0(' > saved the index mapping in {}'.format(
            indexmap_filename))
+2 −2
Original line number Diff line number Diff line
@@ -541,7 +541,7 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
    if iteration % args.log_interval == 0:
        elapsed_time = timers('interval-time').elapsed()
        elapsed_time_per_iteration = elapsed_time / total_iterations
        if writer and torch.distributed.get_rank() == 0:
        if writer:
            if args.log_timers_to_tensorboard:
                writer.add_scalar('iteration-time',
                                  elapsed_time_per_iteration, iteration)
@@ -748,7 +748,7 @@ def evaluate_and_print_results(prefix, forward_step_func,
        string += '{} value: {:.6E} | '.format(key, total_loss_dict[key].item())
        ppl = math.exp(min(20, total_loss_dict[key].item()))
        string += '{} PPL: {:.6E} | '.format(key, ppl)
        if writer and is_last_rank():
        if writer:
            writer.add_scalar('{} validation'.format(key),
                              total_loss_dict[key].item(),
                              iteration)