Commit 0601702a authored by Mohammad Shoeybi's avatar Mohammad Shoeybi
Browse files

zero worker seems to be working

parent f6a6811f
Loading
Loading
Loading
Loading
+154 −24
Original line number Diff line number Diff line
@@ -8,6 +8,7 @@ import numpy as np
import torch
from torch.utils.data import Dataset

from megatron import mpu
from megatron.data import helpers
from megatron.data import FullBertTokenizer
from megatron.data.dataset_utils import build_training_sample
@@ -15,22 +16,97 @@ from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
from megatron.utils import print_rank_0


def build_train_valid_test_datasets(vocab_file, data_prefix, data_impl,
                                    splits_string, train_valid_test_num_samples,
                                    max_seq_length, masked_lm_prob,
                                    short_seq_prob, seed, skip_warmup):

    # Tokenizer is the same
    tokenizer = FullBertTokenizer(vocab_file, do_lower_case=True)
    print_rank_0(' > using full BERT tokenizer with vocabulary size: {}'.format(
        tokenizer.vocab_size()))

    # Indexed dataset.
    indexed_dataset = get_indexed_dataset_(data_prefix,
                                           data_impl,
                                           skip_warmup)

    # Get start and end indices of train/valid/train into doc-idx
    # Note that doc-idx is desinged to be num-docs + 1 so we can
    # easily iterate over it.
    total_num_of_documents = indexed_dataset.doc_idx.shape[0] - 1
    splits = get_train_valid_test_split_(splits_string, total_num_of_documents)

    # Print stats about the splits.
    print_rank_0(' > dataset split:')
    def print_split_stats(name, index):
        print_rank_0('    {}:'.format(name))
        print_rank_0('     document indices in [{}, {}) total of {} '
                     'documents'.format(splits[index], splits[index + 1],
                                        splits[index + 1] - splits[index]))
        start_index = indexed_dataset.doc_idx[splits[index]]
        end_index = indexed_dataset.doc_idx[splits[index + 1]]
        print_rank_0('     sentence indices in [{}, {}) total of {} '
                     'sentences'.format(start_index, end_index,
                                        end_index - start_index))
    print_split_stats('train', 0)
    print_split_stats('validation', 1)
    print_split_stats('test', 2)

    def build_dataset(index, name):
        dataset = None
        if splits[index + 1] > splits[index]:
            # Get the pointer to the original doc-idx so we can set it later.
            doc_idx_ptr = indexed_dataset.get_doc_idx()
            # Slice the doc-idx
            start_index = splits[index]
            # Add +1 so we can index into the dataset to get the upper bound.
            end_index = splits[index + 1] + 1
            # New doc_idx view.
            indexed_dataset.set_doc_idx(doc_idx_ptr[start_index:end_index])
            # Build the dataset accordingly.
            dataset = AlbertDataset(
                name=name,
                indexed_dataset=indexed_dataset,
                tokenizer=tokenizer,
                data_prefix=data_prefix,
                num_epochs=None,
                max_num_samples=train_valid_test_num_samples[index],
                masked_lm_prob=masked_lm_prob,
                max_seq_length=max_seq_length,
                short_seq_prob=short_seq_prob,
                seed=seed)
            # Set the original pointer so dataset remains the main dataset.
            indexed_dataset.set_doc_idx(doc_idx_ptr)
            # Checks.
            assert indexed_dataset.doc_idx[0] == 0
            assert indexed_dataset.doc_idx.shape[0] == \
                (total_num_of_documents + 1)
        return dataset

    train_dataset = build_dataset(0, 'train')
    valid_dataset = build_dataset(1, 'valid')
    test_dataset = build_dataset(2, 'test')

    return (train_dataset, valid_dataset, test_dataset)


