Commit f64977fd authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

evaluation works!

parent 7e335e15
Loading
Loading
Loading
Loading
+10 −7
Original line number Diff line number Diff line
@@ -9,7 +9,7 @@ from megatron.data.orqa_wiki_dataset import get_open_retrieval_wiki_dataset
from megatron.data.orqa_wiki_dataset import get_open_retrieval_batch
from megatron.data.biencoder_dataset_utils import get_one_epoch_dataloader
from megatron.data.realm_index import detach, OpenRetreivalDataStore
from megatron.model.biencoder_model import biencoder_model_provider
from megatron.model.biencoder_model import get_model_provider
from megatron.training import get_model


@@ -50,16 +50,19 @@ class IndexBuilder(object):
        if self.biencoder_shared_query_context_model:
            only_context_model = False

        args.only_context_model = only_context_model
        args.only_query_model = False
        #args.only_context_model = only_context_model
        #args.only_query_model = False

        #model = get_model(biencoder_model_provider)

        model = get_model(get_model_provider(only_context_model=only_context_model, 
            biencoder_shared_query_context_model=self.biencoder_shared_query_context_model))

        #model = get_model(lambda: biencoder_model_provider(only_context_model \
        #model = get_model(lambda: biencoder_model_provider(only_context_model \
        model = get_model(biencoder_model_provider(only_context_model \
            = only_context_model, biencoder_shared_query_context_model = \
            self.biencoder_shared_query_context_model,
            pre_process=True, post_process=True))
        #    = only_context_model, biencoder_shared_query_context_model = \
        #    self.biencoder_shared_query_context_model,
        #    pre_process=True, post_process=True)

        self.model = load_biencoder_checkpoint(model,
                only_context_model=only_context_model)
+19 −0
Original line number Diff line number Diff line
@@ -15,6 +15,25 @@ from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal
from .module import MegatronModule

def get_model_provider(only_query_model=False, only_context_model=False, 
        biencoder_shared_query_context_model=False):

    def model_provider(pre_process=True, post_process=True):
        """Build the model."""

        print_rank_0('building Bienoder model ...')
        model = biencoder_model_provider(only_query_model=only_query_model, 
                only_context_model = only_context_model, 
                biencoder_shared_query_context_model = \
                biencoder_shared_query_context_model, 
                pre_process=True, post_process=True)

        return model

    return model_provider



#def biencoder_model_provider(pre_process=True, 
#                             post_process=True):
 
+1 −1
Original line number Diff line number Diff line
@@ -27,7 +27,7 @@ import sys
#    )
#)

from megatron import get_args
from megatron import get_args, print_rank_0
from megatron.indexer import IndexBuilder
from tasks.orqa.evaluate_utils import ORQAEvaluator

+9 −5
Original line number Diff line number Diff line
@@ -23,7 +23,7 @@ from tasks.orqa.natural_questions.nq import get_one_epoch_nq_dataloader
from tasks.orqa.natural_questions.nq import process_nq_batch
from tasks.orqa.natural_questions.qa_utils import calculate_matches
from megatron.data.realm_index import OpenRetreivalDataStore, FaissMIPSIndex
from megatron.model.biencoder_model import biencoder_model_provider
from megatron.model.biencoder_model import get_model_provider
from megatron.training import get_model

class ORQAEvaluator(object):
@@ -47,11 +47,15 @@ class ORQAEvaluator(object):
        #args.only_query_model = only_query_model
        #args.only_context_model = False

        model = get_model(get_model_provider(only_query_model=only_query_model, 
            biencoder_shared_query_context_model=args.biencoder_shared_query_context_model))


        #model = get_model(lambda: biencoder_model_provider(only_query_model=\
        #model = get_model(lambda: biencoder_model_provider(only_query_model=\
        model = get_model(lambda: biencoder_model_provider(only_query_model=\
            only_query_model, biencoder_shared_query_context_model=\
            args.biencoder_shared_query_context_model,
            pre_process=True, post_process=True))
        #    only_query_model, biencoder_shared_query_context_model=\
        #    args.biencoder_shared_query_context_model,
        #    pre_process=True, post_process=True))

        #model = get_model(biencoder_model_provider)