Commit 0295bb89 authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

WIP: main_retriver_merge

parent 17d897e0
Loading
Loading
Loading
Loading
+202 −0
Original line number Diff line number Diff line
import os
import time

import numpy as np
import torch

from megatron import mpu, print_rank_0
from megatron.data.dataset_utils import create_masked_lm_predictions, pad_and_convert_to_numpy
from megatron import get_args, get_tokenizer, print_rank_0, mpu


def get_one_epoch_dataloader(dataset, micro_batch_size=None):
    """Specifically one epoch to be used in an indexing job."""
    args = get_args()

    world_size = mpu.get_data_parallel_world_size()
    rank = mpu.get_data_parallel_rank()
    if micro_batch_size is None:
        micro_batch_size = args.micro_batch_size
    global_batch_size = micro_batch_size * world_size
    num_workers = args.num_workers

    sampler = torch.utils.data.SequentialSampler(dataset)
    # importantly, drop_last must be False to get all the data.
    assert False, 'DistributedBatchSampler deprecated, change the implementation'
    from megatron.data.samplers import DistributedBatchSampler
    batch_sampler = DistributedBatchSampler(sampler,
                                            batch_size=global_batch_size,
                                            drop_last=False,
                                            rank=rank,
                                            world_size=world_size)

    return torch.utils.data.DataLoader(dataset,
                                       batch_sampler=batch_sampler,
                                       num_workers=num_workers,
                                       pin_memory=True)


def get_ict_batch(data_iterator):
    # Items and their type.
    keys = ['query_tokens', 'query_mask',
            'context_tokens', 'context_mask', 'block_data']
    datatype = torch.int64

    # Broadcast data.
    if data_iterator is None:
        data = None
    else:
        data = next(data_iterator)
    data_b = mpu.broadcast_data(keys, data, datatype)

    # Unpack.
    query_tokens = data_b['query_tokens'].long()
    query_mask = data_b['query_mask'] < 0.5
    context_tokens = data_b['context_tokens'].long()
    context_mask = data_b['context_mask'] < 0.5
    block_indices = data_b['block_data'].long()

    return query_tokens, query_mask,\
           context_tokens, context_mask, block_indices


def join_str_list(str_list):
    """Join a list of strings, handling spaces appropriately"""
    result = ""
    for s in str_list:
        if s.startswith("##"):
            result += s[2:]
        else:
            result += " " + s
    return result


class BlockSampleData(object):
    """A struct for fully describing a fixed-size block of data as used in REALM

    :param start_idx: for first sentence of the block
    :param end_idx: for last sentence of the block (may be partially truncated in sample construction)
    :param doc_idx: the index of the document from which the block comes in the original indexed dataset
    :param block_idx: a unique integer identifier given to every block.
    """
    def __init__(self, start_idx, end_idx, doc_idx, block_idx):
        self.start_idx = start_idx
        self.end_idx = end_idx
        self.doc_idx = doc_idx
        self.block_idx = block_idx

    def as_array(self):
        return np.array([self.start_idx, self.end_idx, self.doc_idx, self.block_idx]).astype(np.int64)

    def as_tuple(self):
        return self.start_idx, self.end_idx, self.doc_idx, self.block_idx


class BlockSamplesMapping(object):
    def __init__(self, mapping_array):
        # make sure that the array is compatible with BlockSampleData
        assert mapping_array.shape[1] == 4
        self.mapping_array = mapping_array

    def __len__(self):
        return self.mapping_array.shape[0]

    def __getitem__(self, idx):
        """Get the data associated with an indexed sample."""
        sample_data = BlockSampleData(*self.mapping_array[idx])
        return sample_data


