Commit 9a0808c9 authored by mohammad's avatar mohammad
Browse files

Merge branch 'megatron_sampler' into blendable_dataset

parents d3bb1a06 f0a445fa
Loading
Loading
Loading
Loading
+3 −0
Original line number Diff line number Diff line
@@ -72,6 +72,9 @@ def parse_args(extra_args_provider=None, defaults={},
        print('using {} for parameters ...'.format(args.params_dtype),
              flush=True)

    # Consumed tokens.
    args.consumed_train_samples = 0
    args.consumed_valid_samples = 0

    # Set input defaults.
    for key in defaults:
+13 −2
Original line number Diff line number Diff line
@@ -89,7 +89,8 @@ def get_checkpoint_tracker_filename(checkpoints_path):
    return os.path.join(checkpoints_path, 'latest_checkpointed_iteration.txt')


def save_checkpoint(iteration, model, optimizer, lr_scheduler):
def save_checkpoint(iteration, model, optimizer, lr_scheduler,
                    consumed_train_samples=None, consumed_valid_samples=None):
    """Save a model checkpoint."""
    args = get_args()

@@ -103,6 +104,10 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
        state_dict['args'] = args
        state_dict['checkpoint_version'] = 2.0
        state_dict['iteration'] = iteration
        if consumed_train_samples:
            state_dict['consumed_train_samples'] = consumed_train_samples
        if consumed_valid_samples:
            state_dict['consumed_valid_samples'] = consumed_valid_samples
        state_dict['model'] = model.state_dict_for_save_checkpoint()

        # Optimizer stuff.
@@ -214,6 +219,12 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'):
                                 checkpoint_name))
                sys.exit()

    if 'consumed_train_samples' in state_dict:
        assert args.consumed_train_samples == 0
        args.consumed_train_samples = state_dict['consumed_train_samples']
    if 'consumed_valid_samples' in state_dict:
        assert args.consumed_valid_samples == 0
        args.consumed_valid_samples = state_dict['consumed_valid_samples']

    # Check arguments.
    if 'args' in state_dict:
+95 −0
Original line number Diff line number Diff line
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Dataloaders."""


import torch

from megatron import get_args
from megatron import mpu


def build_pretraining_data_loader(dataset, consumed_samples):
    """Buld dataloader given an input dataset."""

    if dataset is None:
        return None
    args = get_args()

    world_size = mpu.get_data_parallel_world_size()
    global_batch_size = args.batch_size * world_size

    # Megatron sampler
    batch_sampler = MegatronPretrainingSampler(
        total_samples=len(dataset),
        consumed_samples=consumed_samples,
        global_batch_size=global_batch_size,
        rank=mpu.get_data_parallel_rank(),
        world_size=world_size)

    # Torch dataloader.
    return torch.utils.data.DataLoader(dataset,
                                       batch_sampler=batch_sampler,
                                       num_workers=args.num_workers,
                                       pin_memory=True)


class MegatronPretrainingSampler:


    def __init__(self, total_samples, consumed_samples,
                 global_batch_size, rank, world_size):
        # Keep a copy of input params for later use.
        self.total_samples = total_samples
        self.consumed_samples = consumed_samples
        self.global_batch_size = global_batch_size
        self.rank = rank

        # Sanity checks.
        assert self.total_samples > 0, \
            'no sample to consume: {}'.format(self.total_samples)
        assert self.consumed_samples < self.total_samples, \
            'no samples left to consume: {}, {}'.format(self.consumed_samples,
                                                        self.total_samples)
        assert self.global_batch_size > 0, \
            'Unexpected global batch size: {}'.format(self.global_batch_size)
        assert world_size > 0,\
            'non zero world size is expected: {}'.format(world_size)
        assert self.rank < world_size,\
            'rank should be smaller than world size: {}, {}'.format(
                self.rank, world_size)

        # Batch size per rank.
        assert self.global_batch_size % world_size == 0,\
            'global batch size must be divisible by world size: {}, {}'.format(
                self.global_batch_size, world_size)
        self.batch_size_per_rank = self.global_batch_size // world_size


    def __len__(self):
        return self.total_samples


    def __iter__(self):
        batch = []
        # Last batch if not complete will be dropped.
        for idx in range(self.consumed_samples, self.total_samples):
            batch.append(idx)
            if len(batch) == self.global_batch_size:
                start_idx = self.rank * self.batch_size_per_rank
                end_idx = start_idx + self.batch_size_per_rank
                yield batch[start_idx:end_idx]
                batch = []

megatron/data/samplers.py

deleted100644 → 0
+0 −148
Original line number Diff line number Diff line
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Batch samplers that work with either random or sequential data samplers."""

import torch
from torch.utils import data


