Commit dedb2ef7 authored by Mohammad's avatar Mohammad
Browse files

removed building tokenizer from bert dataset

parent 1788c910
Loading
Loading
Loading
Loading
+12 −22
Original line number Diff line number Diff line
@@ -22,24 +22,19 @@ import numpy as np
import torch
from torch.utils.data import Dataset

from megatron import get_tokenizer
from megatron import mpu
from megatron.data import helpers
from megatron.tokenizer.bert_tokenization import FullTokenizer as FullBertTokenizer
from megatron.data.dataset_utils import build_training_sample
from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
from megatron import print_rank_0


def build_train_valid_test_datasets(vocab_file, data_prefix, data_impl,
                                    splits_string, train_valid_test_num_samples,
def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
                                    train_valid_test_num_samples,
                                    max_seq_length, masked_lm_prob,
                                    short_seq_prob, seed, skip_warmup):

    # Tokenizer is the same
    tokenizer = FullBertTokenizer(vocab_file, do_lower_case=True)
    print_rank_0(' > using full BERT tokenizer with vocabulary size: {}'.format(
        tokenizer.vocab_size()))

    # Indexed dataset.
    indexed_dataset = get_indexed_dataset_(data_prefix,
                                           data_impl,
@@ -82,7 +77,6 @@ def build_train_valid_test_datasets(vocab_file, data_prefix, data_impl,
            dataset = BertDataset(
                name=name,
                indexed_dataset=indexed_dataset,
                tokenizer=tokenizer,
                data_prefix=data_prefix,
                num_epochs=None,
                max_num_samples=train_valid_test_num_samples[index],
@@ -107,7 +101,7 @@ def build_train_valid_test_datasets(vocab_file, data_prefix, data_impl,

class BertDataset(Dataset):

    def __init__(self, name, indexed_dataset, tokenizer, data_prefix,
    def __init__(self, name, indexed_dataset, data_prefix,
                 num_epochs, max_num_samples, masked_lm_prob,
                 max_seq_length, short_seq_prob, seed):

@@ -117,8 +111,7 @@ class BertDataset(Dataset):
        self.masked_lm_prob = masked_lm_prob
        self.max_seq_length = max_seq_length

        # Tokenizer and dataset.
        self.tokenizer = tokenizer
        # Dataset.
        self.indexed_dataset = indexed_dataset


@@ -133,16 +126,13 @@ class BertDataset(Dataset):
                                                    self.name)

        # Vocab stuff.
        self.vocab_id_list = list(self.tokenizer.inv_vocab.keys())
        self.vocab_id_to_token_dict = self.tokenizer.inv_vocab
        self.cls_id = self.tokenizer.vocab['[CLS]']
        self.sep_id = self.tokenizer.vocab['[SEP]']
        self.mask_id = self.tokenizer.vocab['[MASK]']
        self.pad_id = self.tokenizer.vocab['[PAD]']


    def num_tokens(self):
        return self.tokenizer.vocab_size()
        tokenizer = get_tokenizer()
        self.vocab_id_list = list(tokenizer.inv_vocab.keys())
        self.vocab_id_to_token_dict = tokenizer.inv_vocab
        self.cls_id = tokenizer.cls
        self.sep_id = tokenizer.sep
        self.mask_id = tokenizer.mask
        self.pad_id = tokenizer.pad


    def __len__(self):
+37 −0
Original line number Diff line number Diff line
@@ -75,6 +75,18 @@ class AbstractTokenizer(ABC):
    def vocab_size(self):
        pass

    @property
    @abstractmethod
    def vocab(self):
        """Dictionary from vocab text token to id token."""
        pass

    @property
    @abstractmethod
    def inv_vocab(self):
        """Dictionary from vocab id token to text token."""
        pass

    @abstractmethod
    def tokenize(self, text):
        pass
@@ -99,6 +111,11 @@ class AbstractTokenizer(ABC):
        raise NotImplementedError('EOD is not provided for {} '
                                  'tokenizer'.format(self.name))

    @property
    def mask(self):
        raise NotImplementedError('MASK is not provided for {} '
                                  'tokenizer'.format(self.name))


class _BertWordPieceTokenizer(AbstractTokenizer):
    """Original BERT wordpiece tokenizer."""
@@ -113,11 +130,20 @@ class _BertWordPieceTokenizer(AbstractTokenizer):
        self.cls_id = self.tokenizer.vocab['[CLS]']
        self.sep_id = self.tokenizer.vocab['[SEP]']
        self.pad_id = self.tokenizer.vocab['[PAD]']
        self.mask_id = self.tokenizer.vocab['[MASK]']  

    @property
    def vocab_size(self):
        return self.tokenizer.vocab_size()

    @property
    def vocab(self):
        return self.tokenizer.vocab

    @property
    def inv_vocab(self):
        return self.tokenizer.inv_vocab

    def tokenize(self, text):
        text_tokens = self.tokenizer.tokenize(text)
        return self.tokenizer.convert_tokens_to_ids(text_tokens)
@@ -134,6 +160,9 @@ class _BertWordPieceTokenizer(AbstractTokenizer):
    def pad(self):
        return self.pad_id

    @property
    def mask(self):
        return self.mask_id

class _GPT2BPETokenizer(AbstractTokenizer):
    """Original GPT2 BPE tokenizer."""
@@ -150,6 +179,14 @@ class _GPT2BPETokenizer(AbstractTokenizer):
    def vocab_size(self):
        return len(self.tokenizer.encoder)

    @property
    def vocab(self):
        return self.tokenizer.encoder

    @property
    def inv_vocab(self):
        return self.tokenizer.decoder

    def tokenize(self, text):
        return self.tokenizer.encode(text)

+0 −1
Original line number Diff line number Diff line
@@ -137,7 +137,6 @@ def get_train_val_test_data():
        print_rank_0('    test:       {}'.format(train_val_test_num_samples[2]))

        train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
            vocab_file=args.vocab_file,
            data_prefix=args.data_path,
            data_impl=args.data_impl,
            splits_string=args.split,