Commit 6df2bc1e authored by Wang, Xiao's avatar Wang, Xiao
Browse files

pretrain_crusher.slurm first push

parent 1064dc48
Loading
Loading
Loading
Loading
+3 −15
Original line number Diff line number Diff line
@@ -526,9 +526,6 @@ def _add_initialization_args(parser):
    group.add_argument('--seed', type=int, default=1234,
                       help='Random seed used for python, numpy, '
                       'pytorch, and cuda.')
    group.add_argument('--data-parallel-random-init', action='store_true',
                       help='Enable random initialization of params '
                       'across data parallel ranks')
    group.add_argument('--init-method-std', type=float, default=0.02,
                       help='Standard deviation of the zero mean normal '
                       'distribution used for weight initialization.')
@@ -663,7 +660,7 @@ def _add_distributed_args(parser):
    group.add_argument('--num-layers-per-virtual-pipeline-stage', type=int, default=None,
                       help='Number of layers per virtual pipeline stage')
    group.add_argument('--distributed-backend', default='nccl',
                       choices=['nccl', 'gloo', 'mpi'],
                       choices=['nccl', 'gloo'],
                       help='Which backend to use for distributed training.')
    group.add_argument('--DDP-impl', default='local',
                       choices=['local', 'torch'],
@@ -843,20 +840,11 @@ def _add_vit_args(parser):

    group.add_argument('--num-classes', type=int, default=1000,
                       help='num of classes in vision classificaiton task')
    group.add_argument('--img-h', type=int, default=224,
                       help='Image height for vision classification task')
    group.add_argument('--img-w', type=int, default=224,
                       help='Image height for vision classification task')
    group.add_argument('--img-dim', type=int, default=224,
                       help='Image size for vision classification task')
    group.add_argument('--num-channels', type=int, default=3,
                       help='Number of channels in input image data')
    group.add_argument('--patch-dim', type=int, default=16,
                       help='patch dimension used in vit')
    group.add_argument('--classes-fraction', type=float, default=1.0,
                       help='training with fraction of classes.')
    group.add_argument('--data-per-class-fraction', type=float, default=1.0,
                       help='training with fraction of data per class.')
    group.add_argument('--no-data-sharding', action='store_false',
                       help='Disable data sharding.',
                       dest='data_sharding')

    return parser
+16 −58
Original line number Diff line number Diff line
@@ -65,8 +65,6 @@ def check_checkpoint_args(checkpoint_args):
        _compare('make_vocab_size_divisible_by')
        _compare('padded_vocab_size')
        _compare('tokenizer_type')
    if args.data_parallel_random_init:
        _compare('data_parallel_random_init')
    if get_checkpoint_version() < 3.0:
        _compare('tensor_model_parallel_size',
                 old_arg_name='model_parallel_size')
@@ -74,6 +72,7 @@ def check_checkpoint_args(checkpoint_args):
        _compare('tensor_model_parallel_size')
        _compare('pipeline_model_parallel_size')


def ensure_directory_exists(filename):
    """Build filename's path if it does not already exists."""
    dirname = os.path.dirname(filename)
@@ -141,32 +140,6 @@ def read_metadata(tracker_filename):
    return max_iter, release


def get_rng_state():
    """ collect rng state across data parallel ranks """
    args = get_args()
    rng_state = {
        'random_rng_state': random.getstate(),
        'np_rng_state': np.random.get_state(),
        'torch_rng_state': torch.get_rng_state(),
        'cuda_rng_state': torch.cuda.get_rng_state(),
        'rng_tracker_states': mpu.get_cuda_rng_tracker().get_states()}

    rng_state_list = None
    if torch.distributed.is_initialized() and \
            mpu.get_data_parallel_world_size() > 1 and \
            args.data_parallel_random_init:
        rng_state_list = \
            [None for i in range(mpu.get_data_parallel_world_size())]
        torch.distributed.all_gather_object(
            rng_state_list,
            rng_state,
            group=mpu.get_data_parallel_group())
    else:
        rng_state_list = [rng_state]

    return rng_state_list


def save_checkpoint(iteration, model, optimizer, lr_scheduler):
    """Save a model checkpoint."""
    args = get_args()
@@ -177,9 +150,6 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
    print_rank_0('saving checkpoint at iteration {:7d} to {}'.format(
        iteration, args.save))

    # collect rng state across data parallel ranks
    rng_state = get_rng_state()

    if not torch.distributed.is_initialized() or mpu.get_data_parallel_rank() == 0:

        # Arguments, iteration, and model.
@@ -203,7 +173,12 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):

        # RNG states.
        if not args.no_save_rng:
            state_dict["rng_state"] = rng_state
            state_dict['random_rng_state'] = random.getstate()
            state_dict['np_rng_state'] = np.random.get_state()
            state_dict['torch_rng_state'] = torch.get_rng_state()
            state_dict['cuda_rng_state'] = torch.cuda.get_rng_state()
            state_dict['rng_tracker_states'] \
                = mpu.get_cuda_rng_tracker().get_states()

        # Save.
        checkpoint_name = get_checkpoint_name(args.save, iteration)
