Loading megatron/data/samplers.py +15 −2 Original line number Diff line number Diff line Loading @@ -80,10 +80,20 @@ class DistributedBatchSampler(data.sampler.BatchSampler): implementation is at the batch sampler level, instead of just the sampler level. This allows wrapping of arbitrary data samplers (sequential, random, WeightedRandomSampler, etc.) with this batch sampler.""" sampler. The `interleave` argument specifies how to distribute a batch. A value of True combined with the above random sampler is equivalent to pytorch's torch.utils.data.distributed.DistributedSampler. For the following batch [0,1,2,3,4,5,6,7] and data parallelism of 2 specifying True will result in the following samples for each gpu: GPU0: [0,2,4,6] GPU1: [1,3,5,7] specifying False will result in the following samples: GPU0: [0,1,2,3] GPU1: [4,5,6,7]""" def __init__(self, sampler, batch_size, drop_last, rank=-1, world_size=2, wrap_last=False): world_size=2, wrap_last=False, interleave=False): super(DistributedBatchSampler, self).__init__(sampler, batch_size, drop_last) if rank == -1: Loading @@ -95,6 +105,7 @@ class DistributedBatchSampler(data.sampler.BatchSampler): self.wrap_around = 0 self.wrap_last = wrap_last self.start_iter = 0 self.interleave = interleave def __iter__(self): batch = [] Loading Loading @@ -130,6 +141,8 @@ class DistributedBatchSampler(data.sampler.BatchSampler): def _batch(self, batch): """extracts samples only pertaining to this worker's batch""" if self.interleave: return batch[self.rank:self.batch_size:self.world_size] start = self.rank * self.batch_size // self.world_size end = (self.rank + 1) * self.batch_size // self.world_size return batch[start:end] Loading
megatron/data/samplers.py +15 −2 Original line number Diff line number Diff line Loading @@ -80,10 +80,20 @@ class DistributedBatchSampler(data.sampler.BatchSampler): implementation is at the batch sampler level, instead of just the sampler level. This allows wrapping of arbitrary data samplers (sequential, random, WeightedRandomSampler, etc.) with this batch sampler.""" sampler. The `interleave` argument specifies how to distribute a batch. A value of True combined with the above random sampler is equivalent to pytorch's torch.utils.data.distributed.DistributedSampler. For the following batch [0,1,2,3,4,5,6,7] and data parallelism of 2 specifying True will result in the following samples for each gpu: GPU0: [0,2,4,6] GPU1: [1,3,5,7] specifying False will result in the following samples: GPU0: [0,1,2,3] GPU1: [4,5,6,7]""" def __init__(self, sampler, batch_size, drop_last, rank=-1, world_size=2, wrap_last=False): world_size=2, wrap_last=False, interleave=False): super(DistributedBatchSampler, self).__init__(sampler, batch_size, drop_last) if rank == -1: Loading @@ -95,6 +105,7 @@ class DistributedBatchSampler(data.sampler.BatchSampler): self.wrap_around = 0 self.wrap_last = wrap_last self.start_iter = 0 self.interleave = interleave def __iter__(self): batch = [] Loading Loading @@ -130,6 +141,8 @@ class DistributedBatchSampler(data.sampler.BatchSampler): def _batch(self, batch): """extracts samples only pertaining to this worker's batch""" if self.interleave: return batch[self.rank:self.batch_size:self.world_size] start = self.rank * self.batch_size // self.world_size end = (self.rank + 1) * self.batch_size // self.world_size return batch[start:end]