Loading megatron/arguments.py +11 −2 Original line number Diff line number Diff line Loading @@ -835,11 +835,20 @@ def _add_vit_args(parser): group.add_argument('--num-classes', type=int, default=1000, help='num of classes in vision classificaiton task') group.add_argument('--img-dim', type=int, default=224, help='Image size for vision classification task') group.add_argument('--img-h', type=int, default=224, help='Image height for vision classification task') group.add_argument('--img-w', type=int, default=224, help='Image height for vision classification task') group.add_argument('--num-channels', type=int, default=3, help='Number of channels in input image data') group.add_argument('--patch-dim', type=int, default=16, help='patch dimension used in vit') group.add_argument('--classes-fraction', type=float, default=1.0, help='training with fraction of classes.') group.add_argument('--data-per-class-fraction', type=float, default=1.0, help='training with fraction of data per class.') group.add_argument('--no-data-sharding', action='store_false', help='Disable data sharding.', dest='data_sharding') return parser megatron/data/data_samplers.py +55 −13 Original line number Diff line number Diff line Loading @@ -16,8 +16,10 @@ """Dataloaders.""" import torch import random import torch import numpy as np from torch.utils.data import Dataset from megatron import get_args from megatron import mpu Loading @@ -39,11 +41,13 @@ def build_pretraining_data_loader(dataset, consumed_samples): data_parallel_size=mpu.get_data_parallel_world_size()) elif args.dataloader_type == 'cyclic': batch_sampler = MegatronPretrainingRandomSampler( dataset, total_samples=len(dataset), consumed_samples=consumed_samples, micro_batch_size=args.micro_batch_size, data_parallel_rank=mpu.get_data_parallel_rank(), data_parallel_size=mpu.get_data_parallel_world_size()) data_parallel_size=mpu.get_data_parallel_world_size(), data_sharding=args.data_sharding) else: raise Exception('{} dataloader type is not supported.'.format( args.dataloader_type)) Loading Loading @@ -103,16 +107,40 @@ class MegatronPretrainingSampler: yield batch[start_idx:end_idx] class RandomSeedDataset(Dataset): def __init__(self, dataset): args = get_args() self.base_seed = args.seed self.curr_seed = args.seed self.dataset = dataset def __len__(self): return len(self.dataset) def set_epoch(self, epoch): self.curr_seed = self.base_seed + epoch def __getitem__(self, idx): seed = idx + self.curr_seed torch.manual_seed(seed) random.seed(seed) np.random.seed(seed) return self.dataset[idx] class MegatronPretrainingRandomSampler: def __init__(self, total_samples, consumed_samples, micro_batch_size, data_parallel_rank, data_parallel_size): def __init__(self, dataset, total_samples, consumed_samples, micro_batch_size, data_parallel_rank, data_parallel_size, data_sharding): # Keep a copy of input params for later use. self.dataset = dataset self.total_samples = total_samples self.consumed_samples = consumed_samples self.micro_batch_size = micro_batch_size self.data_parallel_rank = data_parallel_rank self.data_parallel_size = data_parallel_size self.data_sharding = data_sharding self.micro_batch_times_data_parallel_size = \ self.micro_batch_size * data_parallel_size self.last_batch_size = \ Loading @@ -136,7 +164,11 @@ class MegatronPretrainingRandomSampler: current_epoch_samples = self.consumed_samples % active_total_samples assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0 if isinstance(dataset, RandomSeedDataset): self.dataset.set_epoch(self.epoch) # data sharding and random sampling if self.data_sharding: bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) \ * self.micro_batch_size bucket_offset = current_epoch_samples // self.data_parallel_size Loading @@ -146,6 +178,16 @@ class MegatronPretrainingRandomSampler: g.manual_seed(self.epoch) random_idx = torch.randperm(bucket_size, generator=g).tolist() idx_range = [start_idx + x for x in random_idx[bucket_offset:]] else: full_bucket_size = (self.total_samples // self.micro_batch_size) \ * self.micro_batch_size full_bucket_offset = current_epoch_samples g = torch.Generator() g.manual_seed(self.epoch) idx_range_total = \ torch.randperm(full_bucket_size, generator=g).tolist() idx_range_active = idx_range_total[full_bucket_offset:] idx_range = idx_range_active[self.data_parallel_rank::self.data_parallel_size] batch = [] # Last batch if not complete will be dropped. Loading megatron/data/image_folder.py 0 → 100644 +271 −0 Original line number Diff line number Diff line # code taken from pytorch # added support for classes_fraction and data_per_class_fraction from torchvision.datasets import VisionDataset from PIL import Image import os import os.path from typing import Any, Callable, cast, Dict, List, Optional, Tuple import numpy as np def has_file_allowed_extension(filename: str, extensions: Tuple[str, ...]) -> bool: """Checks if a file is an allowed extension. Args: filename (string): path to a file extensions (tuple of strings): extensions to consider (lowercase) Returns: bool: True if the filename ends with one of given extensions """ return filename.lower().endswith(extensions) def is_image_file(filename: str) -> bool: """Checks if a file is an allowed image extension. Args: filename (string): path to a file Returns: bool: True if the filename ends with a known image extension """ return has_file_allowed_extension(filename, IMG_EXTENSIONS) def make_dataset( directory: str, class_to_idx: Dict[str, int], data_per_class_fraction: float, extensions: Optional[Tuple[str, ...]] = None, is_valid_file: Optional[Callable[[str], bool]] = None, ) -> List[Tuple[str, int]]: """Generates a list of samples of a form (path_to_sample, class). Args: directory (str): root dataset directory class_to_idx (Dict[str, int]): dictionary mapping class name to class index extensions (optional): A list of allowed extensions. Either extensions or is_valid_file should be passed. Defaults to None. is_valid_file (optional): A function that takes path of a file and checks if the file is a valid file (used to check of corrupt files) both extensions and is_valid_file should not be passed. Defaults to None. Raises: ValueError: In case ``extensions`` and ``is_valid_file`` are None or both are not None. Returns: List[Tuple[str, int]]: samples of a form (path_to_sample, class) """ instances = [] directory = os.path.expanduser(directory) both_none = extensions is None and is_valid_file is None both_something = extensions is not None and is_valid_file is not None if both_none or both_something: raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time") if extensions is not None: def is_valid_file(x: str) -> bool: return has_file_allowed_extension(x, cast(Tuple[str, ...], extensions)) is_valid_file = cast(Callable[[str], bool], is_valid_file) for target_class in sorted(class_to_idx.keys()): class_index = class_to_idx[target_class] target_dir = os.path.join(directory, target_class) if not os.path.isdir(target_dir): continue local_instances = [] for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)): for fname in sorted(fnames): path = os.path.join(root, fname) if is_valid_file(path): item = path, class_index local_instances.append(item) instances.extend(local_instances[0:int(len(local_instances) * data_per_class_fraction)]) return instances class DatasetFolder(VisionDataset): """A generic data loader where the samples are arranged in this way: :: root/class_x/xxx.ext root/class_x/xxy.ext root/class_x/[...]/xxz.ext root/class_y/123.ext root/class_y/nsdf3.ext root/class_y/[...]/asd932_.ext Args: root (string): Root directory path. loader (callable): A function to load a sample given its path. extensions (tuple[string]): A list of allowed extensions. both extensions and is_valid_file should not be passed. transform (callable, optional): A function/transform that takes in a sample and returns a transformed version. E.g, ``transforms.RandomCrop`` for images. target_transform (callable, optional): A function/transform that takes in the target and transforms it. is_valid_file (callable, optional): A function that takes path of a file and check if the file is a valid file (used to check of corrupt files) both extensions and is_valid_file should not be passed. Attributes: classes (list): List of the class names sorted alphabetically. class_to_idx (dict): Dict with items (class_name, class_index). samples (list): List of (sample path, class_index) tuples targets (list): The class_index value for each image in the dataset """ def __init__( self, root: str, loader: Callable[[str], Any], extensions: Optional[Tuple[str, ...]] = None, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, classes_fraction=1.0, data_per_class_fraction=1.0, is_valid_file: Optional[Callable[[str], bool]] = None, ) -> None: super(DatasetFolder, self).__init__(root, transform=transform, target_transform=target_transform) self.classes_fraction = classes_fraction self.data_per_class_fraction = data_per_class_fraction classes, class_to_idx = self._find_classes(self.root) samples = self.make_dataset(self.root, class_to_idx, self.data_per_class_fraction, extensions, is_valid_file) if len(samples) == 0: msg = "Found 0 files in subfolders of: {}\n".format(self.root) if extensions is not None: msg += "Supported extensions are: {}".format(",".join(extensions)) raise RuntimeError(msg) self.loader = loader self.extensions = extensions self.total = len(samples) self.classes = classes self.class_to_idx = class_to_idx self.samples = samples self.targets = [s[1] for s in samples] @staticmethod def make_dataset( directory: str, class_to_idx: Dict[str, int], data_per_class_fraction: float, extensions: Optional[Tuple[str, ...]] = None, is_valid_file: Optional[Callable[[str], bool]] = None, ) -> List[Tuple[str, int]]: return make_dataset(directory, class_to_idx, data_per_class_fraction, extensions=extensions, is_valid_file=is_valid_file) def _find_classes(self, dir: str) -> Tuple[List[str], Dict[str, int]]: """ Finds the class folders in a dataset. Args: dir (string): Root directory path. Returns: tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary. Ensures: No class is a subdirectory of another. """ all_classes = [d.name for d in os.scandir(dir) if d.is_dir()] classes = all_classes[0:int(len(all_classes) * self.classes_fraction)] classes.sort() class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)} return classes, class_to_idx def __getitem__(self, index: int) -> Tuple[Any, Any]: """ Args: index (int): Index Returns: tuple: (sample, target) where target is class_index of the target class. """ curr_index = index for x in range(self.total): try: path, target = self.samples[curr_index] sample = self.loader(path) break except Exception as e: curr_index = np.random.randint(0, self.total) if self.transform is not None: sample = self.transform(sample) if self.target_transform is not None: target = self.target_transform(target) return sample, target def __len__(self) -> int: return len(self.samples) IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp') def pil_loader(path: str) -> Image.Image: # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) with open(path, 'rb') as f: img = Image.open(f) return img.convert('RGB') # TODO: specify the return type def accimage_loader(path: str) -> Any: import accimage try: return accimage.Image(path) except IOError: # Potentially a decoding problem, fall back to PIL.Image return pil_loader(path) def default_loader(path: str) -> Any: from torchvision import get_image_backend if get_image_backend() == 'accimage': return accimage_loader(path) else: return pil_loader(path) class ImageFolder(DatasetFolder): """A generic data loader where the images are arranged in this way: :: root/dog/xxx.png root/dog/xxy.png root/dog/[...]/xxz.png root/cat/123.png root/cat/nsdf3.png root/cat/[...]/asd932_.png Args: root (string): Root directory path. transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. loader (callable, optional): A function to load an image given its path. is_valid_file (callable, optional): A function that takes path of an Image file and check if the file is a valid file (used to check of corrupt files) Attributes: classes (list): List of the class names sorted alphabetically. class_to_idx (dict): Dict with items (class_name, class_index). imgs (list): List of (image path, class_index) tuples """ def __init__( self, root: str, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, classes_fraction=1.0, data_per_class_fraction=1.0, loader: Callable[[str], Any] = default_loader, is_valid_file: Optional[Callable[[str], bool]] = None, ): super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS if is_valid_file is None else None, transform=transform, target_transform=target_transform, classes_fraction=classes_fraction, data_per_class_fraction=data_per_class_fraction, is_valid_file=is_valid_file) self.imgs = self.samples megatron/data/vit_dataset.py +52 −31 Original line number Diff line number Diff line Loading @@ -13,46 +13,67 @@ # See the License for the specific language governing permissions and # limitations under the License. import os import random import numpy as np import torch from torchvision import datasets, transforms import torchvision.transforms as T from torchvision import datasets from megatron import get_args from megatron.data.image_folder import ImageFolder from megatron.data.autoaugment import ImageNetPolicy from megatron.data.data_samplers import RandomSeedDataset class ClassificationTransform(): def __init__(self, image_size, train=True): args = get_args() assert args.fp16 or args.bf16 self.data_type = torch.half if args.fp16 else torch.bfloat16 if train: self.transform = T.Compose([ T.RandomResizedCrop(image_size), T.RandomHorizontalFlip(), T.ColorJitter(0.4, 0.4, 0.4, 0.1), ImageNetPolicy(), T.ToTensor(), T.Normalize(*self.mean_std), T.ConvertImageDtype(self.data_type) ]) else: self.transform = T.Compose([ T.Resize(image_size), T.CenterCrop(image_size), T.ToTensor(), T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), T.ConvertImageDtype(self.data_type) ]) def build_train_valid_datasets(data_path, crop_size=224, color_jitter=True): def __call__(self, input): output = self.transform(input) return output def build_train_valid_datasets(data_path, image_size=224): args = get_args() train_transform = ClassificationTransform(image_size) val_transform = ClassificationTransform(image_size, train=False) # training dataset train_data_path = os.path.join(data_path[0], "train") normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) process = [ transforms.RandomResizedCrop(crop_size), transforms.RandomHorizontalFlip(), ] if color_jitter: process += [ transforms.ColorJitter( brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1 ) ] fp16_t = transforms.ConvertImageDtype(torch.half) process += [ImageNetPolicy(), transforms.ToTensor(), normalize, fp16_t] transform_train = transforms.Compose(process) train_data = datasets.ImageFolder( root=train_data_path, transform=transform_train train_data_path = data_path[0] train_data = ImageFolder( root=train_data_path, transform=train_transform, classes_fraction=args.classes_fraction, data_per_class_fraction=args.data_per_class_fraction ) train_data = RandomSeedDataset(train_data) # validation dataset val_data_path = os.path.join(data_path[0], "val") transform_val = transforms.Compose( [ transforms.Resize(crop_size), transforms.CenterCrop(crop_size), transforms.ToTensor(), normalize, fp16_t ] ) val_data = datasets.ImageFolder( root=val_data_path, transform=transform_val val_data_path = data_path[1] val_data = ImageFolder( root=val_data_path, transform=val_transform ) val_data = RandomSeedDataset(val_data) return train_data, val_data megatron/fused_kernels/__init__.py +6 −0 Original line number Diff line number Diff line Loading @@ -78,6 +78,12 @@ def load(args): scaled_masked_softmax_cuda = _cpp_extention_load_helper( "scaled_masked_softmax_cuda", sources, extra_cuda_flags) # Softmax sources=[srcpath / 'scaled_softmax.cpp', srcpath / 'scaled_softmax_cuda.cu'] scaled_softmax_cuda = _cpp_extention_load_helper( "scaled_softmax_cuda", sources, extra_cuda_flags) # ================================= # Mixed precision fused layer norm. # ================================= Loading Loading
megatron/arguments.py +11 −2 Original line number Diff line number Diff line Loading @@ -835,11 +835,20 @@ def _add_vit_args(parser): group.add_argument('--num-classes', type=int, default=1000, help='num of classes in vision classificaiton task') group.add_argument('--img-dim', type=int, default=224, help='Image size for vision classification task') group.add_argument('--img-h', type=int, default=224, help='Image height for vision classification task') group.add_argument('--img-w', type=int, default=224, help='Image height for vision classification task') group.add_argument('--num-channels', type=int, default=3, help='Number of channels in input image data') group.add_argument('--patch-dim', type=int, default=16, help='patch dimension used in vit') group.add_argument('--classes-fraction', type=float, default=1.0, help='training with fraction of classes.') group.add_argument('--data-per-class-fraction', type=float, default=1.0, help='training with fraction of data per class.') group.add_argument('--no-data-sharding', action='store_false', help='Disable data sharding.', dest='data_sharding') return parser
megatron/data/data_samplers.py +55 −13 Original line number Diff line number Diff line Loading @@ -16,8 +16,10 @@ """Dataloaders.""" import torch import random import torch import numpy as np from torch.utils.data import Dataset from megatron import get_args from megatron import mpu Loading @@ -39,11 +41,13 @@ def build_pretraining_data_loader(dataset, consumed_samples): data_parallel_size=mpu.get_data_parallel_world_size()) elif args.dataloader_type == 'cyclic': batch_sampler = MegatronPretrainingRandomSampler( dataset, total_samples=len(dataset), consumed_samples=consumed_samples, micro_batch_size=args.micro_batch_size, data_parallel_rank=mpu.get_data_parallel_rank(), data_parallel_size=mpu.get_data_parallel_world_size()) data_parallel_size=mpu.get_data_parallel_world_size(), data_sharding=args.data_sharding) else: raise Exception('{} dataloader type is not supported.'.format( args.dataloader_type)) Loading Loading @@ -103,16 +107,40 @@ class MegatronPretrainingSampler: yield batch[start_idx:end_idx] class RandomSeedDataset(Dataset): def __init__(self, dataset): args = get_args() self.base_seed = args.seed self.curr_seed = args.seed self.dataset = dataset def __len__(self): return len(self.dataset) def set_epoch(self, epoch): self.curr_seed = self.base_seed + epoch def __getitem__(self, idx): seed = idx + self.curr_seed torch.manual_seed(seed) random.seed(seed) np.random.seed(seed) return self.dataset[idx] class MegatronPretrainingRandomSampler: def __init__(self, total_samples, consumed_samples, micro_batch_size, data_parallel_rank, data_parallel_size): def __init__(self, dataset, total_samples, consumed_samples, micro_batch_size, data_parallel_rank, data_parallel_size, data_sharding): # Keep a copy of input params for later use. self.dataset = dataset self.total_samples = total_samples self.consumed_samples = consumed_samples self.micro_batch_size = micro_batch_size self.data_parallel_rank = data_parallel_rank self.data_parallel_size = data_parallel_size self.data_sharding = data_sharding self.micro_batch_times_data_parallel_size = \ self.micro_batch_size * data_parallel_size self.last_batch_size = \ Loading @@ -136,7 +164,11 @@ class MegatronPretrainingRandomSampler: current_epoch_samples = self.consumed_samples % active_total_samples assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0 if isinstance(dataset, RandomSeedDataset): self.dataset.set_epoch(self.epoch) # data sharding and random sampling if self.data_sharding: bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) \ * self.micro_batch_size bucket_offset = current_epoch_samples // self.data_parallel_size Loading @@ -146,6 +178,16 @@ class MegatronPretrainingRandomSampler: g.manual_seed(self.epoch) random_idx = torch.randperm(bucket_size, generator=g).tolist() idx_range = [start_idx + x for x in random_idx[bucket_offset:]] else: full_bucket_size = (self.total_samples // self.micro_batch_size) \ * self.micro_batch_size full_bucket_offset = current_epoch_samples g = torch.Generator() g.manual_seed(self.epoch) idx_range_total = \ torch.randperm(full_bucket_size, generator=g).tolist() idx_range_active = idx_range_total[full_bucket_offset:] idx_range = idx_range_active[self.data_parallel_rank::self.data_parallel_size] batch = [] # Last batch if not complete will be dropped. Loading
megatron/data/image_folder.py 0 → 100644 +271 −0 Original line number Diff line number Diff line # code taken from pytorch # added support for classes_fraction and data_per_class_fraction from torchvision.datasets import VisionDataset from PIL import Image import os import os.path from typing import Any, Callable, cast, Dict, List, Optional, Tuple import numpy as np def has_file_allowed_extension(filename: str, extensions: Tuple[str, ...]) -> bool: """Checks if a file is an allowed extension. Args: filename (string): path to a file extensions (tuple of strings): extensions to consider (lowercase) Returns: bool: True if the filename ends with one of given extensions """ return filename.lower().endswith(extensions) def is_image_file(filename: str) -> bool: """Checks if a file is an allowed image extension. Args: filename (string): path to a file Returns: bool: True if the filename ends with a known image extension """ return has_file_allowed_extension(filename, IMG_EXTENSIONS) def make_dataset( directory: str, class_to_idx: Dict[str, int], data_per_class_fraction: float, extensions: Optional[Tuple[str, ...]] = None, is_valid_file: Optional[Callable[[str], bool]] = None, ) -> List[Tuple[str, int]]: """Generates a list of samples of a form (path_to_sample, class). Args: directory (str): root dataset directory class_to_idx (Dict[str, int]): dictionary mapping class name to class index extensions (optional): A list of allowed extensions. Either extensions or is_valid_file should be passed. Defaults to None. is_valid_file (optional): A function that takes path of a file and checks if the file is a valid file (used to check of corrupt files) both extensions and is_valid_file should not be passed. Defaults to None. Raises: ValueError: In case ``extensions`` and ``is_valid_file`` are None or both are not None. Returns: List[Tuple[str, int]]: samples of a form (path_to_sample, class) """ instances = [] directory = os.path.expanduser(directory) both_none = extensions is None and is_valid_file is None both_something = extensions is not None and is_valid_file is not None if both_none or both_something: raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time") if extensions is not None: def is_valid_file(x: str) -> bool: return has_file_allowed_extension(x, cast(Tuple[str, ...], extensions)) is_valid_file = cast(Callable[[str], bool], is_valid_file) for target_class in sorted(class_to_idx.keys()): class_index = class_to_idx[target_class] target_dir = os.path.join(directory, target_class) if not os.path.isdir(target_dir): continue local_instances = [] for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)): for fname in sorted(fnames): path = os.path.join(root, fname) if is_valid_file(path): item = path, class_index local_instances.append(item) instances.extend(local_instances[0:int(len(local_instances) * data_per_class_fraction)]) return instances class DatasetFolder(VisionDataset): """A generic data loader where the samples are arranged in this way: :: root/class_x/xxx.ext root/class_x/xxy.ext root/class_x/[...]/xxz.ext root/class_y/123.ext root/class_y/nsdf3.ext root/class_y/[...]/asd932_.ext Args: root (string): Root directory path. loader (callable): A function to load a sample given its path. extensions (tuple[string]): A list of allowed extensions. both extensions and is_valid_file should not be passed. transform (callable, optional): A function/transform that takes in a sample and returns a transformed version. E.g, ``transforms.RandomCrop`` for images. target_transform (callable, optional): A function/transform that takes in the target and transforms it. is_valid_file (callable, optional): A function that takes path of a file and check if the file is a valid file (used to check of corrupt files) both extensions and is_valid_file should not be passed. Attributes: classes (list): List of the class names sorted alphabetically. class_to_idx (dict): Dict with items (class_name, class_index). samples (list): List of (sample path, class_index) tuples targets (list): The class_index value for each image in the dataset """ def __init__( self, root: str, loader: Callable[[str], Any], extensions: Optional[Tuple[str, ...]] = None, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, classes_fraction=1.0, data_per_class_fraction=1.0, is_valid_file: Optional[Callable[[str], bool]] = None, ) -> None: super(DatasetFolder, self).__init__(root, transform=transform, target_transform=target_transform) self.classes_fraction = classes_fraction self.data_per_class_fraction = data_per_class_fraction classes, class_to_idx = self._find_classes(self.root) samples = self.make_dataset(self.root, class_to_idx, self.data_per_class_fraction, extensions, is_valid_file) if len(samples) == 0: msg = "Found 0 files in subfolders of: {}\n".format(self.root) if extensions is not None: msg += "Supported extensions are: {}".format(",".join(extensions)) raise RuntimeError(msg) self.loader = loader self.extensions = extensions self.total = len(samples) self.classes = classes self.class_to_idx = class_to_idx self.samples = samples self.targets = [s[1] for s in samples] @staticmethod def make_dataset( directory: str, class_to_idx: Dict[str, int], data_per_class_fraction: float, extensions: Optional[Tuple[str, ...]] = None, is_valid_file: Optional[Callable[[str], bool]] = None, ) -> List[Tuple[str, int]]: return make_dataset(directory, class_to_idx, data_per_class_fraction, extensions=extensions, is_valid_file=is_valid_file) def _find_classes(self, dir: str) -> Tuple[List[str], Dict[str, int]]: """ Finds the class folders in a dataset. Args: dir (string): Root directory path. Returns: tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary. Ensures: No class is a subdirectory of another. """ all_classes = [d.name for d in os.scandir(dir) if d.is_dir()] classes = all_classes[0:int(len(all_classes) * self.classes_fraction)] classes.sort() class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)} return classes, class_to_idx def __getitem__(self, index: int) -> Tuple[Any, Any]: """ Args: index (int): Index Returns: tuple: (sample, target) where target is class_index of the target class. """ curr_index = index for x in range(self.total): try: path, target = self.samples[curr_index] sample = self.loader(path) break except Exception as e: curr_index = np.random.randint(0, self.total) if self.transform is not None: sample = self.transform(sample) if self.target_transform is not None: target = self.target_transform(target) return sample, target def __len__(self) -> int: return len(self.samples) IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp') def pil_loader(path: str) -> Image.Image: # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) with open(path, 'rb') as f: img = Image.open(f) return img.convert('RGB') # TODO: specify the return type def accimage_loader(path: str) -> Any: import accimage try: return accimage.Image(path) except IOError: # Potentially a decoding problem, fall back to PIL.Image return pil_loader(path) def default_loader(path: str) -> Any: from torchvision import get_image_backend if get_image_backend() == 'accimage': return accimage_loader(path) else: return pil_loader(path) class ImageFolder(DatasetFolder): """A generic data loader where the images are arranged in this way: :: root/dog/xxx.png root/dog/xxy.png root/dog/[...]/xxz.png root/cat/123.png root/cat/nsdf3.png root/cat/[...]/asd932_.png Args: root (string): Root directory path. transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. loader (callable, optional): A function to load an image given its path. is_valid_file (callable, optional): A function that takes path of an Image file and check if the file is a valid file (used to check of corrupt files) Attributes: classes (list): List of the class names sorted alphabetically. class_to_idx (dict): Dict with items (class_name, class_index). imgs (list): List of (image path, class_index) tuples """ def __init__( self, root: str, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, classes_fraction=1.0, data_per_class_fraction=1.0, loader: Callable[[str], Any] = default_loader, is_valid_file: Optional[Callable[[str], bool]] = None, ): super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS if is_valid_file is None else None, transform=transform, target_transform=target_transform, classes_fraction=classes_fraction, data_per_class_fraction=data_per_class_fraction, is_valid_file=is_valid_file) self.imgs = self.samples
megatron/data/vit_dataset.py +52 −31 Original line number Diff line number Diff line Loading @@ -13,46 +13,67 @@ # See the License for the specific language governing permissions and # limitations under the License. import os import random import numpy as np import torch from torchvision import datasets, transforms import torchvision.transforms as T from torchvision import datasets from megatron import get_args from megatron.data.image_folder import ImageFolder from megatron.data.autoaugment import ImageNetPolicy from megatron.data.data_samplers import RandomSeedDataset class ClassificationTransform(): def __init__(self, image_size, train=True): args = get_args() assert args.fp16 or args.bf16 self.data_type = torch.half if args.fp16 else torch.bfloat16 if train: self.transform = T.Compose([ T.RandomResizedCrop(image_size), T.RandomHorizontalFlip(), T.ColorJitter(0.4, 0.4, 0.4, 0.1), ImageNetPolicy(), T.ToTensor(), T.Normalize(*self.mean_std), T.ConvertImageDtype(self.data_type) ]) else: self.transform = T.Compose([ T.Resize(image_size), T.CenterCrop(image_size), T.ToTensor(), T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), T.ConvertImageDtype(self.data_type) ]) def build_train_valid_datasets(data_path, crop_size=224, color_jitter=True): def __call__(self, input): output = self.transform(input) return output def build_train_valid_datasets(data_path, image_size=224): args = get_args() train_transform = ClassificationTransform(image_size) val_transform = ClassificationTransform(image_size, train=False) # training dataset train_data_path = os.path.join(data_path[0], "train") normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) process = [ transforms.RandomResizedCrop(crop_size), transforms.RandomHorizontalFlip(), ] if color_jitter: process += [ transforms.ColorJitter( brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1 ) ] fp16_t = transforms.ConvertImageDtype(torch.half) process += [ImageNetPolicy(), transforms.ToTensor(), normalize, fp16_t] transform_train = transforms.Compose(process) train_data = datasets.ImageFolder( root=train_data_path, transform=transform_train train_data_path = data_path[0] train_data = ImageFolder( root=train_data_path, transform=train_transform, classes_fraction=args.classes_fraction, data_per_class_fraction=args.data_per_class_fraction ) train_data = RandomSeedDataset(train_data) # validation dataset val_data_path = os.path.join(data_path[0], "val") transform_val = transforms.Compose( [ transforms.Resize(crop_size), transforms.CenterCrop(crop_size), transforms.ToTensor(), normalize, fp16_t ] ) val_data = datasets.ImageFolder( root=val_data_path, transform=transform_val val_data_path = data_path[1] val_data = ImageFolder( root=val_data_path, transform=val_transform ) val_data = RandomSeedDataset(val_data) return train_data, val_data
megatron/fused_kernels/__init__.py +6 −0 Original line number Diff line number Diff line Loading @@ -78,6 +78,12 @@ def load(args): scaled_masked_softmax_cuda = _cpp_extention_load_helper( "scaled_masked_softmax_cuda", sources, extra_cuda_flags) # Softmax sources=[srcpath / 'scaled_softmax.cpp', srcpath / 'scaled_softmax_cuda.cu'] scaled_softmax_cuda = _cpp_extention_load_helper( "scaled_softmax_cuda", sources, extra_cuda_flags) # ================================= # Mixed precision fused layer norm. # ================================= Loading