Commit 44860f8d authored by Neel Kant's avatar Neel Kant
Browse files

Changes after running on draco

parent a00b3c79
Loading
Loading
Loading
Loading
+2 −0
Original line number Diff line number Diff line
@@ -349,6 +349,8 @@ def _add_data_args(parser):
                       help='Path to combined dataset to split.')
    group.add_argument('--titles-data-path', type=str, default=None,
                       help='Path to titles dataset used for ICT')
    group.add_argument('--block-data-path', type=str, default=None,
                       help='Path for loading and saving block data')
    group.add_argument('--split', type=str, default='969, 30, 1',
                       help='Comma-separated list of proportions for training,'
                       ' validation, and test split. For example the split '
+5 −5
Original line number Diff line number Diff line
@@ -6,14 +6,14 @@ import numpy as np
from torch.utils.data import Dataset

from megatron import get_tokenizer
from megatron.data.realm_dataset_utils import get_block_samples_mapping, join_str_list
from megatron.data.realm_dataset_utils import BlockSampleData, get_block_samples_mapping, join_str_list


class ICTDataset(Dataset):
    """Dataset containing sentences and their blocks for an inverse cloze task."""
    def __init__(self, name, block_dataset, title_dataset, data_prefix,
                 num_epochs, max_num_samples, max_seq_length,
                 query_in_block_prob, short_seq_prob, seed, use_titles=True):
                 num_epochs, max_num_samples, max_seq_length, query_in_block_prob,
                 short_seq_prob, seed, use_titles=True, use_one_sent_docs=False):
        self.name = name
        self.seed = seed
        self.max_seq_length = max_seq_length
@@ -26,7 +26,7 @@ class ICTDataset(Dataset):

        self.samples_mapping = get_block_samples_mapping(
            block_dataset, title_dataset, data_prefix, num_epochs,
            max_num_samples, max_seq_length, seed, name)
            max_num_samples, max_seq_length, seed, name, use_one_sent_docs)
        self.tokenizer = get_tokenizer()
        self.vocab_id_list = list(self.tokenizer.inv_vocab.keys())
        self.vocab_id_to_token_list = self.tokenizer.inv_vocab
@@ -50,7 +50,7 @@ class ICTDataset(Dataset):
            title = None
            title_pad_offset = 2
        block = [self.block_dataset[i] for i in range(start_idx, end_idx)]
        assert len(block) > 1
        assert len(block) > 1 or self.query_in_block_prob == 1

        # randint() is inclusive for Python rng
        rand_sent_idx = self.rng.randint(0, len(block) - 1)
+8 −4
Original line number Diff line number Diff line
@@ -46,10 +46,11 @@ class BlockSamplesMapping(object):
        # make sure that the array is compatible with BlockSampleData
        assert mapping_array.shape[1] == 4
        self.mapping_array = mapping_array
        self.shape = self.mapping_array.shape

    def __getitem__(self, idx):
        """Get the data associated with a particular sample."""
        sample_data = BlockSamplesData(*self.mapping_array[idx])
        sample_data = BlockSampleData(*self.mapping_array[idx])
        return sample_data


@@ -113,10 +114,10 @@ def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epo
            seed,
            verbose,
            use_one_sent_docs)
        samples_mapping = BlockSamplesMapping(mapping_array)


        print_rank_0(' > done building samples index mapping')
        np.save(indexmap_filename, samples_mapping, allow_pickle=True)
        np.save(indexmap_filename, mapping_array, allow_pickle=True)
        print_rank_0(' > saved the index mapping in {}'.format(
            indexmap_filename))
        # Make sure all the ranks have built the mapping
@@ -136,7 +137,10 @@ def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epo
    print_rank_0(' > loading indexed mapping from {}'.format(
        indexmap_filename))
    start_time = time.time()
    samples_mapping = np.load(indexmap_filename, allow_pickle=True)

    mapping_array = np.load(indexmap_filename, allow_pickle=True)
    samples_mapping = BlockSamplesMapping(mapping_array)

    print_rank_0('    loaded indexed file in {:3.3f} seconds'.format(
        time.time() - start_time))
    print_rank_0('    total number of samples: {}'.format(