Commit 8e44d619 authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

renaming the folders

parent 25293807
Loading
Loading
Loading
Loading
+2 −4
Original line number Diff line number Diff line
@@ -33,15 +33,13 @@ from megatron.utils import average_losses_across_data_parallel_group

def pretrain_ict_model_provider():
    args = get_args()
    #args.only_context_model = False
    #args.only_query_model = False
    #model = biencoder_model_provider()

    model = biencoder_model_provider(
                only_context_model=False,
                only_query_model=False,
                biencoder_shared_query_context_model=\
                args.biencoder_shared_query_context_model)

    return model

def get_group_world_size_rank():
+5 −4
Original line number Diff line number Diff line
@@ -18,13 +18,14 @@ import torch
from megatron import get_args, print_rank_0
from megatron.checkpointing import load_biencoder_checkpoint
from megatron.data.orqa_wiki_dataset import get_open_retrieval_wiki_dataset
from tasks.orqa.natural_questions.nq import get_nq_dataset
from tasks.orqa.natural_questions.nq import get_one_epoch_nq_dataloader
from tasks.orqa.natural_questions.nq import process_nq_batch
from tasks.orqa.natural_questions.qa_utils import calculate_matches
from megatron.data.realm_index import OpenRetreivalDataStore, FaissMIPSIndex
from megatron.model.biencoder_model import get_model_provider
from megatron.training import get_model
from tasks.orqa.unsupervised.nq import get_nq_dataset
from tasks.orqa.unsupervised.nq import get_one_epoch_nq_dataloader
from tasks.orqa.unsupervised.nq import process_nq_batch
from tasks.orqa.unsupervised.qa_utils import calculate_matches


class ORQAEvaluator(object):
    def __init__(self):
+1 −1
Original line number Diff line number Diff line
@@ -22,7 +22,7 @@ from multiprocessing import Pool as ProcessPool
from typing import Tuple, List, Dict

import regex as re
from tasks.orqa.natural_questions.tokenizers import SimpleTokenizer
from tasks.orqa.unsupervised.tokenizers import SimpleTokenizer

logger = logging.getLogger(__name__)