@@ -406,23 +381,6 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
    # rng states.
    if not release and not args.finetune and not args.no_load_rng:
        try:
            if 'rng_state' in state_dict:
                # access rng_state for data parallel rank
                if args.data_parallel_random_init:

                    rng_state = state_dict['rng_state'][mpu.get_data_parallel_rank()]
                else:
                    rng_state = state_dict['rng_state'][0]
                random.setstate(rng_state['random_rng_state'])
                np.random.set_state(rng_state['np_rng_state'])
                torch.set_rng_state(rng_state['torch_rng_state'])
                torch.cuda.set_rng_state(rng_state['cuda_rng_state'])
                # Check for empty states array
                if not rng_state['rng_tracker_states']:
                    raise KeyError
                mpu.get_cuda_rng_tracker().set_states(
                    rng_state['rng_tracker_states'])
            else:  # backward compatability
            random.setstate(state_dict['random_rng_state'])
            np.random.set_state(state_dict['np_rng_state'])
            torch.set_rng_state(state_dict['torch_rng_state'])
+13 −55
Original line number Diff line number Diff line
@@ -16,10 +16,8 @@
"""Dataloaders."""


import random
import torch
import numpy as np
from torch.utils.data import Dataset
import random
from megatron import get_args
from megatron import mpu

@@ -41,13 +39,11 @@ def build_pretraining_data_loader(dataset, consumed_samples):
            data_parallel_size=mpu.get_data_parallel_world_size())
    elif args.dataloader_type == 'cyclic':
        batch_sampler = MegatronPretrainingRandomSampler(
            dataset,
            total_samples=len(dataset),
            consumed_samples=consumed_samples,
            micro_batch_size=args.micro_batch_size,
            data_parallel_rank=mpu.get_data_parallel_rank(),
            data_parallel_size=mpu.get_data_parallel_world_size(),
            data_sharding=args.data_sharding)
            data_parallel_size=mpu.get_data_parallel_world_size())
    else:
        raise Exception('{} dataloader type is not supported.'.format(
                args.dataloader_type))
@@ -107,40 +103,16 @@ class MegatronPretrainingSampler:
            yield batch[start_idx:end_idx]


class RandomSeedDataset(Dataset):

    def __init__(self, dataset):
        args = get_args()
        self.base_seed = args.seed
        self.curr_seed = args.seed
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def set_epoch(self, epoch):
        self.curr_seed = self.base_seed + epoch

    def __getitem__(self, idx):
        seed = idx + self.curr_seed
        torch.manual_seed(seed)
        random.seed(seed)
        np.random.seed(seed)
        return self.dataset[idx]


class MegatronPretrainingRandomSampler:

    def __init__(self, dataset, total_samples, consumed_samples, micro_batch_size,
                 data_parallel_rank, data_parallel_size, data_sharding):
    def __init__(self, total_samples, consumed_samples, micro_batch_size,
                 data_parallel_rank, data_parallel_size):
        # Keep a copy of input params for later use.
        self.dataset = dataset
        self.total_samples = total_samples
        self.consumed_samples = consumed_samples
        self.micro_batch_size = micro_batch_size
        self.data_parallel_rank = data_parallel_rank
        self.data_parallel_size = data_parallel_size
        self.data_sharding = data_sharding
        self.micro_batch_times_data_parallel_size = \
            self.micro_batch_size * data_parallel_size
        self.last_batch_size = \
