Loading examples/evaluate_ict_zeroshot_nq.sh 0 → 100644 +36 −0 Original line number Diff line number Diff line #!/bin/bash # Evaluate natural question test data given Wikipedia embeddings and pretrained # ICT model # Datasets can be downloaded from the following link: # https://github.com/facebookresearch/DPR/blob/master/data/download_data.py EVIDENCE_DATA_DIR=<Specify path of Wikipedia dataset> EMBEDDING_PATH=<Specify path of the embeddings> CHECKPOINT_PATH=<Specify path of pretrained ICT model> QA_FILE=<Path of the natural question test dataset> python tasks/main.py \ --task ICT-ZEROSHOT-NQ \ --tokenizer-type BertWordPieceLowerCase \ --num-layers 12 \ --hidden-size 768 \ --num-attention-heads 12 \ --tensor-model-parallel-size 1 \ --micro-batch-size 128 \ --checkpoint-activations \ --seq-length 512 \ --max-position-embeddings 512 \ --load ${CHECKPOINT_PATH} \ --evidence-data-path ${EVIDENCE_DATA_DIR} \ --embedding-path ${EMBEDDING_PATH} \ --retriever-seq-length 256 \ --vocab-file bert-vocab.txt\ --qa-data-test ${QA_FILE} \ --num-workers 2 \ --faiss-use-gpu \ --retriever-report-topk-accuracies 1 5 20 100 \ --fp16 megatron/arguments.py +0 −2 Original line number Diff line number Diff line Loading @@ -712,8 +712,6 @@ def _add_biencoder_args(parser): 'square root of hidden size') # faiss index group.add_argument('--faiss-use-gpu', action='store_true', help='Whether create the FaissMIPSIndex on GPU') group.add_argument('--block-data-path', type=str, default=None, help='Where to save/load BlockData to/from') group.add_argument('--embedding-path', type=str, default=None, Loading megatron/data/biencoder_dataset_utils.py +0 −3 Original line number Diff line number Diff line Loading @@ -24,11 +24,8 @@ def get_one_epoch_dataloader(dataset, micro_batch_size=None): """Specifically one epoch to be used in an indexing job.""" args = get_args() world_size = mpu.get_data_parallel_world_size() rank = mpu.get_data_parallel_rank() if micro_batch_size is None: micro_batch_size = args.micro_batch_size global_batch_size = micro_batch_size * world_size num_workers = args.num_workers # Use megatron's sampler with consumed samples set to 0 as Loading megatron/data/realm_index.py +57 −52 Original line number Diff line number Diff line Loading @@ -116,18 +116,22 @@ class OpenRetreivalDataStore(object): class FaissMIPSIndex(object): """Wrapper object for a BlockData which similarity search via FAISS under the hood""" def __init__(self, embed_size, block_data=None, use_gpu=False): """ Wrapper object for a BlockData which similarity search via FAISS under the hood """ def __init__(self, embed_size, embed_data=None, use_gpu=False): self.embed_size = embed_size self.block_data = block_data self.embed_data = embed_data self.use_gpu = use_gpu self.id_map = dict() self.block_mips_index = None self._set_block_index() self.mips_index = None self._set_mips_index() def _set_block_index(self): """Create a Faiss Flat index with inner product as the metric to search against""" def _set_mips_index(self): """ Create a Faiss Flat index with inner product as the metric to search against """ try: import faiss except ImportError: Loading @@ -135,85 +139,86 @@ class FaissMIPSIndex(object): if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0: print("\n> Building index", flush=True) self.block_mips_index = faiss.index_factory(self.embed_size, 'Flat', faiss.METRIC_INNER_PRODUCT) cpu_index = faiss.IndexFlatIP(self.embed_size) if self.use_gpu: # create resources and config for GpuIndex res = faiss.StandardGpuResources() config = faiss.GpuIndexFlatConfig() config.device = torch.cuda.current_device() config = faiss.GpuMultipleClonerOptions() config.shard = True config.useFloat16 = True self.block_mips_index = faiss.GpuIndexFlat(res, self.block_mips_index, config) gpu_index = faiss.index_cpu_to_all_gpus(cpu_index, co=config) self.mips_index = faiss.IndexIDMap(gpu_index) if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0: print(">> Initialized index on GPU {}".format(self.block_mips_index.getDevice()), flush=True) print(">> Initialized index on GPU", flush=True) else: # CPU index supports IDs so wrap with IDMap self.block_mips_index = faiss.IndexIDMap(self.block_mips_index) self.mips_index = faiss.IndexIDMap(cpu_index) if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0: print(">> Initialized index on CPU", flush=True) # if we were constructed with a BlockData, then automatically load it when the FAISS structure is built if self.block_data is not None: self.add_block_embed_data(self.block_data) # if we were constructed with a BlockData, then automatically load it # when the FAISS structure is built if self.embed_data is not None: self.add_embed_data(self.embed_data) def reset_index(self): """Delete existing index and create a new""" del self.block_mips_index del self.mips_index # reset the block data so that _set_block_index will reload it as well if self.block_data is not None: block_data_path = self.block_data.block_data_path del self.block_data self.block_data = BlockData(block_data_path) if self.embed_data is not None: embed_data_path = self.embed_data.embedding_path del self.embed_data self.embed_data = OpenRetreivalDataStore(embed_data_path) self._set_mips_index() def update_index(self): """Delete existing index and create a new""" del self.mips_index self._set_block_index() # reset the block data so that _set_mips_index will reload it as well if self.embed_data is not None: self.embed_data.load_from_file() self._set_mips_index() def add_block_embed_data(self, all_block_data): def add_embed_data(self, all_embed_data): """Add the embedding of each block to the underlying FAISS index""" # this assumes the embed_data is a dict : {int: np.array<float>} block_indices, block_embeds = zip(*all_block_data.embed_data.items()) # the embeddings have to be entered in as float32 even though the math internally is done with float16. block_embeds_arr = np.float32(np.array(block_embeds)) block_indices_arr = np.array(block_indices) block_indices, block_embeds = zip(*all_embed_data.embed_data.items()) # faiss GpuIndex doesn't work with IDMap wrapper so store ids to map back with if self.use_gpu: for i, idx in enumerate(block_indices): self.id_map[i] = idx # the embeddings have to be entered in as float32 even though the math # internally is done with float16. embeds_arr = np.float32(np.array(block_embeds)) indices_arr = np.array(block_indices) # we no longer need the embedding data since it's in the index now all_block_data.clear() all_embed_data.clear() if self.use_gpu: self.block_mips_index.add(block_embeds_arr) else: self.block_mips_index.add_with_ids(block_embeds_arr, block_indices_arr) self.mips_index.add_with_ids(embeds_arr, indices_arr) if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0: print(">>> Finished adding block data to index", flush=True) def search_mips_index(self, query_embeds, top_k, reconstruct=True): """Get the top-k blocks by the index distance metric. """ Get the top-k blocks by the index distance metric. :param reconstruct: if True: return a [num_queries x k x embed_dim] array of blocks if False: return [num_queries x k] array of distances, and another for indices :param reconstruct: if True: return a [num_queries x k x embed_dim] array of blocks if False: return [num_queries x k] array of distances, and another for indices """ query_embeds = np.float32(detach(query_embeds)) if reconstruct: # get the vectors themselves top_k_block_embeds = self.block_mips_index.search_and_reconstruct(query_embeds, top_k) top_k_block_embeds = self.mips_index.search_and_reconstruct(\ query_embeds, top_k) return top_k_block_embeds else: # get distances and indices of closest vectors distances, block_indices = self.block_mips_index.search(query_embeds, top_k) if self.use_gpu: fresh_indices = np.zeros(block_indices.shape) for i, j in itertools.product(block_indices.shape): fresh_indices[i, j] = self.id_map[block_indices[i, j]] block_indices = fresh_indices distances, block_indices = self.mips_index.search(query_embeds, top_k) return distances, block_indices tasks/main.py +16 −0 Original line number Diff line number Diff line Loading @@ -47,6 +47,20 @@ def get_tasks_args(parser): help='Sliding window for overlapping evaluation.') group.add_argument('--strict-lambada', action='store_true', help='Use more difficult formulation of lambada.') # Retriever args group.add_argument('--qa-data-dev', type=str, default=None, help='Path to the QA dataset dev file.') group.add_argument('--qa-data-test', type=str, default=None, help='Path to the QA dataset test file.') # Faiss arguments for retriever group.add_argument('--faiss-use-gpu', action='store_true', help='Whether create the FaissMIPSIndex on GPU') group.add_argument('--faiss-match', type=str, default='string', \ choices=['regex', 'string'], help="Answer matching '\ 'logic type") group.add_argument('--faiss-topk-retrievals', type=int, default=100, help='Number of blocks to use as top-k during retrieval') return parser Loading @@ -62,6 +76,8 @@ if __name__ == '__main__': from glue.finetune import main elif args.task in ['LAMBADA', 'WIKITEXT103']: from zeroshot_gpt.evaluate import main elif args.task in ['ICT-ZEROSHOT-NQ']: from orqa.evaluate_orqa import main else: raise NotImplementedError('Task {} is not implemented.'.format( args.task)) Loading Loading
examples/evaluate_ict_zeroshot_nq.sh 0 → 100644 +36 −0 Original line number Diff line number Diff line #!/bin/bash # Evaluate natural question test data given Wikipedia embeddings and pretrained # ICT model # Datasets can be downloaded from the following link: # https://github.com/facebookresearch/DPR/blob/master/data/download_data.py EVIDENCE_DATA_DIR=<Specify path of Wikipedia dataset> EMBEDDING_PATH=<Specify path of the embeddings> CHECKPOINT_PATH=<Specify path of pretrained ICT model> QA_FILE=<Path of the natural question test dataset> python tasks/main.py \ --task ICT-ZEROSHOT-NQ \ --tokenizer-type BertWordPieceLowerCase \ --num-layers 12 \ --hidden-size 768 \ --num-attention-heads 12 \ --tensor-model-parallel-size 1 \ --micro-batch-size 128 \ --checkpoint-activations \ --seq-length 512 \ --max-position-embeddings 512 \ --load ${CHECKPOINT_PATH} \ --evidence-data-path ${EVIDENCE_DATA_DIR} \ --embedding-path ${EMBEDDING_PATH} \ --retriever-seq-length 256 \ --vocab-file bert-vocab.txt\ --qa-data-test ${QA_FILE} \ --num-workers 2 \ --faiss-use-gpu \ --retriever-report-topk-accuracies 1 5 20 100 \ --fp16
megatron/arguments.py +0 −2 Original line number Diff line number Diff line Loading @@ -712,8 +712,6 @@ def _add_biencoder_args(parser): 'square root of hidden size') # faiss index group.add_argument('--faiss-use-gpu', action='store_true', help='Whether create the FaissMIPSIndex on GPU') group.add_argument('--block-data-path', type=str, default=None, help='Where to save/load BlockData to/from') group.add_argument('--embedding-path', type=str, default=None, Loading
megatron/data/biencoder_dataset_utils.py +0 −3 Original line number Diff line number Diff line Loading @@ -24,11 +24,8 @@ def get_one_epoch_dataloader(dataset, micro_batch_size=None): """Specifically one epoch to be used in an indexing job.""" args = get_args() world_size = mpu.get_data_parallel_world_size() rank = mpu.get_data_parallel_rank() if micro_batch_size is None: micro_batch_size = args.micro_batch_size global_batch_size = micro_batch_size * world_size num_workers = args.num_workers # Use megatron's sampler with consumed samples set to 0 as Loading
megatron/data/realm_index.py +57 −52 Original line number Diff line number Diff line Loading @@ -116,18 +116,22 @@ class OpenRetreivalDataStore(object): class FaissMIPSIndex(object): """Wrapper object for a BlockData which similarity search via FAISS under the hood""" def __init__(self, embed_size, block_data=None, use_gpu=False): """ Wrapper object for a BlockData which similarity search via FAISS under the hood """ def __init__(self, embed_size, embed_data=None, use_gpu=False): self.embed_size = embed_size self.block_data = block_data self.embed_data = embed_data self.use_gpu = use_gpu self.id_map = dict() self.block_mips_index = None self._set_block_index() self.mips_index = None self._set_mips_index() def _set_block_index(self): """Create a Faiss Flat index with inner product as the metric to search against""" def _set_mips_index(self): """ Create a Faiss Flat index with inner product as the metric to search against """ try: import faiss except ImportError: Loading @@ -135,85 +139,86 @@ class FaissMIPSIndex(object): if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0: print("\n> Building index", flush=True) self.block_mips_index = faiss.index_factory(self.embed_size, 'Flat', faiss.METRIC_INNER_PRODUCT) cpu_index = faiss.IndexFlatIP(self.embed_size) if self.use_gpu: # create resources and config for GpuIndex res = faiss.StandardGpuResources() config = faiss.GpuIndexFlatConfig() config.device = torch.cuda.current_device() config = faiss.GpuMultipleClonerOptions() config.shard = True config.useFloat16 = True self.block_mips_index = faiss.GpuIndexFlat(res, self.block_mips_index, config) gpu_index = faiss.index_cpu_to_all_gpus(cpu_index, co=config) self.mips_index = faiss.IndexIDMap(gpu_index) if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0: print(">> Initialized index on GPU {}".format(self.block_mips_index.getDevice()), flush=True) print(">> Initialized index on GPU", flush=True) else: # CPU index supports IDs so wrap with IDMap self.block_mips_index = faiss.IndexIDMap(self.block_mips_index) self.mips_index = faiss.IndexIDMap(cpu_index) if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0: print(">> Initialized index on CPU", flush=True) # if we were constructed with a BlockData, then automatically load it when the FAISS structure is built if self.block_data is not None: self.add_block_embed_data(self.block_data) # if we were constructed with a BlockData, then automatically load it # when the FAISS structure is built if self.embed_data is not None: self.add_embed_data(self.embed_data) def reset_index(self): """Delete existing index and create a new""" del self.block_mips_index del self.mips_index # reset the block data so that _set_block_index will reload it as well if self.block_data is not None: block_data_path = self.block_data.block_data_path del self.block_data self.block_data = BlockData(block_data_path) if self.embed_data is not None: embed_data_path = self.embed_data.embedding_path del self.embed_data self.embed_data = OpenRetreivalDataStore(embed_data_path) self._set_mips_index() def update_index(self): """Delete existing index and create a new""" del self.mips_index self._set_block_index() # reset the block data so that _set_mips_index will reload it as well if self.embed_data is not None: self.embed_data.load_from_file() self._set_mips_index() def add_block_embed_data(self, all_block_data): def add_embed_data(self, all_embed_data): """Add the embedding of each block to the underlying FAISS index""" # this assumes the embed_data is a dict : {int: np.array<float>} block_indices, block_embeds = zip(*all_block_data.embed_data.items()) # the embeddings have to be entered in as float32 even though the math internally is done with float16. block_embeds_arr = np.float32(np.array(block_embeds)) block_indices_arr = np.array(block_indices) block_indices, block_embeds = zip(*all_embed_data.embed_data.items()) # faiss GpuIndex doesn't work with IDMap wrapper so store ids to map back with if self.use_gpu: for i, idx in enumerate(block_indices): self.id_map[i] = idx # the embeddings have to be entered in as float32 even though the math # internally is done with float16. embeds_arr = np.float32(np.array(block_embeds)) indices_arr = np.array(block_indices) # we no longer need the embedding data since it's in the index now all_block_data.clear() all_embed_data.clear() if self.use_gpu: self.block_mips_index.add(block_embeds_arr) else: self.block_mips_index.add_with_ids(block_embeds_arr, block_indices_arr) self.mips_index.add_with_ids(embeds_arr, indices_arr) if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0: print(">>> Finished adding block data to index", flush=True) def search_mips_index(self, query_embeds, top_k, reconstruct=True): """Get the top-k blocks by the index distance metric. """ Get the top-k blocks by the index distance metric. :param reconstruct: if True: return a [num_queries x k x embed_dim] array of blocks if False: return [num_queries x k] array of distances, and another for indices :param reconstruct: if True: return a [num_queries x k x embed_dim] array of blocks if False: return [num_queries x k] array of distances, and another for indices """ query_embeds = np.float32(detach(query_embeds)) if reconstruct: # get the vectors themselves top_k_block_embeds = self.block_mips_index.search_and_reconstruct(query_embeds, top_k) top_k_block_embeds = self.mips_index.search_and_reconstruct(\ query_embeds, top_k) return top_k_block_embeds else: # get distances and indices of closest vectors distances, block_indices = self.block_mips_index.search(query_embeds, top_k) if self.use_gpu: fresh_indices = np.zeros(block_indices.shape) for i, j in itertools.product(block_indices.shape): fresh_indices[i, j] = self.id_map[block_indices[i, j]] block_indices = fresh_indices distances, block_indices = self.mips_index.search(query_embeds, top_k) return distances, block_indices
tasks/main.py +16 −0 Original line number Diff line number Diff line Loading @@ -47,6 +47,20 @@ def get_tasks_args(parser): help='Sliding window for overlapping evaluation.') group.add_argument('--strict-lambada', action='store_true', help='Use more difficult formulation of lambada.') # Retriever args group.add_argument('--qa-data-dev', type=str, default=None, help='Path to the QA dataset dev file.') group.add_argument('--qa-data-test', type=str, default=None, help='Path to the QA dataset test file.') # Faiss arguments for retriever group.add_argument('--faiss-use-gpu', action='store_true', help='Whether create the FaissMIPSIndex on GPU') group.add_argument('--faiss-match', type=str, default='string', \ choices=['regex', 'string'], help="Answer matching '\ 'logic type") group.add_argument('--faiss-topk-retrievals', type=int, default=100, help='Number of blocks to use as top-k during retrieval') return parser Loading @@ -62,6 +76,8 @@ if __name__ == '__main__': from glue.finetune import main elif args.task in ['LAMBADA', 'WIKITEXT103']: from zeroshot_gpt.evaluate import main elif args.task in ['ICT-ZEROSHOT-NQ']: from orqa.evaluate_orqa import main else: raise NotImplementedError('Task {} is not implemented.'.format( args.task)) Loading