Loading megatron/data/samplers.py +0 −13 Original line number Diff line number Diff line Loading @@ -15,8 +15,6 @@ """Batch samplers that work with either random or sequential data samplers.""" import numpy as np import torch from torch.utils import data Loading Loading @@ -100,7 +98,6 @@ class DistributedBatchSampler(data.sampler.BatchSampler): def __iter__(self): batch = [] last_batch = None i = 0 for idx in self.data_iterator(self.sampler, wrap_around=False): batch.append(idx) Loading @@ -110,7 +107,6 @@ class DistributedBatchSampler(data.sampler.BatchSampler): yield tbatch self.start_iter = 0 i += 1 last_batch = np.array(list(tbatch)) batch = [] batch_len = len(batch) if batch_len > 0 and not self.drop_last: Loading @@ -118,15 +114,6 @@ class DistributedBatchSampler(data.sampler.BatchSampler): self.sampler.wrap_around -= (self.batch_size) self.wrap_around += (len(batch)) self.wrap_around %= self.batch_size if isinstance(self.sampler, TransposedSampler): for i, idx in enumerate(self.data_iterator( self.sampler, wrap_around=True)): if i == 0: continue batch.append(idx) new_batch_len = len(batch) if len(batch) == self.batch_size: break yield self._batch(batch) if self.wrap_last: self.sampler.wrap_around += self.batch_size Loading Loading
megatron/data/samplers.py +0 −13 Original line number Diff line number Diff line Loading @@ -15,8 +15,6 @@ """Batch samplers that work with either random or sequential data samplers.""" import numpy as np import torch from torch.utils import data Loading Loading @@ -100,7 +98,6 @@ class DistributedBatchSampler(data.sampler.BatchSampler): def __iter__(self): batch = [] last_batch = None i = 0 for idx in self.data_iterator(self.sampler, wrap_around=False): batch.append(idx) Loading @@ -110,7 +107,6 @@ class DistributedBatchSampler(data.sampler.BatchSampler): yield tbatch self.start_iter = 0 i += 1 last_batch = np.array(list(tbatch)) batch = [] batch_len = len(batch) if batch_len > 0 and not self.drop_last: Loading @@ -118,15 +114,6 @@ class DistributedBatchSampler(data.sampler.BatchSampler): self.sampler.wrap_around -= (self.batch_size) self.wrap_around += (len(batch)) self.wrap_around %= self.batch_size if isinstance(self.sampler, TransposedSampler): for i, idx in enumerate(self.data_iterator( self.sampler, wrap_around=True)): if i == 0: continue batch.append(idx) new_batch_len = len(batch) if len(batch) == self.batch_size: break yield self._batch(batch) if self.wrap_last: self.sampler.wrap_around += self.batch_size Loading