Commit eaa5d877 authored by Neel Kant's avatar Neel Kant
Browse files

Add additional assertion on Indexer to test correctness, and limit verbosity in other classes

parent c2a32e12
Loading
Loading
Loading
Loading
+14 −7
Original line number Diff line number Diff line
@@ -7,6 +7,7 @@ import numpy as np
import torch

from megatron import get_args
from megatron import mpu


def detach(tensor):
@@ -47,8 +48,10 @@ class BlockData(object):
    def load_from_file(self):
        """Populate members from instance saved to file"""

        if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
            print("\n> Unpickling BlockData", flush=True)
        state_dict = pickle.load(open(self.block_data_path, 'rb'))
        if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
            print(">> Finished unpickling BlockData\n", flush=True)

        self.embed_data = state_dict['embed_data']
@@ -127,6 +130,7 @@ class FaissMIPSIndex(object):
        except ImportError:
            raise Exception("Error: Please install faiss to use FaissMIPSIndex")

        if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
            print("\n> Building index", flush=True)
        self.block_mips_index = faiss.index_factory(self.embed_size, 'Flat', faiss.METRIC_INNER_PRODUCT)

@@ -138,10 +142,12 @@ class FaissMIPSIndex(object):
            config.useFloat16 = True

            self.block_mips_index = faiss.GpuIndexFlat(res, self.block_mips_index, config)
            if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
                print(">> Initialized index on GPU {}".format(self.block_mips_index.getDevice()), flush=True)
        else:
            # CPU index supports IDs so wrap with IDMap
            self.block_mips_index = faiss.IndexIDMap(self.block_mips_index)
            if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
                print(">> Initialized index on CPU", flush=True)

        # if we were constructed with a BlockData, then automatically load it when the FAISS structure is built
@@ -156,7 +162,7 @@ class FaissMIPSIndex(object):
        if self.block_data is not None:
            block_data_path = self.block_data.block_data_path
            del self.block_data
            self.block_data = BlockData.load_from_file(block_data_path)
            self.block_data = BlockData(block_data_path)

        self._set_block_index()

@@ -183,6 +189,7 @@ class FaissMIPSIndex(object):
        else:
            self.block_mips_index.add_with_ids(block_embeds_arr, block_indices_arr)

        if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
            print(">>> Finished adding block data to index", flush=True)

    def search_mips_index(self, query_embeds, top_k, reconstruct=True):
+5 −4
Original line number Diff line number Diff line
@@ -37,7 +37,8 @@ class IndexBuilder(object):
        model = get_model(lambda: general_ict_model_provider(only_block_model=True))
        self.model = load_ict_checkpoint(model, only_block_model=True, from_realm_chkpt=self.using_realm_chkpt)
        self.model.eval()
        self.dataloader = iter(get_one_epoch_dataloader(get_ict_dataset(), self.batch_size))
        self.dataset = get_ict_dataset()
        self.dataloader = iter(get_one_epoch_dataloader(self.dataset, self.batch_size))
        self.block_data = BlockData(load_from_path=False)

    def track_and_report_progress(self, batch_size):
@@ -58,7 +59,7 @@ class IndexBuilder(object):
            try:
                # batch also has query_tokens and query_pad_data
                _, _, block_tokens, block_pad_mask, block_sample_data = get_ict_batch(self.dataloader)
            except StopIteration:
            except (StopIteration, IndexError):
                break

            unwrapped_model = self.model
@@ -85,6 +86,6 @@ class IndexBuilder(object):
        # rank 0 process builds the final copy
        if self.is_main_builder:
            self.block_data.merge_shards_and_save()
            # make sure that every single piece of data was embedded
            assert len(self.block_data.embed_data) == len(self.dataset)
        self.block_data.clear()

+1 −0
Original line number Diff line number Diff line
@@ -21,6 +21,7 @@ from .data import broadcast_data

from .grads import clip_grad_norm

from .initialize import is_unitialized
from .initialize import destroy_model_parallel
from .initialize import get_data_parallel_group
from .initialize import get_data_parallel_rank
+5 −0
Original line number Diff line number Diff line
@@ -31,6 +31,11 @@ _MPU_WORLD_SIZE = None
_MPU_RANK = None


def is_unitialized():
    """Useful for code segments that may be accessed with or without mpu initialization"""
    return _DATA_PARALLEL_GROUP is None


def initialize_model_parallel(model_parallel_size_):
    """
    Initialize model data parallel groups.