Commit 29a69547 authored by mshoeybi's avatar mshoeybi Committed by Deepak Narayanan
Browse files

Some bugfixes

parent 39181113
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -408,7 +408,7 @@ def _build_sample_idx(sizes, doc_idx, seq_length,
    return sample_idx


def def _build_shuffle_idx(num_samples, total_size, np_rng):
def _build_shuffle_idx(num_samples, total_size, np_rng):
    """Build the range [0, size) and shuffle."""
    print(' > building shuffle index with split [0, {}) and [{}, {}) '
          '...'.format(num_samples, num_samples, total_size), flush=True)
+5 −3
Original line number Diff line number Diff line
@@ -717,13 +717,14 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
    add_to_logging('optimizer')
    add_to_logging('batch generator')

    batch_size = args.micro_batch_size * args.data_parallel_size * \
        get_num_microbatches()

    # Tensorboard values.
    if writer and torch.distributed.get_rank() == 0:
        writer.add_scalar('learning_rate-iterations', learning_rate, iteration)
        writer.add_scalar('learning_rate-samples', learning_rate,
                          args.consumed_train_samples)
        batch_size = args.micro_batch_size * args.data_parallel_size * \
            get_num_microbatches()
        writer.add_scalar('batch_size-iterations', batch_size, iteration)
        writer.add_scalar('batch_size-samples', batch_size,
                          args.consumed_train_samples)
@@ -748,11 +749,12 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
                              elapsed_time / args.log_interval, iteration)
        log_string = ' iteration {:8d}/{:8d} |'.format(
            iteration, args.train_iters)
        log_string += ' consumed samples {:12d} |'.format(
        log_string += ' consumed samples: {:12d} |'.format(
            args.consumed_train_samples)
        log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(
            elapsed_time * 1000.0 / args.log_interval)
        log_string += ' learning rate: {:.3E} |'.format(learning_rate)
        log_string += ' global batch size: {:6d} |'.format(batch_size)
        num_iterations = max(
            1, args.log_interval - total_loss_dict[skipped_iters_key])
        for key in total_loss_dict: