Commit 3ee811be authored by Mohammad Shoeybi's avatar Mohammad Shoeybi
Browse files

Merge branch 'batch_sampler_update' into 'master'

batch sampler interleave flag

See merge request ADLR/megatron-lm!81
parents ad299490 5cb0c807
Loading
Loading
Loading
Loading
+15 −2
Original line number Diff line number Diff line
@@ -80,10 +80,20 @@ class DistributedBatchSampler(data.sampler.BatchSampler):
    implementation is at the batch sampler level, instead of just the
    sampler level. This allows wrapping of arbitrary data samplers
    (sequential, random, WeightedRandomSampler, etc.) with this batch
    sampler."""
    sampler.
    
    The `interleave` argument specifies how to distribute a batch. A value
    of True combined with the above random sampler is equivalent to pytorch's
    torch.utils.data.distributed.DistributedSampler.

    For the following batch [0,1,2,3,4,5,6,7] and data parallelism of 2 
    specifying True will result in the following samples for each gpu:
        GPU0: [0,2,4,6] GPU1: [1,3,5,7]
    specifying False will result in the following samples:
        GPU0: [0,1,2,3] GPU1: [4,5,6,7]"""

    def __init__(self, sampler, batch_size, drop_last, rank=-1,
                 world_size=2, wrap_last=False):
                 world_size=2, wrap_last=False, interleave=False):
        super(DistributedBatchSampler, self).__init__(sampler, batch_size,
                                                      drop_last)
        if rank == -1:
@@ -95,6 +105,7 @@ class DistributedBatchSampler(data.sampler.BatchSampler):
        self.wrap_around = 0
        self.wrap_last = wrap_last
        self.start_iter = 0
        self.interleave = interleave

    def __iter__(self):
        batch = []
@@ -130,6 +141,8 @@ class DistributedBatchSampler(data.sampler.BatchSampler):

    def _batch(self, batch):
        """extracts samples only pertaining to this worker's batch"""
        if self.interleave:
            return batch[self.rank:self.batch_size:self.world_size]
        start = self.rank * self.batch_size // self.world_size
        end = (self.rank + 1) * self.batch_size // self.world_size
        return batch[start:end]