Commit 03feecbc authored by Neel Kant's avatar Neel Kant
Browse files

One sentence options

parent 76928caa
Loading
Loading
Loading
Loading
+2 −0
Original line number Diff line number Diff line
@@ -386,6 +386,8 @@ def _add_data_args(parser):
                       help='Mask loss for the end of document tokens.')
    group.add_argument('--query-in-block-prob', type=float, default=0.1,
                       help='Probability of keeping query in block for ICT dataset')
    group.add_argument('--ict-one-sent', action='store_true',
                       help='Whether to use one sentence documents in ICT')

    return parser

+1 −0
Original line number Diff line number Diff line
@@ -427,6 +427,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
                    block_dataset=indexed_dataset,
                    title_dataset=title_dataset,
                    query_in_block_prob=args.query_in_block_prob,
                    use_one_sent_docs=args.ict_one_sent,
                    **kwargs
                )
            else:
+5 −4
Original line number Diff line number Diff line
@@ -5,14 +5,14 @@ import numpy as np
from torch.utils.data import Dataset

from megatron import get_tokenizer
from megatron.data.realm_dataset_utils import get_block_samples_mapping, join_str_list
from megatron.data.realm_dataset_utils import get_block_samples_mapping


class ICTDataset(Dataset):
    """Dataset containing sentences and their blocks for an inverse cloze task."""
    def __init__(self, name, block_dataset, title_dataset, data_prefix,
                 num_epochs, max_num_samples, max_seq_length,
                 query_in_block_prob, short_seq_prob, seed, use_titles=True):
                 query_in_block_prob, short_seq_prob, seed, use_titles=True, use_one_sent_docs=False):
        self.name = name
        self.seed = seed
        self.max_seq_length = max_seq_length
@@ -22,10 +22,11 @@ class ICTDataset(Dataset):
        self.short_seq_prob = short_seq_prob
        self.rng = random.Random(self.seed)
        self.use_titles = use_titles
        self.use_one_sent_docs = use_one_sent_docs

        self.samples_mapping = get_block_samples_mapping(
            block_dataset, title_dataset, data_prefix, num_epochs,
            max_num_samples, max_seq_length, seed, name)
            max_num_samples, max_seq_length, seed, name, use_one_sent_docs)
        self.tokenizer = get_tokenizer()
        self.vocab_id_list = list(self.tokenizer.inv_vocab.keys())
        self.vocab_id_to_token_list = self.tokenizer.inv_vocab
@@ -47,7 +48,7 @@ class ICTDataset(Dataset):
            title = None
            title_pad_offset = 2
        block = [self.block_dataset[i] for i in range(start_idx, end_idx)]
        assert len(block) > 1
        assert len(block) > 1 or self.use_one_sent_docs

        # randint() is inclusive for Python rng
        rand_sent_idx = self.rng.randint(0, len(block) - 1)
+1 −1
Original line number Diff line number Diff line
@@ -91,7 +91,7 @@ def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epo
    print_rank_0(' > loading indexed mapping from {}'.format(
        indexmap_filename))
    start_time = time.time()
    samples_mapping = np.load(indexmap_filename, allow_pickle=True)
    samples_mapping = np.load(indexmap_filename, allow_pickle=True, mmap_mode='r')
    print_rank_0('    loaded indexed file in {:3.3f} seconds'.format(
        time.time() - start_time))
    print_rank_0('    total number of samples: {}'.format(