Loading megatron/arguments.py +3 −1 Original line number Diff line number Diff line Loading @@ -362,7 +362,9 @@ def _add_training_args(parser): group.add_argument('--optimizer', type=str, default='adam', choices=['adam', 'sgd'], help='Optimizer function') group.add_argument('--dataloader_type', type=str, default='single', choices=['single', 'cyclic'], help='Single pass vs multiple pass data loader') return parser Loading megatron/data/data_loaders.py→megatron/data/data_samplers.py +145 −0 Original line number Diff line number Diff line Loading @@ -17,12 +17,12 @@ import torch import random from megatron import get_args from megatron import mpu def build_pretraining_data_loader(dataset, consumed_samples, random_sample=False): def build_pretraining_data_loader(dataset, consumed_samples): """Buld dataloader given an input dataset.""" if dataset is None: Loading @@ -30,13 +30,23 @@ def build_pretraining_data_loader(dataset, consumed_samples, random_sample=False args = get_args() # Megatron sampler if args.dataloader_type == 'single': batch_sampler = MegatronPretrainingSampler( total_samples=len(dataset), consumed_samples=consumed_samples, micro_batch_size=args.micro_batch_size, data_parallel_rank=mpu.get_data_parallel_rank(), data_parallel_size=mpu.get_data_parallel_world_size(), random_sample=random_sample) data_parallel_size=mpu.get_data_parallel_world_size()) elif args.dataloader_type == 'cyclic': batch_sampler = MegatronPretrainingRandomSampler( total_samples=len(dataset), consumed_samples=consumed_samples, micro_batch_size=args.micro_batch_size, data_parallel_rank=mpu.get_data_parallel_rank(), data_parallel_size=mpu.get_data_parallel_world_size()) else: raise Exception('{} dataloader type is not supported.'.format( args.dataloader_type)) # Torch dataloader. return torch.utils.data.DataLoader(dataset, Loading @@ -44,11 +54,10 @@ def build_pretraining_data_loader(dataset, consumed_samples, random_sample=False num_workers=args.num_workers, pin_memory=True) class MegatronPretrainingSampler: def __init__(self, total_samples, consumed_samples, micro_batch_size, data_parallel_rank, data_parallel_size, random_sample=False): data_parallel_rank, data_parallel_size): # Keep a copy of input params for later use. self.total_samples = total_samples self.consumed_samples = consumed_samples Loading @@ -56,14 +65,50 @@ class MegatronPretrainingSampler: self.data_parallel_rank = data_parallel_rank self.micro_batch_times_data_parallel_size = \ self.micro_batch_size * data_parallel_size self.random_sample = random_sample # Sanity checks. assert self.total_samples > 0, \ 'no sample to consume: {}'.format(self.total_samples) #assert self.consumed_samples < self.total_samples, \ # 'no samples left to consume: {}, {}'.format(self.consumed_samples, # self.total_samples) assert self.consumed_samples < self.total_samples, \ 'no samples left to consume: {}, {}'.format(self.consumed_samples, self.total_samples) assert self.micro_batch_size > 0 assert data_parallel_size > 0 assert self.data_parallel_rank < data_parallel_size, \ 'data_parallel_rank should be smaller than data size: {}, ' \ '{}'.format(self.data_parallel_rank, data_parallel_size) def __len__(self): return self.total_samples def __iter__(self): batch = [] # Last batch if not complete will be dropped. for idx in range(self.consumed_samples, self.total_samples): batch.append(idx) if len(batch) == self.micro_batch_times_data_parallel_size: start_idx = self.data_parallel_rank * self.micro_batch_size end_idx = start_idx + self.micro_batch_size yield batch[start_idx:end_idx] batch = [] class MegatronPretrainingRandomSampler: def __init__(self, total_samples, consumed_samples, micro_batch_size, data_parallel_rank, data_parallel_size): # Keep a copy of input params for later use. self.total_samples = total_samples self.consumed_samples = consumed_samples self.micro_batch_size = micro_batch_size self.data_parallel_rank = data_parallel_rank self.data_parallel_size = data_parallel_size self.micro_batch_times_data_parallel_size = \ self.micro_batch_size * data_parallel_size # Sanity checks. assert self.total_samples > 0, \ 'no sample to consume: {}'.format(self.total_samples) assert self.micro_batch_size > 0 assert data_parallel_size > 0 assert self.data_parallel_rank < data_parallel_size, \ Loading @@ -76,23 +121,25 @@ class MegatronPretrainingSampler: def __iter__(self): self.epoch = self.consumed_samples // self.total_samples current_epoch_samples = self.consumed_samples % self.total_samples if self.random_sample: assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0 # data sharding and random sampling bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) \ * self.micro_batch_size bucket_offset = current_epoch_samples // self.data_parallel_size start_idx = self.data_parallel_rank * bucket_size g = torch.Generator() g.manual_seed(self.epoch) idx_range_total = \ torch.randperm(self.total_samples, generator=g).tolist() idx_range = idx_range_total[current_epoch_samples:] else: idx_range = range(current_epoch_samples, self.total_samples) random_idx = torch.randperm(bucket_size, generator=g).tolist() idx_range = [start_idx + x for x in random_idx[bucket_offset:]] batch = [] # Last batch if not complete will be dropped. for idx in idx_range: batch.append(idx) if len(batch) == self.micro_batch_times_data_parallel_size: self.consumed_samples += len(batch) start_idx = self.data_parallel_rank * self.micro_batch_size end_idx = start_idx + self.micro_batch_size yield batch[start_idx:end_idx] if len(batch) == self.micro_batch_size: self.consumed_samples += self.micro_batch_times_data_parallel_size yield batch batch = [] self.consumed_samples += len(batch) self.consumed_samples += self.total_samples % self.micro_batch_times_data_parallel_size megatron/data/vit_dataset.py +4 −1 Original line number Diff line number Diff line Loading @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os import torch from torchvision import datasets, transforms from megatron.data.autoaugment import ImageNetPolicy Loading @@ -32,7 +33,8 @@ def build_train_valid_datasets(data_path, crop_size=224, color_jitter=True): brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1 ) ] process += [ImageNetPolicy(), transforms.ToTensor(), normalize] fp16_t = transforms.ConvertImageDtype(torch.half) process += [ImageNetPolicy(), transforms.ToTensor(), normalize, fp16_t] transform_train = transforms.Compose(process) train_data = datasets.ImageFolder( root=train_data_path, transform=transform_train Loading @@ -46,6 +48,7 @@ def build_train_valid_datasets(data_path, crop_size=224, color_jitter=True): transforms.CenterCrop(crop_size), transforms.ToTensor(), normalize, fp16_t ] ) val_data = datasets.ImageFolder( Loading megatron/model/fused_softmax.py +2 −2 Original line number Diff line number Diff line Loading @@ -122,7 +122,7 @@ class FusedScaleMaskSoftmax(torch.nn.Module): assert input.dim() == 4 # invoke custom kernel if self.input_in_fp16 and key_seq_len <= 2048 and \ if self.input_in_fp16 and key_seq_len <= 2048 and mask is not None and \ query_seq_len % 4 == 0 and self.scaled_masked_softmax_fusion: scale = self.scale if self.scale is not None else 1.0 Loading @@ -142,7 +142,7 @@ class FusedScaleMaskSoftmax(torch.nn.Module): if self.scale is not None: input = input * self.scale mask_output = self.mask_func(input, mask) if mask else input mask_output = self.mask_func(input, mask) if mask is not None else input probs = torch.nn.Softmax(dim=-1)(mask_output) if self.input_in_fp16 and self.softmax_in_fp32: Loading megatron/model/vit_model.py +1 −1 Original line number Diff line number Diff line Loading @@ -120,7 +120,7 @@ def twod_interpolate_position_embeddings_hook( class VitModel(MegatronModule): """Bert Language model.""" """Vision Transformer Model.""" def __init__(self, num_classes, finetune=False): super(VitModel, self).__init__() Loading Loading
megatron/arguments.py +3 −1 Original line number Diff line number Diff line Loading @@ -362,7 +362,9 @@ def _add_training_args(parser): group.add_argument('--optimizer', type=str, default='adam', choices=['adam', 'sgd'], help='Optimizer function') group.add_argument('--dataloader_type', type=str, default='single', choices=['single', 'cyclic'], help='Single pass vs multiple pass data loader') return parser Loading
megatron/data/data_loaders.py→megatron/data/data_samplers.py +145 −0 Original line number Diff line number Diff line Loading @@ -17,12 +17,12 @@ import torch import random from megatron import get_args from megatron import mpu def build_pretraining_data_loader(dataset, consumed_samples, random_sample=False): def build_pretraining_data_loader(dataset, consumed_samples): """Buld dataloader given an input dataset.""" if dataset is None: Loading @@ -30,13 +30,23 @@ def build_pretraining_data_loader(dataset, consumed_samples, random_sample=False args = get_args() # Megatron sampler if args.dataloader_type == 'single': batch_sampler = MegatronPretrainingSampler( total_samples=len(dataset), consumed_samples=consumed_samples, micro_batch_size=args.micro_batch_size, data_parallel_rank=mpu.get_data_parallel_rank(), data_parallel_size=mpu.get_data_parallel_world_size(), random_sample=random_sample) data_parallel_size=mpu.get_data_parallel_world_size()) elif args.dataloader_type == 'cyclic': batch_sampler = MegatronPretrainingRandomSampler( total_samples=len(dataset), consumed_samples=consumed_samples, micro_batch_size=args.micro_batch_size, data_parallel_rank=mpu.get_data_parallel_rank(), data_parallel_size=mpu.get_data_parallel_world_size()) else: raise Exception('{} dataloader type is not supported.'.format( args.dataloader_type)) # Torch dataloader. return torch.utils.data.DataLoader(dataset, Loading @@ -44,11 +54,10 @@ def build_pretraining_data_loader(dataset, consumed_samples, random_sample=False num_workers=args.num_workers, pin_memory=True) class MegatronPretrainingSampler: def __init__(self, total_samples, consumed_samples, micro_batch_size, data_parallel_rank, data_parallel_size, random_sample=False): data_parallel_rank, data_parallel_size): # Keep a copy of input params for later use. self.total_samples = total_samples self.consumed_samples = consumed_samples Loading @@ -56,14 +65,50 @@ class MegatronPretrainingSampler: self.data_parallel_rank = data_parallel_rank self.micro_batch_times_data_parallel_size = \ self.micro_batch_size * data_parallel_size self.random_sample = random_sample # Sanity checks. assert self.total_samples > 0, \ 'no sample to consume: {}'.format(self.total_samples) #assert self.consumed_samples < self.total_samples, \ # 'no samples left to consume: {}, {}'.format(self.consumed_samples, # self.total_samples) assert self.consumed_samples < self.total_samples, \ 'no samples left to consume: {}, {}'.format(self.consumed_samples, self.total_samples) assert self.micro_batch_size > 0 assert data_parallel_size > 0 assert self.data_parallel_rank < data_parallel_size, \ 'data_parallel_rank should be smaller than data size: {}, ' \ '{}'.format(self.data_parallel_rank, data_parallel_size) def __len__(self): return self.total_samples def __iter__(self): batch = [] # Last batch if not complete will be dropped. for idx in range(self.consumed_samples, self.total_samples): batch.append(idx) if len(batch) == self.micro_batch_times_data_parallel_size: start_idx = self.data_parallel_rank * self.micro_batch_size end_idx = start_idx + self.micro_batch_size yield batch[start_idx:end_idx] batch = [] class MegatronPretrainingRandomSampler: def __init__(self, total_samples, consumed_samples, micro_batch_size, data_parallel_rank, data_parallel_size): # Keep a copy of input params for later use. self.total_samples = total_samples self.consumed_samples = consumed_samples self.micro_batch_size = micro_batch_size self.data_parallel_rank = data_parallel_rank self.data_parallel_size = data_parallel_size self.micro_batch_times_data_parallel_size = \ self.micro_batch_size * data_parallel_size # Sanity checks. assert self.total_samples > 0, \ 'no sample to consume: {}'.format(self.total_samples) assert self.micro_batch_size > 0 assert data_parallel_size > 0 assert self.data_parallel_rank < data_parallel_size, \ Loading @@ -76,23 +121,25 @@ class MegatronPretrainingSampler: def __iter__(self): self.epoch = self.consumed_samples // self.total_samples current_epoch_samples = self.consumed_samples % self.total_samples if self.random_sample: assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0 # data sharding and random sampling bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) \ * self.micro_batch_size bucket_offset = current_epoch_samples // self.data_parallel_size start_idx = self.data_parallel_rank * bucket_size g = torch.Generator() g.manual_seed(self.epoch) idx_range_total = \ torch.randperm(self.total_samples, generator=g).tolist() idx_range = idx_range_total[current_epoch_samples:] else: idx_range = range(current_epoch_samples, self.total_samples) random_idx = torch.randperm(bucket_size, generator=g).tolist() idx_range = [start_idx + x for x in random_idx[bucket_offset:]] batch = [] # Last batch if not complete will be dropped. for idx in idx_range: batch.append(idx) if len(batch) == self.micro_batch_times_data_parallel_size: self.consumed_samples += len(batch) start_idx = self.data_parallel_rank * self.micro_batch_size end_idx = start_idx + self.micro_batch_size yield batch[start_idx:end_idx] if len(batch) == self.micro_batch_size: self.consumed_samples += self.micro_batch_times_data_parallel_size yield batch batch = [] self.consumed_samples += len(batch) self.consumed_samples += self.total_samples % self.micro_batch_times_data_parallel_size
megatron/data/vit_dataset.py +4 −1 Original line number Diff line number Diff line Loading @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os import torch from torchvision import datasets, transforms from megatron.data.autoaugment import ImageNetPolicy Loading @@ -32,7 +33,8 @@ def build_train_valid_datasets(data_path, crop_size=224, color_jitter=True): brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1 ) ] process += [ImageNetPolicy(), transforms.ToTensor(), normalize] fp16_t = transforms.ConvertImageDtype(torch.half) process += [ImageNetPolicy(), transforms.ToTensor(), normalize, fp16_t] transform_train = transforms.Compose(process) train_data = datasets.ImageFolder( root=train_data_path, transform=transform_train Loading @@ -46,6 +48,7 @@ def build_train_valid_datasets(data_path, crop_size=224, color_jitter=True): transforms.CenterCrop(crop_size), transforms.ToTensor(), normalize, fp16_t ] ) val_data = datasets.ImageFolder( Loading
megatron/model/fused_softmax.py +2 −2 Original line number Diff line number Diff line Loading @@ -122,7 +122,7 @@ class FusedScaleMaskSoftmax(torch.nn.Module): assert input.dim() == 4 # invoke custom kernel if self.input_in_fp16 and key_seq_len <= 2048 and \ if self.input_in_fp16 and key_seq_len <= 2048 and mask is not None and \ query_seq_len % 4 == 0 and self.scaled_masked_softmax_fusion: scale = self.scale if self.scale is not None else 1.0 Loading @@ -142,7 +142,7 @@ class FusedScaleMaskSoftmax(torch.nn.Module): if self.scale is not None: input = input * self.scale mask_output = self.mask_func(input, mask) if mask else input mask_output = self.mask_func(input, mask) if mask is not None else input probs = torch.nn.Softmax(dim=-1)(mask_output) if self.input_in_fp16 and self.softmax_in_fp32: Loading
megatron/model/vit_model.py +1 −1 Original line number Diff line number Diff line Loading @@ -120,7 +120,7 @@ def twod_interpolate_position_embeddings_hook( class VitModel(MegatronModule): """Bert Language model.""" """Vision Transformer Model.""" def __init__(self, num_classes, finetune=False): super(VitModel, self).__init__() Loading