Commit a7169297 authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

Addressing review comments

parent 58edb19a
Loading
Loading
Loading
Loading
+3 −1
Original line number Diff line number Diff line
@@ -362,7 +362,9 @@ def _add_training_args(parser):
    group.add_argument('--optimizer', type=str, default='adam',
                       choices=['adam', 'sgd'],
                       help='Optimizer function')

    group.add_argument('--dataloader_type', type=str, default='single',
                       choices=['single', 'cyclic'],
                       help='Single pass vs multiple pass data loader')
    return parser


+145 −0
Original line number Diff line number Diff line
@@ -17,12 +17,12 @@


import torch

import random
from megatron import get_args
from megatron import mpu


def build_pretraining_data_loader(dataset, consumed_samples, random_sample=False):
def build_pretraining_data_loader(dataset, consumed_samples):
    """Buld dataloader given an input dataset."""

    if dataset is None:
@@ -30,13 +30,23 @@ def build_pretraining_data_loader(dataset, consumed_samples, random_sample=False
    args = get_args()

    # Megatron sampler
    if args.dataloader_type == 'single':
        batch_sampler = MegatronPretrainingSampler(
            total_samples=len(dataset),
            consumed_samples=consumed_samples,
            micro_batch_size=args.micro_batch_size,
            data_parallel_rank=mpu.get_data_parallel_rank(),
        data_parallel_size=mpu.get_data_parallel_world_size(),
        random_sample=random_sample)
            data_parallel_size=mpu.get_data_parallel_world_size())
    elif args.dataloader_type == 'cyclic':
        batch_sampler = MegatronPretrainingRandomSampler(
            total_samples=len(dataset),
            consumed_samples=consumed_samples,
            micro_batch_size=args.micro_batch_size,
            data_parallel_rank=mpu.get_data_parallel_rank(),
            data_parallel_size=mpu.get_data_parallel_world_size())
    else:
        raise Exception('{} dataloader type is not supported.'.format(
                args.dataloader_type))

    # Torch dataloader.
    return torch.utils.data.DataLoader(dataset,
@@ -44,11 +54,10 @@ def build_pretraining_data_loader(dataset, consumed_samples, random_sample=False
                                       num_workers=args.num_workers,
                                       pin_memory=True)


class MegatronPretrainingSampler:

    def __init__(self, total_samples, consumed_samples, micro_batch_size,
                 data_parallel_rank, data_parallel_size, random_sample=False):
                 data_parallel_rank, data_parallel_size):
        # Keep a copy of input params for later use.
        self.total_samples = total_samples
        self.consumed_samples = consumed_samples
@@ -56,14 +65,50 @@ class MegatronPretrainingSampler:
        self.data_parallel_rank = data_parallel_rank
        self.micro_batch_times_data_parallel_size = \
            self.micro_batch_size * data_parallel_size
        self.random_sample = random_sample

        # Sanity checks.
        assert self.total_samples > 0, \
            'no sample to consume: {}'.format(self.total_samples)
        #assert self.consumed_samples < self.total_samples, \
        #    'no samples left to consume: {}, {}'.format(self.consumed_samples,
        #                                                self.total_samples)
        assert self.consumed_samples < self.total_samples, \
            'no samples left to consume: {}, {}'.format(self.consumed_samples,
                                                        self.total_samples)
        assert self.micro_batch_size > 0
        assert data_parallel_size > 0
        assert self.data_parallel_rank < data_parallel_size, \
            'data_parallel_rank should be smaller than data size: {}, ' \
            '{}'.format(self.data_parallel_rank, data_parallel_size)

    def __len__(self):
        return self.total_samples

    def __iter__(self):
        batch = []
        # Last batch if not complete will be dropped.
        for idx in range(self.consumed_samples, self.total_samples):
            batch.append(idx)
            if len(batch) == self.micro_batch_times_data_parallel_size:
                start_idx = self.data_parallel_rank * self.micro_batch_size
                end_idx = start_idx + self.micro_batch_size
                yield batch[start_idx:end_idx]
                batch = []