def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epochs,
                              max_num_samples, max_seq_length, seed, name, use_one_sent_docs=False):
    """Get samples mapping for a dataset over fixed size blocks. This function also requires
    a dataset of the titles for the source documents since their lengths must be taken into account.

    :return: samples_mapping (BlockSamplesMapping)
    """

    if not num_epochs:
        if not max_num_samples:
            raise ValueError("Need to specify either max_num_samples "
                             "or num_epochs")
        num_epochs = np.iinfo(np.int32).max - 1
    if not max_num_samples:
        max_num_samples = np.iinfo(np.int64).max - 1

    # Filename of the index mapping
    indexmap_filename = data_prefix
    indexmap_filename += '_{}_indexmap'.format(name)
    if num_epochs != (np.iinfo(np.int32).max - 1):
        indexmap_filename += '_{}ep'.format(num_epochs)
    if max_num_samples != (np.iinfo(np.int64).max - 1):
        indexmap_filename += '_{}mns'.format(max_num_samples)
    indexmap_filename += '_{}msl'.format(max_seq_length)
    indexmap_filename += '_{}s'.format(seed)
    if use_one_sent_docs:
        indexmap_filename += '_1sentok'
    indexmap_filename += '.npy'

    # Build the indexed mapping if not exist.
    if mpu.get_data_parallel_rank() == 0 and \
            not os.path.isfile(indexmap_filename):
        print(' > WARNING: could not find index map file {}, building '
              'the indices on rank 0 ...'.format(indexmap_filename))

        # Make sure the types match the helpers input types.
        assert block_dataset.doc_idx.dtype == np.int64
        assert block_dataset.sizes.dtype == np.int32

        # Build samples mapping
        verbose = torch.distributed.get_rank() == 0
        start_time = time.time()
        print_rank_0(' > building samples index mapping for {} ...'.format(
            name))

        # compile/bind the C++ helper code
        from megatron.data.dataset_utils import compile_helper
        compile_helper()

        from megatron.data import helpers
        mapping_array = helpers.build_blocks_mapping(
            block_dataset.doc_idx,
            block_dataset.sizes,
            title_dataset.sizes,
            num_epochs,
            max_num_samples,
            max_seq_length - 3,  # account for added tokens
            seed,
            verbose,
            use_one_sent_docs)


        print_rank_0(' > done building samples index mapping')
        np.save(indexmap_filename, mapping_array, allow_pickle=True)
        print_rank_0(' > saved the index mapping in {}'.format(
            indexmap_filename))
        # Make sure all the ranks have built the mapping
        print_rank_0(' > elapsed time to build and save samples mapping '
                     '(seconds): {:4f}'.format(
            time.time() - start_time))

    # This should be a barrier but nccl barrier assumes
    # device_index=rank which is not the case for model
    # parallel case
    counts = torch.cuda.LongTensor([1])
    torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
    assert counts[0].item() == torch.distributed.get_world_size(
        group=mpu.get_data_parallel_group())

    # Load indexed dataset.
    print_rank_0(' > loading indexed mapping from {}'.format(
        indexmap_filename))
    start_time = time.time()

    mapping_array = np.load(indexmap_filename, allow_pickle=True, mmap_mode='r')
    samples_mapping = BlockSamplesMapping(mapping_array)

    print_rank_0('    loaded indexed file in {:3.3f} seconds'.format(
        time.time() - start_time))
    print_rank_0('    total number of samples: {}'.format(
        mapping_array.shape[0]))

    return samples_mapping
+291 −0
Original line number Diff line number Diff line
import os
import torch
import sys

from megatron import get_args, print_rank_0
from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoint_name
from megatron.module import MegatronModule
from megatron import mpu, get_tokenizer
from megatron.model.bert_model import bert_attention_mask_func
from megatron.model.bert_model import bert_extended_attention_mask
from megatron.model.bert_model import bert_position_ids
from megatron.model.language_model import get_language_model
from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal


def biencoder_model_provider(only_query_model=False,
                             only_context_model=False,
                             shared_query_context_model=False):
    """Build the model."""
    args = get_args()

    assert mpu.get_tensor_model_parallel_world_size() == 1 and \
        mpu.get_pipeline_model_parallel_world_size() == 1, \
        "Model parallel size > 1 not supported for ICT"

    print_rank_0('building BiEncoderModel...')

    # simpler to just keep using 2 tokentypes since 
    # the LM we initialize with has 2 tokentypes
    model = BiEncoderModel(
        num_tokentypes=2,
        parallel_output=True,
        only_query_model=only_query_model,
        only_context_model=only_context_model,
        shared_query_context_model=shared_query_context_model)

    return model


