Commit 66719e97 authored by Raul Puri's avatar Raul Puri Committed by Mohammad Shoeybi
Browse files

Faster dataloader merge (#1)

* threaded tf_dl+presplit sentences+shuffled dataset with resume

* elaborate in readme
parent fb4cbdc2
Loading
Loading
Loading
Loading
+2 −1
Original line number Diff line number Diff line
@@ -33,6 +33,7 @@ python pretrain_bert.py \
    --tokenizer-model-type bert-large-uncased \
    --vocab-size 30522 \
    --train-data wikipedia \
    --presplit-sentences \
    --loose-json \
    --text-key text \
    --split 1000,1,1 \
@@ -79,7 +80,7 @@ This script runs BERT pretraining with a `sentencepiece` tokenizer. If no senten
# Collecting Wikipedia Training Data
We recommend following the wikipedia data extraction process specified by google research: "the recommended pre-processing is to download [the latest dump](https://dumps.wikimedia.org/enwiki/latest/enwiki-latest-pages-articles.xml.bz2), extract the text with [WikiExtractor.py](https://github.com/attardi/wikiextractor), and then apply any necessary cleanup to convert it into plain text." 

We recommend using the `--json` argument when using WikiExtractor, which will dump the wikipedia data into loose json format (one json per line), making it more manageable and readily consumable by our codebase.
We recommend using the `--json` argument when using WikiExtractor, which will dump the wikipedia data into loose json format (one json per line), making it more manageable and readily consumable by our codebase. We recommend further preprocessing this json dataset by preprocessing the dataset with nltk punctuation standardization, and presplitting each document into newline separated sentences. This can be done with the provided script `./scripts/presplit_sentences_json.py` and will allow for faster data processing during training time. Pretraining with presplit data should be run with the `--presplit-sentences` flag as shown above.

Once the json dataset is ready make sure to set the path in line 27 of `data_utils/corpora.py`.

+6 −0
Original line number Diff line number Diff line
@@ -184,6 +184,9 @@ def add_data_args(parser):

    group = parser.add_argument_group('data', 'data configurations')

    group.add_argument('--shuffle', action='store_true',
                       help='Shuffle data. Shuffling is deterministic '
                       'based on seed and current epoch.')
    group.add_argument('--train-data', nargs='+', required=True,
                       help='Filename (or whitespace separated filenames) '
                       'for training.')
@@ -208,6 +211,9 @@ def add_data_args(parser):
                       help='Use loose json (one json-formatted string per '
                       'newline), instead of tight json (data file is one '
                       'json string)')
    group.add_argument('--presplit-sentences', action='store_true',
                       help='Dataset content consists of documents where '
                       'each document consists of newline separated sentences')
    group.add_argument('--num-workers', type=int, default=2,
                       help="""Number of workers to use for dataloading""")
    group.add_argument('--tokenizer-model-type', type=str,
+7 −5
Original line number Diff line number Diff line
@@ -46,7 +46,7 @@ def make_data_loader(dataset, batch_size, args):

    shuffle = args.shuffle
    if shuffle:
        sampler = torch.utils.data.RandomSampler(dataset)
        sampler = data_utils.samplers.RandomSampler(dataset, replacement=True, num_samples=batch_size*args.train_iters)
    else:
        sampler = torch.utils.data.SequentialSampler(dataset)
    world_size = args.world_size
@@ -81,8 +81,10 @@ def make_tfrecord_loaders(args):
                     'max_seq_len': args.seq_length,
                     'max_preds_per_seq': args.max_preds_per_seq,
                     'train': True,
                     'num_workers': args.num_workers,
                     'seed': args.seed+args.rank+1}
                     'num_workers': max(args.num_workers, 1),
                     'seed': args.seed + args.rank + 1,
                     'threaded_dl': args.num_workers > 0
                     }
    train = data_utils.tf_dl.TFRecordDataLoader(args.train_data,
                                                **data_set_args)
    data_set_args['train'] = False
