Loading megatron/arguments.py +2 −0 Original line number Diff line number Diff line Loading @@ -349,6 +349,8 @@ def _add_data_args(parser): help='Path to combined dataset to split.') group.add_argument('--titles-data-path', type=str, default=None, help='Path to titles dataset used for ICT') group.add_argument('--block-data-path', type=str, default=None, help='Path for loading and saving block data') group.add_argument('--split', type=str, default='969, 30, 1', help='Comma-separated list of proportions for training,' ' validation, and test split. For example the split ' Loading megatron/data/realm_dataset.py +5 −5 Original line number Diff line number Diff line Loading @@ -6,14 +6,14 @@ import numpy as np from torch.utils.data import Dataset from megatron import get_tokenizer from megatron.data.realm_dataset_utils import get_block_samples_mapping, join_str_list from megatron.data.realm_dataset_utils import BlockSampleData, get_block_samples_mapping, join_str_list class ICTDataset(Dataset): """Dataset containing sentences and their blocks for an inverse cloze task.""" def __init__(self, name, block_dataset, title_dataset, data_prefix, num_epochs, max_num_samples, max_seq_length, query_in_block_prob, short_seq_prob, seed, use_titles=True): num_epochs, max_num_samples, max_seq_length, query_in_block_prob, short_seq_prob, seed, use_titles=True, use_one_sent_docs=False): self.name = name self.seed = seed self.max_seq_length = max_seq_length Loading @@ -26,7 +26,7 @@ class ICTDataset(Dataset): self.samples_mapping = get_block_samples_mapping( block_dataset, title_dataset, data_prefix, num_epochs, max_num_samples, max_seq_length, seed, name) max_num_samples, max_seq_length, seed, name, use_one_sent_docs) self.tokenizer = get_tokenizer() self.vocab_id_list = list(self.tokenizer.inv_vocab.keys()) self.vocab_id_to_token_list = self.tokenizer.inv_vocab Loading @@ -50,7 +50,7 @@ class ICTDataset(Dataset): title = None title_pad_offset = 2 block = [self.block_dataset[i] for i in range(start_idx, end_idx)] assert len(block) > 1 assert len(block) > 1 or self.query_in_block_prob == 1 # randint() is inclusive for Python rng rand_sent_idx = self.rng.randint(0, len(block) - 1) Loading megatron/data/realm_dataset_utils.py +8 −4 Original line number Diff line number Diff line Loading @@ -46,10 +46,11 @@ class BlockSamplesMapping(object): # make sure that the array is compatible with BlockSampleData assert mapping_array.shape[1] == 4 self.mapping_array = mapping_array self.shape = self.mapping_array.shape def __getitem__(self, idx): """Get the data associated with a particular sample.""" sample_data = BlockSamplesData(*self.mapping_array[idx]) sample_data = BlockSampleData(*self.mapping_array[idx]) return sample_data Loading Loading @@ -113,10 +114,10 @@ def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epo seed, verbose, use_one_sent_docs) samples_mapping = BlockSamplesMapping(mapping_array) print_rank_0(' > done building samples index mapping') np.save(indexmap_filename, samples_mapping, allow_pickle=True) np.save(indexmap_filename, mapping_array, allow_pickle=True) print_rank_0(' > saved the index mapping in {}'.format( indexmap_filename)) # Make sure all the ranks have built the mapping Loading @@ -136,7 +137,10 @@ def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epo print_rank_0(' > loading indexed mapping from {}'.format( indexmap_filename)) start_time = time.time() samples_mapping = np.load(indexmap_filename, allow_pickle=True) mapping_array = np.load(indexmap_filename, allow_pickle=True) samples_mapping = BlockSamplesMapping(mapping_array) print_rank_0(' loaded indexed file in {:3.3f} seconds'.format( time.time() - start_time)) print_rank_0(' total number of samples: {}'.format( Loading Loading
megatron/arguments.py +2 −0 Original line number Diff line number Diff line Loading @@ -349,6 +349,8 @@ def _add_data_args(parser): help='Path to combined dataset to split.') group.add_argument('--titles-data-path', type=str, default=None, help='Path to titles dataset used for ICT') group.add_argument('--block-data-path', type=str, default=None, help='Path for loading and saving block data') group.add_argument('--split', type=str, default='969, 30, 1', help='Comma-separated list of proportions for training,' ' validation, and test split. For example the split ' Loading
megatron/data/realm_dataset.py +5 −5 Original line number Diff line number Diff line Loading @@ -6,14 +6,14 @@ import numpy as np from torch.utils.data import Dataset from megatron import get_tokenizer from megatron.data.realm_dataset_utils import get_block_samples_mapping, join_str_list from megatron.data.realm_dataset_utils import BlockSampleData, get_block_samples_mapping, join_str_list class ICTDataset(Dataset): """Dataset containing sentences and their blocks for an inverse cloze task.""" def __init__(self, name, block_dataset, title_dataset, data_prefix, num_epochs, max_num_samples, max_seq_length, query_in_block_prob, short_seq_prob, seed, use_titles=True): num_epochs, max_num_samples, max_seq_length, query_in_block_prob, short_seq_prob, seed, use_titles=True, use_one_sent_docs=False): self.name = name self.seed = seed self.max_seq_length = max_seq_length Loading @@ -26,7 +26,7 @@ class ICTDataset(Dataset): self.samples_mapping = get_block_samples_mapping( block_dataset, title_dataset, data_prefix, num_epochs, max_num_samples, max_seq_length, seed, name) max_num_samples, max_seq_length, seed, name, use_one_sent_docs) self.tokenizer = get_tokenizer() self.vocab_id_list = list(self.tokenizer.inv_vocab.keys()) self.vocab_id_to_token_list = self.tokenizer.inv_vocab Loading @@ -50,7 +50,7 @@ class ICTDataset(Dataset): title = None title_pad_offset = 2 block = [self.block_dataset[i] for i in range(start_idx, end_idx)] assert len(block) > 1 assert len(block) > 1 or self.query_in_block_prob == 1 # randint() is inclusive for Python rng rand_sent_idx = self.rng.randint(0, len(block) - 1) Loading
megatron/data/realm_dataset_utils.py +8 −4 Original line number Diff line number Diff line Loading @@ -46,10 +46,11 @@ class BlockSamplesMapping(object): # make sure that the array is compatible with BlockSampleData assert mapping_array.shape[1] == 4 self.mapping_array = mapping_array self.shape = self.mapping_array.shape def __getitem__(self, idx): """Get the data associated with a particular sample.""" sample_data = BlockSamplesData(*self.mapping_array[idx]) sample_data = BlockSampleData(*self.mapping_array[idx]) return sample_data Loading Loading @@ -113,10 +114,10 @@ def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epo seed, verbose, use_one_sent_docs) samples_mapping = BlockSamplesMapping(mapping_array) print_rank_0(' > done building samples index mapping') np.save(indexmap_filename, samples_mapping, allow_pickle=True) np.save(indexmap_filename, mapping_array, allow_pickle=True) print_rank_0(' > saved the index mapping in {}'.format( indexmap_filename)) # Make sure all the ranks have built the mapping Loading @@ -136,7 +137,10 @@ def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epo print_rank_0(' > loading indexed mapping from {}'.format( indexmap_filename)) start_time = time.time() samples_mapping = np.load(indexmap_filename, allow_pickle=True) mapping_array = np.load(indexmap_filename, allow_pickle=True) samples_mapping = BlockSamplesMapping(mapping_array) print_rank_0(' loaded indexed file in {:3.3f} seconds'.format( time.time() - start_time)) print_rank_0(' total number of samples: {}'.format( Loading