@@ -164,11 +136,7 @@ class MegatronPretrainingRandomSampler:
        current_epoch_samples = self.consumed_samples % active_total_samples
        assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0

        if isinstance(self.dataset, RandomSeedDataset):
            self.dataset.set_epoch(self.epoch)

        # data sharding and random sampling
        if self.data_sharding:
        bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) \
                       * self.micro_batch_size
        bucket_offset = current_epoch_samples // self.data_parallel_size
@@ -178,16 +146,6 @@ class MegatronPretrainingRandomSampler:
        g.manual_seed(self.epoch)
        random_idx = torch.randperm(bucket_size, generator=g).tolist()
        idx_range = [start_idx + x for x in random_idx[bucket_offset:]]
        else:
            full_bucket_size = (self.total_samples // self.micro_batch_size) \
                                * self.micro_batch_size
            full_bucket_offset = current_epoch_samples
            g = torch.Generator()
            g.manual_seed(self.epoch)
            idx_range_total = \
                torch.randperm(full_bucket_size, generator=g).tolist()
            idx_range_active = idx_range_total[full_bucket_offset:]
            idx_range = idx_range_active[self.data_parallel_rank::self.data_parallel_size]

        batch = []
        # Last batch if not complete will be dropped.
+196 KiB

File added.

No diff preview for this file type.

+31 −52
Original line number Diff line number Diff line
@@ -13,67 +13,46 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import random
import numpy as np
import torch
import torchvision.transforms as T
from torchvision import datasets
from megatron import get_args
from megatron.data.image_folder import ImageFolder
from torchvision import datasets, transforms
from megatron.data.autoaugment import ImageNetPolicy
from megatron.data.data_samplers import RandomSeedDataset

class ClassificationTransform():
    def __init__(self, image_size, train=True):
        args = get_args()
        assert args.fp16 or args.bf16
        self.data_type = torch.half if args.fp16 else torch.bfloat16
        if train:
            self.transform = T.Compose([
                T.RandomResizedCrop(image_size),
                T.RandomHorizontalFlip(),
                T.ColorJitter(0.4, 0.4, 0.4, 0.1),
                ImageNetPolicy(),
                T.ToTensor(),
                T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
                T.ConvertImageDtype(self.data_type)
            ])
        else:
            self.transform = T.Compose([
                T.Resize(image_size),
                T.CenterCrop(image_size),
                T.ToTensor(),
                T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
                T.ConvertImageDtype(self.data_type)
            ])

    def __call__(self, input):
        output = self.transform(input)
        return output



def build_train_valid_datasets(data_path, image_size=224):
    args = get_args()
    train_transform = ClassificationTransform(image_size)
    val_transform = ClassificationTransform(image_size, train=False)
def build_train_valid_datasets(data_path, crop_size=224, color_jitter=True):

    # training dataset
    train_data_path = data_path[0]
    train_data = ImageFolder(
        root=train_data_path,
        transform=train_transform,
        classes_fraction=args.classes_fraction,
        data_per_class_fraction=args.data_per_class_fraction
    train_data_path = os.path.join(data_path[0], "train")
    normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    process = [
        transforms.RandomResizedCrop(crop_size),
        transforms.RandomHorizontalFlip(),
    ]
    if color_jitter:
        process += [
            transforms.ColorJitter(
                brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1
            )
        ]
    fp16_t = transforms.ConvertImageDtype(torch.half)
    process += [ImageNetPolicy(), transforms.ToTensor(), normalize, fp16_t]
    transform_train = transforms.Compose(process)
    train_data = datasets.ImageFolder(
        root=train_data_path, transform=transform_train
    )
    train_data = RandomSeedDataset(train_data)

    # validation dataset
    val_data_path = data_path[1]
    val_data = ImageFolder(
        root=val_data_path,
        transform=val_transform
    val_data_path = os.path.join(data_path[0], "val")
    transform_val = transforms.Compose(
        [
            transforms.Resize(crop_size),
            transforms.CenterCrop(crop_size),
            transforms.ToTensor(),
            normalize,
            fp16_t
        ]
    )
    val_data = datasets.ImageFolder(
        root=val_data_path, transform=transform_val
    )
    val_data = RandomSeedDataset(val_data)

    return train_data, val_data
Loading