Loading megatron/indexer.py +10 −7 Original line number Diff line number Diff line Loading @@ -9,7 +9,7 @@ from megatron.data.orqa_wiki_dataset import get_open_retrieval_wiki_dataset from megatron.data.orqa_wiki_dataset import get_open_retrieval_batch from megatron.data.biencoder_dataset_utils import get_one_epoch_dataloader from megatron.data.realm_index import detach, OpenRetreivalDataStore from megatron.model.biencoder_model import biencoder_model_provider from megatron.model.biencoder_model import get_model_provider from megatron.training import get_model Loading Loading @@ -50,16 +50,19 @@ class IndexBuilder(object): if self.biencoder_shared_query_context_model: only_context_model = False args.only_context_model = only_context_model args.only_query_model = False #args.only_context_model = only_context_model #args.only_query_model = False #model = get_model(biencoder_model_provider) model = get_model(get_model_provider(only_context_model=only_context_model, biencoder_shared_query_context_model=self.biencoder_shared_query_context_model)) #model = get_model(lambda: biencoder_model_provider(only_context_model \ #model = get_model(lambda: biencoder_model_provider(only_context_model \ model = get_model(biencoder_model_provider(only_context_model \ = only_context_model, biencoder_shared_query_context_model = \ self.biencoder_shared_query_context_model, pre_process=True, post_process=True)) # = only_context_model, biencoder_shared_query_context_model = \ # self.biencoder_shared_query_context_model, # pre_process=True, post_process=True) self.model = load_biencoder_checkpoint(model, only_context_model=only_context_model) Loading megatron/model/biencoder_model.py +19 −0 Original line number Diff line number Diff line Loading @@ -15,6 +15,25 @@ from megatron.model.utils import init_method_normal from megatron.model.utils import scaled_init_method_normal from .module import MegatronModule def get_model_provider(only_query_model=False, only_context_model=False, biencoder_shared_query_context_model=False): def model_provider(pre_process=True, post_process=True): """Build the model.""" print_rank_0('building Bienoder model ...') model = biencoder_model_provider(only_query_model=only_query_model, only_context_model = only_context_model, biencoder_shared_query_context_model = \ biencoder_shared_query_context_model, pre_process=True, post_process=True) return model return model_provider #def biencoder_model_provider(pre_process=True, # post_process=True): Loading tasks/orqa/evaluate_orqa.py +1 −1 Original line number Diff line number Diff line Loading @@ -27,7 +27,7 @@ import sys # ) #) from megatron import get_args from megatron import get_args, print_rank_0 from megatron.indexer import IndexBuilder from tasks.orqa.evaluate_utils import ORQAEvaluator Loading tasks/orqa/evaluate_utils.py +9 −5 Original line number Diff line number Diff line Loading @@ -23,7 +23,7 @@ from tasks.orqa.natural_questions.nq import get_one_epoch_nq_dataloader from tasks.orqa.natural_questions.nq import process_nq_batch from tasks.orqa.natural_questions.qa_utils import calculate_matches from megatron.data.realm_index import OpenRetreivalDataStore, FaissMIPSIndex from megatron.model.biencoder_model import biencoder_model_provider from megatron.model.biencoder_model import get_model_provider from megatron.training import get_model class ORQAEvaluator(object): Loading @@ -47,11 +47,15 @@ class ORQAEvaluator(object): #args.only_query_model = only_query_model #args.only_context_model = False model = get_model(get_model_provider(only_query_model=only_query_model, biencoder_shared_query_context_model=args.biencoder_shared_query_context_model)) #model = get_model(lambda: biencoder_model_provider(only_query_model=\ #model = get_model(lambda: biencoder_model_provider(only_query_model=\ model = get_model(lambda: biencoder_model_provider(only_query_model=\ only_query_model, biencoder_shared_query_context_model=\ args.biencoder_shared_query_context_model, pre_process=True, post_process=True)) # only_query_model, biencoder_shared_query_context_model=\ # args.biencoder_shared_query_context_model, # pre_process=True, post_process=True)) #model = get_model(biencoder_model_provider) Loading Loading
megatron/indexer.py +10 −7 Original line number Diff line number Diff line Loading @@ -9,7 +9,7 @@ from megatron.data.orqa_wiki_dataset import get_open_retrieval_wiki_dataset from megatron.data.orqa_wiki_dataset import get_open_retrieval_batch from megatron.data.biencoder_dataset_utils import get_one_epoch_dataloader from megatron.data.realm_index import detach, OpenRetreivalDataStore from megatron.model.biencoder_model import biencoder_model_provider from megatron.model.biencoder_model import get_model_provider from megatron.training import get_model Loading Loading @@ -50,16 +50,19 @@ class IndexBuilder(object): if self.biencoder_shared_query_context_model: only_context_model = False args.only_context_model = only_context_model args.only_query_model = False #args.only_context_model = only_context_model #args.only_query_model = False #model = get_model(biencoder_model_provider) model = get_model(get_model_provider(only_context_model=only_context_model, biencoder_shared_query_context_model=self.biencoder_shared_query_context_model)) #model = get_model(lambda: biencoder_model_provider(only_context_model \ #model = get_model(lambda: biencoder_model_provider(only_context_model \ model = get_model(biencoder_model_provider(only_context_model \ = only_context_model, biencoder_shared_query_context_model = \ self.biencoder_shared_query_context_model, pre_process=True, post_process=True)) # = only_context_model, biencoder_shared_query_context_model = \ # self.biencoder_shared_query_context_model, # pre_process=True, post_process=True) self.model = load_biencoder_checkpoint(model, only_context_model=only_context_model) Loading
megatron/model/biencoder_model.py +19 −0 Original line number Diff line number Diff line Loading @@ -15,6 +15,25 @@ from megatron.model.utils import init_method_normal from megatron.model.utils import scaled_init_method_normal from .module import MegatronModule def get_model_provider(only_query_model=False, only_context_model=False, biencoder_shared_query_context_model=False): def model_provider(pre_process=True, post_process=True): """Build the model.""" print_rank_0('building Bienoder model ...') model = biencoder_model_provider(only_query_model=only_query_model, only_context_model = only_context_model, biencoder_shared_query_context_model = \ biencoder_shared_query_context_model, pre_process=True, post_process=True) return model return model_provider #def biencoder_model_provider(pre_process=True, # post_process=True): Loading
tasks/orqa/evaluate_orqa.py +1 −1 Original line number Diff line number Diff line Loading @@ -27,7 +27,7 @@ import sys # ) #) from megatron import get_args from megatron import get_args, print_rank_0 from megatron.indexer import IndexBuilder from tasks.orqa.evaluate_utils import ORQAEvaluator Loading
tasks/orqa/evaluate_utils.py +9 −5 Original line number Diff line number Diff line Loading @@ -23,7 +23,7 @@ from tasks.orqa.natural_questions.nq import get_one_epoch_nq_dataloader from tasks.orqa.natural_questions.nq import process_nq_batch from tasks.orqa.natural_questions.qa_utils import calculate_matches from megatron.data.realm_index import OpenRetreivalDataStore, FaissMIPSIndex from megatron.model.biencoder_model import biencoder_model_provider from megatron.model.biencoder_model import get_model_provider from megatron.training import get_model class ORQAEvaluator(object): Loading @@ -47,11 +47,15 @@ class ORQAEvaluator(object): #args.only_query_model = only_query_model #args.only_context_model = False model = get_model(get_model_provider(only_query_model=only_query_model, biencoder_shared_query_context_model=args.biencoder_shared_query_context_model)) #model = get_model(lambda: biencoder_model_provider(only_query_model=\ #model = get_model(lambda: biencoder_model_provider(only_query_model=\ model = get_model(lambda: biencoder_model_provider(only_query_model=\ only_query_model, biencoder_shared_query_context_model=\ args.biencoder_shared_query_context_model, pre_process=True, post_process=True)) # only_query_model, biencoder_shared_query_context_model=\ # args.biencoder_shared_query_context_model, # pre_process=True, post_process=True)) #model = get_model(biencoder_model_provider) Loading