Commit a6e00d97 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'main_retriver_merge_ict_eval' into 'main'

ICT zeroshot evaluation

See merge request ADLR/megatron-lm!248
parents c5346794 fcfd0949
Loading
Loading
Loading
Loading
+36 −0
Original line number Diff line number Diff line
#!/bin/bash

# Evaluate natural question test data given Wikipedia embeddings and pretrained
# ICT model

# Datasets can be downloaded from the following link:
# https://github.com/facebookresearch/DPR/blob/master/data/download_data.py

EVIDENCE_DATA_DIR=<Specify path of Wikipedia dataset>
EMBEDDING_PATH=<Specify path of the embeddings>
CHECKPOINT_PATH=<Specify path of pretrained ICT model>

QA_FILE=<Path of the natural question test dataset>

python tasks/main.py \
    --task ICT-ZEROSHOT-NQ \
    --tokenizer-type BertWordPieceLowerCase \
    --num-layers 12 \
    --hidden-size 768 \
    --num-attention-heads 12 \
    --tensor-model-parallel-size 1 \
    --micro-batch-size 128 \
    --checkpoint-activations \
    --seq-length 512 \
    --max-position-embeddings 512 \
    --load ${CHECKPOINT_PATH} \
    --evidence-data-path ${EVIDENCE_DATA_DIR} \
    --embedding-path ${EMBEDDING_PATH} \
    --retriever-seq-length 256 \
    --vocab-file  bert-vocab.txt\
    --qa-data-test ${QA_FILE} \
    --num-workers 2 \
    --faiss-use-gpu \
    --retriever-report-topk-accuracies 1 5 20 100 \
    --fp16
