Loading pretrain_ict.py +2 −4 Original line number Diff line number Diff line Loading @@ -33,15 +33,13 @@ from megatron.utils import average_losses_across_data_parallel_group def pretrain_ict_model_provider(): args = get_args() #args.only_context_model = False #args.only_query_model = False #model = biencoder_model_provider() model = biencoder_model_provider( only_context_model=False, only_query_model=False, biencoder_shared_query_context_model=\ args.biencoder_shared_query_context_model) return model def get_group_world_size_rank(): Loading tasks/orqa/evaluate_utils.py +5 −4 Original line number Diff line number Diff line Loading @@ -18,13 +18,14 @@ import torch from megatron import get_args, print_rank_0 from megatron.checkpointing import load_biencoder_checkpoint from megatron.data.orqa_wiki_dataset import get_open_retrieval_wiki_dataset from tasks.orqa.natural_questions.nq import get_nq_dataset from tasks.orqa.natural_questions.nq import get_one_epoch_nq_dataloader from tasks.orqa.natural_questions.nq import process_nq_batch from tasks.orqa.natural_questions.qa_utils import calculate_matches from megatron.data.realm_index import OpenRetreivalDataStore, FaissMIPSIndex from megatron.model.biencoder_model import get_model_provider from megatron.training import get_model from tasks.orqa.unsupervised.nq import get_nq_dataset from tasks.orqa.unsupervised.nq import get_one_epoch_nq_dataloader from tasks.orqa.unsupervised.nq import process_nq_batch from tasks.orqa.unsupervised.qa_utils import calculate_matches class ORQAEvaluator(object): def __init__(self): Loading tasks/orqa/natural_questions/nq.py→tasks/orqa/unsupervised/nq.py +0 −0 File moved. View file tasks/orqa/natural_questions/qa_utils.py→tasks/orqa/unsupervised/qa_utils.py +1 −1 Original line number Diff line number Diff line Loading @@ -22,7 +22,7 @@ from multiprocessing import Pool as ProcessPool from typing import Tuple, List, Dict import regex as re from tasks.orqa.natural_questions.tokenizers import SimpleTokenizer from tasks.orqa.unsupervised.tokenizers import SimpleTokenizer logger = logging.getLogger(__name__) Loading tasks/orqa/natural_questions/tokenizers.py→tasks/orqa/unsupervised/tokenizers.py +0 −0 File moved. View file Loading
pretrain_ict.py +2 −4 Original line number Diff line number Diff line Loading @@ -33,15 +33,13 @@ from megatron.utils import average_losses_across_data_parallel_group def pretrain_ict_model_provider(): args = get_args() #args.only_context_model = False #args.only_query_model = False #model = biencoder_model_provider() model = biencoder_model_provider( only_context_model=False, only_query_model=False, biencoder_shared_query_context_model=\ args.biencoder_shared_query_context_model) return model def get_group_world_size_rank(): Loading
tasks/orqa/evaluate_utils.py +5 −4 Original line number Diff line number Diff line Loading @@ -18,13 +18,14 @@ import torch from megatron import get_args, print_rank_0 from megatron.checkpointing import load_biencoder_checkpoint from megatron.data.orqa_wiki_dataset import get_open_retrieval_wiki_dataset from tasks.orqa.natural_questions.nq import get_nq_dataset from tasks.orqa.natural_questions.nq import get_one_epoch_nq_dataloader from tasks.orqa.natural_questions.nq import process_nq_batch from tasks.orqa.natural_questions.qa_utils import calculate_matches from megatron.data.realm_index import OpenRetreivalDataStore, FaissMIPSIndex from megatron.model.biencoder_model import get_model_provider from megatron.training import get_model from tasks.orqa.unsupervised.nq import get_nq_dataset from tasks.orqa.unsupervised.nq import get_one_epoch_nq_dataloader from tasks.orqa.unsupervised.nq import process_nq_batch from tasks.orqa.unsupervised.qa_utils import calculate_matches class ORQAEvaluator(object): def __init__(self): Loading
tasks/orqa/natural_questions/qa_utils.py→tasks/orqa/unsupervised/qa_utils.py +1 −1 Original line number Diff line number Diff line Loading @@ -22,7 +22,7 @@ from multiprocessing import Pool as ProcessPool from typing import Tuple, List, Dict import regex as re from tasks.orqa.natural_questions.tokenizers import SimpleTokenizer from tasks.orqa.unsupervised.tokenizers import SimpleTokenizer logger = logging.getLogger(__name__) Loading
tasks/orqa/natural_questions/tokenizers.py→tasks/orqa/unsupervised/tokenizers.py +0 −0 File moved. View file