Loading megatron/indexer.py +4 −3 Original line number Diff line number Diff line Loading @@ -53,11 +53,12 @@ class IndexBuilder(object): args.only_context_model = only_context_model args.only_query_model = False model = get_model(biencoder_model_provider) #model = get_model(biencoder_model_provider) #model = get_model(lambda: biencoder_model_provider(only_context_model \ # = only_context_model, biencoder_shared_query_context_model = \ # self.biencoder_shared_query_context_model)) model = get_model(biencoder_model_provider(only_context_model \ = only_context_model, biencoder_shared_query_context_model = \ self.biencoder_shared_query_context_model)) self.model = load_biencoder_checkpoint(model, only_context_model=only_context_model) Loading megatron/model/biencoder_model.py +10 −10 Original line number Diff line number Diff line Loading @@ -15,20 +15,20 @@ from megatron.model.utils import init_method_normal from megatron.model.utils import scaled_init_method_normal from .module import MegatronModule #def biencoder_model_provider(only_query_model=False, # only_context_model=False, # biencoder_shared_query_context_model=False, # pre_process=True, #def biencoder_model_provider(pre_process=True, # post_process=True): def biencoder_model_provider(pre_process=True, def biencoder_model_provider(only_query_model=False, only_context_model=False, biencoder_shared_query_context_model=False, pre_process=True, post_process=True): """Build the model.""" args = get_args() #args = get_args() biencoder_shared_query_context_model = args.biencoder_shared_query_context_model only_context_model = args.only_context_model only_query_model = args.only_query_model #biencoder_shared_query_context_model = args.biencoder_shared_query_context_model #only_context_model = args.only_context_model #only_query_model = args.only_query_model assert mpu.get_tensor_model_parallel_world_size() == 1 and \ mpu.get_pipeline_model_parallel_world_size() == 1, \ Loading pretrain_ict.py +8 −8 Original line number Diff line number Diff line Loading @@ -33,15 +33,15 @@ from megatron.utils import average_losses_across_data_parallel_group def pretrain_ict_model_provider(): args = get_args() args.only_context_model = False args.only_query_model = False model = biencoder_model_provider() #model = biencoder_model_provider( # only_context_model=False, # only_query_model=False, # biencoder_shared_query_context_model=\ # args.biencoder_shared_query_context_model) #args.only_context_model = False #args.only_query_model = False #model = biencoder_model_provider() model = biencoder_model_provider( only_context_model=False, only_query_model=False, biencoder_shared_query_context_model=\ args.biencoder_shared_query_context_model) return model def get_group_world_size_rank(): Loading tasks/main.py +1 −1 Original line number Diff line number Diff line Loading @@ -110,7 +110,7 @@ if __name__ == '__main__': from glue.finetune import main elif args.task in ['LAMBADA', 'WIKITEXT103']: from zeroshot_gpt.evaluate import main elif args.task in ['ICT-ZEROSHOT-NQ']: elif args.task in ['ICT-ZEROSHOT-NQ', 'RETRIEVER-EVAL']: from orqa.evaluate_orqa import main elif args.task in ['RET-FINETUNE-NQ']: from orqa.supervised.finetune import main Loading tasks/orqa/evaluate_orqa.py +13 −0 Original line number Diff line number Diff line Loading @@ -18,6 +18,15 @@ import os import sys #sys.path.append( # os.path.abspath( # os.path.join( # os.path.join(os.path.dirname(__file__), os.path.pardir), # os.path.pardir, # ) # ) #) from megatron import get_args from megatron.indexer import IndexBuilder from tasks.orqa.evaluate_utils import ORQAEvaluator Loading @@ -26,6 +35,8 @@ def main(): """ Main program """ #initialize_megatron(extra_args_provider=None, # args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'}) args = get_args() Loading @@ -42,6 +53,8 @@ def main(): Check README.md for example script """ #print_rank_0("Starting index builder!") index_builder = IndexBuilder() index_builder.build_and_save_index() print_rank_0("Build and save indices: done!") Loading Loading
megatron/indexer.py +4 −3 Original line number Diff line number Diff line Loading @@ -53,11 +53,12 @@ class IndexBuilder(object): args.only_context_model = only_context_model args.only_query_model = False model = get_model(biencoder_model_provider) #model = get_model(biencoder_model_provider) #model = get_model(lambda: biencoder_model_provider(only_context_model \ # = only_context_model, biencoder_shared_query_context_model = \ # self.biencoder_shared_query_context_model)) model = get_model(biencoder_model_provider(only_context_model \ = only_context_model, biencoder_shared_query_context_model = \ self.biencoder_shared_query_context_model)) self.model = load_biencoder_checkpoint(model, only_context_model=only_context_model) Loading
megatron/model/biencoder_model.py +10 −10 Original line number Diff line number Diff line Loading @@ -15,20 +15,20 @@ from megatron.model.utils import init_method_normal from megatron.model.utils import scaled_init_method_normal from .module import MegatronModule #def biencoder_model_provider(only_query_model=False, # only_context_model=False, # biencoder_shared_query_context_model=False, # pre_process=True, #def biencoder_model_provider(pre_process=True, # post_process=True): def biencoder_model_provider(pre_process=True, def biencoder_model_provider(only_query_model=False, only_context_model=False, biencoder_shared_query_context_model=False, pre_process=True, post_process=True): """Build the model.""" args = get_args() #args = get_args() biencoder_shared_query_context_model = args.biencoder_shared_query_context_model only_context_model = args.only_context_model only_query_model = args.only_query_model #biencoder_shared_query_context_model = args.biencoder_shared_query_context_model #only_context_model = args.only_context_model #only_query_model = args.only_query_model assert mpu.get_tensor_model_parallel_world_size() == 1 and \ mpu.get_pipeline_model_parallel_world_size() == 1, \ Loading
pretrain_ict.py +8 −8 Original line number Diff line number Diff line Loading @@ -33,15 +33,15 @@ from megatron.utils import average_losses_across_data_parallel_group def pretrain_ict_model_provider(): args = get_args() args.only_context_model = False args.only_query_model = False model = biencoder_model_provider() #model = biencoder_model_provider( # only_context_model=False, # only_query_model=False, # biencoder_shared_query_context_model=\ # args.biencoder_shared_query_context_model) #args.only_context_model = False #args.only_query_model = False #model = biencoder_model_provider() model = biencoder_model_provider( only_context_model=False, only_query_model=False, biencoder_shared_query_context_model=\ args.biencoder_shared_query_context_model) return model def get_group_world_size_rank(): Loading
tasks/main.py +1 −1 Original line number Diff line number Diff line Loading @@ -110,7 +110,7 @@ if __name__ == '__main__': from glue.finetune import main elif args.task in ['LAMBADA', 'WIKITEXT103']: from zeroshot_gpt.evaluate import main elif args.task in ['ICT-ZEROSHOT-NQ']: elif args.task in ['ICT-ZEROSHOT-NQ', 'RETRIEVER-EVAL']: from orqa.evaluate_orqa import main elif args.task in ['RET-FINETUNE-NQ']: from orqa.supervised.finetune import main Loading
tasks/orqa/evaluate_orqa.py +13 −0 Original line number Diff line number Diff line Loading @@ -18,6 +18,15 @@ import os import sys #sys.path.append( # os.path.abspath( # os.path.join( # os.path.join(os.path.dirname(__file__), os.path.pardir), # os.path.pardir, # ) # ) #) from megatron import get_args from megatron.indexer import IndexBuilder from tasks.orqa.evaluate_utils import ORQAEvaluator Loading @@ -26,6 +35,8 @@ def main(): """ Main program """ #initialize_megatron(extra_args_provider=None, # args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'}) args = get_args() Loading @@ -42,6 +53,8 @@ def main(): Check README.md for example script """ #print_rank_0("Starting index builder!") index_builder = IndexBuilder() index_builder.build_and_save_index() print_rank_0("Build and save indices: done!") Loading