Loading megatron/arguments.py +8 −2 Original line number Diff line number Diff line Loading @@ -411,7 +411,7 @@ def _add_realm_args(parser): help='Path to titles dataset used for ICT') group.add_argument('--query-in-block-prob', type=float, default=0.1, help='Probability of keeping query in block for ICT dataset') group.add_argument('--ict-one-sent', action='store_true', group.add_argument('--use-one-sent-docs', action='store_true', help='Whether to use one sentence documents in ICT') # training Loading @@ -421,7 +421,13 @@ def _add_realm_args(parser): # faiss index group.add_argument('--faiss-use-gpu', action='store_true', help='Whether create the FaissMIPSIndex on GPU') group.add_argument('--block-data-path', type=str, group.add_argument('--block-data-path', type=str, default=None, help='Where to save/load BlockData to/from') # indexer group.add_argument('--indexer-batch-size', type=int, default=128, help='How large of batches to use when doing indexing jobs') group.add_argument('--indexer-log-interval', type=int, default=1000, help='After how many batches should the indexer report progress') return parser megatron/checkpointing.py +42 −1 Original line number Diff line number Diff line Loading @@ -21,9 +21,10 @@ import sys import numpy as np import torch from torch.nn.parallel import DistributedDataParallel as torchDDP from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP from megatron import mpu from megatron import mpu, get_args from megatron import get_args from megatron import print_rank_0 Loading Loading @@ -244,3 +245,43 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'): print(' successfully loaded {}'.format(checkpoint_name)) return iteration def load_ict_checkpoint(model, only_query_model=False, only_block_model=False, from_realm_chkpt=False): """selectively load ICT models for indexing/retrieving from ICT or REALM checkpoints""" args = get_args() if isinstance(model, torchDDP): model = model.module load_path = args.load if from_realm_chkpt else args.ict_load tracker_filename = get_checkpoint_tracker_filename(load_path) with open(tracker_filename, 'r') as f: iteration = int(f.read().strip()) # assert iteration > 0 checkpoint_name = get_checkpoint_name(load_path, iteration, False) if mpu.get_data_parallel_rank() == 0: print('global rank {} is loading checkpoint {}'.format( torch.distributed.get_rank(), checkpoint_name)) state_dict = torch.load(checkpoint_name, map_location='cpu') ict_state_dict = state_dict['model'] if from_realm_chkpt and mpu.get_data_parallel_rank() == 0: print(" loading ICT state dict from REALM", flush=True) ict_state_dict = ict_state_dict['retriever']['ict_model'] if only_query_model: ict_state_dict.pop('context_model') if only_block_model: ict_state_dict.pop('question_model') model.load_state_dict(ict_state_dict) torch.distributed.barrier() if mpu.get_data_parallel_rank() == 0: print(' successfully loaded {}'.format(checkpoint_name)) return model No newline at end of file megatron/data/dataset_utils.py +1 −1 Original line number Diff line number Diff line Loading @@ -426,7 +426,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, block_dataset=indexed_dataset, title_dataset=title_dataset, query_in_block_prob=args.query_in_block_prob, use_one_sent_docs=args.ict_one_sent, use_one_sent_docs=args.use_one_sent_docs, **kwargs ) else: Loading megatron/data/ict_dataset.py +28 −1 Original line number Diff line number Diff line Loading @@ -5,9 +5,36 @@ import numpy as np from torch.utils.data import Dataset from megatron import get_tokenizer from megatron import get_args from megatron.data.dataset_utils import get_indexed_dataset_ from megatron.data.realm_dataset_utils import get_block_samples_mapping def get_ict_dataset(use_titles=True, query_in_block_prob=1): """Get a dataset which uses block samples mappings to get ICT/block indexing data (via get_block()) rather than for training, since it is only built with a single epoch sample mapping. """ args = get_args() block_dataset = get_indexed_dataset_(args.data_path, 'mmap', True) titles_dataset = get_indexed_dataset_(args.titles_data_path, 'mmap', True) kwargs = dict( name='full', block_dataset=block_dataset, title_dataset=titles_dataset, data_prefix=args.data_path, num_epochs=1, max_num_samples=None, max_seq_length=args.seq_length, seed=1, query_in_block_prob=query_in_block_prob, use_titles=use_titles, use_one_sent_docs=args.use_one_sent_docs ) dataset = ICTDataset(**kwargs) return dataset class ICTDataset(Dataset): """Dataset containing sentences and their blocks for an inverse cloze task.""" def __init__(self, name, block_dataset, title_dataset, data_prefix, Loading Loading @@ -35,7 +62,7 @@ class ICTDataset(Dataset): self.pad_id = self.tokenizer.pad def __len__(self): return self.samples_mapping.shape[0] return len(self.samples_mapping) def __getitem__(self, idx): """Get an ICT example of a pseudo-query and the block of text from which it was extracted""" Loading megatron/data/realm_dataset_utils.py +55 −3 Original line number Diff line number Diff line Loading @@ -6,9 +6,59 @@ import torch from megatron import mpu, print_rank_0 from megatron.data.dataset_utils import create_masked_lm_predictions, pad_and_convert_to_numpy from megatron.data.samplers import DistributedBatchSampler from megatron import get_args, get_tokenizer, print_rank_0, mpu def get_one_epoch_dataloader(dataset, batch_size=None): """Specifically one epoch to be used in an indexing job.""" args = get_args() world_size = mpu.get_data_parallel_world_size() rank = mpu.get_data_parallel_rank() if batch_size is None: batch_size = args.batch_size global_batch_size = batch_size * world_size num_workers = args.num_workers sampler = torch.utils.data.SequentialSampler(dataset) # importantly, drop_last must be False to get all the data. batch_sampler = DistributedBatchSampler(sampler, batch_size=global_batch_size, drop_last=False, rank=rank, world_size=world_size) return torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler, num_workers=num_workers, pin_memory=True) def get_ict_batch(data_iterator): # Items and their type. keys = ['query_tokens', 'query_pad_mask', 'block_tokens', 'block_pad_mask', 'block_data'] datatype = torch.int64 # Broadcast data. if data_iterator is None: data = None else: data = next(data_iterator) data_b = mpu.broadcast_data(keys, data, datatype) # Unpack. query_tokens = data_b['query_tokens'].long() query_pad_mask = data_b['query_pad_mask'].long() block_tokens = data_b['block_tokens'].long() block_pad_mask = data_b['block_pad_mask'].long() block_indices = data_b['block_data'].long() return query_tokens, query_pad_mask,\ block_tokens, block_pad_mask, block_indices def join_str_list(str_list): """Join a list of strings, handling spaces appropriately""" result = "" Loading Loading @@ -46,10 +96,12 @@ class BlockSamplesMapping(object): # make sure that the array is compatible with BlockSampleData assert mapping_array.shape[1] == 4 self.mapping_array = mapping_array self.shape = self.mapping_array.shape def __len__(self): return self.mapping_array.shape[0] def __getitem__(self, idx): """Get the data associated with a particular sample.""" """Get the data associated with an indexed sample.""" sample_data = BlockSampleData(*self.mapping_array[idx]) return sample_data Loading Loading @@ -144,6 +196,6 @@ def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epo print_rank_0(' loaded indexed file in {:3.3f} seconds'.format( time.time() - start_time)) print_rank_0(' total number of samples: {}'.format( samples_mapping.shape[0])) mapping_array.shape[0])) return samples_mapping Loading
megatron/arguments.py +8 −2 Original line number Diff line number Diff line Loading @@ -411,7 +411,7 @@ def _add_realm_args(parser): help='Path to titles dataset used for ICT') group.add_argument('--query-in-block-prob', type=float, default=0.1, help='Probability of keeping query in block for ICT dataset') group.add_argument('--ict-one-sent', action='store_true', group.add_argument('--use-one-sent-docs', action='store_true', help='Whether to use one sentence documents in ICT') # training Loading @@ -421,7 +421,13 @@ def _add_realm_args(parser): # faiss index group.add_argument('--faiss-use-gpu', action='store_true', help='Whether create the FaissMIPSIndex on GPU') group.add_argument('--block-data-path', type=str, group.add_argument('--block-data-path', type=str, default=None, help='Where to save/load BlockData to/from') # indexer group.add_argument('--indexer-batch-size', type=int, default=128, help='How large of batches to use when doing indexing jobs') group.add_argument('--indexer-log-interval', type=int, default=1000, help='After how many batches should the indexer report progress') return parser
megatron/checkpointing.py +42 −1 Original line number Diff line number Diff line Loading @@ -21,9 +21,10 @@ import sys import numpy as np import torch from torch.nn.parallel import DistributedDataParallel as torchDDP from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP from megatron import mpu from megatron import mpu, get_args from megatron import get_args from megatron import print_rank_0 Loading Loading @@ -244,3 +245,43 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'): print(' successfully loaded {}'.format(checkpoint_name)) return iteration def load_ict_checkpoint(model, only_query_model=False, only_block_model=False, from_realm_chkpt=False): """selectively load ICT models for indexing/retrieving from ICT or REALM checkpoints""" args = get_args() if isinstance(model, torchDDP): model = model.module load_path = args.load if from_realm_chkpt else args.ict_load tracker_filename = get_checkpoint_tracker_filename(load_path) with open(tracker_filename, 'r') as f: iteration = int(f.read().strip()) # assert iteration > 0 checkpoint_name = get_checkpoint_name(load_path, iteration, False) if mpu.get_data_parallel_rank() == 0: print('global rank {} is loading checkpoint {}'.format( torch.distributed.get_rank(), checkpoint_name)) state_dict = torch.load(checkpoint_name, map_location='cpu') ict_state_dict = state_dict['model'] if from_realm_chkpt and mpu.get_data_parallel_rank() == 0: print(" loading ICT state dict from REALM", flush=True) ict_state_dict = ict_state_dict['retriever']['ict_model'] if only_query_model: ict_state_dict.pop('context_model') if only_block_model: ict_state_dict.pop('question_model') model.load_state_dict(ict_state_dict) torch.distributed.barrier() if mpu.get_data_parallel_rank() == 0: print(' successfully loaded {}'.format(checkpoint_name)) return model No newline at end of file
megatron/data/dataset_utils.py +1 −1 Original line number Diff line number Diff line Loading @@ -426,7 +426,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, block_dataset=indexed_dataset, title_dataset=title_dataset, query_in_block_prob=args.query_in_block_prob, use_one_sent_docs=args.ict_one_sent, use_one_sent_docs=args.use_one_sent_docs, **kwargs ) else: Loading
megatron/data/ict_dataset.py +28 −1 Original line number Diff line number Diff line Loading @@ -5,9 +5,36 @@ import numpy as np from torch.utils.data import Dataset from megatron import get_tokenizer from megatron import get_args from megatron.data.dataset_utils import get_indexed_dataset_ from megatron.data.realm_dataset_utils import get_block_samples_mapping def get_ict_dataset(use_titles=True, query_in_block_prob=1): """Get a dataset which uses block samples mappings to get ICT/block indexing data (via get_block()) rather than for training, since it is only built with a single epoch sample mapping. """ args = get_args() block_dataset = get_indexed_dataset_(args.data_path, 'mmap', True) titles_dataset = get_indexed_dataset_(args.titles_data_path, 'mmap', True) kwargs = dict( name='full', block_dataset=block_dataset, title_dataset=titles_dataset, data_prefix=args.data_path, num_epochs=1, max_num_samples=None, max_seq_length=args.seq_length, seed=1, query_in_block_prob=query_in_block_prob, use_titles=use_titles, use_one_sent_docs=args.use_one_sent_docs ) dataset = ICTDataset(**kwargs) return dataset class ICTDataset(Dataset): """Dataset containing sentences and their blocks for an inverse cloze task.""" def __init__(self, name, block_dataset, title_dataset, data_prefix, Loading Loading @@ -35,7 +62,7 @@ class ICTDataset(Dataset): self.pad_id = self.tokenizer.pad def __len__(self): return self.samples_mapping.shape[0] return len(self.samples_mapping) def __getitem__(self, idx): """Get an ICT example of a pseudo-query and the block of text from which it was extracted""" Loading
megatron/data/realm_dataset_utils.py +55 −3 Original line number Diff line number Diff line Loading @@ -6,9 +6,59 @@ import torch from megatron import mpu, print_rank_0 from megatron.data.dataset_utils import create_masked_lm_predictions, pad_and_convert_to_numpy from megatron.data.samplers import DistributedBatchSampler from megatron import get_args, get_tokenizer, print_rank_0, mpu def get_one_epoch_dataloader(dataset, batch_size=None): """Specifically one epoch to be used in an indexing job.""" args = get_args() world_size = mpu.get_data_parallel_world_size() rank = mpu.get_data_parallel_rank() if batch_size is None: batch_size = args.batch_size global_batch_size = batch_size * world_size num_workers = args.num_workers sampler = torch.utils.data.SequentialSampler(dataset) # importantly, drop_last must be False to get all the data. batch_sampler = DistributedBatchSampler(sampler, batch_size=global_batch_size, drop_last=False, rank=rank, world_size=world_size) return torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler, num_workers=num_workers, pin_memory=True) def get_ict_batch(data_iterator): # Items and their type. keys = ['query_tokens', 'query_pad_mask', 'block_tokens', 'block_pad_mask', 'block_data'] datatype = torch.int64 # Broadcast data. if data_iterator is None: data = None else: data = next(data_iterator) data_b = mpu.broadcast_data(keys, data, datatype) # Unpack. query_tokens = data_b['query_tokens'].long() query_pad_mask = data_b['query_pad_mask'].long() block_tokens = data_b['block_tokens'].long() block_pad_mask = data_b['block_pad_mask'].long() block_indices = data_b['block_data'].long() return query_tokens, query_pad_mask,\ block_tokens, block_pad_mask, block_indices def join_str_list(str_list): """Join a list of strings, handling spaces appropriately""" result = "" Loading Loading @@ -46,10 +96,12 @@ class BlockSamplesMapping(object): # make sure that the array is compatible with BlockSampleData assert mapping_array.shape[1] == 4 self.mapping_array = mapping_array self.shape = self.mapping_array.shape def __len__(self): return self.mapping_array.shape[0] def __getitem__(self, idx): """Get the data associated with a particular sample.""" """Get the data associated with an indexed sample.""" sample_data = BlockSampleData(*self.mapping_array[idx]) return sample_data Loading Loading @@ -144,6 +196,6 @@ def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epo print_rank_0(' loaded indexed file in {:3.3f} seconds'.format( time.time() - start_time)) print_rank_0(' total number of samples: {}'.format( samples_mapping.shape[0])) mapping_array.shape[0])) return samples_mapping