Loading megatron/arguments.py +4 −1 Original line number Diff line number Diff line Loading @@ -124,6 +124,9 @@ def parse_args(extra_args_provider=None, defaults={}, print('using {} for parameters ...'.format(args.params_dtype), flush=True) if args.dataloader_type is None: args.dataloader_type = 'single' # Consumed tokens. args.consumed_train_samples = 0 args.consumed_valid_samples = 0 Loading Loading @@ -365,7 +368,7 @@ def _add_training_args(parser): group.add_argument('--optimizer', type=str, default='adam', choices=['adam', 'sgd'], help='Optimizer function') group.add_argument('--dataloader_type', type=str, default='single', group.add_argument('--dataloader-type', type=str, default=None, choices=['single', 'cyclic'], help='Single pass vs multiple pass data loader') return parser Loading megatron/data/data_samplers.py +6 −4 Original line number Diff line number Diff line Loading @@ -105,6 +105,8 @@ class MegatronPretrainingRandomSampler: self.data_parallel_size = data_parallel_size self.micro_batch_times_data_parallel_size = \ self.micro_batch_size * data_parallel_size self.last_batch_size = \ self.total_samples % self.micro_batch_times_data_parallel_size # Sanity checks. assert self.total_samples > 0, \ Loading @@ -119,8 +121,9 @@ class MegatronPretrainingRandomSampler: return self.total_samples def __iter__(self): self.epoch = self.consumed_samples // self.total_samples current_epoch_samples = self.consumed_samples % self.total_samples active_total_samples = self.total_samples - self.last_batch_size self.epoch = self.consumed_samples // active_total_samples current_epoch_samples = self.consumed_samples % active_total_samples assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0 # data sharding and random sampling Loading @@ -142,4 +145,3 @@ class MegatronPretrainingRandomSampler: self.consumed_samples += self.micro_batch_times_data_parallel_size yield batch batch = [] self.consumed_samples += self.total_samples % self.micro_batch_times_data_parallel_size Loading
megatron/arguments.py +4 −1 Original line number Diff line number Diff line Loading @@ -124,6 +124,9 @@ def parse_args(extra_args_provider=None, defaults={}, print('using {} for parameters ...'.format(args.params_dtype), flush=True) if args.dataloader_type is None: args.dataloader_type = 'single' # Consumed tokens. args.consumed_train_samples = 0 args.consumed_valid_samples = 0 Loading Loading @@ -365,7 +368,7 @@ def _add_training_args(parser): group.add_argument('--optimizer', type=str, default='adam', choices=['adam', 'sgd'], help='Optimizer function') group.add_argument('--dataloader_type', type=str, default='single', group.add_argument('--dataloader-type', type=str, default=None, choices=['single', 'cyclic'], help='Single pass vs multiple pass data loader') return parser Loading
megatron/data/data_samplers.py +6 −4 Original line number Diff line number Diff line Loading @@ -105,6 +105,8 @@ class MegatronPretrainingRandomSampler: self.data_parallel_size = data_parallel_size self.micro_batch_times_data_parallel_size = \ self.micro_batch_size * data_parallel_size self.last_batch_size = \ self.total_samples % self.micro_batch_times_data_parallel_size # Sanity checks. assert self.total_samples > 0, \ Loading @@ -119,8 +121,9 @@ class MegatronPretrainingRandomSampler: return self.total_samples def __iter__(self): self.epoch = self.consumed_samples // self.total_samples current_epoch_samples = self.consumed_samples % self.total_samples active_total_samples = self.total_samples - self.last_batch_size self.epoch = self.consumed_samples // active_total_samples current_epoch_samples = self.consumed_samples % active_total_samples assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0 # data sharding and random sampling Loading @@ -142,4 +145,3 @@ class MegatronPretrainingRandomSampler: self.consumed_samples += self.micro_batch_times_data_parallel_size yield batch batch = [] self.consumed_samples += self.total_samples % self.micro_batch_times_data_parallel_size