Commit 9ea9d50f authored by Mohammad Shoeybi's avatar Mohammad Shoeybi
Browse files

added write option for index map

parent 690291dd
Loading
Loading
Loading
Loading
+102 −47
Original line number Diff line number Diff line
@@ -2,6 +2,7 @@

import random
import time
import os

import numpy as np
import torch
@@ -11,55 +12,37 @@ from .dataset_utils import build_training_sample
#from data.mapping import build_training_samples_mapping

from . import helpers
from megatron.data import FullBertTokenizer, indexed_dataset
from megatron.data import FullBertTokenizer
from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
from megatron.utils import print_rank_0


class AlbertDataset(Dataset):

    def __init__(self, indexed_dataset, tokenizer, num_epochs, max_num_samples,

    def __init__(self,
                 vocab_file, data_prefix, data_impl, skip_warmup,
                 num_epochs, max_num_samples,
                 masked_lm_prob, max_seq_length, short_seq_prob, seed):

        # Params to store.
        self.seed = seed
        self.masked_lm_prob = masked_lm_prob
        self.max_seq_length = max_seq_length
        self.tokenizer = tokenizer
        self.tokenizer = FullBertTokenizer(vocab_file, do_lower_case=True)

        # Indexed dataset.
        self.indexed_dataset = indexed_dataset
        self.indexed_dataset = self._get_indexed_dataset(data_prefix, data_impl,
                                                         skip_warmup)

        # Build the samples mapping.
        if not num_epochs:
            if not max_num_samples:
                raise ValueError("Need to specify either max_num_samples "
                                 "or num_epochs")
            num_epochs = np.iinfo(np.int32).max - 1
        if not max_num_samples:
            max_num_samples = np.iinfo(np.int64).max - 1

        # Make sure the types match the helpers input types.
        assert indexed_dataset.doc_idx.dtype == np.int64
        assert indexed_dataset.sizes.dtype == np.int32

        # Build samples mapping
        verbose = torch.distributed.get_rank()==0
        start_time = time.time()
        self.samples_mapping = helpers.build_mapping(
            indexed_dataset.doc_idx,
            indexed_dataset.sizes,
        self.samples_mapping = self._get_samples_mapping(self.indexed_dataset,
                                                         data_prefix,
                                                         num_epochs,
                                                         max_num_samples,
            self.max_seq_length-3, # account for added tokens
                                                         self.max_seq_length,
                                                         short_seq_prob,
            self.seed,
            verbose)
        # Make sure all the ranks have built the mapping
        torch.distributed.barrier()
        print_rank_0('> elasped time to build samples mapping (seconds): '
                     '{:2f}'.format(time.time() - start_time))

        exit()
                                                         self.seed)

        # Vocab stuff.
        self.vocab_id_list = list(tokenizer.inv_vocab.keys())
@@ -68,27 +51,19 @@ class AlbertDataset(Dataset):
        self.sep_id = tokenizer.vocab['[SEP]']
        self.mask_id = tokenizer.vocab['[MASK]']
        self.pad_id = tokenizer.vocab['[PAD]']
        exit()


    @classmethod
    def from_paths(cls, vocab, data_prefix, data_impl,
                   num_epochs, max_num_samples, masked_lm_prob,
                   max_seq_length, short_seq_prob, seed, skip_warmup=False):
        tokenizer = FullBertTokenizer(vocab, do_lower_case=True)
        print_rank_0("> Reading dataset index ...")
        idx_ds = indexed_dataset.make_dataset(data_prefix, data_impl,
                                              skip_warmup)
        print_rank_0("> Finished creating indexed dataset")
        return cls(idx_ds, tokenizer, num_epochs, max_num_samples,
                   masked_lm_prob, max_seq_length, short_seq_prob, seed)

    def num_tokens(self):
        return self.tokenizer.vocab_size()


    def __len__(self):
        return self.samples_mapping.shape[0]


    def __getitem__(self, idx):

        rng = random.Random(self.seed + idx)
        start_index, end_index, seq_length = self.samples_mapping[idx]
        sample = []
@@ -98,13 +73,93 @@ class AlbertDataset(Dataset):
            if len(s) > 1000:
                print(self.tokenizer.convert_ids_to_tokens(s))
        return build_training_sample(sample, seq_length,
                                     self.max_seq_length,
                                     self.max_seq_length, # needed for padding
                                     self.vocab_id_list,
                                     self.vocab_id_to_token_dict,
                                     self.cls_id, self.sep_id,
                                     self.mask_id, self.pad_id,
                                     self.masked_lm_prob, rng)



    def _get_indexed_dataset(self, data_prefix, data_impl, skip_warmup):
        start_time = time.time()
        print_rank_0("> Reading dataset index ...")
        indexed_dataset = make_indexed_dataset(data_prefix,
                                               data_impl,
                                               skip_warmup)
        print_rank_0("> Finished creating indexed dataset in {:4f} "
                     "seconds".format(time.time() - start_time))
        return indexed_dataset


    def _get_samples_mapping(self,
                             indexed_dataset,
                             data_prefix,
                             num_epochs,
                             max_num_samples,
                             max_seq_length,
                             short_seq_prob,
                             seed):
        if not num_epochs:
            if not max_num_samples:
                raise ValueError("Need to specify either max_num_samples "
                                 "or num_epochs")
            num_epochs = np.iinfo(np.int32).max - 1
        if not max_num_samples:
            max_num_samples = np.iinfo(np.int64).max - 1

        # Filename of the index mapping
        indexmap_filename = data_prefix
        indexmap_filename += '_indexmap'
        indexmap_filename += '_{}ep'.format(num_epochs)
        indexmap_filename += '_{}mns'.format(max_num_samples)
        indexmap_filename += '_{}msl'.format(max_seq_length)
        indexmap_filename += '_{:0.2f}ssp'.format(short_seq_prob)
        indexmap_filename += '_{}s'.format(seed)
        indexmap_filename += '.npy'

        # Build the indexed mapping if not exist.
        if torch.distributed.get_rank() == 0 and \
           not os.path.isfile(indexmap_filename):
            print('WARNING: could not find index map file {}, building '
                  'the indices on rank 0 ...'.format(indexmap_filename))
            # Make sure the types match the helpers input types.
            assert indexed_dataset.doc_idx.dtype == np.int64
            assert indexed_dataset.sizes.dtype == np.int32

            # Build samples mapping
            verbose = torch.distributed.get_rank()==0
            start_time = time.time()
            samples_mapping = helpers.build_mapping(
                indexed_dataset.doc_idx,
                indexed_dataset.sizes,
                num_epochs,
                max_num_samples,
                max_seq_length-3, # account for added tokens
                short_seq_prob,
                seed,
                verbose)
            np.save(indexmap_filename, samples_mapping, allow_pickle=True)
            # Make sure all the ranks have built the mapping
            print_rank_0('> elasped time to build and save samples mapping '
                         '(seconds): {:4f}'.format(
                             time.time() - start_time))
        torch.distributed.barrier()

        # Load indexed dataset.
        print_rank_0('> loading indexed mapping from {}'.format(
            indexmap_filename))
        start_time = time.time()
        samples_mapping = np.load(indexmap_filename, allow_pickle=True)
        print_rank_0('  loaded indexed file in {:3.3f} seconds'.format(
            time.time() - start_time))
        print_rank_0('  total number of samples: {}'.format(
            samples_mapping.shape[0]))

        return samples_mapping


'''
def get_target_seq_length(max_num_tokens, short_seq_prob, np_rng):
    """With probability `short_seq_prob` generate a smaller sequence lenght."""
+13 −8
Original line number Diff line number Diff line
@@ -121,14 +121,19 @@ def get_train_val_test_data(args):
            if not args.data_path:
                print("Albert currently only supports a unified dataset specified with --data-path")
                exit(1)
            print("Creating AlbertDataset...")
            full_data = AlbertDataset.from_paths(args.vocab, args.data_path,
                                                 args.data_impl, args.data_epochs,
                                                 args.max_num_samples,
                                                 args.mask_prob, args.seq_length,
                                                 args.short_seq_prob,
                                                 args.seed, args.skip_mmap_warmup)
            print("Finished creating AlbertDataset...")
            print_rank_0("Creating AlbertDataset...")
            full_data = AlbertDataset(
                vocab_file=args.vocab,
                data_prefix=args.data_path,
                data_impl=args.data_impl,
                skip_warmup=args.skip_mmap_warmup,
                num_epochs=args.data_epochs,
                max_num_samples=args.max_num_samples,
                masked_lm_prob=args.mask_prob,
                max_seq_length=args.seq_length,
                short_seq_prob=args.short_seq_prob,
                seed=args.seed)
            print_rank_0("Finished creating AlbertDataset...")
            split = split_dataset.get_split(args)
            if split_dataset.should_split(split):
                train_ds, val_ds, test_ds = split_dataset.split_ds(full_data, split, args.shuffle)