class BiEncoderModel(MegatronModule):
    """Bert-based module for Biencoder model."""

    def __init__(self,
                 num_tokentypes=1,
                 parallel_output=True,
                 only_query_model=False,
                 only_context_model=False,
                 shared_query_context_model=False):
        super(BiEncoderModel, self).__init__()
        args = get_args()

        bert_kwargs = dict(
            num_tokentypes=num_tokentypes,
            parallel_output=parallel_output)

        self.shared_query_context_model = shared_query_context_model
        assert not (only_context_model and only_query_model)
        self.use_context_model = not only_query_model
        self.use_query_model = not only_context_model
        self.projection_dim = args.projection_dim

        if self.shared_query_context_model:
            self.model = PretrainedBertModel(**bert_kwargs)
            self._model_key = 'shared_model'
            self.query_model, self.context_model = self.model, self.model
        else:
            if self.use_query_model:
                # this model embeds (pseudo-)queries - Embed_input in the paper
                self.query_model = PretrainedBertModel(**bert_kwargs)
                self._query_key = 'query_model'

            if self.use_context_model:
                # this model embeds evidence blocks - Embed_doc in the paper
                self.context_model = PretrainedBertModel(**bert_kwargs)
                self._context_key = 'context_model'

    def forward(self, query_tokens, query_attention_mask, query_types,
                context_tokens, context_attention_mask, context_types):
        """Run a forward pass for each of the models and 
        return the respective embeddings."""

        if self.use_query_model:
            query_logits = self.embed_text(self.query_model,
                                           query_tokens,
                                           query_attention_mask,
                                           query_types)
        else:
            raise ValueError("Cannot embed query without the query model.")
        if self.use_context_model:
            context_logits = self.embed_text(self.context_model,
                                             context_tokens,
                                             context_attention_mask,
                                             context_types)
        else:
            raise ValueError("Cannot embed block without the block model.")
        return query_logits, context_logits

    @staticmethod
    def embed_text(model, tokens, attention_mask, token_types):
        """Embed a batch of tokens using the model"""
        logits = model(tokens,
                              attention_mask,
                              token_types)
        return logits

    def state_dict_for_save_checkpoint(self, destination=None, \
        prefix='', keep_vars=False):
        """Save dict with state dicts of each of the models."""
        state_dict_ = {}
        if self.shared_query_context_model:
            state_dict_[self._model_key] = \
                self.model.state_dict_for_save_checkpoint(destination,
                                                          prefix,
                                                          keep_vars)
        else:
            if self.use_query_model:
                state_dict_[self._query_key] = \
                    self.query_model.state_dict_for_save_checkpoint(
                        destination, prefix, keep_vars)

            if self.use_context_model:
                state_dict_[self._context_key] = \
                    self.context_model.state_dict_for_save_checkpoint(
                        destination, prefix, keep_vars)

        return state_dict_

    def load_state_dict(self, state_dict, strict=True):
        """Load the state dicts of each of the models"""
        if self.shared_query_context_model:
            print_rank_0("Loading shared query-context model")
            self.model.load_state_dict(state_dict[self._model_key], \
                strict=strict)
        else:
            if self.use_query_model:
                print_rank_0("Loading query model")
                self.query_model.load_state_dict( \
                    state_dict[self._query_key], strict=strict)

            if self.use_context_model:
                print_rank_0("Loading context model")
                self.context_model.load_state_dict( \
                    state_dict[self._context_key], strict=strict)

    def init_state_dict_from_bert(self):
        """Initialize the state from a pretrained BERT model 
        on iteration zero of ICT pretraining"""
        args = get_args()

        if args.bert_load is None:
            print_rank_0("bert-load argument is None")
            return

        tracker_filename = get_checkpoint_tracker_filename(args.bert_load)
        if not os.path.isfile(tracker_filename):
            raise FileNotFoundError("Could not find BERT checkpoint")
        with open(tracker_filename, 'r') as f:
            iteration = int(f.read().strip())
            assert iteration > 0

        #for param in self.query_model.language_model.parameters():
        #    print(param.data)
            #break
            #sys.exit()

        checkpoint_name = get_checkpoint_name(args.bert_load, iteration, False)
        if mpu.get_data_parallel_rank() == 0:
            print('global rank {} is loading BERT checkpoint {}'.format(
                torch.distributed.get_rank(), checkpoint_name))

        try:
            state_dict = torch.load(checkpoint_name, map_location='cpu')
        except BaseException:
            raise ValueError("Could not load BERT checkpoint")

        # load the LM state dict into each model
        model_dict = state_dict['model']['language_model']

        if self.shared_query_context_model:
            self.model.language_model.load_state_dict(model_dict)
        else:
            if self.use_query_model:
                self.query_model.language_model.load_state_dict(model_dict)
                # give each model the same ict_head to begin with as well
                if self.projection_dim > 0:
                    query_proj_state_dict = \
                        self.state_dict_for_save_checkpoint()\
                        [self._query_key]['projection_enc']
            if self.use_context_model:
                self.context_model.language_model.load_state_dict(model_dict)
                if self.query_model is not None and self.projection_dim > 0:
                    self.context_model.projection_enc.load_state_dict\
                        (query_proj_state_dict)
        #for param in self.query_model.language_model.parameters():
        #    print(param.data)
        #    #sys.exit()



