Loading megatron/checkpointing.py +5 −3 Original line number Diff line number Diff line Loading @@ -371,7 +371,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True return iteration def load_ict_checkpoint(model, only_query_model=False, only_block_model=False, from_realm_chkpt=False): def load_ict_checkpoint(model, only_query_model=False, only_context_model=False, from_realm_chkpt=False): """selectively load ICT models for indexing/retrieving from ICT or REALM checkpoints""" args = get_args() Loading @@ -393,14 +393,16 @@ def load_ict_checkpoint(model, only_query_model=False, only_block_model=False, f state_dict = torch.load(checkpoint_name, map_location='cpu') ict_state_dict = state_dict['model'] print(ict_state_dict) sys.exit() if from_realm_chkpt and mpu.get_data_parallel_rank() == 0: print(" loading ICT state dict from REALM", flush=True) ict_state_dict = ict_state_dict['retriever']['ict_model'] if only_query_model: ict_state_dict.pop('context_model') if only_block_model: ict_state_dict.pop('question_model') if only_context_model: ict_state_dict.pop('query_model') model.load_state_dict(ict_state_dict) torch.distributed.barrier() Loading megatron/data/biencoder_dataset_utils.py +27 −0 Original line number Diff line number Diff line Loading @@ -9,6 +9,33 @@ from megatron.data.dataset_utils import create_masked_lm_predictions, pad_and_co from megatron import get_args, get_tokenizer, print_rank_0, mpu def get_one_epoch_dataloader(dataset, micro_batch_size=None): """Specifically one epoch to be used in an indexing job.""" args = get_args() world_size = mpu.get_data_parallel_world_size() rank = mpu.get_data_parallel_rank() if micro_batch_size is None: micro_batch_size = args.micro_batch_size global_batch_size = micro_batch_size * world_size num_workers = args.num_workers sampler = torch.utils.data.SequentialSampler(dataset) # importantly, drop_last must be False to get all the data. assert False, 'DistributedBatchSampler deprecated, change the implementation' from megatron.data.samplers import DistributedBatchSampler batch_sampler = DistributedBatchSampler(sampler, batch_size=global_batch_size, drop_last=False, rank=rank, world_size=world_size) return torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler, num_workers=num_workers, pin_memory=True) def get_ict_batch(data_iterator): # Items and their type. keys = ['query_tokens', 'query_mask', Loading megatron/data/realm_index.py +23 −20 Original line number Diff line number Diff line Loading @@ -14,28 +14,29 @@ def detach(tensor): return tensor.detach().cpu().numpy() class BlockData(object): """Serializable data structure for holding data for blocks -- embeddings and necessary metadata for REALM""" def __init__(self, block_data_path=None, load_from_path=True, rank=None): class OpenRetreivalDataStore(object): """Serializable data structure for holding data for blocks -- embeddings and necessary metadata for Retriever""" def __init__(self, embedding_path=None, load_from_path=True, rank=None): self.embed_data = dict() self.meta_data = dict() if block_data_path is None: #self.meta_data = dict() if embedding_path is None: args = get_args() block_data_path = args.block_data_path embedding_path = args.embedding_path rank = args.rank self.block_data_path = block_data_path self.embedding_path = embedding_path self.rank = rank if load_from_path: self.load_from_file() block_data_name = os.path.splitext(self.block_data_path)[0] block_data_name = os.path.splitext(self.embedding_path)[0] self.temp_dir_name = block_data_name + '_tmp' def state(self): return { 'embed_data': self.embed_data, 'meta_data': self.meta_data, #'meta_data': self.meta_data, } def clear(self): Loading @@ -50,26 +51,28 @@ class BlockData(object): if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0: print("\n> Unpickling BlockData", flush=True) state_dict = pickle.load(open(self.block_data_path, 'rb')) state_dict = pickle.load(open(self.embedding_path, 'rb')) if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0: print(">> Finished unpickling BlockData\n", flush=True) self.embed_data = state_dict['embed_data'] self.meta_data = state_dict['meta_data'] #self.meta_data = state_dict['meta_data'] def add_block_data(self, block_indices, block_embeds, block_metas, allow_overwrite=False): #def add_block_data(self, block_indices, block_embeds, block_metas, allow_overwrite=False): def add_block_data(self, row_id, block_embeds, allow_overwrite=False): """Add data for set of blocks :param block_indices: 1D array of unique int ids for the blocks :param row_id: 1D array of unique int ids for the blocks :param block_embeds: 2D array of embeddings of the blocks :param block_metas: 2D array of metadata for the blocks. #:param block_metas: 2D array of metadata for the blocks. In the case of REALM this will be [start_idx, end_idx, doc_idx] """ for idx, embed, meta in zip(block_indices, block_embeds, block_metas): #for idx, embed, meta in zip(block_indices, block_embeds, block_metas): for idx, embed in zip(row_id, block_embeds): if not allow_overwrite and idx in self.embed_data: raise ValueError("Unexpectedly tried to overwrite block data") self.embed_data[idx] = np.float16(embed) self.meta_data[idx] = meta #self.meta_data[idx] = meta def save_shard(self): """Save the block data that was created this in this process""" Loading @@ -77,8 +80,8 @@ class BlockData(object): os.makedirs(self.temp_dir_name, exist_ok=True) # save the data for each shard with open('{}/{}.pkl'.format(self.temp_dir_name, self.rank), 'wb') as data_file: pickle.dump(self.state(), data_file) with open('{}/{}.pkl'.format(self.temp_dir_name, self.rank), 'wb') as writer: pickle.dump(self.state(), writer) def merge_shards_and_save(self): """Combine all the shards made using self.save_shard()""" Loading @@ -98,13 +101,13 @@ class BlockData(object): # add the shard's data and check to make sure there is no overlap self.embed_data.update(data['embed_data']) self.meta_data.update(data['meta_data']) #self.meta_data.update(data['meta_data']) assert len(self.embed_data) == old_size + shard_size assert seen_own_shard # save the consolidated shards and remove temporary directory with open(self.block_data_path, 'wb') as final_file: with open(self.embedding_path, 'wb') as final_file: pickle.dump(self.state(), final_file) shutil.rmtree(self.temp_dir_name, ignore_errors=True) Loading megatron/indexer.py +13 −8 Original line number Diff line number Diff line import sys import torch import torch.distributed as dist Loading @@ -5,10 +6,11 @@ from megatron import get_args from megatron import mpu from megatron.checkpointing import load_ict_checkpoint from megatron.data.ict_dataset import get_ict_dataset from megatron.data.realm_dataset_utils import get_one_epoch_dataloader from megatron.data.realm_index import detach, BlockData from megatron.data.realm_dataset_utils import get_ict_batch from megatron.model.realm_model import general_ict_model_provider from megatron.data.biencoder_dataset_utils import get_one_epoch_dataloader from megatron.data.realm_index import detach, OpenRetreivalDataStore from megatron.data.biencoder_dataset_utils import get_ict_batch from megatron.model.biencoder_model import biencoder_model_provider #from megatron.model.realm_model import general_ict_model_provider from megatron.training import get_model Loading @@ -34,12 +36,15 @@ class IndexBuilder(object): def load_attributes(self): """Load the necessary attributes: model, dataloader and empty BlockData""" model = get_model(lambda: general_ict_model_provider(only_block_model=True)) self.model = load_ict_checkpoint(model, only_block_model=True, from_realm_chkpt=self.using_realm_chkpt) model = get_model(lambda: biencoder_model_provider(only_context_model=True)) self.model = load_ict_checkpoint(model, only_context_model=True, from_realm_chkpt=self.using_realm_chkpt) sys.exit() self.model.eval() self.dataset = get_ict_dataset() self.dataloader = iter(get_one_epoch_dataloader(self.dataset, self.batch_size)) self.block_data = BlockData(load_from_path=False) self.block_data = OpenRetreivalDataStore(load_from_path=False) print("load_attributes is done", flush=True) sys.exit() def track_and_report_progress(self, batch_size): """Utility function for tracking progress""" Loading Loading
megatron/checkpointing.py +5 −3 Original line number Diff line number Diff line Loading @@ -371,7 +371,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True return iteration def load_ict_checkpoint(model, only_query_model=False, only_block_model=False, from_realm_chkpt=False): def load_ict_checkpoint(model, only_query_model=False, only_context_model=False, from_realm_chkpt=False): """selectively load ICT models for indexing/retrieving from ICT or REALM checkpoints""" args = get_args() Loading @@ -393,14 +393,16 @@ def load_ict_checkpoint(model, only_query_model=False, only_block_model=False, f state_dict = torch.load(checkpoint_name, map_location='cpu') ict_state_dict = state_dict['model'] print(ict_state_dict) sys.exit() if from_realm_chkpt and mpu.get_data_parallel_rank() == 0: print(" loading ICT state dict from REALM", flush=True) ict_state_dict = ict_state_dict['retriever']['ict_model'] if only_query_model: ict_state_dict.pop('context_model') if only_block_model: ict_state_dict.pop('question_model') if only_context_model: ict_state_dict.pop('query_model') model.load_state_dict(ict_state_dict) torch.distributed.barrier() Loading
megatron/data/biencoder_dataset_utils.py +27 −0 Original line number Diff line number Diff line Loading @@ -9,6 +9,33 @@ from megatron.data.dataset_utils import create_masked_lm_predictions, pad_and_co from megatron import get_args, get_tokenizer, print_rank_0, mpu def get_one_epoch_dataloader(dataset, micro_batch_size=None): """Specifically one epoch to be used in an indexing job.""" args = get_args() world_size = mpu.get_data_parallel_world_size() rank = mpu.get_data_parallel_rank() if micro_batch_size is None: micro_batch_size = args.micro_batch_size global_batch_size = micro_batch_size * world_size num_workers = args.num_workers sampler = torch.utils.data.SequentialSampler(dataset) # importantly, drop_last must be False to get all the data. assert False, 'DistributedBatchSampler deprecated, change the implementation' from megatron.data.samplers import DistributedBatchSampler batch_sampler = DistributedBatchSampler(sampler, batch_size=global_batch_size, drop_last=False, rank=rank, world_size=world_size) return torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler, num_workers=num_workers, pin_memory=True) def get_ict_batch(data_iterator): # Items and their type. keys = ['query_tokens', 'query_mask', Loading
megatron/data/realm_index.py +23 −20 Original line number Diff line number Diff line Loading @@ -14,28 +14,29 @@ def detach(tensor): return tensor.detach().cpu().numpy() class BlockData(object): """Serializable data structure for holding data for blocks -- embeddings and necessary metadata for REALM""" def __init__(self, block_data_path=None, load_from_path=True, rank=None): class OpenRetreivalDataStore(object): """Serializable data structure for holding data for blocks -- embeddings and necessary metadata for Retriever""" def __init__(self, embedding_path=None, load_from_path=True, rank=None): self.embed_data = dict() self.meta_data = dict() if block_data_path is None: #self.meta_data = dict() if embedding_path is None: args = get_args() block_data_path = args.block_data_path embedding_path = args.embedding_path rank = args.rank self.block_data_path = block_data_path self.embedding_path = embedding_path self.rank = rank if load_from_path: self.load_from_file() block_data_name = os.path.splitext(self.block_data_path)[0] block_data_name = os.path.splitext(self.embedding_path)[0] self.temp_dir_name = block_data_name + '_tmp' def state(self): return { 'embed_data': self.embed_data, 'meta_data': self.meta_data, #'meta_data': self.meta_data, } def clear(self): Loading @@ -50,26 +51,28 @@ class BlockData(object): if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0: print("\n> Unpickling BlockData", flush=True) state_dict = pickle.load(open(self.block_data_path, 'rb')) state_dict = pickle.load(open(self.embedding_path, 'rb')) if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0: print(">> Finished unpickling BlockData\n", flush=True) self.embed_data = state_dict['embed_data'] self.meta_data = state_dict['meta_data'] #self.meta_data = state_dict['meta_data'] def add_block_data(self, block_indices, block_embeds, block_metas, allow_overwrite=False): #def add_block_data(self, block_indices, block_embeds, block_metas, allow_overwrite=False): def add_block_data(self, row_id, block_embeds, allow_overwrite=False): """Add data for set of blocks :param block_indices: 1D array of unique int ids for the blocks :param row_id: 1D array of unique int ids for the blocks :param block_embeds: 2D array of embeddings of the blocks :param block_metas: 2D array of metadata for the blocks. #:param block_metas: 2D array of metadata for the blocks. In the case of REALM this will be [start_idx, end_idx, doc_idx] """ for idx, embed, meta in zip(block_indices, block_embeds, block_metas): #for idx, embed, meta in zip(block_indices, block_embeds, block_metas): for idx, embed in zip(row_id, block_embeds): if not allow_overwrite and idx in self.embed_data: raise ValueError("Unexpectedly tried to overwrite block data") self.embed_data[idx] = np.float16(embed) self.meta_data[idx] = meta #self.meta_data[idx] = meta def save_shard(self): """Save the block data that was created this in this process""" Loading @@ -77,8 +80,8 @@ class BlockData(object): os.makedirs(self.temp_dir_name, exist_ok=True) # save the data for each shard with open('{}/{}.pkl'.format(self.temp_dir_name, self.rank), 'wb') as data_file: pickle.dump(self.state(), data_file) with open('{}/{}.pkl'.format(self.temp_dir_name, self.rank), 'wb') as writer: pickle.dump(self.state(), writer) def merge_shards_and_save(self): """Combine all the shards made using self.save_shard()""" Loading @@ -98,13 +101,13 @@ class BlockData(object): # add the shard's data and check to make sure there is no overlap self.embed_data.update(data['embed_data']) self.meta_data.update(data['meta_data']) #self.meta_data.update(data['meta_data']) assert len(self.embed_data) == old_size + shard_size assert seen_own_shard # save the consolidated shards and remove temporary directory with open(self.block_data_path, 'wb') as final_file: with open(self.embedding_path, 'wb') as final_file: pickle.dump(self.state(), final_file) shutil.rmtree(self.temp_dir_name, ignore_errors=True) Loading
megatron/indexer.py +13 −8 Original line number Diff line number Diff line import sys import torch import torch.distributed as dist Loading @@ -5,10 +6,11 @@ from megatron import get_args from megatron import mpu from megatron.checkpointing import load_ict_checkpoint from megatron.data.ict_dataset import get_ict_dataset from megatron.data.realm_dataset_utils import get_one_epoch_dataloader from megatron.data.realm_index import detach, BlockData from megatron.data.realm_dataset_utils import get_ict_batch from megatron.model.realm_model import general_ict_model_provider from megatron.data.biencoder_dataset_utils import get_one_epoch_dataloader from megatron.data.realm_index import detach, OpenRetreivalDataStore from megatron.data.biencoder_dataset_utils import get_ict_batch from megatron.model.biencoder_model import biencoder_model_provider #from megatron.model.realm_model import general_ict_model_provider from megatron.training import get_model Loading @@ -34,12 +36,15 @@ class IndexBuilder(object): def load_attributes(self): """Load the necessary attributes: model, dataloader and empty BlockData""" model = get_model(lambda: general_ict_model_provider(only_block_model=True)) self.model = load_ict_checkpoint(model, only_block_model=True, from_realm_chkpt=self.using_realm_chkpt) model = get_model(lambda: biencoder_model_provider(only_context_model=True)) self.model = load_ict_checkpoint(model, only_context_model=True, from_realm_chkpt=self.using_realm_chkpt) sys.exit() self.model.eval() self.dataset = get_ict_dataset() self.dataloader = iter(get_one_epoch_dataloader(self.dataset, self.batch_size)) self.block_data = BlockData(load_from_path=False) self.block_data = OpenRetreivalDataStore(load_from_path=False) print("load_attributes is done", flush=True) sys.exit() def track_and_report_progress(self, batch_size): """Utility function for tracking progress""" Loading