class RandomSampler(data.sampler.Sampler):
    """Based off of pytorch RandomSampler and DistributedSampler. Essentially
    a RandomSampler, but this class lets the user set an epoch like
    DistributedSampler Samples elements randomly. If without replacement, then
    sample from a shuffled dataset. If with replacement, then user can
    specify ``num_samples`` to draw.
    Arguments:
        data_source (Dataset): dataset to sample from
        num_samples (int): number of samples to draw, default=len(dataset)
        replacement (bool): samples are drawn with replacement if ``True``,
        default=False
    """

    def __init__(self, data_source, replacement=False, num_samples=None):
        self.data_source = data_source
        self.replacement = replacement
        self._num_samples = num_samples
        self.epoch = -1

        if self._num_samples is not None and replacement is False:
            raise ValueError("With replacement=False, num_samples should not "
                             "be specified, since a random permute will be "
                             "performed.")

        if not isinstance(self.num_samples, int) or self.num_samples <= 0:
            raise ValueError("num_samples should be a positive integer "
                             "value, but got num_samples={}".format(
                                 self.num_samples))
        if not isinstance(self.replacement, bool):
            raise ValueError("replacement should be a boolean value, but got "
                             "replacement={}".format(self.replacement))

    @property
    def num_samples(self):
        # dataset size might change at runtime
        if self._num_samples is None:
            return len(self.data_source)
        return self._num_samples

    def __iter__(self):
        n = len(self.data_source)
        g = torch.Generator()
        if self.epoch >= 0:
            g.manual_seed(self.epoch)
        if self.replacement:
            return iter(torch.randint(high=n, size=(self.num_samples,),
                                      dtype=torch.int64, generator=g).tolist())
        return iter(torch.randperm(n, generator=g).tolist())

    def __len__(self):
        return self.num_samples

    def set_epoch(self, epoch):
        self.epoch = epoch


class DistributedBatchSampler(data.sampler.BatchSampler):
    """Similar to normal implementation of distributed sampler, except
    implementation is at the batch sampler level, instead of just the
    sampler level. This allows wrapping of arbitrary data samplers
    (sequential, random, WeightedRandomSampler, etc.) with this batch
    sampler.
    
    The `interleave` argument specifies how to distribute a batch. A value
    of True combined with the above random sampler is equivalent to pytorch's
    torch.utils.data.distributed.DistributedSampler.

    For the following batch [0,1,2,3,4,5,6,7] and data parallelism of 2 
    specifying True will result in the following samples for each gpu:
        GPU0: [0,2,4,6] GPU1: [1,3,5,7]
    specifying False will result in the following samples:
        GPU0: [0,1,2,3] GPU1: [4,5,6,7]"""

    def __init__(self, sampler, batch_size, drop_last, rank=-1,
                 world_size=2, wrap_last=False, interleave=False):
        super(DistributedBatchSampler, self).__init__(sampler, batch_size,
                                                      drop_last)
        if rank == -1:
            assert False, 'should not be here'
            rank = torch.distributed.get_rank()
        self.rank = rank
        self.world_size = world_size
        self.sampler.wrap_around = 0
        self.wrap_around = 0
        self.wrap_last = wrap_last
        self.start_iter = 0
        self.interleave = interleave

    def __iter__(self):
        batch = []
        i = 0
        for idx in self.data_iterator(self.sampler, wrap_around=False):
            batch.append(idx)
            if len(batch) == self.batch_size:
                tbatch = self._batch(batch)
                if i >= self.start_iter:
                    yield tbatch
                    self.start_iter = 0
                i += 1
                batch = []
        batch_len = len(batch)
        if batch_len > 0 and not self.drop_last:
            if self.wrap_last:
                self.sampler.wrap_around -= (self.batch_size)
                self.wrap_around += (len(batch))
                self.wrap_around %= self.batch_size
            yield self._batch(batch)
        if self.wrap_last:
            self.sampler.wrap_around += self.batch_size

    def data_iterator(self, _iter, wrap_around=False):
        """iterates through data and handles wrap around"""
        for i, idx in enumerate(_iter):
            if i < self.wrap_around % self.batch_size:
                continue
            if wrap_around:
                self.wrap_around += 1
                self.wrap_around %= self.batch_size
            yield idx

    def _batch(self, batch):
        """extracts samples only pertaining to this worker's batch"""
        if self.interleave:
            return batch[self.rank:self.batch_size:self.world_size]
        start = self.rank * self.batch_size // self.world_size
        end = (self.rank + 1) * self.batch_size // self.world_size
        return batch[start:end]