class PretrainedBertModel(MegatronModule):
    """BERT-based encoder for queries or contexts used for 
    learned information retrieval."""

    def __init__(self, num_tokentypes=2, 
            parallel_output=True):
        super(PretrainedBertModel, self).__init__()

        args = get_args()
        tokenizer = get_tokenizer()
        self.pad_id = tokenizer.pad
        self.pool_type = args.pool_type
        self.projection_dim = args.projection_dim
        self.parallel_output = parallel_output
        init_method = init_method_normal(args.init_method_std)
        scaled_init_method = scaled_init_method_normal(
            args.init_method_std, args.num_layers)

        self.language_model, self._language_model_key = get_language_model(
            attention_mask_func=bert_attention_mask_func,
            num_tokentypes=num_tokentypes,
            add_pooler=False,
            init_method=init_method,
            scaled_init_method=scaled_init_method)

        if args.projection_dim > 0:
            self.projection_enc = get_linear_layer(args.hidden_size,
                                                   args.projection_dim,
                                                   init_method)
            self._projection_enc_key = 'projection_enc'

    def forward(self, input_ids, attention_mask, tokentype_ids=None):
        extended_attention_mask = attention_mask.unsqueeze(1)
        #extended_attention_mask = bert_extended_attention_mask(attention_mask)
        position_ids = bert_position_ids(input_ids)


        lm_output = self.language_model(input_ids,
                                        position_ids,
                                        extended_attention_mask,
                                        tokentype_ids=tokentype_ids)
        # This mask will be used in average-pooling and max-pooling
        pool_mask = (input_ids == self.pad_id).unsqueeze(2)
        
         # Taking the representation of the [CLS] token of BERT
        if self.pool_type == "cls-token":
            pooled_output = lm_output[:, 0, :]
        elif self.pool_type == "avg":    # Average Pooling
            pooled_output = lm_output.masked_fill(pool_mask, 0)
            pooled_output = pooled_output.sum(1) / (pool_mask.size(1) \
                - pool_mask.float().sum(1))
        elif self.pool_type == "max":    # Max-Pooling
            pooled_output = lm_output.masked_fill(pool_mask, -1000)
            pooled_output = torch.max(pooled_output, 1)[0]

        # Converting to float16 dtype
        pooled_output = pooled_output.to(lm_output.dtype)
        
        # Output.
        if self.projection_dim:
            pooled_output = self.projection_enc(pooled_output)

        return pooled_output

    def state_dict_for_save_checkpoint(self, destination=None, prefix='',
                                       keep_vars=False):
        """For easy load when model is combined with other heads,
        add an extra key."""

        state_dict_ = {}
        state_dict_[self._language_model_key] \
            = self.language_model.state_dict_for_save_checkpoint(
            destination, prefix, keep_vars)

        if self.projection_dim > 0:
            state_dict_[self._projection_enc_key] = \
                self.projection_enc.state_dict(destination, prefix, keep_vars)

        return state_dict_

    def load_state_dict(self, state_dict, strict=True):
        """Customized load."""
        print_rank_0("loading BERT weights")
        self.language_model.load_state_dict(
            state_dict[self._language_model_key], strict=strict)

        if self.projection_dim > 0:
            print_rank_0("loading projection head weights")
            self.projection_enc.load_state_dict(
                state_dict[self._projection_enc_key], strict=strict)