Commit 6d03d7af authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

DPR finetune and evaluation

parent d2d5086e
Loading
Loading
Loading
Loading
+3 −5
Original line number Diff line number Diff line
@@ -413,11 +413,9 @@ def load_biencoder_checkpoint(model, only_query_model=False,
    if only_context_model:
        ret_state_dict.pop('query_model')

    #print_rank_0(len(model))
    #sys.exit()
    #assert len(model) == 1
    #model[0].load_state_dict(ret_state_dict)
    model.load_state_dict(ret_state_dict)
    assert len(model) == 1
    model[0].load_state_dict(ret_state_dict)

    torch.distributed.barrier()

    if mpu.get_data_parallel_rank() == 0:
+12 −27
Original line number Diff line number Diff line
@@ -45,26 +45,25 @@ class IndexBuilder(object):
        """
        Load the necessary attributes: model, dataloader and empty BlockData
        """
        args = get_args()
        only_context_model = True
        if self.biencoder_shared_query_context_model:
            only_context_model = False

        model = get_model(lambda: biencoder_model_provider(only_context_model \
            = only_context_model, biencoder_shared_query_context_model = \
            self.biencoder_shared_query_context_model, \
            pre_process=self.pre_process, post_process=self.post_process))
        args.only_context_model = only_context_model
        args.only_query_model = False

        model = get_model(biencoder_model_provider)

        #model = biencoder_model_provider(only_context_model \
        #model = get_model(lambda: biencoder_model_provider(only_context_model \
        #    = only_context_model, biencoder_shared_query_context_model = \
        #    self.biencoder_shared_query_context_model, \
        #    pre_process=self.pre_process, post_process=self.post_process)
        #    self.biencoder_shared_query_context_model))

        self.model = load_biencoder_checkpoint(model,
                only_context_model=only_context_model)

        #assert len(self.model) == 1
        #self.model[0].eval()
        self.model.eval()
        assert len(self.model) == 1
        self.model[0].eval()

        self.dataset = get_open_retrieval_wiki_dataset()
        self.dataloader = iter(get_one_epoch_dataloader(self.dataset, \
@@ -92,12 +91,11 @@ class IndexBuilder(object):
        distributed setting will be consolidated by the rank 0 process
        and saved as a final pickled BlockData.
        """
        #assert len(self.model) == 1
        #unwrapped_model = self.model[0]
        unwrapped_model = self.model
        assert len(self.model) == 1
        unwrapped_model = self.model[0]

        while not hasattr(unwrapped_model, 'embed_text'):
            unwrapped_model = unwrapped_model.module
            print_rank_0("hasattr")

        while True:
            try:
@@ -108,17 +106,6 @@ class IndexBuilder(object):
            except (StopIteration, IndexError):
                break

            print_rank_0(context_tokens)
            print_rank_0(context_mask)
            print_rank_0(context_types)
            #if torch.cuda.is_available():
            #    print_rank_0("cuda available")
            #print_rank_0(torch.cuda.current_device())
            #print_rank_0(torch.cuda.get_device_name())
            print_rank_0(next(unwrapped_model.parameters()).device)
            print_rank_0(next(unwrapped_model.context_model.parameters()).device)
            #print_rank_0("After get_open_retrieval_batch")

            # TODO: can we add with torch.no_grad() to reduce memory usage
            # detach, separate fields and add to BlockData
            assert context_mask.dtype == torch.bool
@@ -126,8 +113,6 @@ class IndexBuilder(object):
                unwrapped_model.context_model, context_tokens, context_mask,
                context_types)

            sys.exit()

            context_logits = detach(context_logits)
            row_id = detach(row_id)

+11 −9
Original line number Diff line number Diff line
@@ -15,14 +15,21 @@ from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal
from .module import MegatronModule

def biencoder_model_provider(only_query_model=False,
                             only_context_model=False,
                             biencoder_shared_query_context_model=False,
                             pre_process=True, 
#def biencoder_model_provider(only_query_model=False,
#                             only_context_model=False,
#                             biencoder_shared_query_context_model=False,
#                             pre_process=True,
#                             post_process=True):

def biencoder_model_provider(pre_process=True, 
                             post_process=True):
    """Build the model."""
    args = get_args()

    biencoder_shared_query_context_model = args.biencoder_shared_query_context_model
    only_context_model = args.only_context_model
    only_query_model = args.only_query_model

    assert mpu.get_tensor_model_parallel_world_size() == 1 and \
        mpu.get_pipeline_model_parallel_world_size() == 1, \
        "Model parallel size > 1 not supported for ICT"
@@ -266,11 +273,6 @@ class PretrainedBertModel(MegatronModule):
        #extended_attention_mask = bert_extended_attention_mask(attention_mask)
        position_ids = bert_position_ids(input_ids)

        print_rank_0(input_ids.device)
        print_rank_0(position_ids.device)
        print_rank_0(extended_attention_mask.device)
        print_rank_0(tokentype_ids.device)

        lm_output = self.language_model(input_ids,
                                        position_ids,
                                        extended_attention_mask,
+0 −5
Original line number Diff line number Diff line
@@ -338,11 +338,6 @@ class TransformerLanguageModel(MegatronModule):
                get_key_value=False, pooling_sequence_index=0,
                enc_hidden_states=None, output_enc_hidden=False):

        print_rank_0("before self.embedding")
        print_rank_0(enc_input_ids.device)
        print_rank_0(enc_position_ids.device)
        print_rank_0(tokentype_ids.device)

        # Embeddings.
        if self.pre_process:
            embedding_output = self.embedding(enc_input_ids, enc_position_ids,
+9 −5
Original line number Diff line number Diff line
@@ -33,11 +33,15 @@ from megatron.utils import average_losses_across_data_parallel_group

def pretrain_ict_model_provider():
    args = get_args()
    model = biencoder_model_provider(
                only_context_model=False,
                only_query_model=False,
                biencoder_shared_query_context_model=\
                    args.biencoder_shared_query_context_model)
    args.only_context_model = False
    args.only_query_model = False
    model = biencoder_model_provider()
 
    #model = biencoder_model_provider(
    #            only_context_model=False,
    #            only_query_model=False,
    #            biencoder_shared_query_context_model=\
    #                args.biencoder_shared_query_context_model)
    return model

def get_group_world_size_rank():
Loading