class MegatronPretrainingRandomSampler:

    def __init__(self, total_samples, consumed_samples, micro_batch_size,
                 data_parallel_rank, data_parallel_size):
        # Keep a copy of input params for later use.
        self.total_samples = total_samples
        self.consumed_samples = consumed_samples
        self.micro_batch_size = micro_batch_size
        self.data_parallel_rank = data_parallel_rank
        self.data_parallel_size = data_parallel_size
        self.micro_batch_times_data_parallel_size = \
            self.micro_batch_size * data_parallel_size

        # Sanity checks.
        assert self.total_samples > 0, \
            'no sample to consume: {}'.format(self.total_samples)
        assert self.micro_batch_size > 0
        assert data_parallel_size > 0
        assert self.data_parallel_rank < data_parallel_size, \
@@ -76,23 +121,25 @@ class MegatronPretrainingSampler:
    def __iter__(self):
        self.epoch = self.consumed_samples // self.total_samples
        current_epoch_samples = self.consumed_samples % self.total_samples
        if self.random_sample:
        assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0

        # data sharding and random sampling
        bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) \
                       * self.micro_batch_size
        bucket_offset = current_epoch_samples // self.data_parallel_size
        start_idx = self.data_parallel_rank * bucket_size
        
        g = torch.Generator()
        g.manual_seed(self.epoch)
            idx_range_total = \
                torch.randperm(self.total_samples, generator=g).tolist()
            idx_range = idx_range_total[current_epoch_samples:]
        else:
            idx_range = range(current_epoch_samples, self.total_samples)
        random_idx = torch.randperm(bucket_size, generator=g).tolist()
        idx_range = [start_idx + x for x in random_idx[bucket_offset:]] 

        batch = []
        # Last batch if not complete will be dropped.
        for idx in idx_range:
            batch.append(idx)
            if len(batch) == self.micro_batch_times_data_parallel_size:
                self.consumed_samples += len(batch)
                start_idx = self.data_parallel_rank * self.micro_batch_size
                end_idx = start_idx + self.micro_batch_size
                yield batch[start_idx:end_idx]
            if len(batch) == self.micro_batch_size:
                self.consumed_samples += self.micro_batch_times_data_parallel_size
                yield batch
                batch = []
        self.consumed_samples += len(batch)
        self.consumed_samples += self.total_samples % self.micro_batch_times_data_parallel_size
+4 −1
Original line number Diff line number Diff line
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import torch
from torchvision import datasets, transforms
from megatron.data.autoaugment import ImageNetPolicy

@@ -32,7 +33,8 @@ def build_train_valid_datasets(data_path, crop_size=224, color_jitter=True):
                brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1
            )
        ]
    process += [ImageNetPolicy(), transforms.ToTensor(), normalize]
    fp16_t = transforms.ConvertImageDtype(torch.half)
    process += [ImageNetPolicy(), transforms.ToTensor(), normalize, fp16_t]
    transform_train = transforms.Compose(process)
    train_data = datasets.ImageFolder(
        root=train_data_path, transform=transform_train
@@ -46,6 +48,7 @@ def build_train_valid_datasets(data_path, crop_size=224, color_jitter=True):
            transforms.CenterCrop(crop_size),
            transforms.ToTensor(),
            normalize,
            fp16_t
        ]
    )
    val_data = datasets.ImageFolder(
+2 −2
Original line number Diff line number Diff line
@@ -122,7 +122,7 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
        assert input.dim() == 4

        # invoke custom kernel
        if self.input_in_fp16 and key_seq_len <= 2048 and \
        if self.input_in_fp16 and key_seq_len <= 2048 and mask is not None and \
           query_seq_len % 4 == 0 and self.scaled_masked_softmax_fusion:

            scale = self.scale if self.scale is not None else 1.0
@@ -142,7 +142,7 @@ class FusedScaleMaskSoftmax(torch.nn.Module):

            if self.scale is not None:
                input = input * self.scale
            mask_output = self.mask_func(input, mask) if mask else input
            mask_output = self.mask_func(input, mask) if mask is not None else input
            probs = torch.nn.Softmax(dim=-1)(mask_output)

            if self.input_in_fp16 and self.softmax_in_fp32:
+1 −1
Original line number Diff line number Diff line
@@ -120,7 +120,7 @@ def twod_interpolate_position_embeddings_hook(


class VitModel(MegatronModule):
    """Bert Language model."""
    """Vision Transformer Model."""

    def __init__(self, num_classes, finetune=False):
        super(VitModel, self).__init__()
Loading