+0 −141
Original line number Diff line number Diff line
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""utils for creating datasets"""
import os
import math

import torch

from .samplers import DistributedBatchSampler
from .datasets import json_dataset, csv_dataset, split_ds, ConcatDataset, SplitDataset, bert_sentencepair_dataset, GPT2Dataset
from .lazy_loader import exists_lazy, make_lazy, lazy_array_loader
from .tokenization import Tokenization, CommandToken, Tokenizer, CharacterLevelTokenizer, BertWordPieceTokenizer, GPT2BPETokenizer, make_tokenizer
from . import corpora

TRAIN_DATA = 0
VAL_DATA = 1
TEST_DATA = 2


def should_split(split):
    """
    given split proportions checks if should split
    Examples:
    >>> should_split([10,0,0])
    False
    >>> should_split([1,.1,.2])
    True
    """
    return max(split) / sum(split) != 1.


def get_ext(path):
    """gets path extension"""
    return os.path.splitext(path)[1]


def get_dataset(path, **kwargs):
    """gets dataset object based on keyword args and file at `path`"""
    if supported_corpus(path):
        return corpora.NAMED_CORPORA[path](**kwargs)
    ext = get_ext(path)
    if '.json' in ext:
        text = json_dataset(path, **kwargs)
    elif ext in ['.csv', '.tsv']:
        text = csv_dataset(path, **kwargs)
    else:
        raise NotImplementedError('data file type %s is not supported' % (ext))
    return text


def supported_corpus(corpus_name):
    """checks if corpus name is defined in `corpora.py`"""
    return corpus_name in corpora.NAMED_CORPORA


def make_dataset(path, seq_length, text_key, label_key, lazy=False, process_fn=None, split=[1.],
                 delim=',', loose=False, binarize_sent=False, drop_unlabeled=False, tokenizer=None,
                 tokenizer_type='CharacterLevelTokenizer', tokenizer_model_path=None, vocab_size=None,
                 model_type='bpe', pad_token=0, character_converage=1.0, non_binary_cols=None,
                 parallel_group=None, **kwargs):
    """function to create datasets+tokenizers for common options"""
    if isinstance(process_fn, str):
        process_fn = eval(process_fn)
    if non_binary_cols is not None:
        # multilabel dataset support (only for csvs)
        label_key = non_binary_cols

    def get_dataset_from_path(path_):
        if lazy:
            # get lazily loaded dataset
            named_corpora = False
            if supported_corpus(path_):
                named_corpora = True
                name = path_
                path_ = corpora.NAMED_CORPORA[path_].PATH
            if torch.distributed.get_rank() == 0 and not exists_lazy(path_, data_type='data'):
                # create cached version of dataset for lazy loading if it doesn't exist
                text = get_dataset(name if named_corpora else path_, text_key=text_key, label_key=label_key, binarize_sent=binarize_sent,
                                   delim=delim, drop_unlabeled=drop_unlabeled, loose_json=loose)
                make_lazy(path_, text.X, data_type='data')
            # This should be a barrier but nccl barrier assumes
            # device_index=rank which is not the case for model
            # parallel case
            counts = torch.cuda.LongTensor([1])
            torch.distributed.all_reduce(counts, group=parallel_group)
            assert counts[0].item() == torch.distributed.get_world_size(
                group=parallel_group)

            text = lazy_array_loader(path_, data_type='data', map_fn=process_fn)
        else:
            # get dataset
            text = get_dataset(path_, text_key=text_key, label_key=label_key, binarize_sent=binarize_sent,
                               delim=delim, drop_unlabeled=drop_unlabeled, loose_json=loose, preprocess_fn=process_fn)
        return text
    # get one or multiple datasets and concatenate
    if isinstance(path, str):
        path = [path]
    datasets = [get_dataset_from_path(p) for p in path]
    if len(datasets) == 1:
        ds = datasets[0]
    else:
        ds = ConcatDataset(datasets)
    # make tokenizer for dataset
    if tokenizer is None:
        tokenizer = make_tokenizer(tokenizer_type, ds, tokenizer_model_path, vocab_size, model_type,
                                   pad_token, character_converage, **kwargs)

    ds_type = ''
    if 'ds_type' in kwargs:
        ds_type = kwargs['ds_type']
    ds.SetTokenizer(tokenizer)
    # Split dataset into train/val/test (and wrap bert dataset)
    if should_split(split):
        ds = split_ds(ds, split)
        if 'bert' in ds_type.lower():
            presplit_sentences = kwargs['presplit_sentences'] if 'presplit_sentences' in kwargs else False
            dstype = bert_sentencepair_dataset
            ds = [dstype(d, max_seq_len=seq_length, presplit_sentences=presplit_sentences)
                  if d is not None else None for d in ds]
        elif ds_type.lower() == 'gpt2':
            ds = [GPT2Dataset(d, max_seq_len=seq_length) if d is not None else None for d in ds]
    else:
        if 'bert' in ds_type.lower():
            presplit_sentences = kwargs['presplit_sentences'] if 'presplit_sentences' in kwargs else False
            dstype = bert_sentencepair_dataset
            ds = dstype(ds, max_seq_len=seq_length, presplit_sentences=presplit_sentences)
        elif ds_type.lower() == 'gpt2':
            ds = GPT2Dataset(ds, max_seq_len=seq_length)
    return ds, tokenizer
Loading