Loading megatron/data/__init__.py 0 → 100644 +2 −0 Original line number Diff line number Diff line from . import indexed_dataset from .bert_tokenization import FullTokenizer as FullBertTokenizer megatron/data/indexed_dataset.py 0 → 100644 +463 −0 Original line number Diff line number Diff line # Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. # copied from fairseq/fairseq/data/indexed_dataset.py # Removed IndexedRawTextDataset since it relied on Fairseq dictionary # other slight modifications to remove fairseq dependencies from functools import lru_cache import os import shutil import struct import numpy as np import torch def __best_fitting_dtype(vocab_size=None): if vocab_size is not None and vocab_size < 65500: return np.uint16 else: return np.int32 def get_available_dataset_impl(): return ['lazy', 'cached', 'mmap'] def infer_dataset_impl(path): if IndexedDataset.exists(path): with open(index_file_path(path), 'rb') as f: magic = f.read(8) if magic == IndexedDataset._HDR_MAGIC: return 'cached' elif magic == MMapIndexedDataset.Index._HDR_MAGIC[:8]: return 'mmap' else: return None else: return None def make_builder(out_file, impl, vocab_size=None): if impl == 'mmap': return MMapIndexedDatasetBuilder(out_file, dtype=__best_fitting_dtype(vocab_size)) else: return IndexedDatasetBuilder(out_file) def make_dataset(path, impl, fix_lua_indexing=False): if impl == 'lazy' and IndexedDataset.exists(path): return IndexedDataset(path, fix_lua_indexing=fix_lua_indexing) elif impl == 'cached' and IndexedDataset.exists(path): return IndexedCachedDataset(path, fix_lua_indexing=fix_lua_indexing) elif impl == 'mmap' and MMapIndexedDataset.exists(path): return MMapIndexedDataset(path) return None def dataset_exists(path, impl): if impl == 'mmap': return MMapIndexedDataset.exists(path) else: return IndexedDataset.exists(path) def read_longs(f, n): a = np.empty(n, dtype=np.int64) f.readinto(a) return a def write_longs(f, a): f.write(np.array(a, dtype=np.int64)) dtypes = { 1: np.uint8, 2: np.int8, 3: np.int16, 4: np.int32, 5: np.int64, 6: np.float, 7: np.double, 8: np.uint16 } def code(dtype): for k in dtypes.keys(): if dtypes[k] == dtype: return k raise ValueError(dtype) def index_file_path(prefix_path): return prefix_path + '.idx' def data_file_path(prefix_path): return prefix_path + '.bin' class IndexedDataset(torch.utils.data.Dataset): """Loader for IndexedDataset""" _HDR_MAGIC = b'TNTIDX\x00\x00' def __init__(self, path, fix_lua_indexing=False): super().__init__() self.path = path self.fix_lua_indexing = fix_lua_indexing self.data_file = None self.read_index(path) def read_index(self, path): with open(index_file_path(path), 'rb') as f: magic = f.read(8) assert magic == self._HDR_MAGIC, ( 'Index file doesn\'t match expected format. ' 'Make sure that --dataset-impl is configured properly.' ) version = f.read(8) assert struct.unpack('<Q', version) == (1,) code, self.element_size = struct.unpack('<QQ', f.read(16)) self.dtype = dtypes[code] self._len, self.s = struct.unpack('<QQ', f.read(16)) self.dim_offsets = read_longs(f, self._len + 1) self.data_offsets = read_longs(f, self._len + 1) self.sizes = read_longs(f, self.s) def read_data(self, path): self.data_file = open(data_file_path(path), 'rb', buffering=0) def check_index(self, i): if i < 0 or i >= self._len: raise IndexError('index out of range') def __del__(self): if self.data_file: self.data_file.close() @lru_cache(maxsize=8) def __getitem__(self, i): if not self.data_file: self.read_data(self.path) self.check_index(i) tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]] a = np.empty(tensor_size, dtype=self.dtype) self.data_file.seek(self.data_offsets[i] * self.element_size) self.data_file.readinto(a) item = torch.from_numpy(a).long() if self.fix_lua_indexing: item -= 1 # subtract 1 for 0-based indexing return item def __len__(self): return self._len def num_tokens(self, index): return self.sizes[index] def size(self, index): return self.sizes[index] @staticmethod def exists(path): return ( os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path)) ) @property def supports_prefetch(self): return False # avoid prefetching to save memory class IndexedCachedDataset(IndexedDataset): def __init__(self, path, fix_lua_indexing=False): super().__init__(path, fix_lua_indexing=fix_lua_indexing) self.cache = None self.cache_index = {} @property def supports_prefetch(self): return True def prefetch(self, indices): if all(i in self.cache_index for i in indices): return if not self.data_file: self.read_data(self.path) indices = sorted(set(indices)) total_size = 0 for i in indices: total_size += self.data_offsets[i + 1] - self.data_offsets[i] self.cache = np.empty(total_size, dtype=self.dtype) ptx = 0 self.cache_index.clear() for i in indices: self.cache_index[i] = ptx size = self.data_offsets[i + 1] - self.data_offsets[i] a = self.cache[ptx: ptx + size] self.data_file.seek(self.data_offsets[i] * self.element_size) self.data_file.readinto(a) ptx += size if self.data_file: # close and delete data file after prefetch so we can pickle self.data_file.close() self.data_file = None @lru_cache(maxsize=8) def __getitem__(self, i): self.check_index(i) tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]] a = np.empty(tensor_size, dtype=self.dtype) ptx = self.cache_index[i] np.copyto(a, self.cache[ptx: ptx + a.size]) item = torch.from_numpy(a).long() if self.fix_lua_indexing: item -= 1 # subtract 1 for 0-based indexing return item class IndexedDatasetBuilder(object): element_sizes = { np.uint8: 1, np.int8: 1, np.int16: 2, np.int32: 4, np.int64: 8, np.float: 4, np.double: 8 } def __init__(self, out_file, dtype=np.int32): self.out_file = open(out_file, 'wb') self.dtype = dtype self.data_offsets = [0] self.dim_offsets = [0] self.sizes = [] self.element_size = self.element_sizes[self.dtype] def add_item(self, tensor): # +1 for Lua compatibility bytes = self.out_file.write(np.array(tensor.numpy() + 1, dtype=self.dtype)) self.data_offsets.append(self.data_offsets[-1] + bytes / self.element_size) for s in tensor.size(): self.sizes.append(s) self.dim_offsets.append(self.dim_offsets[-1] + len(tensor.size())) def merge_file_(self, another_file): index = IndexedDataset(another_file) assert index.dtype == self.dtype begin = self.data_offsets[-1] for offset in index.data_offsets[1:]: self.data_offsets.append(begin + offset) self.sizes.extend(index.sizes) begin = self.dim_offsets[-1] for dim_offset in index.dim_offsets[1:]: self.dim_offsets.append(begin + dim_offset) with open(data_file_path(another_file), 'rb') as f: while True: data = f.read(1024) if data: self.out_file.write(data) else: break def finalize(self, index_file): self.out_file.close() index = open(index_file, 'wb') index.write(b'TNTIDX\x00\x00') index.write(struct.pack('<Q', 1)) index.write(struct.pack('<QQ', code(self.dtype), self.element_size)) index.write(struct.pack('<QQ', len(self.data_offsets) - 1, len(self.sizes))) write_longs(index, self.dim_offsets) write_longs(index, self.data_offsets) write_longs(index, self.sizes) index.close() def _warmup_mmap_file(path): with open(path, 'rb') as stream: while stream.read(100 * 1024 * 1024): pass class MMapIndexedDataset(torch.utils.data.Dataset): class Index(object): _HDR_MAGIC = b'MMIDIDX\x00\x00' @classmethod def writer(cls, path, dtype): class _Writer(object): def __enter__(self): self._file = open(path, 'wb') self._file.write(cls._HDR_MAGIC) self._file.write(struct.pack('<Q', 1)) self._file.write(struct.pack('<B', code(dtype))) return self @staticmethod def _get_pointers(sizes): dtype_size = dtype().itemsize address = 0 pointers = [] for size in sizes: pointers.append(address) address += size * dtype_size return pointers def write(self, sizes): pointers = self._get_pointers(sizes) self._file.write(struct.pack('<Q', len(sizes))) sizes = np.array(sizes, dtype=np.int32) self._file.write(sizes.tobytes(order='C')) del sizes pointers = np.array(pointers, dtype=np.int64) self._file.write(pointers.tobytes(order='C')) del pointers def __exit__(self, exc_type, exc_val, exc_tb): self._file.close() return _Writer() def __init__(self, path): with open(path, 'rb') as stream: magic_test = stream.read(9) assert self._HDR_MAGIC == magic_test, ( 'Index file doesn\'t match expected format. ' 'Make sure that --dataset-impl is configured properly.' ) version = struct.unpack('<Q', stream.read(8)) assert (1,) == version dtype_code, = struct.unpack('<B', stream.read(1)) self._dtype = dtypes[dtype_code] self._dtype_size = self._dtype().itemsize self._len = struct.unpack('<Q', stream.read(8))[0] offset = stream.tell() _warmup_mmap_file(path) self._bin_buffer_mmap = np.memmap(path, mode='r', order='C') self._bin_buffer = memoryview(self._bin_buffer_mmap) self._sizes = np.frombuffer(self._bin_buffer, dtype=np.int32, count=self._len, offset=offset) self._pointers = np.frombuffer(self._bin_buffer, dtype=np.int64, count=self._len, offset=offset + self._sizes.nbytes) def __del__(self): self._bin_buffer_mmap._mmap.close() del self._bin_buffer_mmap @property def dtype(self): return self._dtype @property def sizes(self): return self._sizes @lru_cache(maxsize=8) def __getitem__(self, i): return self._pointers[i], self._sizes[i] def __len__(self): return self._len def __init__(self, path): super().__init__() self._path = None self._index = None self._bin_buffer = None self._do_init(path) def __getstate__(self): return self._path def __setstate__(self, state): self._do_init(state) def _do_init(self, path): self._path = path self._index = self.Index(index_file_path(self._path)) _warmup_mmap_file(data_file_path(self._path)) self._bin_buffer_mmap = np.memmap(data_file_path(self._path), mode='r', order='C') self._bin_buffer = memoryview(self._bin_buffer_mmap) def __del__(self): self._bin_buffer_mmap._mmap.close() del self._bin_buffer_mmap del self._index def __len__(self): return len(self._index) @lru_cache(maxsize=8) def __getitem__(self, i): ptr, size = self._index[i] np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype, count=size, offset=ptr) if self._index.dtype != np.int64: np_array = np_array.astype(np.int64) return torch.from_numpy(np_array) @property def sizes(self): return self._index.sizes @property def supports_prefetch(self): return False @staticmethod def exists(path): return ( os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path)) ) class MMapIndexedDatasetBuilder(object): def __init__(self, out_file, dtype=np.int64): self._data_file = open(out_file, 'wb') self._dtype = dtype self._sizes = [] def add_item(self, tensor): np_array = np.array(tensor.numpy(), dtype=self._dtype) self._data_file.write(np_array.tobytes(order='C')) self._sizes.append(np_array.size) def merge_file_(self, another_file): # Concatenate index index = MMapIndexedDataset.Index(index_file_path(another_file)) assert index.dtype == self._dtype for size in index.sizes: self._sizes.append(size) # Concatenate data with open(data_file_path(another_file), 'rb') as f: shutil.copyfileobj(f, self._data_file) def finalize(self, index_file): self._data_file.close() with MMapIndexedDataset.Index.writer(index_file, self._dtype) as index: index.write(self._sizes) megatron/data/preprocess_data.py +106 −22 Original line number Diff line number Diff line import argparse import json import multiprocessing import nltk nltk.download('punkt') import sys import time import torch from bert_tokenization import FullTokenizer import indexed_dataset class CustomLanguageVars(nltk.tokenize.punkt.PunktLanguageVars): def document_generator_provider(input_file): with open(input_file, 'r') as ifile: for document in ifile: data = json.loads(document) text = data['text'] sentences = [] for line in text.split('\n'): if line != '\n': sentences.extend(nltk.tokenize.sent_tokenize(line)) yield sentences _period_context_fmt = r""" \S* # some word material %(SentEndChars)s # a potential sentence ending \s* # <-- THIS is what I changed (?=(?P<after_tok> %(NonWord)s # either other punctuation | (?P<next_tok>\S+) # <-- Normally you would have \s+ here ))""" class Encoder(object): def __init__(self, args): self.args = args if __name__ == '__main__': def initializer(self): # Use Encoder class as a container for global data Encoder.tokenizer = FullTokenizer(self.args.vocab, do_lower_case=True) spliter = nltk.load("tokenizers/punkt/english.pickle") if self.args.keep_newlines: # this prevents punkt from eating newlines after sentences Encoder.spliter = nltk.tokenize.punkt.PunktSentenceTokenizer( train_text = spliter._params, lang_vars = CustomLanguageVars()) else: Encoder.splitter = spliter def encode(self, json_line): text = json.loads(json_line)[self.args.json_key] doc_ids = [] for sentence in Encoder.splitter.tokenize(text): tokens = Encoder.tokenizer.tokenize(sentence) ids = Encoder.tokenizer.convert_tokens_to_ids(tokens) doc_ids.append(ids) doc_ids.append([]) return doc_ids, len(json_line) def main(): parser = argparse.ArgumentParser() parser.add_argument('--input', type=str, help='Path to input JSON') parser.add_argument('--vocab', type=str, help='Path to vocab.txt') parser.add_argument('--json-key', type=str, default='text', help='Key to extract from json') parser.add_argument('--output-prefix', type=str, help='Path to binary output file without suffix') parser.add_argument('--workers', type=int, default=20, help='Number of worker processes to launch') parser.add_argument('--log-interval', type=int, default=100, help='Interval between progress updates') parser.add_argument('--keep-newlines', action='store_true', help='Keep newlines between sentences.') parser.add_argument('--dataset-impl', type=str, default='mmap', choices=['lazy', 'cached', 'mmap']) args = parser.parse_args() args.keep_empty = False print('processing data ...') startup_start = time.time() input_file = '/raid/mshoeybi/data/albert/sample/samples_11.json' vocab_file = '/raid/mshoeybi/data/albert/bert_vocab/vocab.txt' print("Opening", args.input) fin = open(args.input, 'r', encoding='utf-8') tokenizer = FullTokenizer(vocab_file, do_lower_case=True) document_generator = document_generator_provider(input_file) for sentences in document_generator: for sentence in sentences: tokens = tokenizer.tokenize(sentence) vocab_size = 1 nltk.download("punkt", quiet=True) encoder = Encoder(args) tokenizer = FullTokenizer(args.vocab, do_lower_case=True) pool = multiprocessing.Pool(args.workers, initializer=encoder.initializer) encoded_docs = pool.imap(encoder.encode, fin, 25) output_bin_file = "{}.bin".format(args.output_prefix) output_idx_file = "{}.idx".format(args.output_prefix) ds = indexed_dataset.make_builder(output_bin_file, impl=args.dataset_impl, vocab_size=vocab_size) startup_end = time.time() proc_start = time.time() total_bytes_processed = 0 print("Time to startup:", startup_end - startup_start) for i, (doc, bytes_processed) in enumerate(encoded_docs, start=1): total_bytes_processed += bytes_processed for sentence in doc: print(sentence) print(tokens) print(tokenizer.convert_ids_to_tokens(sentence)) ds.add_item(torch.IntTensor(sentence)) if i % args.log_interval == 0: current = time.time() elapsed = current - proc_start mbs = total_bytes_processed/elapsed/1024/1024 print(f'Processed {i} documents', f"({i/elapsed} docs/s, {mbs} MB/s).", file=sys.stderr) ds.finalize(output_idx_file) if __name__ == '__main__': main() # print('processing data ...') # input_file = '/raid/mshoeybi/data/albert/sample/samples_11.json' # vocab_file = '/raid/mshoeybi/data/albert/bert_vocab/vocab.txt' # tokenizer = FullTokenizer(vocab_file, do_lower_case=True) # document_generator = document_generator_provider(input_file) # for sentences in document_generator: # for sentence in sentences: # tokens = tokenizer.tokenize(sentence) # print(sentence) # print(tokens) megatron/data/test/test_indexed_dataset.py 0 → 100644 +31 −0 Original line number Diff line number Diff line import argparse import os import sys import torch script_dir = os.path.dirname(os.path.realpath(__file__)) sys.path.append(os.path.join(script_dir, "../../../")) from megatron.data import indexed_dataset, FullBertTokenizer def test_indexed_dataset(args): ds_impl = indexed_dataset.infer_dataset_impl(args.data) ds = indexed_dataset.make_dataset(args.data, ds_impl) tokenizer = FullBertTokenizer(args.vocab, do_lower_case=True) for sample in ds: print(sample) print(sample.data.tolist()) print(tokenizer.convert_ids_to_tokens(sample.data.tolist())) print("---") def main(): parser = argparse.ArgumentParser() parser.add_argument('--data', type=str, help='prefix to data files') parser.add_argument('--vocab', type=str, help='Path to vocab.txt') args = parser.parse_args() test_indexed_dataset(args) if __name__ == "__main__": main() megatron/data/test/test_preprocess_data.sh 0 → 100755 +8 −0 Original line number Diff line number Diff line #!/bin/bash python ../preprocess_data.py \ --input test_samples.json \ --vocab vocab.txt \ --output-prefix test_samples \ --workers 1 \ --log-interval 2 Loading
megatron/data/__init__.py 0 → 100644 +2 −0 Original line number Diff line number Diff line from . import indexed_dataset from .bert_tokenization import FullTokenizer as FullBertTokenizer
megatron/data/indexed_dataset.py 0 → 100644 +463 −0 Original line number Diff line number Diff line # Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. # copied from fairseq/fairseq/data/indexed_dataset.py # Removed IndexedRawTextDataset since it relied on Fairseq dictionary # other slight modifications to remove fairseq dependencies from functools import lru_cache import os import shutil import struct import numpy as np import torch def __best_fitting_dtype(vocab_size=None): if vocab_size is not None and vocab_size < 65500: return np.uint16 else: return np.int32 def get_available_dataset_impl(): return ['lazy', 'cached', 'mmap'] def infer_dataset_impl(path): if IndexedDataset.exists(path): with open(index_file_path(path), 'rb') as f: magic = f.read(8) if magic == IndexedDataset._HDR_MAGIC: return 'cached' elif magic == MMapIndexedDataset.Index._HDR_MAGIC[:8]: return 'mmap' else: return None else: return None def make_builder(out_file, impl, vocab_size=None): if impl == 'mmap': return MMapIndexedDatasetBuilder(out_file, dtype=__best_fitting_dtype(vocab_size)) else: return IndexedDatasetBuilder(out_file) def make_dataset(path, impl, fix_lua_indexing=False): if impl == 'lazy' and IndexedDataset.exists(path): return IndexedDataset(path, fix_lua_indexing=fix_lua_indexing) elif impl == 'cached' and IndexedDataset.exists(path): return IndexedCachedDataset(path, fix_lua_indexing=fix_lua_indexing) elif impl == 'mmap' and MMapIndexedDataset.exists(path): return MMapIndexedDataset(path) return None def dataset_exists(path, impl): if impl == 'mmap': return MMapIndexedDataset.exists(path) else: return IndexedDataset.exists(path) def read_longs(f, n): a = np.empty(n, dtype=np.int64) f.readinto(a) return a def write_longs(f, a): f.write(np.array(a, dtype=np.int64)) dtypes = { 1: np.uint8, 2: np.int8, 3: np.int16, 4: np.int32, 5: np.int64, 6: np.float, 7: np.double, 8: np.uint16 } def code(dtype): for k in dtypes.keys(): if dtypes[k] == dtype: return k raise ValueError(dtype) def index_file_path(prefix_path): return prefix_path + '.idx' def data_file_path(prefix_path): return prefix_path + '.bin' class IndexedDataset(torch.utils.data.Dataset): """Loader for IndexedDataset""" _HDR_MAGIC = b'TNTIDX\x00\x00' def __init__(self, path, fix_lua_indexing=False): super().__init__() self.path = path self.fix_lua_indexing = fix_lua_indexing self.data_file = None self.read_index(path) def read_index(self, path): with open(index_file_path(path), 'rb') as f: magic = f.read(8) assert magic == self._HDR_MAGIC, ( 'Index file doesn\'t match expected format. ' 'Make sure that --dataset-impl is configured properly.' ) version = f.read(8) assert struct.unpack('<Q', version) == (1,) code, self.element_size = struct.unpack('<QQ', f.read(16)) self.dtype = dtypes[code] self._len, self.s = struct.unpack('<QQ', f.read(16)) self.dim_offsets = read_longs(f, self._len + 1) self.data_offsets = read_longs(f, self._len + 1) self.sizes = read_longs(f, self.s) def read_data(self, path): self.data_file = open(data_file_path(path), 'rb', buffering=0) def check_index(self, i): if i < 0 or i >= self._len: raise IndexError('index out of range') def __del__(self): if self.data_file: self.data_file.close() @lru_cache(maxsize=8) def __getitem__(self, i): if not self.data_file: self.read_data(self.path) self.check_index(i) tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]] a = np.empty(tensor_size, dtype=self.dtype) self.data_file.seek(self.data_offsets[i] * self.element_size) self.data_file.readinto(a) item = torch.from_numpy(a).long() if self.fix_lua_indexing: item -= 1 # subtract 1 for 0-based indexing return item def __len__(self): return self._len def num_tokens(self, index): return self.sizes[index] def size(self, index): return self.sizes[index] @staticmethod def exists(path): return ( os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path)) ) @property def supports_prefetch(self): return False # avoid prefetching to save memory class IndexedCachedDataset(IndexedDataset): def __init__(self, path, fix_lua_indexing=False): super().__init__(path, fix_lua_indexing=fix_lua_indexing) self.cache = None self.cache_index = {} @property def supports_prefetch(self): return True def prefetch(self, indices): if all(i in self.cache_index for i in indices): return if not self.data_file: self.read_data(self.path) indices = sorted(set(indices)) total_size = 0 for i in indices: total_size += self.data_offsets[i + 1] - self.data_offsets[i] self.cache = np.empty(total_size, dtype=self.dtype) ptx = 0 self.cache_index.clear() for i in indices: self.cache_index[i] = ptx size = self.data_offsets[i + 1] - self.data_offsets[i] a = self.cache[ptx: ptx + size] self.data_file.seek(self.data_offsets[i] * self.element_size) self.data_file.readinto(a) ptx += size if self.data_file: # close and delete data file after prefetch so we can pickle self.data_file.close() self.data_file = None @lru_cache(maxsize=8) def __getitem__(self, i): self.check_index(i) tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]] a = np.empty(tensor_size, dtype=self.dtype) ptx = self.cache_index[i] np.copyto(a, self.cache[ptx: ptx + a.size]) item = torch.from_numpy(a).long() if self.fix_lua_indexing: item -= 1 # subtract 1 for 0-based indexing return item class IndexedDatasetBuilder(object): element_sizes = { np.uint8: 1, np.int8: 1, np.int16: 2, np.int32: 4, np.int64: 8, np.float: 4, np.double: 8 } def __init__(self, out_file, dtype=np.int32): self.out_file = open(out_file, 'wb') self.dtype = dtype self.data_offsets = [0] self.dim_offsets = [0] self.sizes = [] self.element_size = self.element_sizes[self.dtype] def add_item(self, tensor): # +1 for Lua compatibility bytes = self.out_file.write(np.array(tensor.numpy() + 1, dtype=self.dtype)) self.data_offsets.append(self.data_offsets[-1] + bytes / self.element_size) for s in tensor.size(): self.sizes.append(s) self.dim_offsets.append(self.dim_offsets[-1] + len(tensor.size())) def merge_file_(self, another_file): index = IndexedDataset(another_file) assert index.dtype == self.dtype begin = self.data_offsets[-1] for offset in index.data_offsets[1:]: self.data_offsets.append(begin + offset) self.sizes.extend(index.sizes) begin = self.dim_offsets[-1] for dim_offset in index.dim_offsets[1:]: self.dim_offsets.append(begin + dim_offset) with open(data_file_path(another_file), 'rb') as f: while True: data = f.read(1024) if data: self.out_file.write(data) else: break def finalize(self, index_file): self.out_file.close() index = open(index_file, 'wb') index.write(b'TNTIDX\x00\x00') index.write(struct.pack('<Q', 1)) index.write(struct.pack('<QQ', code(self.dtype), self.element_size)) index.write(struct.pack('<QQ', len(self.data_offsets) - 1, len(self.sizes))) write_longs(index, self.dim_offsets) write_longs(index, self.data_offsets) write_longs(index, self.sizes) index.close() def _warmup_mmap_file(path): with open(path, 'rb') as stream: while stream.read(100 * 1024 * 1024): pass class MMapIndexedDataset(torch.utils.data.Dataset): class Index(object): _HDR_MAGIC = b'MMIDIDX\x00\x00' @classmethod def writer(cls, path, dtype): class _Writer(object): def __enter__(self): self._file = open(path, 'wb') self._file.write(cls._HDR_MAGIC) self._file.write(struct.pack('<Q', 1)) self._file.write(struct.pack('<B', code(dtype))) return self @staticmethod def _get_pointers(sizes): dtype_size = dtype().itemsize address = 0 pointers = [] for size in sizes: pointers.append(address) address += size * dtype_size return pointers def write(self, sizes): pointers = self._get_pointers(sizes) self._file.write(struct.pack('<Q', len(sizes))) sizes = np.array(sizes, dtype=np.int32) self._file.write(sizes.tobytes(order='C')) del sizes pointers = np.array(pointers, dtype=np.int64) self._file.write(pointers.tobytes(order='C')) del pointers def __exit__(self, exc_type, exc_val, exc_tb): self._file.close() return _Writer() def __init__(self, path): with open(path, 'rb') as stream: magic_test = stream.read(9) assert self._HDR_MAGIC == magic_test, ( 'Index file doesn\'t match expected format. ' 'Make sure that --dataset-impl is configured properly.' ) version = struct.unpack('<Q', stream.read(8)) assert (1,) == version dtype_code, = struct.unpack('<B', stream.read(1)) self._dtype = dtypes[dtype_code] self._dtype_size = self._dtype().itemsize self._len = struct.unpack('<Q', stream.read(8))[0] offset = stream.tell() _warmup_mmap_file(path) self._bin_buffer_mmap = np.memmap(path, mode='r', order='C') self._bin_buffer = memoryview(self._bin_buffer_mmap) self._sizes = np.frombuffer(self._bin_buffer, dtype=np.int32, count=self._len, offset=offset) self._pointers = np.frombuffer(self._bin_buffer, dtype=np.int64, count=self._len, offset=offset + self._sizes.nbytes) def __del__(self): self._bin_buffer_mmap._mmap.close() del self._bin_buffer_mmap @property def dtype(self): return self._dtype @property def sizes(self): return self._sizes @lru_cache(maxsize=8) def __getitem__(self, i): return self._pointers[i], self._sizes[i] def __len__(self): return self._len def __init__(self, path): super().__init__() self._path = None self._index = None self._bin_buffer = None self._do_init(path) def __getstate__(self): return self._path def __setstate__(self, state): self._do_init(state) def _do_init(self, path): self._path = path self._index = self.Index(index_file_path(self._path)) _warmup_mmap_file(data_file_path(self._path)) self._bin_buffer_mmap = np.memmap(data_file_path(self._path), mode='r', order='C') self._bin_buffer = memoryview(self._bin_buffer_mmap) def __del__(self): self._bin_buffer_mmap._mmap.close() del self._bin_buffer_mmap del self._index def __len__(self): return len(self._index) @lru_cache(maxsize=8) def __getitem__(self, i): ptr, size = self._index[i] np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype, count=size, offset=ptr) if self._index.dtype != np.int64: np_array = np_array.astype(np.int64) return torch.from_numpy(np_array) @property def sizes(self): return self._index.sizes @property def supports_prefetch(self): return False @staticmethod def exists(path): return ( os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path)) ) class MMapIndexedDatasetBuilder(object): def __init__(self, out_file, dtype=np.int64): self._data_file = open(out_file, 'wb') self._dtype = dtype self._sizes = [] def add_item(self, tensor): np_array = np.array(tensor.numpy(), dtype=self._dtype) self._data_file.write(np_array.tobytes(order='C')) self._sizes.append(np_array.size) def merge_file_(self, another_file): # Concatenate index index = MMapIndexedDataset.Index(index_file_path(another_file)) assert index.dtype == self._dtype for size in index.sizes: self._sizes.append(size) # Concatenate data with open(data_file_path(another_file), 'rb') as f: shutil.copyfileobj(f, self._data_file) def finalize(self, index_file): self._data_file.close() with MMapIndexedDataset.Index.writer(index_file, self._dtype) as index: index.write(self._sizes)
megatron/data/preprocess_data.py +106 −22 Original line number Diff line number Diff line import argparse import json import multiprocessing import nltk nltk.download('punkt') import sys import time import torch from bert_tokenization import FullTokenizer import indexed_dataset class CustomLanguageVars(nltk.tokenize.punkt.PunktLanguageVars): def document_generator_provider(input_file): with open(input_file, 'r') as ifile: for document in ifile: data = json.loads(document) text = data['text'] sentences = [] for line in text.split('\n'): if line != '\n': sentences.extend(nltk.tokenize.sent_tokenize(line)) yield sentences _period_context_fmt = r""" \S* # some word material %(SentEndChars)s # a potential sentence ending \s* # <-- THIS is what I changed (?=(?P<after_tok> %(NonWord)s # either other punctuation | (?P<next_tok>\S+) # <-- Normally you would have \s+ here ))""" class Encoder(object): def __init__(self, args): self.args = args if __name__ == '__main__': def initializer(self): # Use Encoder class as a container for global data Encoder.tokenizer = FullTokenizer(self.args.vocab, do_lower_case=True) spliter = nltk.load("tokenizers/punkt/english.pickle") if self.args.keep_newlines: # this prevents punkt from eating newlines after sentences Encoder.spliter = nltk.tokenize.punkt.PunktSentenceTokenizer( train_text = spliter._params, lang_vars = CustomLanguageVars()) else: Encoder.splitter = spliter def encode(self, json_line): text = json.loads(json_line)[self.args.json_key] doc_ids = [] for sentence in Encoder.splitter.tokenize(text): tokens = Encoder.tokenizer.tokenize(sentence) ids = Encoder.tokenizer.convert_tokens_to_ids(tokens) doc_ids.append(ids) doc_ids.append([]) return doc_ids, len(json_line) def main(): parser = argparse.ArgumentParser() parser.add_argument('--input', type=str, help='Path to input JSON') parser.add_argument('--vocab', type=str, help='Path to vocab.txt') parser.add_argument('--json-key', type=str, default='text', help='Key to extract from json') parser.add_argument('--output-prefix', type=str, help='Path to binary output file without suffix') parser.add_argument('--workers', type=int, default=20, help='Number of worker processes to launch') parser.add_argument('--log-interval', type=int, default=100, help='Interval between progress updates') parser.add_argument('--keep-newlines', action='store_true', help='Keep newlines between sentences.') parser.add_argument('--dataset-impl', type=str, default='mmap', choices=['lazy', 'cached', 'mmap']) args = parser.parse_args() args.keep_empty = False print('processing data ...') startup_start = time.time() input_file = '/raid/mshoeybi/data/albert/sample/samples_11.json' vocab_file = '/raid/mshoeybi/data/albert/bert_vocab/vocab.txt' print("Opening", args.input) fin = open(args.input, 'r', encoding='utf-8') tokenizer = FullTokenizer(vocab_file, do_lower_case=True) document_generator = document_generator_provider(input_file) for sentences in document_generator: for sentence in sentences: tokens = tokenizer.tokenize(sentence) vocab_size = 1 nltk.download("punkt", quiet=True) encoder = Encoder(args) tokenizer = FullTokenizer(args.vocab, do_lower_case=True) pool = multiprocessing.Pool(args.workers, initializer=encoder.initializer) encoded_docs = pool.imap(encoder.encode, fin, 25) output_bin_file = "{}.bin".format(args.output_prefix) output_idx_file = "{}.idx".format(args.output_prefix) ds = indexed_dataset.make_builder(output_bin_file, impl=args.dataset_impl, vocab_size=vocab_size) startup_end = time.time() proc_start = time.time() total_bytes_processed = 0 print("Time to startup:", startup_end - startup_start) for i, (doc, bytes_processed) in enumerate(encoded_docs, start=1): total_bytes_processed += bytes_processed for sentence in doc: print(sentence) print(tokens) print(tokenizer.convert_ids_to_tokens(sentence)) ds.add_item(torch.IntTensor(sentence)) if i % args.log_interval == 0: current = time.time() elapsed = current - proc_start mbs = total_bytes_processed/elapsed/1024/1024 print(f'Processed {i} documents', f"({i/elapsed} docs/s, {mbs} MB/s).", file=sys.stderr) ds.finalize(output_idx_file) if __name__ == '__main__': main() # print('processing data ...') # input_file = '/raid/mshoeybi/data/albert/sample/samples_11.json' # vocab_file = '/raid/mshoeybi/data/albert/bert_vocab/vocab.txt' # tokenizer = FullTokenizer(vocab_file, do_lower_case=True) # document_generator = document_generator_provider(input_file) # for sentences in document_generator: # for sentence in sentences: # tokens = tokenizer.tokenize(sentence) # print(sentence) # print(tokens)
megatron/data/test/test_indexed_dataset.py 0 → 100644 +31 −0 Original line number Diff line number Diff line import argparse import os import sys import torch script_dir = os.path.dirname(os.path.realpath(__file__)) sys.path.append(os.path.join(script_dir, "../../../")) from megatron.data import indexed_dataset, FullBertTokenizer def test_indexed_dataset(args): ds_impl = indexed_dataset.infer_dataset_impl(args.data) ds = indexed_dataset.make_dataset(args.data, ds_impl) tokenizer = FullBertTokenizer(args.vocab, do_lower_case=True) for sample in ds: print(sample) print(sample.data.tolist()) print(tokenizer.convert_ids_to_tokens(sample.data.tolist())) print("---") def main(): parser = argparse.ArgumentParser() parser.add_argument('--data', type=str, help='prefix to data files') parser.add_argument('--vocab', type=str, help='Path to vocab.txt') args = parser.parse_args() test_indexed_dataset(args) if __name__ == "__main__": main()
megatron/data/test/test_preprocess_data.sh 0 → 100755 +8 −0 Original line number Diff line number Diff line #!/bin/bash python ../preprocess_data.py \ --input test_samples.json \ --vocab vocab.txt \ --output-prefix test_samples \ --workers 1 \ --log-interval 2