+0 −2
Original line number Diff line number Diff line
@@ -712,8 +712,6 @@ def _add_biencoder_args(parser):
                        'square root of hidden size')

    # faiss index
    group.add_argument('--faiss-use-gpu', action='store_true',
                       help='Whether create the FaissMIPSIndex on GPU')
    group.add_argument('--block-data-path', type=str, default=None,
                       help='Where to save/load BlockData to/from')
    group.add_argument('--embedding-path', type=str, default=None,
+0 −3
Original line number Diff line number Diff line
@@ -24,11 +24,8 @@ def get_one_epoch_dataloader(dataset, micro_batch_size=None):
    """Specifically one epoch to be used in an indexing job."""
    args = get_args()

    world_size = mpu.get_data_parallel_world_size()
    rank = mpu.get_data_parallel_rank()
    if micro_batch_size is None:
        micro_batch_size = args.micro_batch_size
    global_batch_size = micro_batch_size * world_size
    num_workers = args.num_workers

    # Use megatron's sampler with consumed samples set to 0 as
+57 −52
Original line number Diff line number Diff line
@@ -116,18 +116,22 @@ class OpenRetreivalDataStore(object):


class FaissMIPSIndex(object):
    """Wrapper object for a BlockData which similarity search via FAISS under the hood"""
    def __init__(self, embed_size, block_data=None, use_gpu=False):
    """
    Wrapper object for a BlockData which similarity search via FAISS under the hood
    """
    def __init__(self, embed_size, embed_data=None, use_gpu=False):
        self.embed_size = embed_size
        self.block_data = block_data
        self.embed_data = embed_data
        self.use_gpu = use_gpu
        self.id_map = dict()

        self.block_mips_index = None
        self._set_block_index()
        self.mips_index = None
        self._set_mips_index()

    def _set_block_index(self):
        """Create a Faiss Flat index with inner product as the metric to search against"""
    def _set_mips_index(self):
        """
        Create a Faiss Flat index with inner product as the metric
        to search against
        """
        try:
            import faiss
        except ImportError:
@@ -135,85 +139,86 @@ class FaissMIPSIndex(object):

        if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
            print("\n> Building index", flush=True)
        self.block_mips_index = faiss.index_factory(self.embed_size, 'Flat', faiss.METRIC_INNER_PRODUCT)

        cpu_index = faiss.IndexFlatIP(self.embed_size)

        if self.use_gpu:
            # create resources and config for GpuIndex
            res = faiss.StandardGpuResources()
            config = faiss.GpuIndexFlatConfig()
            config.device = torch.cuda.current_device()
            config = faiss.GpuMultipleClonerOptions()
            config.shard = True
            config.useFloat16 = True

            self.block_mips_index = faiss.GpuIndexFlat(res, self.block_mips_index, config)
            gpu_index = faiss.index_cpu_to_all_gpus(cpu_index, co=config)
            self.mips_index = faiss.IndexIDMap(gpu_index)
            if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
                print(">> Initialized index on GPU {}".format(self.block_mips_index.getDevice()), flush=True)
                print(">> Initialized index on GPU", flush=True)
        else:
            # CPU index supports IDs so wrap with IDMap
            self.block_mips_index = faiss.IndexIDMap(self.block_mips_index)
            self.mips_index = faiss.IndexIDMap(cpu_index)
            if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
                print(">> Initialized index on CPU", flush=True)

        # if we were constructed with a BlockData, then automatically load it when the FAISS structure is built
        if self.block_data is not None:
            self.add_block_embed_data(self.block_data)
        # if we were constructed with a BlockData, then automatically load it
        # when the FAISS structure is built
        if self.embed_data is not None:
            self.add_embed_data(self.embed_data)

    def reset_index(self):
        """Delete existing index and create a new"""
        del self.block_mips_index
        del self.mips_index

        # reset the block data so that _set_block_index will reload it as well
        if self.block_data is not None:
            block_data_path = self.block_data.block_data_path
            del self.block_data
            self.block_data = BlockData(block_data_path)
        if self.embed_data is not None:
            embed_data_path = self.embed_data.embedding_path
            del self.embed_data
            self.embed_data = OpenRetreivalDataStore(embed_data_path)

        self._set_mips_index()

    def update_index(self):
        """Delete existing index and create a new"""
        del self.mips_index

        self._set_block_index()
        # reset the block data so that _set_mips_index will reload it as well
        if self.embed_data is not None:
            self.embed_data.load_from_file()
        self._set_mips_index()

    def add_block_embed_data(self, all_block_data):
    def add_embed_data(self, all_embed_data):
        """Add the embedding of each block to the underlying FAISS index"""

        # this assumes the embed_data is a dict : {int: np.array<float>}
        block_indices, block_embeds = zip(*all_block_data.embed_data.items())

        # the embeddings have to be entered in as float32 even though the math internally is done with float16.
        block_embeds_arr = np.float32(np.array(block_embeds))
        block_indices_arr = np.array(block_indices)
        block_indices, block_embeds = zip(*all_embed_data.embed_data.items())

        # faiss GpuIndex doesn't work with IDMap wrapper so store ids to map back with
        if self.use_gpu:
            for i, idx in enumerate(block_indices):
                self.id_map[i] = idx
        # the embeddings have to be entered in as float32 even though the math
        # internally is done with float16.
        embeds_arr = np.float32(np.array(block_embeds))
        indices_arr = np.array(block_indices)

        # we no longer need the embedding data since it's in the index now
        all_block_data.clear()
        all_embed_data.clear()

        if self.use_gpu:
            self.block_mips_index.add(block_embeds_arr)
        else:
            self.block_mips_index.add_with_ids(block_embeds_arr, block_indices_arr)
        self.mips_index.add_with_ids(embeds_arr, indices_arr)

        if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
            print(">>> Finished adding block data to index", flush=True)

    def search_mips_index(self, query_embeds, top_k, reconstruct=True):
        """Get the top-k blocks by the index distance metric.
        """
        Get the top-k blocks by the index distance metric.

        :param reconstruct: if True: return a [num_queries x k x embed_dim] array of blocks
                            if False: return [num_queries x k] array of distances, and another for indices
        :param reconstruct: if True: return a [num_queries x k x embed_dim]
                                array of blocks
                            if False: return [num_queries x k] array of
                                distances, and another for indices
        """
        query_embeds = np.float32(detach(query_embeds))

        if reconstruct:
            # get the vectors themselves
            top_k_block_embeds = self.block_mips_index.search_and_reconstruct(query_embeds, top_k)
            top_k_block_embeds = self.mips_index.search_and_reconstruct(\
                query_embeds, top_k)
            return top_k_block_embeds

        else:
            # get distances and indices of closest vectors
            distances, block_indices = self.block_mips_index.search(query_embeds, top_k)
            if self.use_gpu:
                fresh_indices = np.zeros(block_indices.shape)
                for i, j in itertools.product(block_indices.shape):
                    fresh_indices[i, j] = self.id_map[block_indices[i, j]]
                block_indices = fresh_indices
            distances, block_indices = self.mips_index.search(query_embeds, top_k)
            return distances, block_indices
+16 −0
Original line number Diff line number Diff line
@@ -47,6 +47,20 @@ def get_tasks_args(parser):
                       help='Sliding window for overlapping evaluation.')
    group.add_argument('--strict-lambada', action='store_true',
                       help='Use more difficult formulation of lambada.')
    # Retriever args
    group.add_argument('--qa-data-dev', type=str, default=None,
                       help='Path to the QA dataset dev file.')
    group.add_argument('--qa-data-test', type=str, default=None,
                       help='Path to the QA dataset test file.')

    # Faiss arguments for retriever
    group.add_argument('--faiss-use-gpu', action='store_true',
                       help='Whether create the FaissMIPSIndex on GPU')
    group.add_argument('--faiss-match', type=str, default='string', \
                        choices=['regex', 'string'], help="Answer matching '\
                        'logic type")
    group.add_argument('--faiss-topk-retrievals', type=int, default=100,
                       help='Number of blocks to use as top-k during retrieval')

    return parser

@@ -62,6 +76,8 @@ if __name__ == '__main__':
        from glue.finetune import main
    elif args.task in ['LAMBADA', 'WIKITEXT103']:
        from zeroshot_gpt.evaluate import main
    elif args.task in ['ICT-ZEROSHOT-NQ']:
        from orqa.evaluate_orqa import main
    else:
        raise NotImplementedError('Task {} is not implemented.'.format(
            args.task))
Loading