Commit 612f438a authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

evaluation code ongoing

parent ebc95c35
Loading
Loading
Loading
Loading
+5 −3
Original line number Diff line number Diff line
@@ -371,7 +371,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
    return iteration


def load_ict_checkpoint(model, only_query_model=False, only_block_model=False, from_realm_chkpt=False):
def load_ict_checkpoint(model, only_query_model=False, only_context_model=False, from_realm_chkpt=False):
    """selectively load ICT models for indexing/retrieving from ICT or REALM checkpoints"""

    args = get_args()
@@ -393,14 +393,16 @@ def load_ict_checkpoint(model, only_query_model=False, only_block_model=False, f

    state_dict = torch.load(checkpoint_name, map_location='cpu')
    ict_state_dict = state_dict['model']
    print(ict_state_dict)
    sys.exit()
    if from_realm_chkpt and mpu.get_data_parallel_rank() == 0:
        print(" loading ICT state dict from REALM", flush=True)
        ict_state_dict = ict_state_dict['retriever']['ict_model']

    if only_query_model:
        ict_state_dict.pop('context_model')
    if only_block_model:
        ict_state_dict.pop('question_model')
    if only_context_model:
        ict_state_dict.pop('query_model')

    model.load_state_dict(ict_state_dict)
    torch.distributed.barrier()
+27 −0
Original line number Diff line number Diff line
@@ -9,6 +9,33 @@ from megatron.data.dataset_utils import create_masked_lm_predictions, pad_and_co
from megatron import get_args, get_tokenizer, print_rank_0, mpu


def get_one_epoch_dataloader(dataset, micro_batch_size=None):
    """Specifically one epoch to be used in an indexing job."""
    args = get_args()

    world_size = mpu.get_data_parallel_world_size()
    rank = mpu.get_data_parallel_rank()
    if micro_batch_size is None:
        micro_batch_size = args.micro_batch_size
    global_batch_size = micro_batch_size * world_size
    num_workers = args.num_workers

    sampler = torch.utils.data.SequentialSampler(dataset)
    # importantly, drop_last must be False to get all the data.
    assert False, 'DistributedBatchSampler deprecated, change the implementation'
    from megatron.data.samplers import DistributedBatchSampler
    batch_sampler = DistributedBatchSampler(sampler,
                                            batch_size=global_batch_size,
                                            drop_last=False,
                                            rank=rank,
                                            world_size=world_size)

    return torch.utils.data.DataLoader(dataset,
                                       batch_sampler=batch_sampler,
                                       num_workers=num_workers,
                                       pin_memory=True)


def get_ict_batch(data_iterator):
    # Items and their type.
    keys = ['query_tokens', 'query_mask',
+23 −20
Original line number Diff line number Diff line
@@ -14,28 +14,29 @@ def detach(tensor):
    return tensor.detach().cpu().numpy()


class BlockData(object):
    """Serializable data structure for holding data for blocks -- embeddings and necessary metadata for REALM"""
    def __init__(self, block_data_path=None, load_from_path=True, rank=None):
class OpenRetreivalDataStore(object):
    """Serializable data structure for holding data for blocks -- embeddings 
    and necessary metadata for Retriever"""
    def __init__(self, embedding_path=None, load_from_path=True, rank=None):
        self.embed_data = dict()
        self.meta_data = dict()
        if block_data_path is None:
        #self.meta_data = dict()
        if embedding_path is None:
            args = get_args()
            block_data_path = args.block_data_path
            embedding_path = args.embedding_path
            rank = args.rank
        self.block_data_path = block_data_path
        self.embedding_path = embedding_path
        self.rank = rank

        if load_from_path:
            self.load_from_file()

        block_data_name = os.path.splitext(self.block_data_path)[0]
        block_data_name = os.path.splitext(self.embedding_path)[0]
        self.temp_dir_name = block_data_name + '_tmp'

    def state(self):
        return {
            'embed_data': self.embed_data,
            'meta_data': self.meta_data,
            #'meta_data': self.meta_data,
        }

    def clear(self):
@@ -50,26 +51,28 @@ class BlockData(object):

        if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
            print("\n> Unpickling BlockData", flush=True)
        state_dict = pickle.load(open(self.block_data_path, 'rb'))
        state_dict = pickle.load(open(self.embedding_path, 'rb'))
        if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
            print(">> Finished unpickling BlockData\n", flush=True)

        self.embed_data = state_dict['embed_data']
        self.meta_data = state_dict['meta_data']
        #self.meta_data = state_dict['meta_data']

    def add_block_data(self, block_indices, block_embeds, block_metas, allow_overwrite=False):
    #def add_block_data(self, block_indices, block_embeds, block_metas, allow_overwrite=False):
    def add_block_data(self, row_id, block_embeds, allow_overwrite=False):
        """Add data for set of blocks
        :param block_indices: 1D array of unique int ids for the blocks
        :param row_id: 1D array of unique int ids for the blocks
        :param block_embeds: 2D array of embeddings of the blocks
        :param block_metas: 2D array of metadata for the blocks.
        #:param block_metas: 2D array of metadata for the blocks.
            In the case of REALM this will be [start_idx, end_idx, doc_idx]
        """
        for idx, embed, meta in zip(block_indices, block_embeds, block_metas):
        #for idx, embed, meta in zip(block_indices, block_embeds, block_metas):
        for idx, embed in zip(row_id, block_embeds):
            if not allow_overwrite and idx in self.embed_data:
                raise ValueError("Unexpectedly tried to overwrite block data")

            self.embed_data[idx] = np.float16(embed)
            self.meta_data[idx] = meta
            #self.meta_data[idx] = meta

    def save_shard(self):
        """Save the block data that was created this in this process"""
@@ -77,8 +80,8 @@ class BlockData(object):
            os.makedirs(self.temp_dir_name, exist_ok=True)

        # save the data for each shard
        with open('{}/{}.pkl'.format(self.temp_dir_name, self.rank), 'wb') as data_file:
            pickle.dump(self.state(), data_file)
        with open('{}/{}.pkl'.format(self.temp_dir_name, self.rank), 'wb') as writer:
            pickle.dump(self.state(), writer)

    def merge_shards_and_save(self):
        """Combine all the shards made using self.save_shard()"""
@@ -98,13 +101,13 @@ class BlockData(object):

                # add the shard's data and check to make sure there is no overlap
                self.embed_data.update(data['embed_data'])
                self.meta_data.update(data['meta_data'])
                #self.meta_data.update(data['meta_data'])
                assert len(self.embed_data) == old_size + shard_size

        assert seen_own_shard

        # save the consolidated shards and remove temporary directory
        with open(self.block_data_path, 'wb') as final_file:
        with open(self.embedding_path, 'wb') as final_file:
            pickle.dump(self.state(), final_file)
        shutil.rmtree(self.temp_dir_name, ignore_errors=True)

+13 −8
Original line number Diff line number Diff line
import sys
import torch
import torch.distributed as dist

@@ -5,10 +6,11 @@ from megatron import get_args
from megatron import mpu
from megatron.checkpointing import load_ict_checkpoint
from megatron.data.ict_dataset import get_ict_dataset
from megatron.data.realm_dataset_utils import get_one_epoch_dataloader
from megatron.data.realm_index import detach, BlockData
from megatron.data.realm_dataset_utils import get_ict_batch
from megatron.model.realm_model import general_ict_model_provider
from megatron.data.biencoder_dataset_utils import get_one_epoch_dataloader
from megatron.data.realm_index import detach, OpenRetreivalDataStore
from megatron.data.biencoder_dataset_utils import get_ict_batch
from megatron.model.biencoder_model import biencoder_model_provider
#from megatron.model.realm_model import general_ict_model_provider
from megatron.training import get_model


@@ -34,12 +36,15 @@ class IndexBuilder(object):

    def load_attributes(self):
        """Load the necessary attributes: model, dataloader and empty BlockData"""
        model = get_model(lambda: general_ict_model_provider(only_block_model=True))
        self.model = load_ict_checkpoint(model, only_block_model=True, from_realm_chkpt=self.using_realm_chkpt)
        model = get_model(lambda: biencoder_model_provider(only_context_model=True))
        self.model = load_ict_checkpoint(model, only_context_model=True, from_realm_chkpt=self.using_realm_chkpt)
        sys.exit()
        self.model.eval()
        self.dataset = get_ict_dataset()
        self.dataloader = iter(get_one_epoch_dataloader(self.dataset, self.batch_size))
        self.block_data = BlockData(load_from_path=False)
        self.block_data = OpenRetreivalDataStore(load_from_path=False)
        print("load_attributes is done", flush=True)
        sys.exit()
 
    def track_and_report_progress(self, batch_size):
        """Utility function for tracking progress"""