Commit 85589322 authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

dataloader_type argument fix + randomsampler fix

parent e6c7b05e
Loading
Loading
Loading
Loading
+4 −1
Original line number Diff line number Diff line
@@ -124,6 +124,9 @@ def parse_args(extra_args_provider=None, defaults={},
        print('using {} for parameters ...'.format(args.params_dtype),
              flush=True)

    if args.dataloader_type is None:
        args.dataloader_type = 'single'

    # Consumed tokens.
    args.consumed_train_samples = 0
    args.consumed_valid_samples = 0
@@ -365,7 +368,7 @@ def _add_training_args(parser):
    group.add_argument('--optimizer', type=str, default='adam',
                       choices=['adam', 'sgd'],
                       help='Optimizer function')
    group.add_argument('--dataloader_type', type=str, default='single',
    group.add_argument('--dataloader-type', type=str, default=None,
                       choices=['single', 'cyclic'],
                       help='Single pass vs multiple pass data loader')
    return parser
+6 −4
Original line number Diff line number Diff line
@@ -105,6 +105,8 @@ class MegatronPretrainingRandomSampler:
        self.data_parallel_size = data_parallel_size
        self.micro_batch_times_data_parallel_size = \
            self.micro_batch_size * data_parallel_size
        self.last_batch_size = \
            self.total_samples % self.micro_batch_times_data_parallel_size

        # Sanity checks.
        assert self.total_samples > 0, \
@@ -119,8 +121,9 @@ class MegatronPretrainingRandomSampler:
        return self.total_samples

    def __iter__(self):
        self.epoch = self.consumed_samples // self.total_samples
        current_epoch_samples = self.consumed_samples % self.total_samples
        active_total_samples = self.total_samples - self.last_batch_size
        self.epoch = self.consumed_samples // active_total_samples
        current_epoch_samples = self.consumed_samples % active_total_samples
        assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0

        # data sharding and random sampling
@@ -142,4 +145,3 @@ class MegatronPretrainingRandomSampler:
                self.consumed_samples += self.micro_batch_times_data_parallel_size
                yield batch
                batch = []
        self.consumed_samples += self.total_samples % self.micro_batch_times_data_parallel_size