Commit c3691006 authored by Mohammad's avatar Mohammad
Browse files

testing new gpt2 dataset

parent 836c6776
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -24,7 +24,6 @@ from torch.utils.data import Dataset

from megatron import get_tokenizer
from megatron import mpu
from megatron.data import helpers
from megatron.data.dataset_utils import build_training_sample
from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
from megatron import print_rank_0
@@ -249,6 +248,7 @@ def get_samples_mapping_(indexed_dataset,
        start_time = time.time()
        print_rank_0(' > building sapmles index mapping for {} ...'.format(
            name))
        from megatron.data import helpers
        samples_mapping = helpers.build_mapping(
            indexed_dataset.doc_idx,
            indexed_dataset.sizes,
+29 −84
Original line number Diff line number Diff line
@@ -13,26 +13,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""GPT2 Style dataset."""
"""GPT2 style dataset."""

import os
import time

import numpy as np
import torch
from torch.utils.data import Dataset

import helpers
#from bert_dataset import get_train_valid_test_split_


def print_rank_0(message):
    print(message)
from megatron import print_rank_0
from megatron import mpu
from megatron.data.bert_dataset import get_train_valid_test_split_
from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset


def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
                                    train_valid_test_num_samples,
                                    seq_length, seed, skip_warmup):
    """Build train, valid, and test datasets."""

    # Indexed dataset.
    indexed_dataset = get_indexed_dataset_(data_prefix,
@@ -56,7 +54,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
    def build_dataset(index, name):
        dataset = None
        if splits[index + 1] > splits[index]:
            documents = np.arange(start=splits[index], end=splits[index+1],
            documents = np.arange(start=splits[index], stop=splits[index+1],
                                  step=1, dtype=np.int32)
            dataset = GPT2Dataset(name, data_prefix,
                                  documents, indexed_dataset,
@@ -72,7 +70,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,


def get_indexed_dataset_(data_prefix, data_impl, skip_warmup):

    """Build indexed dataset."""
    print_rank_0(' > building dataset index ...')

    start_time = time.time()
@@ -81,25 +79,18 @@ def get_indexed_dataset_(data_prefix, data_impl, skip_warmup):
                                           skip_warmup)
    print_rank_0(' > finished creating indexed dataset in {:4f} '
                 'seconds'.format(time.time() - start_time))

    print_rank_0(' > indexed dataset stats:')
    print_rank_0('    number of documents: {}'.format(
        indexed_dataset.sizes.shape[0]))

    return indexed_dataset


class GPT2Dataset(Dataset):
class GPT2Dataset(torch.utils.data.Dataset):

    def __init__(self, name, data_prefix,
                 documents, indexed_dataset,
    def __init__(self, name, data_prefix, documents, indexed_dataset,
                 num_samples, seq_length, seed):

        self.name = name
        self.data_prefix = data_prefix
        self.num_samples = num_samples
        self.seq_length = seq_length
        self.seed = seed
        self.indexed_dataset = indexed_dataset

        # Checks
@@ -107,11 +98,9 @@ class GPT2Dataset(Dataset):
        assert np.max(documents) < indexed_dataset.sizes.shape[0]

        # Build index mappings.
        self.num_epochs, self.doc_idx, self.sample_idx, self.shuffle_idx \
            = _build_index_mappings(self.name, self.data_prefix, documents,
                                    self.indexed_dataset.sizes,
                                    self.num_samples, self.seq_length,
                                    self.seed)
        self.doc_idx, self.sample_idx, self.shuffle_idx = _build_index_mappings(
            self.name, data_prefix, documents, self.indexed_dataset.sizes,
            num_samples, seq_length, seed)


    def __len__(self):
@@ -144,7 +133,7 @@ class GPT2Dataset(Dataset):
                length=offset_l+1))
            sample = np.concatenate(sample_list)

        return sample
        return {'text': np.array(sample, dtype=np.int64)}



@@ -168,7 +157,7 @@ def _build_index_mappings(name, data_prefix, documents, sizes,
    shuffle_idx_filename = _filename + '_shuffle_idx.npy'

    # Build the indexed mapping if not exist.
    if True: #torch.distributed.get_rank() == 0:
    if torch.distributed.get_rank() == 0:
        if (not os.path.isfile(doc_idx_filename)) or \
           (not os.path.isfile(sample_idx_filename)) or \
           (not os.path.isfile(shuffle_idx_filename)):
@@ -183,7 +172,10 @@ def _build_index_mappings(name, data_prefix, documents, sizes,
                         '(seconds): {:4f}'.format(time.time() - start_time))
            # sample-idx.
            start_time = time.time()
            import helpers
            # Use C++ implementation for speed.
            from megatron.data import helpers
            assert doc_idx.dtype == np.int32
            assert sizes.dtype == np.int32
            sample_idx = helpers.build_sample_idx(sizes, doc_idx, seq_length,
                                                  num_epochs, tokens_per_epoch)
            #sample_idx = _build_sample_idx(sizes, doc_idx, seq_length,
@@ -202,9 +194,9 @@ def _build_index_mappings(name, data_prefix, documents, sizes,
    # device_index=rank which is not the case for model
    # parallel case
    counts = torch.cuda.LongTensor([1])
    #torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
    #assert counts[0].item() == torch.distributed.get_world_size(
    #    group=mpu.get_data_parallel_group())
    torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
    assert counts[0].item() == torch.distributed.get_world_size(
        group=mpu.get_data_parallel_group())

    # Load mappings.
    start_time = time.time()
@@ -221,8 +213,9 @@ def _build_index_mappings(name, data_prefix, documents, sizes,
        time.time() - start_time))
    print_rank_0('    total number of samples: {}'.format(
        sample_idx.shape[0]))
    print_rank_0('    total number of epochs: {}'.format(num_epochs))

    return num_epochs, doc_idx, sample_idx, shuffle_idx
    return doc_idx, sample_idx, shuffle_idx


def _num_tokens(documents, sizes):
@@ -311,10 +304,11 @@ def _build_shuffle_idx(size, np_rng):
    if size >= (np.iinfo(np.uint32).max - 1):
        dtype_ = np.int64
    shuffle_idx = np.arange(start=0, stop=size, step=1, dtype=dtype_)
    #np_rng.shuffle(shuffle_idx)
    np_rng.shuffle(shuffle_idx)
    return shuffle_idx


'''

class IndexedDataset:

@@ -399,53 +393,4 @@ if __name__ == '__main__':

        test(seed, data_prefix, seq_length, num_samples,
             num_docs, min_doc_length, max_doc_length)
    exit()

    '''

    num_docs = 5
    min_doc_length = 2
    max_doc_length = 10
    num_samples = 9
    seq_length = 4
    seed = 1234
    
    np.random.seed(seed)
    indexed_dataset = IndexedDataset(num_docs, min_doc_length,
                                     max_doc_length, seq_length)
    print('> indexed dataset:')
    for s in indexed_dataset.tokens:
        print('   {}'.format(s))

    documents = np.array([1,2,3], dtype=np.int32)

    dataset = GPT2Dataset('gpt2', documents, indexed_dataset,
                          num_samples, seq_length, seed)
    indexed_dataset.build_tokens_flat(dataset.doc_idx)

    print(indexed_dataset.get_sample(6))
    print(dataset[6])
    '''    
'''
    myds = MyDataset(ds, num_samples, seq_length)
    num_docs = myds._num_docs()
    print('> number of document: {}'.format(num_docs))
    tokens_per_epoch = myds._num_tokens()
    print('> number of tokens: {}'.format(tokens_per_epoch))
    num_epochs = myds._num_epochs(tokens_per_epoch)
    print('> number of epochs: {}'.format(num_epochs))
    doc_idx = myds._build_doc_idx(num_docs, num_epochs)
    print('> doc_idx: {}'.format(doc_idx))

    ds.build_tokens_flat(doc_idx)
    sample_idx =myds._build_sample_idx(num_epochs, tokens_per_epoch, doc_idx)

    for s in sample_idx:
        print(s)
        
    print(ds.tokens_flat)
    print(myds.get_sample(8))
    print(ds.get_sample(8))
    '''