Commit 87bbe9be authored by Jared Casper's avatar Jared Casper
Browse files

Add document index to index file. An empty sentence no longer separate documents.

parent 1237533e
Loading
Loading
Loading
Loading
+3 −0
Original line number Diff line number Diff line
@@ -180,6 +180,9 @@ class FullTokenizer(object):
  def convert_ids_to_tokens(self, ids):
    return convert_by_vocab(self.inv_vocab, ids)

  def vocab_size(self):
    return len(self.vocab)


class BasicTokenizer(object):
  """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
+64 −5
Original line number Diff line number Diff line
@@ -7,6 +7,8 @@
# copied from fairseq/fairseq/data/indexed_dataset.py
# Removed IndexedRawTextDataset since it relied on Fairseq dictionary
# other slight modifications to remove fairseq dependencies
# Added document index to index file and made it accessible.
#    An empty sentence no longer separates documents.

from functools import lru_cache
import os
@@ -101,6 +103,12 @@ def index_file_path(prefix_path):
def data_file_path(prefix_path):
    return prefix_path + '.bin'

def create_doc_idx(sizes):
    doc_idx = [0]
    for i, s in enumerate(sizes):
        if s == 0:
            doc_idx.append(i+1)
    return doc_idx

class IndexedDataset(torch.utils.data.Dataset):
    """Loader for IndexedDataset"""
@@ -125,9 +133,11 @@ class IndexedDataset(torch.utils.data.Dataset):
            code, self.element_size = struct.unpack('<QQ', f.read(16))
            self.dtype = dtypes[code]
            self._len, self.s = struct.unpack('<QQ', f.read(16))
            self.doc_count = struct.unpack('<Q', f.read(8))
            self.dim_offsets = read_longs(f, self._len + 1)
            self.data_offsets = read_longs(f, self._len + 1)
            self.sizes = read_longs(f, self.s)
            self.doc_idx = read_longs(f, self.doc_count)

    def read_data(self, path):
        self.data_file = open(data_file_path(path), 'rb', buffering=0)
@@ -240,15 +250,18 @@ class IndexedDatasetBuilder(object):
        self.dim_offsets = [0]
        self.sizes = []
        self.element_size = self.element_sizes[self.dtype]
        self.doc_idx = [0]

    def add_item(self, tensor):
        # +1 for Lua compatibility
        bytes = self.out_file.write(np.array(tensor.numpy() + 1, dtype=self.dtype))
        bytes = self.out_file.write(np.array(tensor.numpy(), dtype=self.dtype))
        self.data_offsets.append(self.data_offsets[-1] + bytes / self.element_size)
        for s in tensor.size():
            self.sizes.append(s)
        self.dim_offsets.append(self.dim_offsets[-1] + len(tensor.size()))

    def end_document(self):
        self.doc_idx.append(len(self.sizes))

    def merge_file_(self, another_file):
        index = IndexedDataset(another_file)
        assert index.dtype == self.dtype
@@ -276,9 +289,11 @@ class IndexedDatasetBuilder(object):
        index.write(struct.pack('<Q', 1))
        index.write(struct.pack('<QQ', code(self.dtype), self.element_size))
        index.write(struct.pack('<QQ', len(self.data_offsets) - 1, len(self.sizes)))
        index.write(struct.pack('<Q', len(self.doc_idx)))
        write_longs(index, self.dim_offsets)
        write_longs(index, self.data_offsets)
        write_longs(index, self.sizes)
        write_longs(index, self.doc_idx)
        index.close()


@@ -316,10 +331,11 @@ class MMapIndexedDataset(torch.utils.data.Dataset):

                    return pointers

                def write(self, sizes):
                def write(self, sizes, doc_idx):
                    pointers = self._get_pointers(sizes)

                    self._file.write(struct.pack('<Q', len(sizes)))
                    self._file.write(struct.pack('<Q', len(doc_idx)))

                    sizes = np.array(sizes, dtype=np.int32)
                    self._file.write(sizes.tobytes(order='C'))
@@ -329,6 +345,9 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
                    self._file.write(pointers.tobytes(order='C'))
                    del pointers

                    doc_idx = np.array(doc_idx, dtype=np.int64)
                    self._file.write(doc_idx.tobytes(order='C'))

                def __exit__(self, exc_type, exc_val, exc_tb):
                    self._file.close()

@@ -349,6 +368,7 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
                self._dtype_size = self._dtype().itemsize

                self._len = struct.unpack('<Q', stream.read(8))[0]
                self._doc_count = struct.unpack('<Q', stream.read(8))[0]
                offset = stream.tell()

            _warmup_mmap_file(path)
@@ -358,7 +378,8 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
            self._sizes = np.frombuffer(self._bin_buffer, dtype=np.int32, count=self._len, offset=offset)
            self._pointers = np.frombuffer(self._bin_buffer, dtype=np.int64, count=self._len,
                                           offset=offset + self._sizes.nbytes)

            self._doc_idx = np.frombuffer(self._bin_buffer, dtype=np.int64, count=self._doc_count,
                                          offset=offset + self._sizes.nbytes + self._pointers.nbytes)
        def __del__(self):
            self._bin_buffer_mmap._mmap.close()
            del self._bin_buffer_mmap
@@ -371,6 +392,10 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
        def sizes(self):
            return self._sizes

        @property
        def doc_idx(self):
            return self._doc_idx

        @lru_cache(maxsize=8)
        def __getitem__(self, i):
            return self._pointers[i], self._sizes[i]
@@ -422,6 +447,10 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
    def sizes(self):
        return self._index.sizes

    @property
    def doc_idx(self):
        return self._index.doc_idx

    @property
    def supports_prefetch(self):
        return False
@@ -438,12 +467,16 @@ class MMapIndexedDatasetBuilder(object):
        self._data_file = open(out_file, 'wb')
        self._dtype = dtype
        self._sizes = []
        self._doc_idx = [0]

    def add_item(self, tensor):
        np_array = np.array(tensor.numpy(), dtype=self._dtype)
        self._data_file.write(np_array.tobytes(order='C'))
        self._sizes.append(np_array.size)

    def end_document(self):
        self._doc_idx.append(len(self._sizes))

    def merge_file_(self, another_file):
        # Concatenate index
        index = MMapIndexedDataset.Index(index_file_path(another_file))
@@ -460,4 +493,30 @@ class MMapIndexedDatasetBuilder(object):
        self._data_file.close()

        with MMapIndexedDataset.Index.writer(index_file, self._dtype) as index:
            index.write(self._sizes)
            index.write(self._sizes, self._doc_idx)

class indexed_doc_dataset(torch.utils.data.Dataset):
    def __init__(self, path):
        impl = infer_dataset_impl(path)
        self.ds = make_dataset(path, impl)
        self._docs = []
        doc_idxs = []
        for i, s in enumerate(self._sizes):
            if s > 0:
                doc_idxs.append(i)
            else:
                self._docs.append(doc_idxs)
                doc_idxs = []

    def __getitem__(self, i):
        if not isinstance(i, tuple):
            raise ValueError("Index into indexed_doc_dataset must be a tuple")
        idx = self._docs[i[0]][i[1]]
        return self.ds[idx]

    def __len__(self):
        """Returns number of documents, not number of sentences"""
        return len(self._docs)

    def doc_len(self, d):
        return len(self._docs[d])
+10 −22
Original line number Diff line number Diff line
@@ -45,7 +45,6 @@ class Encoder(object):
            tokens = Encoder.tokenizer.tokenize(sentence)
            ids = Encoder.tokenizer.convert_tokens_to_ids(tokens)
            doc_ids.append(ids)
        doc_ids.append([])
        return doc_ids, len(json_line)

def main():
@@ -71,8 +70,6 @@ def main():
    print("Opening", args.input)
    fin = open(args.input, 'r', encoding='utf-8')

    vocab_size = 1

    nltk.download("punkt", quiet=True)

    encoder = Encoder(args)
@@ -80,11 +77,13 @@ def main():
    pool = multiprocessing.Pool(args.workers, initializer=encoder.initializer)
    encoded_docs = pool.imap(encoder.encode, fin, 25)

    print(f"Vocab size: {tokenizer.vocab_size()}")

    output_bin_file = "{}.bin".format(args.output_prefix)
    output_idx_file = "{}.idx".format(args.output_prefix)
    ds = indexed_dataset.make_builder(output_bin_file,
    builder = indexed_dataset.make_builder(output_bin_file,
                                      impl=args.dataset_impl,
                                      vocab_size=vocab_size)
                                      vocab_size=tokenizer.vocab_size())

    startup_end = time.time()
    proc_start = time.time()
@@ -93,30 +92,19 @@ def main():
    for i, (doc, bytes_processed) in enumerate(encoded_docs, start=1):
        total_bytes_processed += bytes_processed
        for sentence in doc:
            print(sentence)
            print(tokenizer.convert_ids_to_tokens(sentence))
            ds.add_item(torch.IntTensor(sentence))
            #print(sentence)
            #print(tokenizer.convert_ids_to_tokens(sentence))
            builder.add_item(torch.IntTensor(sentence))
        builder.end_document()
        if i % args.log_interval == 0:
            current = time.time()
            elapsed = current - proc_start
            mbs = total_bytes_processed/elapsed/1024/1024
            print(f'Processed {i} documents',
            print(f"Processed {i} documents",
                  f"({i/elapsed} docs/s, {mbs} MB/s).",
                  file=sys.stderr)

    ds.finalize(output_idx_file)
    builder.finalize(output_idx_file)

if __name__ == '__main__':
    main()
    # print('processing data ...')

    # input_file = '/raid/mshoeybi/data/albert/sample/samples_11.json'
    # vocab_file = '/raid/mshoeybi/data/albert/bert_vocab/vocab.txt'

    # tokenizer = FullTokenizer(vocab_file, do_lower_case=True)
    # document_generator = document_generator_provider(input_file)
    # for sentences in document_generator:
    #     for sentence in sentences:
    #         tokens = tokenizer.tokenize(sentence)
    #         print(sentence)
    #         print(tokens)
+22 −7
Original line number Diff line number Diff line
@@ -10,21 +10,36 @@ sys.path.append(os.path.join(script_dir, "../../../"))
from megatron.data import indexed_dataset, FullBertTokenizer

def test_indexed_dataset(args):
    ds_impl = indexed_dataset.infer_dataset_impl(args.data)
    ds = indexed_dataset.make_dataset(args.data, ds_impl)
    ds = indexed_dataset.make_dataset(args.data, args.dataset_impl)
    tokenizer = FullBertTokenizer(args.vocab, do_lower_case=True)
    for sample in ds:
        print(sample)
        print(sample.data.tolist())
        print(tokenizer.convert_ids_to_tokens(sample.data.tolist()))
        print("---")
    print(len(ds.doc_idx))
    print(len(ds))
    print(ds.doc_idx[-1])
    if ds.supports_prefetch:
        # just prefetch the whole thing in test (so assume it is small)
        ds.prefetch(range(len(ds)))
    for i in range(1):
        start = ds.doc_idx[i]
        end = ds.doc_idx[i+1]
        print(start, end)
        for j in range(start, end):
            ids = ds[j].data.tolist()
            print(ids)
            tokens = tokenizer.convert_ids_to_tokens(ids)
            print(tokens)
        print("******** END DOCUMENT **********")

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--data', type=str, help='prefix to data files')
    parser.add_argument('--vocab', type=str, help='Path to vocab.txt')
    parser.add_argument('--dataset-impl', type=str, default='infer',
                        choices=['lazy', 'cached', 'mmap', 'infer'])
    args = parser.parse_args()

    if args.dataset_impl == "infer":
        args.dataset_impl = indexed_dataset.infer_dataset_impl(args.data)

    test_indexed_dataset(args)

if __name__ == "__main__":
+3 −1
Original line number Diff line number Diff line
#!/bin/bash

IMPL=cached
python ../preprocess_data.py \
       --input test_samples.json \
       --vocab vocab.txt \
       --output-prefix test_samples \
       --dataset-impl ${IMPL} \
       --output-prefix test_samples_${IMPL} \
       --workers 1 \
       --log-interval 2