Commit b1714c14 authored by Mohammad Shoeybi's avatar Mohammad Shoeybi
Browse files

Merge branch 'albert_data_loader' of...

Merge branch 'albert_data_loader' of ssh://gitlab-master.nvidia.com:12051/ADLR/megatron-lm into albert_data_loader
parents f51ceb7c 3f4bc91b
Loading
Loading
Loading
Loading
+65 −37
Original line number Diff line number Diff line
@@ -14,6 +14,7 @@ from functools import lru_cache
import os
import shutil
import struct
from itertools import accumulate

import numpy as np
import torch
@@ -50,11 +51,11 @@ def make_builder(out_file, impl, vocab_size=None):
        return IndexedDatasetBuilder(out_file)


def make_dataset(path, impl, fix_lua_indexing=False):
def make_dataset(path, impl):
    if impl == 'lazy' and IndexedDataset.exists(path):
        return IndexedDataset(path, fix_lua_indexing=fix_lua_indexing)
        return IndexedDataset(path)
    elif impl == 'cached' and IndexedDataset.exists(path):
        return IndexedCachedDataset(path, fix_lua_indexing=fix_lua_indexing)
        return IndexedCachedDataset(path)
    elif impl == 'mmap' and MMapIndexedDataset.exists(path):
        return MMapIndexedDataset(path)
    return None
@@ -114,10 +115,9 @@ class IndexedDataset(torch.utils.data.Dataset):
    """Loader for IndexedDataset"""
    _HDR_MAGIC = b'TNTIDX\x00\x00'

    def __init__(self, path, fix_lua_indexing=False):
    def __init__(self, path):
        super().__init__()
        self.path = path
        self.fix_lua_indexing = fix_lua_indexing
        self.data_file = None
        self.read_index(path)

@@ -150,19 +150,30 @@ class IndexedDataset(torch.utils.data.Dataset):
        if self.data_file:
            self.data_file.close()

    @lru_cache(maxsize=8)
    def __getitem__(self, i):
    #@lru_cache(maxsize=8)
    def __getitem__(self, idx):
        if not self.data_file:
            self.read_data(self.path)
        if isinstance(idx, int):
            i = idx
            self.check_index(i)
            tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]]
            a = np.empty(tensor_size, dtype=self.dtype)
            self.data_file.seek(self.data_offsets[i] * self.element_size)
            self.data_file.readinto(a)
        item = torch.from_numpy(a).long()
        if self.fix_lua_indexing:
            item -= 1  # subtract 1 for 0-based indexing
        return item
            return a
        elif isinstance(idx, slice):
            start, stop, step = idx.indices(len(self))
            if step != 1:
                raise ValueError("Slices into indexed_dataset must be contiguous")
            sizes = self.sizes[self.dim_offsets[start]:self.dim_offsets[stop]]
            size = sum(sizes)
            a = np.empty(size, dtype=self.dtype)
            self.data_file.seek(self.data_offsets[start] * self.element_size)
            self.data_file.readinto(a)
            offsets = list(accumulate(sizes))
            sents = np.split(a, offsets[:-1])
            return sents

    def __len__(self):
        return self._len
@@ -186,8 +197,8 @@ class IndexedDataset(torch.utils.data.Dataset):

class IndexedCachedDataset(IndexedDataset):

    def __init__(self, path, fix_lua_indexing=False):
        super().__init__(path, fix_lua_indexing=fix_lua_indexing)
    def __init__(self, path):
        super().__init__(path)
        self.cache = None
        self.cache_index = {}

@@ -219,17 +230,22 @@ class IndexedCachedDataset(IndexedDataset):
            self.data_file.close()
            self.data_file = None

    @lru_cache(maxsize=8)
    def __getitem__(self, i):
    #@lru_cache(maxsize=8)
    def __getitem__(self, idx):
        if isinstance(idx, int):
            i = idx
            self.check_index(i)
            tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]]
            a = np.empty(tensor_size, dtype=self.dtype)
            ptx = self.cache_index[i]
            np.copyto(a, self.cache[ptx: ptx + a.size])
        item = torch.from_numpy(a).long()
        if self.fix_lua_indexing:
            item -= 1  # subtract 1 for 0-based indexing
        return item
            return a
        elif isinstance(idx, slice):
            # Hack just to make this work, can optimizer later if necessary
            sents = []
            for i in range(*idx.indices(len(self))):
                sents.append(self[i])
            return sents


class IndexedDatasetBuilder(object):
@@ -434,14 +450,26 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
    def __len__(self):
        return len(self._index)

    @lru_cache(maxsize=8)
    def __getitem__(self, i):
        ptr, size = self._index[i]
    #@lru_cache(maxsize=8)
    def __getitem__(self, idx):
        if isinstance(idx, int):
            ptr, size = self._index[idx]
            np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype, count=size, offset=ptr)
            if self._index.dtype != np.int64:
                np_array = np_array.astype(np.int64)

            return torch.from_numpy(np_array)
        elif isinstance(idx, slice):
            start, stop, step = idx.indices(len(self))
            if step != 1:
                raise ValueError("Slices into indexed_dataset must be contiguous")
            ptr = self._index._pointers[start]
            sizes = self._index._sizes[idx]
            offsets = list(accumulate(sizes))
            total_size = sum(sizes)
            np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype, count=total_size, offset=ptr)
            sents = np.split(np_array, offsets[:-1])
            return sents

    @property
    def sizes(self):
+2 −1
Original line number Diff line number Diff line
@@ -44,6 +44,7 @@ class Encoder(object):
        for sentence in Encoder.splitter.tokenize(text):
            tokens = Encoder.tokenizer.tokenize(sentence)
            ids = Encoder.tokenizer.convert_tokens_to_ids(tokens)
            if len(ids) > 0:
                doc_ids.append(ids)
        return doc_ids, len(json_line)

+10 −8
Original line number Diff line number Diff line
@@ -18,16 +18,18 @@ def test_indexed_dataset(args):
    if ds.supports_prefetch:
        # just prefetch the whole thing in test (so assume it is small)
        ds.prefetch(range(len(ds)))
    for i in range(1):
    for i in range(len(ds.doc_idx)-1):
        start = ds.doc_idx[i]
        end = ds.doc_idx[i+1]
        print(start, end)
        for j in range(start, end):
            ids = ds[j].data.tolist()
            print(ids)
            tokens = tokenizer.convert_ids_to_tokens(ids)
            print(tokens)
        print("******** END DOCUMENT **********")
        ids = ds[start:end]
        for s in ids:
            assert len(s) > 0
            l = s.data.tolist()
            tokens = tokenizer.convert_ids_to_tokens(l)
            for t in tokens:
                if '\n' in t:
                    print("Newline in string!")
        print(i)

def main():
    parser = argparse.ArgumentParser()