class AlbertDataset(Dataset):

    def __init__(self, vocab_file, data_prefix, data_impl, skip_warmup,
                 num_epochs, max_num_samples, masked_lm_prob, max_seq_length,
                 short_seq_prob, seed):
    def __init__(self, name, indexed_dataset, tokenizer, data_prefix,
                 num_epochs, max_num_samples, masked_lm_prob,
                 max_seq_length, short_seq_prob, seed):

        # Params to store.
        self.name = name
        self.seed = seed
        self.masked_lm_prob = masked_lm_prob
        self.max_seq_length = max_seq_length
        self.tokenizer = FullBertTokenizer(vocab_file, do_lower_case=True)

        # Indexed dataset.
        self.indexed_dataset = get_indexed_dataset_(data_prefix,
                                                    data_impl,
                                                    skip_warmup)
        # Tokenizer and dataset.
        self.tokenizer = tokenizer
        self.indexed_dataset = indexed_dataset


        # Build the samples mapping.
        self.samples_mapping = get_samples_mapping_(self.indexed_dataset,
@@ -39,7 +115,8 @@ class AlbertDataset(Dataset):
                                                    max_num_samples,
                                                    self.max_seq_length,
                                                    short_seq_prob,
                                                    self.seed)
                                                    self.seed,
                                                    self.name)

        # Vocab stuff.
        self.vocab_id_list = list(self.tokenizer.inv_vocab.keys())
@@ -48,7 +125,6 @@ class AlbertDataset(Dataset):
        self.sep_id = self.tokenizer.vocab['[SEP]']
        self.mask_id = self.tokenizer.vocab['[MASK]']
        self.pad_id = self.tokenizer.vocab['[PAD]']
        exit()


    def num_tokens(self):
@@ -68,9 +144,11 @@ class AlbertDataset(Dataset):
        sample = []
        for index in range(start_index, end_index):
            sample.append(self.indexed_dataset[index])
        '''
        for s in sample:
            if len(s) > 1000:
                print(self.tokenizer.convert_ids_to_tokens(s))
        '''
        return build_training_sample(sample, seq_length,
                                     self.max_seq_length, # needed for padding
                                     self.vocab_id_list,
@@ -80,25 +158,63 @@ class AlbertDataset(Dataset):
                                     self.masked_lm_prob, rng)



def get_indexed_dataset_(data_prefix, data_impl, skip_warmup):

    print_rank_0(' > building dataset index ...')

    start_time = time.time()
    print_rank_0("> Reading dataset index ...")
    indexed_dataset = make_indexed_dataset(data_prefix,
                                           data_impl,
                                           skip_warmup)
    print_rank_0("> Finished creating indexed dataset in {:4f} "
                 "seconds".format(time.time() - start_time))
    assert indexed_dataset.sizes.shape[0] == indexed_dataset.doc_idx[-1]
    print_rank_0(' > finished creating indexed dataset in {:4f} '
                 'seconds'.format(time.time() - start_time))

    print_rank_0(' > indexed dataset stats:')
    print_rank_0('    number of documents: {}'.format(
        indexed_dataset.doc_idx.shape[0] - 1))
    print_rank_0('    number of sentences: {}'.format(
        indexed_dataset.sizes.shape[0]))

    return indexed_dataset


def get_train_valid_test_split_(splits_string, size):
    """ Get dataset splits from comma or '/' separated string list."""

    splits = []
    if splits_string.find(',') != -1:
        splits = [float(s) for s in splits_string.split(',')]
    elif splits_string.find('/') != -1:
        splits = [float(s) for s in splits_string.split('/')]
    else:
        splits = [float(splits_string)]
    while len(splits) < 3:
        splits.append(0.)
    splits = splits[:3]
    splits_sum = sum(splits)
    assert splits_sum > 0.0
    splits = [split/splits_sum for split in splits]
    splits_index = [0]
    for index, split in enumerate(splits):
        splits_index.append(splits_index[index] +
                            int(round(split * float(size))))
    diff = splits_index[-1] - size
    for index in range(1, len(splits_index)):
        splits_index[index] -= diff
    assert len(splits_index) == 4
    assert splits_index[-1] == size
    return splits_index


