Commit f9267205 authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

fixing model evaluation of retriver

parent 6d03d7af
Loading
Loading
Loading
Loading
+4 −3
Original line number Diff line number Diff line
@@ -53,11 +53,12 @@ class IndexBuilder(object):
        args.only_context_model = only_context_model
        args.only_query_model = False

        model = get_model(biencoder_model_provider)
        #model = get_model(biencoder_model_provider)

        #model = get_model(lambda: biencoder_model_provider(only_context_model \
        #    = only_context_model, biencoder_shared_query_context_model = \
        #    self.biencoder_shared_query_context_model))
        model = get_model(biencoder_model_provider(only_context_model \
            = only_context_model, biencoder_shared_query_context_model = \
            self.biencoder_shared_query_context_model))

        self.model = load_biencoder_checkpoint(model,
                only_context_model=only_context_model)
+10 −10
Original line number Diff line number Diff line
@@ -15,20 +15,20 @@ from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal
from .module import MegatronModule

#def biencoder_model_provider(only_query_model=False,
#                             only_context_model=False,
#                             biencoder_shared_query_context_model=False,
#                             pre_process=True,
#def biencoder_model_provider(pre_process=True, 
#                             post_process=True):
 
def biencoder_model_provider(pre_process=True, 
def biencoder_model_provider(only_query_model=False,
                             only_context_model=False,
                             biencoder_shared_query_context_model=False,
                             pre_process=True,
                             post_process=True):
    """Build the model."""
    args = get_args()
    #args = get_args()

    biencoder_shared_query_context_model = args.biencoder_shared_query_context_model
    only_context_model = args.only_context_model
    only_query_model = args.only_query_model
    #biencoder_shared_query_context_model = args.biencoder_shared_query_context_model
    #only_context_model = args.only_context_model
    #only_query_model = args.only_query_model

    assert mpu.get_tensor_model_parallel_world_size() == 1 and \
        mpu.get_pipeline_model_parallel_world_size() == 1, \
+8 −8
Original line number Diff line number Diff line
@@ -33,15 +33,15 @@ from megatron.utils import average_losses_across_data_parallel_group

def pretrain_ict_model_provider():
    args = get_args()
    args.only_context_model = False
    args.only_query_model = False
    model = biencoder_model_provider()
 
    #model = biencoder_model_provider(
    #            only_context_model=False,
    #            only_query_model=False,
    #            biencoder_shared_query_context_model=\
    #                args.biencoder_shared_query_context_model)
    #args.only_context_model = False
    #args.only_query_model = False
    #model = biencoder_model_provider()
 
    model = biencoder_model_provider(
                only_context_model=False,
                only_query_model=False,
                biencoder_shared_query_context_model=\
                    args.biencoder_shared_query_context_model)
    return model

def get_group_world_size_rank():
+1 −1
Original line number Diff line number Diff line
@@ -110,7 +110,7 @@ if __name__ == '__main__':
        from glue.finetune import main
    elif args.task in ['LAMBADA', 'WIKITEXT103']:
        from zeroshot_gpt.evaluate import main
    elif args.task in ['ICT-ZEROSHOT-NQ']:
    elif args.task in ['ICT-ZEROSHOT-NQ', 'RETRIEVER-EVAL']:
        from orqa.evaluate_orqa import main
    elif args.task in ['RET-FINETUNE-NQ']:
        from orqa.supervised.finetune import main
+13 −0
Original line number Diff line number Diff line
@@ -18,6 +18,15 @@
import os
import sys

#sys.path.append(
#    os.path.abspath(
#        os.path.join(
#            os.path.join(os.path.dirname(__file__), os.path.pardir),
#            os.path.pardir,
#        )
#    )
#)

from megatron import get_args
from megatron.indexer import IndexBuilder
from tasks.orqa.evaluate_utils import ORQAEvaluator
@@ -26,6 +35,8 @@ def main():
    """
    Main program
    """
    #initialize_megatron(extra_args_provider=None,
    #                    args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})

    args = get_args()

@@ -42,6 +53,8 @@ def main():
    Check README.md for example script
    """

    #print_rank_0("Starting index builder!")

    index_builder = IndexBuilder()
    index_builder.build_and_save_index()
    print_rank_0("Build and save indices: done!")
Loading