@@ -140,7 +142,8 @@ def make_loaders(args):
        'vocab_size': args.vocab_size,
        'model_type': args.tokenizer_model_type,
        'cache_dir': args.cache_dir,
        'max_preds_per_seq': args.max_preds_per_seq}
        'max_preds_per_seq': args.max_preds_per_seq,
        'presplit_sentences': args.presplit_sentences}

    eval_set_args = copy.copy(data_set_args)
    eval_set_args['split'] = [1.]
@@ -218,7 +221,6 @@ def configure_data():
        'rank': -1,
        'persist_state': 0,
        'lazy': False,
        'shuffle': False,
        'transpose': False,
        'data_set_type': 'supervised',
        'seq_length': 256,
+5 −3
Original line number Diff line number Diff line
@@ -46,7 +46,7 @@ def get_dataset(path, **kwargs):
    if supported_corpus(path):
        return corpora.NAMED_CORPORA[path](**kwargs)
    ext = get_ext(path)
    if ext =='.json':
    if '.json' in ext:
        text = json_dataset(path, **kwargs)
    elif ext in ['.csv', '.tsv']:
        text = csv_dataset(path, **kwargs)
@@ -108,8 +108,10 @@ def make_dataset(path, seq_length, text_key, label_key, lazy=False, process_fn=N
    if should_split(split):
        ds = split_ds(ds, split)
        if ds_type.lower() == 'bert':
            ds = [bert_sentencepair_dataset(d, max_seq_len=seq_length) for d in ds]
            presplit_sentences = kwargs['presplit_sentences'] if 'presplit_sentences' in kwargs else False
            ds = [bert_sentencepair_dataset(d, max_seq_len=seq_length, presplit_sentences=presplit_sentences) for d in ds]
    else:
        if ds_type.lower() == 'bert':
            ds = bert_sentencepair_dataset(ds, max_seq_len=seq_length)
            presplit_sentences = kwargs['presplit_sentences'] if 'presplit_sentences' in kwargs else False
            ds = bert_sentencepair_dataset(ds, max_seq_len=seq_length, presplit_sentences=presplit_sentences)
    return ds, tokenizer
+10 −2
Original line number Diff line number Diff line
@@ -449,7 +449,7 @@ class bert_sentencepair_dataset(data.Dataset):
        dataset_size (int): number of random sentencepairs in the dataset. Default: len(ds)*(len(ds)-1)

    """
    def __init__(self, ds, max_seq_len=512, mask_lm_prob=.15, max_preds_per_seq=None, short_seq_prob=.01, dataset_size=None, **kwargs):
    def __init__(self, ds, max_seq_len=512, mask_lm_prob=.15, max_preds_per_seq=None, short_seq_prob=.01, dataset_size=None, presplit_sentences=False, **kwargs):
        self.ds = ds
        self.ds_len = len(self.ds)
        self.tokenizer = self.ds.GetTokenizer()
@@ -464,6 +464,7 @@ class bert_sentencepair_dataset(data.Dataset):
        self.dataset_size = dataset_size
        if self.dataset_size is None:
            self.dataset_size = self.ds_len * (self.ds_len-1)
        self.presplit_sentences = presplit_sentences

    def __len__(self):
        return self.dataset_size
@@ -494,7 +495,14 @@ class bert_sentencepair_dataset(data.Dataset):

    def sentence_split(self, document):
        """split document into sentences"""
        return tokenize.sent_tokenize(document)
        lines = document.split('\n')
        if self.presplit_sentences:
            return [line for line in lines if line]
        rtn = []
        for line in lines:
            if line != '':
                rtn.extend(tokenize.sent_tokenize(line))
        return rtn

    def sentence_tokenize(self, sent, sentence_num=0, beginning=False, ending=False):
        """tokenize sentence and get token types"""
Loading