Commit c125d247 authored by Mohammad Shoeybi's avatar Mohammad Shoeybi
Browse files

built simple test for dataset

parent 7120e931
Loading
Loading
Loading
Loading
+62 −14
Original line number Diff line number Diff line
@@ -11,11 +11,11 @@ from torch.utils.data import Dataset
# WILL BE REPLACED WITH JARED'S
class JaredDataset(object):

    def __init__(self):
        self.doc_idx = []
    def __init__(self, doc_idx, sizes, sentences):
        self.doc_idx = doc_idx
        self.num_docs = len(self.doc_idx) - 1
        self.sizes = []
        self.sentences = []
        self.sizes = sizes
        self.sentences = sentences

    def __getitem__(self, idx):
        return self.sentences[idx]
@@ -62,7 +62,7 @@ def build_training_samples_mapping(indexed_dataset, num_epochs, max_seq_length,
            # Document sentences are in [sent_index_first, sent_index_last).
            sent_index_first = indexed_dataset.doc_idx[doc_index]
            sent_index_last = indexed_dataset.doc_idx[doc_index+1]
            assert sent_index_last >= sent_index_first:
            assert sent_index_last >= sent_index_first

            # Empty docs.
            if (sent_index_last - sent_index_first) == 0:
@@ -82,7 +82,7 @@ def build_training_samples_mapping(indexed_dataset, num_epochs, max_seq_length,
            # Loop through sentences.
            sent_index = sent_index_first
            target_seq_length = get_target_seq_length(max_num_tokens,
                                                      short_seq_prob, rng)
                                                      short_seq_prob, np_rng)
            size = 0
            while sent_index < sent_index_last:

@@ -94,19 +94,22 @@ def build_training_samples_mapping(indexed_dataset, num_epochs, max_seq_length,
                exceeded_target_size = (size >= target_seq_length)
                # If only one sentence is left in the document.
                only_one_sent_left = (sent_index == (sent_index_last - 1))
                # If we have at least two sentneces.
                have_more_than_one_sent = (sent_index - sent_index_first) > 1
                # If we have reached end of the document.
                reached_end_of_doc = (sent_index == sent_index_last)
                if (exceeded_target_size and not only_one_sent_left) or \
                   reached_end_of_doc:
                if (exceeded_target_size and not only_one_sent_left and
                    have_more_than_one_sent) or reached_end_of_doc:
                    assert (sent_index - sent_index_first) > 1
                    assert size > 1
                    # Add the sample.
                    samples.append([sent_index_first, sent_index])
                    samples.append([sent_index_first, sent_index,
                                    target_seq_length])
                    # Reset indices
                    sent_index_first = sent_index
                    target_seq_length = get_target_seq_length(max_num_tokens,
                                                              short_seq_prob,
                                                              rng)
                                                              np_rng)
                    size = 0
                    num_sentences = 0

@@ -132,16 +135,16 @@ def build_training_samples_mapping(indexed_dataset, num_epochs, max_seq_length,

class AlbertDataSet(Dataset):

    def __init__(self, tokenizer, num_epochs, masked_lm_prob, max_seq_length
                 short_seq_prob, seed):
    def __init__(self, indexed_dataset, tokenizer, num_epochs,
                 masked_lm_prob, max_seq_length, short_seq_prob, seed):

        # Params to store.
        self.seed = seed
        self.masked_lm_prob = masked_lm_prob
        self.max_seq_length = max_seq_length

        # Build the indexed dataset.
        self.indexed_dataset = JaredDataset()
        # Indexed dataset.
        self.indexed_dataset = indexed_dataset

        # Build the samples mapping.
        self.samples_mapping = build_training_samples_mapping(
@@ -181,3 +184,48 @@ class AlbertDataSet(Dataset):
if __name__ == '__main__':

    print('dataset ...')

    from bert_tokenization import FullTokenizer
    import json
    import nltk
    nltk.download('punkt')

    def document_generator_provider(input_file):
        with open(input_file, 'r') as ifile:
            for document in ifile:
                data = json.loads(document)
                text = data['text']
                sentences = []
                for line in text.split('\n'):
                    if line != '\n':
                        sentences.extend(nltk.tokenize.sent_tokenize(line))
                yield sentences

    input_file = '/raid/mshoeybi/data/albert/sample/samples_11.json'
    vocab_file = '/raid/mshoeybi/data/albert/bert_vocab/vocab.txt'

    tokenizer = FullTokenizer(vocab_file, do_lower_case=True)
    document_generator = document_generator_provider(input_file)

    doc_idx = [0]
    sizes = []
    sentences_list = []

    for sentences in document_generator:
        doc_idx.append(len(sentences))
        for sentence in sentences:
            tokens = tokenizer.tokenize(sentence)
            ids = tokenizer.convert_tokens_to_ids(tokens)
            sizes.append(len(ids))
            sentences_list.append(ids)
    for i in range(1, len(doc_idx)):
        doc_idx[i] += doc_idx[i-1]

    indexed_dataset = JaredDataset(doc_idx, sizes, sentences_list)
    dataset = AlbertDataSet(indexed_dataset=indexed_dataset,
                            tokenizer=tokenizer,
                            num_epochs=3,
                            masked_lm_prob=0.15,
                            max_seq_length=512,
                            short_seq_prob=0.1,
                            seed=1234)