def get_samples_mapping_(indexed_dataset,
                         data_prefix,
                         num_epochs,
                         max_num_samples,
                         max_seq_length,
                         short_seq_prob,
                         seed):
                         seed,
                         name):
    if not num_epochs:
        if not max_num_samples:
            raise ValueError("Need to specify either max_num_samples "
@@ -109,8 +225,10 @@ def get_samples_mapping_(indexed_dataset,

    # Filename of the index mapping
    indexmap_filename = data_prefix
    indexmap_filename += '_indexmap'
    indexmap_filename += '_{}_indexmap'.format(name)
    if num_epochs != (np.iinfo(np.int32).max - 1):
        indexmap_filename += '_{}ep'.format(num_epochs)
    if max_num_samples != (np.iinfo(np.int64).max - 1):
        indexmap_filename += '_{}mns'.format(max_num_samples)
    indexmap_filename += '_{}msl'.format(max_seq_length)
    indexmap_filename += '_{:0.2f}ssp'.format(short_seq_prob)
@@ -120,8 +238,9 @@ def get_samples_mapping_(indexed_dataset,
    # Build the indexed mapping if not exist.
    if torch.distributed.get_rank() == 0 and \
       not os.path.isfile(indexmap_filename):
        print('WARNING: could not find index map file {}, building '
        print(' > WARNING: could not find index map file {}, building '
              'the indices on rank 0 ...'.format(indexmap_filename))

        # Make sure the types match the helpers input types.
        assert indexed_dataset.doc_idx.dtype == np.int64
        assert indexed_dataset.sizes.dtype == np.int32
@@ -129,6 +248,8 @@ def get_samples_mapping_(indexed_dataset,
        # Build samples mapping
        verbose = torch.distributed.get_rank() == 0
        start_time = time.time()
        print_rank_0(' > building sapmles index mapping for {} ...'.format(
            name))
        samples_mapping = helpers.build_mapping(
            indexed_dataset.doc_idx,
            indexed_dataset.sizes,
@@ -138,12 +259,21 @@ def get_samples_mapping_(indexed_dataset,
            short_seq_prob,
            seed,
            verbose)
        print_rank_0(' > done building sapmles index maping')
        np.save(indexmap_filename, samples_mapping, allow_pickle=True)
        print_rank_0(' > saved the index mapping in {}'.format(
            indexmap_filename))
        # Make sure all the ranks have built the mapping
        print_rank_0(' > elasped time to build and save samples mapping '
                     '(seconds): {:4f}'.format(
                         time.time() - start_time))
    torch.distributed.barrier()
    # This should be a barrier but nccl barrier assumes
    # device_index=rank which is not the case for model
    # parallel case
    counts = torch.cuda.LongTensor([1])
    torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
    assert counts[0].item() == torch.distributed.get_world_size(
        group=mpu.get_data_parallel_group())

    # Load indexed dataset.
    print_rank_0(' > loading indexed mapping from {}'.format(
+40 −44
Original line number Diff line number Diff line
@@ -39,12 +39,6 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
       and sequence-length is the target sequence length.
    */

    if (verbose) {
         cout << " > using " << docs_.shape(0) - 1 <<
	   " documents with " << sizes_.shape(0) << " sentences ..." <<
	   endl << std::flush;
    }

    // Consistency checks.
    assert(num_epochs > 0);
    assert(max_seq_length > 1);
@@ -52,16 +46,36 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
    assert(short_seq_prob <= 1.0);
    assert(seed > 0);

    // For efficiency, convert probability to ratio. Note: rand() generates int.
    const auto short_seq_ratio = static_cast<int32_t>(round(1.0 / short_seq_prob));

    // Remove bound checks.
    auto docs = docs_.unchecked<1>();
    auto sizes = sizes_.unchecked<1>();
    if (docs[docs.shape(0) - 1] != sizes.shape(0)) {
        cout << "document values is not consistent with length of sizes: " <<
                docs[docs.shape(0) - 1] << " != " << sizes.shape(0) << endl;
        throw std::length_error("docs and sizes");

    // For efficiency, convert probability to ratio. Note: rand() generates int.
    const auto short_seq_ratio = static_cast<int32_t>(round(1.0 / short_seq_prob));

    if (verbose) {
        const auto sent_start_index = docs[0];
	const auto sent_end_index = docs[docs_.shape(0) - 1];
	const auto num_sentences = sent_end_index - sent_start_index;
	cout << "    using:" << endl << std::flush;
	cout << "     number of documents:            " << docs_.shape(0) - 1 <<
	  endl << std::flush;
	cout << "     sentences range:                [" << sent_start_index <<
	", " << sent_end_index << ")" << endl << std::flush;
	cout << "     total number of sentences:      " << num_sentences <<
	  endl << std::flush;
	cout << "     number of epochs:               " << num_epochs <<
	  endl << std::flush;
	cout << "     maximum number of samples:      " << max_num_samples <<
	  endl << std::flush;
	cout << "     maximum sequence length:        " << max_seq_length <<
	  endl << std::flush;
	cout << "     short sequence probability:     " << short_seq_prob <<
	endl << std::flush;
	cout << "     short sequence ration (1/prob): " << short_seq_ratio <<
	  endl << std::flush;
	cout << "     seed:                           " << seed << endl <<
	  std::flush;
    }

    // Mapping and it's length (1D).
@@ -90,7 +104,7 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
        for (int32_t epoch=0; epoch<num_epochs; ++epoch) {
            if (map_index >= max_num_samples) {
	        if (verbose && (!second)) {
		  cout << " > reached " << max_num_samples << " samples after "
		  cout << "    reached " << max_num_samples << " samples after "
		       << epoch << " epochs ..." << endl << std::flush;
		}
                break;
@@ -181,11 +195,11 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_,

        if (!second) {
	    if (verbose) {
	        cout << " > number of empty documents: " << empty_docs <<
	        cout << "   number of empty documents: " << empty_docs <<
		  endl << std::flush;
		cout << " > number of documents with one sentence: " <<
		cout << "   number of documents with one sentence: " <<
		  one_sent_docs << endl << std::flush;
		cout << " > will create mapping for " << map_index <<
		cout << "   will create mapping for " << map_index <<
		  " samples" << endl << std::flush;
	    }
	    assert(maps == NULL);
@@ -210,10 +224,6 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
      swap(maps[i0 + 2], maps[j0 + 2]);
    }

    if (verbose) {
        cout << "> done building the mapping." << endl;
    }

    // Method to deallocate memory.
    py::capsule free_when_done(maps, [](void *mem_) {
            DocIdx *mem = reinterpret_cast<DocIdx*>(mem_);
@@ -239,30 +249,16 @@ py::array build_mapping(const py::array_t<int64_t>& docs_,
                        const int seed,
			const bool verbose) {

    if (verbose) {
        cout << "> building sample map using: " << endl << std::flush;
	cout << "     number of epochs:           " << num_epochs << endl
	     << std::flush;
	cout << "     maximum number of samples:  " << max_num_samples << endl
	     << std::flush;
	cout << "     maximum sequence length:    " << max_seq_length << endl
	     << std::flush;
	cout << "     short sequence probability: " << short_seq_prob << endl
	     << std::flush;
	cout << "     seed:                       " << seed << endl
	     << std::flush;
    }

    if (sizes_.size() > std::numeric_limits<uint32_t>::max()) {
        if (verbose) {
	    cout << " > using uint64 for data mapping..." << endl << std::flush;
	   cout << "    using uint64 for data mapping..." << endl << std::flush;
	}
	return build_mapping_impl<uint64_t>(docs_, sizes_, num_epochs,
					    max_num_samples, max_seq_length,
					    short_seq_prob, seed, verbose);
    } else {
       if (verbose) {
	    cout << " > using uint32 for data mapping..." << endl << std::flush;
	   cout << "    using uint32 for data mapping..." << endl << std::flush;
       }
       return build_mapping_impl<uint32_t>(docs_, sizes_, num_epochs,
					   max_num_samples, max_seq_length,
+13 −9
Original line number Diff line number Diff line
@@ -391,17 +391,17 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
                offset = stream.tell()

            if not skip_warmup:
                print_rank_0(">    Warming up index mmap file...")
                print_rank_0("    warming up index mmap file...")
                _warmup_mmap_file(path)

            self._bin_buffer_mmap = np.memmap(path, mode='r', order='C')
            self._bin_buffer = memoryview(self._bin_buffer_mmap)
            print_rank_0(">    Reading sizes...")
            print_rank_0("    reading sizes...")
            self._sizes = np.frombuffer(self._bin_buffer, dtype=np.int32, count=self._len, offset=offset)
            print_rank_0(">    Reading pointers...")
            print_rank_0("    reading pointers...")
            self._pointers = np.frombuffer(self._bin_buffer, dtype=np.int64, count=self._len,
                                           offset=offset + self._sizes.nbytes)
            print_rank_0(">    Reading document index...")
            print_rank_0("    reading document index...")
            self._doc_idx = np.frombuffer(self._bin_buffer, dtype=np.int64, count=self._doc_count,
                                          offset=offset + self._sizes.nbytes + self._pointers.nbytes)
        def __del__(self):
@@ -447,13 +447,12 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
        self._index = self.Index(index_file_path(self._path), skip_warmup)

        if not skip_warmup:
            print_rank_0(">    Warming up data mmap file...")
            print_rank_0("    warming up data mmap file...")
            _warmup_mmap_file(data_file_path(self._path))
        print_rank_0(">    Creating numpy buffer of mmap...")
        print_rank_0("    creating numpy buffer of mmap...")
        self._bin_buffer_mmap = np.memmap(data_file_path(self._path), mode='r', order='C')
        print_rank_0(">    Creating memory view of numpy buffer...")
        print_rank_0("    creating memory view of numpy buffer...")
        self._bin_buffer = memoryview(self._bin_buffer_mmap)
        print_rank_0(">    Done")

    def __del__(self):
        self._bin_buffer_mmap._mmap.close()
@@ -470,7 +469,6 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
            np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype, count=size, offset=ptr)
            if self._index.dtype != np.int64:
                np_array = np_array.astype(np.int64)

            return np_array
        elif isinstance(idx, slice):
            start, stop, step = idx.indices(len(self))
@@ -492,6 +490,12 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
    def doc_idx(self):
        return self._index.doc_idx

    def get_doc_idx(self):
        return self._index._doc_idx

    def set_doc_idx(self, doc_idx_):
        self._index._doc_idx = doc_idx_

    @property
    def supports_prefetch(self):
        return False
+19 −28
Original line number Diff line number Diff line
@@ -13,43 +13,34 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""dataset to split one large one into multiple smaller datasets"""

import torch
import numpy as np

def should_split(split):
    """
    given split proportions checks if should split
    Examples:
    >>> should_split([10,0,0])
    False
    >>> should_split([1,.1,.2])
    True
    """
    return max(split)/sum(split) != 1.
def get_train_valid_test_split(splits_string, size):
    """ Get dataset splits from comma or '/' separated string list."""

def get_split(args):
    """
    Get dataset splits from comma separated string list
    """
    splits = []
    if args.split.find(',') != -1:
        splits = [float(s) for s in args.split.split(',')]
    elif args.split.find('/') != -1:
        splits = [float(s) for s in args.split.split('/')]
    if splits_string.find(',') != -1:
        splits = [float(s) for s in splits_string.split(',')]
    elif splits_string.find('/') != -1:
        splits = [float(s) for s in splits_string.split('/')]
    else:
        splits = [float(args.split)]
    split_total = sum(splits)
    if split_total < 1.:
        splits.append(1-split_total)
        splits = [float(splits_string)]
    while len(splits) < 3:
        splits.append(0.)
    splits = splits[:3]
    if args.valid_data is not None:
        splits[1] = 0.
    if args.test_data is not None:
        splits[2] = 0.
    final_sum = sum(splits)
    return [s/final_sum for s in splits]
    splits_sum = sum(splits)
    assert splits_sum > 0.0
    splits = [split/splits_sum for split in splits]
    splits_index = [0]
    for index, split in enumerate(splits):
        splits_index.append(splits_index[index] +
                            int(round(split * float(size))))
    diff = splits_index[-1] - size
    for index in range(1, len(splits_index)):
        splits_index[index] -= diff
    return splits_index

class SplitDataset(torch.utils.data.Dataset):
    """
+79 −75

File changed.

Preview size limit exceeded, changes collapsed.