Commit da0562fc authored by Jared Casper's avatar Jared Casper
Browse files

Updates to preprocess_data.py and indexed_dataset.

preprocess_data:
- Adds ability to not split sentences. This is used for gpt2 datasets.

- Adds ability to create multiple datasets from different json keys,
this is current untested.

indexed_dataset:
- Add new "get" function to get a portion of an entry.
parent f66c58a9
Loading
Loading
Loading
Loading
+17 −2
Original line number Diff line number Diff line
@@ -42,6 +42,7 @@ def infer_dataset_impl(path):
            else:
                return None
    else:
        print(f"Dataset path does not exist: {path}")
        return None


@@ -61,6 +62,7 @@ def make_dataset(path, impl, skip_warmup=False):
        return IndexedCachedDataset(path)
    elif impl == 'mmap' and MMapIndexedDataset.exists(path):
        return MMapIndexedDataset(path, skip_warmup)
    print(f"Unknown dataset implementation: {impl}")
    return None


@@ -466,7 +468,8 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
    def __getitem__(self, idx):
        if isinstance(idx, int):
            ptr, size = self._index[idx]
            np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype, count=size, offset=ptr)
            np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype,
                                     count=size, offset=ptr)
            if self._index.dtype != np.int64:
                np_array = np_array.astype(np.int64)
            return np_array
@@ -478,10 +481,22 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
            sizes = self._index._sizes[idx]
            offsets = list(accumulate(sizes))
            total_size = sum(sizes)
            np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype, count=total_size, offset=ptr)
            np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype,
                                     count=total_size, offset=ptr)
            sents = np.split(np_array, offsets[:-1])
            return sents

    def get(self, idx, offset=0, length=None):
        ptr, size = self._index[idx]
        if length is None:
            length = size - offset
        ptr += offset * np.dtype(self._index.dtype).itemsize
        np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype,
                                 count=length, offset=ptr)
        if self._index.dtype != np.int64:
            np_array = np_array.astype(np.int64)
        return np_array

    @property
    def sizes(self):
        return self._index.sizes
+67 −26
Original line number Diff line number Diff line
@@ -7,52 +7,90 @@ import torch
script_dir = os.path.dirname(os.path.realpath(__file__))
sys.path.append(os.path.join(script_dir, "../../../"))

from megatron.data import indexed_dataset, FullBertTokenizer, AlbertDataset
from megatron.tokenizer import build_tokenizer
from megatron.data import indexed_dataset

def test_indexed_dataset(args):
    ds = indexed_dataset.make_dataset(args.data, args.dataset_impl)
    tokenizer = FullBertTokenizer(args.vocab, do_lower_case=True)
    tokenizer = build_tokenizer(args)
    print(len(ds.doc_idx))
    print(len(ds))
    print(ds.doc_idx[-1])
    if ds.supports_prefetch:
        # just prefetch the whole thing in test (so assume it is small)
        ds.prefetch(range(len(ds)))
    for i in range(len(ds.doc_idx)-1):
    if args.count > len(ds.doc_idx)-1:
        args.count = len(ds.doc_idx)-1

    for i in range(args.count):
        start = ds.doc_idx[i]
        end = ds.doc_idx[i+1]
        ids = ds[start:end]
        print(f"Document {i}:")
        print("--------------")
        for s in ids:
            assert len(s) > 0
            l = s.data.tolist()
            tokens = tokenizer.convert_ids_to_tokens(l)
            for t in tokens:
                if '\n' in t:
                    print("Newline in string!")
        print(i)
            text = tokenizer.detokenize(l)
            print(text)
            print("---")

def test_indexed_dataset_get(args):
    ds = indexed_dataset.make_dataset(args.data, args.dataset_impl)
    tokenizer = build_tokenizer(args)
    size = ds.sizes[0]
    print(f"size: {size}")
    full = ds.get(0)
    print(full)
    #print(tokenizer.detokenize(full.data.tolist()))
    print("---")
    end = ds.get(0, offset=size-10)
    print(end)
    #print(tokenizer.detokenize(end.data.tolist()))

    start = ds.get(0, length=10)
    print(start)
    #print(tokenizer.detokenize(start.data.tolist()))

def test_albert_dataset(args):
    # tokenizer = FullBertTokenizer(args.vocab, do_lower_case=True)
    # idataset = indexed_dataset.make_dataset(args.data, args.dataset_impl)
    # ds = AlbertDataset(idataset, tokenizer)
    ds = AlbertDataset.from_paths(args.vocab, args.data, args.dataset_impl,
                                  args.epochs, args.max_num_samples,
                                  args.masked_lm_prob, args.seq_length,
                                  args.short_seq_prob, args.seed)
    truncated = 0
    total = 0
    for s in ds:
        ids = s['text']
        tokens = ds.tokenizer.convert_ids_to_tokens(ids)
        print(tokens)
        exit()
    part = ds.get(0, offset=2, length=8)
    print(part)
    #print(tokenizer.detokenize(part.data.tolist()))

# def test_albert_dataset(args):
#     # tokenizer = FullBertTokenizer(args.vocab, do_lower_case=True)
#     # idataset = indexed_dataset.make_dataset(args.data, args.dataset_impl)
#     # ds = AlbertDataset(idataset, tokenizer)
#     ds = AlbertDataset.from_paths(args.vocab, args.data, args.dataset_impl,
#                                   args.epochs, args.max_num_samples,
#                                   args.masked_lm_prob, args.seq_length,
#                                   args.short_seq_prob, args.seed)
#     truncated = 0
#     total = 0
#     for i, s in enumerate(ds):
#         ids = s['text']
#         tokens = ds.tokenizer.convert_ids_to_tokens(ids)
#         print(tokens)
#         if i >= args.count-1:
#             exit()

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--data', type=str, help='prefix to data files')
    parser.add_argument('--vocab', type=str, help='Path to vocab.txt')
    parser.add_argument('--dataset-impl', type=str, default='infer',
                        choices=['lazy', 'cached', 'mmap', 'infer'])
    parser.add_argument('--count', type=int, default=10,
                        help='Number of samples/documents to print')

    group = parser.add_argument_group(title='tokenizer')
    group.add_argument('--tokenizer-type', type=str, required=True,
                       choices=['BertWordPieceLowerCase',
                                'GPT2BPETokenizer'],
                       help='What type of tokenizer to use.')
    group.add_argument('--vocab-file', type=str, default=None,
                       help='Path to the vocab file')
    group.add_argument('--merge-file', type=str, default=None,
                       help='Path to the BPE merge file (if necessary).')

    parser.add_argument('--epochs', type=int, default=5,
                        help='Number of epochs to plan for')
    parser.add_argument('--max-num-samples', type=int, default=None,
@@ -66,12 +104,15 @@ def main():
    parser.add_argument('--seed', type=int, default=1234,
                        help='random seed')
    args = parser.parse_args()
    args.rank = 0
    args.make_vocab_size_divisible_by = 128
    args.model_parallel_size = 1

    if args.dataset_impl == "infer":
        args.dataset_impl = indexed_dataset.infer_dataset_impl(args.data)

    test_albert_dataset(args)
#    test_indexed_dataset(args)
#    test_albert_dataset(args)
    test_indexed_dataset_get(args)

if __name__ == "__main__":
    main()
+175 −0
Original line number Diff line number Diff line
import argparse
import json
import multiprocessing
import sys
import time



import torch
try:
    import nltk
    nltk_available = True
except ImportError:
    nltk_available = False

from megatron.tokenizer import build_tokenizer
from megatron.data import indexed_dataset

# https://stackoverflow.com/questions/33139531/preserve-empty-lines-with-nltks-punkt-tokenizer
class CustomLanguageVars(nltk.tokenize.punkt.PunktLanguageVars):

    _period_context_fmt = r"""
        \S*                          # some word material
        %(SentEndChars)s             # a potential sentence ending
        \s*                       #  <-- THIS is what I changed
        (?=(?P<after_tok>
            %(NonWord)s              # either other punctuation
            |
            (?P<next_tok>\S+)     #  <-- Normally you would have \s+ here
        ))"""

class IdentitySplitter(object):
    def tokenize(self, *text):
        return text

class Encoder(object):
    def __init__(self, args):
        self.args = args

    def initializer(self):
        # Use Encoder class as a container for global data
        Encoder.tokenizer = build_tokenizer(self.args)
        if self.args.split_sentences:
            if not nltk_available:
                print("NLTK is not available to split sentences.")
                exit()
            splitter = nltk.load("tokenizers/punkt/english.pickle")
            if self.args.keep_newlines:
                # this prevents punkt from eating newlines after sentences
                Encoder.splitter = nltk.tokenize.punkt.PunktSentenceTokenizer(
                    train_text = splitter._params,
                    lang_vars = CustomLanguageVars())
            else:
                Encoder.splitter = splitter

        else:
            Encoder.splitter = IdentitySplitter()

    def encode(self, json_line):
        data = json.loads(json_line)
        ids = {}
        for key in self.args.json_keys:
            text = data[key]
            doc_ids = []
            for sentence in Encoder.splitter.tokenize(text):
                sentence_ids = Encoder.tokenizer.tokenize(sentence)
                if len(sentence_ids) > 0:
                    doc_ids.append(sentence_ids)
            if self.args.append_eod:
                doc_ids[-1].append(Encoder.tokenizer.eod)
            ids[key] = doc_ids
        return ids, len(json_line)

def get_args():
    parser = argparse.ArgumentParser()
    group = parser.add_argument_group(title='input data')
    group.add_argument('--input', type=str, required=True,
                       help='Path to input JSON')
    group.add_argument('--json-keys', nargs='+', default=['text'],
                       help='space separate listed of keys to extract from json')
    group.add_argument('--split-sentences', action='store_true',
                       help='Split documents into sentences.')
    group.add_argument('--keep-newlines', action='store_true',
                       help='Keep newlines between sentences when splitting.')

    group = parser.add_argument_group(title='tokenizer')
    group.add_argument('--tokenizer-type', type=str, required=True,
                       choices=['BertWordPieceLowerCase',
                                'GPT2BPETokenizer'],
                       help='What type of tokenizer to use.')
    group.add_argument('--vocab-file', type=str, default=None,
                       help='Path to the vocab file')
    group.add_argument('--merge-file', type=str, default=None,
                       help='Path to the BPE merge file (if necessary).')
    group.add_argument('--append-eod', action='store_true',
                       help='Append an <eod> token to the end of a document.')


    group = parser.add_argument_group(title='output data')
    group.add_argument('--output-prefix', type=str, required=True,
                       help='Path to binary output file without suffix')
    group.add_argument('--dataset-impl', type=str, default='mmap',
                       choices=['lazy', 'cached', 'mmap'])

    group = parser.add_argument_group(title='runtime')
    group.add_argument('--workers', type=int, default=1,
                       help='Number of worker processes to launch')
    group.add_argument('--log-interval', type=int, default=100,
                       help='Interval between progress updates')
    args = parser.parse_args()
    args.keep_empty = False

    if args.tokenizer_type.lower().startswith('bert'):
        if not args.split_sentences:
            print("Bert tokenizer detected, are you sure you don't want to split sentences?")

    # some default/dummy values for the tokenizer
    args.rank = 0
    args.make_vocab_size_divisible_by = 128
    args.model_parallel_size = 1

    return args

def main():
    args = get_args()
    startup_start = time.time()

    print("Opening", args.input)
    fin = open(args.input, 'r', encoding='utf-8')

    if nltk_available and args.split_sentences:
        nltk.download("punkt", quiet=True)

    encoder = Encoder(args)
    tokenizer = build_tokenizer(args)
    pool = multiprocessing.Pool(args.workers, initializer=encoder.initializer)
    encoded_docs = pool.imap(encoder.encode, fin, 25)
    #encoded_docs = map(encoder.encode, fin)

    print(f"Vocab size: {tokenizer.vocab_size}")
    print(f"Output prefix: {args.output_prefix}")
    output_bin_files = {}
    output_idx_files = {}
    builders = {}
    for key in args.json_keys:
        output_bin_files[key] = "{}_{}.bin".format(args.output_prefix, key)
        output_idx_files[key] = "{}_{}.idx".format(args.output_prefix, key)
        builders[key] = indexed_dataset.make_builder(output_bin_files[key],
                                               impl=args.dataset_impl,
                                               vocab_size=tokenizer.vocab_size)

    startup_end = time.time()
    proc_start = time.time()
    total_bytes_processed = 0
    print("Time to startup:", startup_end - startup_start)

    for i, (doc, bytes_processed) in enumerate(encoded_docs, start=1):
        total_bytes_processed += bytes_processed
        for key, sentences in doc.items():
            for sentence in sentences:
                builders[key].add_item(torch.IntTensor(sentence))
            builders[key].end_document()
        if i % args.log_interval == 0:
            current = time.time()
            elapsed = current - proc_start
            mbs = total_bytes_processed/elapsed/1024/1024
            print(f"Processed {i} documents",
                  f"({i/elapsed} docs/s, {mbs} MB/s).",
                  file=sys.stderr)

    for key in args.json_keys:
        builders[key].finalize(output_idx_files[key])

if __name__ == '__main__':
    main()