Loading megatron/indexer.py +16 −16 Original line number Diff line number Diff line Loading @@ -26,8 +26,8 @@ class IndexBuilder(object): self.evidence_embedder_obj = None self.biencoder_shared_query_context_model = \ args.biencoder_shared_query_context_model self.pre_process = True self.post_process = True #self.pre_process = True #self.post_process = True # need to know whether we're using a REALM checkpoint (args.load) # or ICT checkpoint Loading @@ -46,7 +46,7 @@ class IndexBuilder(object): """ Load the necessary attributes: model, dataloader and empty BlockData """ args = get_args() #args = get_args() only_context_model = True if self.biencoder_shared_query_context_model: only_context_model = False Loading Loading @@ -103,12 +103,12 @@ class IndexBuilder(object): while not hasattr(unwrapped_model, 'embed_text'): unwrapped_model = unwrapped_model.module counter = 0 start_time = time.time() cur_time = start_time #counter = 0 #start_time = time.time() #cur_time = start_time while True: #start_time = time.time() t1 = time.time() #t1 = time.time() try: # batch also has query_tokens and query_pad_data row_id, context_tokens, context_mask, context_types, \ Loading @@ -118,7 +118,7 @@ class IndexBuilder(object): break #print_rank_0("get batch time {}".format(cur_time - time.time())) t2 = time.time() #t2 = time.time() # TODO: can we add with torch.no_grad() to reduce memory usage # detach, separate fields and add to BlockData assert context_mask.dtype == torch.bool Loading @@ -129,17 +129,17 @@ class IndexBuilder(object): context_logits = detach(context_logits) row_id = detach(row_id) #print_rank_0("embed text {}".format(cur_time - time.time())) t3 = time.time() #t3 = time.time() self.evidence_embedder_obj.add_block_data(row_id, context_logits) self.track_and_report_progress(batch_size=len(row_id)) #print_rank_0("add block time {}".format(cur_time - time.time())) t4 = time.time() counter += 1 if counter % 1000 == 0: print_rank_0("total time {} 1000 iter time {}".format(time.time() - start_time, time.time() - cur_time)) print_rank_0("breakdown batch {} model {} block {}".format(t2 - t1, t3 - t2, t4 -t3)) cur_time = time.time() #t4 = time.time() #counter += 1 #if counter % 1000 == 0: # print_rank_0("total time {} 1000 iter time {}".format(time.time() - start_time, time.time() - cur_time)) # print_rank_0("breakdown batch {} model {} block {}".format(t2 - t1, t3 - t2, t4 -t3)) # cur_time = time.time() # This process signals to finalize its shard and then synchronize with # the other processes self.evidence_embedder_obj.save_shard() Loading megatron/model/biencoder_model.py +8 −17 Original line number Diff line number Diff line Loading @@ -33,21 +33,12 @@ def get_model_provider(only_query_model=False, only_context_model=False, return model_provider #def biencoder_model_provider(pre_process=True, # post_process=True): def biencoder_model_provider(only_query_model=False, only_context_model=False, biencoder_shared_query_context_model=False, pre_process=True, post_process=True): """Build the model.""" #args = get_args() #biencoder_shared_query_context_model = args.biencoder_shared_query_context_model #only_context_model = args.only_context_model #only_query_model = args.only_query_model assert mpu.get_tensor_model_parallel_world_size() == 1 and \ mpu.get_pipeline_model_parallel_world_size() == 1, \ Loading megatron/model/language_model.py +1 −1 Original line number Diff line number Diff line Loading @@ -18,7 +18,7 @@ import torch import torch.nn.functional as F from megatron import get_args, print_rank_0 from megatron import get_args from megatron import mpu from .module import MegatronModule from megatron.model.enums import LayerType, AttnMaskType Loading tasks/finetune_utils.py +2 −6 Original line number Diff line number Diff line Loading @@ -16,7 +16,6 @@ """Finetune utilities.""" from functools import partial import sys import torch Loading Loading @@ -226,9 +225,6 @@ def _train(model, optimizer, lr_scheduler, forward_step, valid_dataloader, model, iteration, False) #if iteration == 600: # sys.exit() # Checkpointing at the end of each epoch. if args.save: save_checkpoint(iteration, model, optimizer, lr_scheduler) Loading tasks/orqa/evaluate_orqa.py +8 −27 Original line number Diff line number Diff line Loading @@ -15,18 +15,6 @@ """Main tasks functionality.""" import os import sys #sys.path.append( # os.path.abspath( # os.path.join( # os.path.join(os.path.dirname(__file__), os.path.pardir), # os.path.pardir, # ) # ) #) from megatron import get_args, print_rank_0 from megatron.indexer import IndexBuilder from tasks.orqa.evaluate_utils import ORQAEvaluator Loading @@ -35,30 +23,23 @@ def main(): """ Main program """ #initialize_megatron(extra_args_provider=None, # args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'}) args = get_args() """Create a BlockData data structure by running an IndexBuilder over an ICT Dataset - Include all args needed for initial model specification Other key args: --block-data-path: path to write to --ict-load or --realm-load: path to checkpoint with which to embed --data-path and --titles-data-path: paths for dataset --indexer-log-interval: reporting interval --indexer-batch-size: size specific for indexer jobs Check README.md for example script """ Create a BlockData data structure by running an IndexBuilder over an ICT Dataset and then evaluate on NQ task """ #print_rank_0("Starting index builder!") print_rank_0("Starting index builder!") index_builder = IndexBuilder() index_builder.build_and_save_index() print_rank_0("Build and save indices: done!") print_rank_0("Starting evaluations!") # Set up the model and evaluator evaluator = ORQAEvaluator() Loading Loading
megatron/indexer.py +16 −16 Original line number Diff line number Diff line Loading @@ -26,8 +26,8 @@ class IndexBuilder(object): self.evidence_embedder_obj = None self.biencoder_shared_query_context_model = \ args.biencoder_shared_query_context_model self.pre_process = True self.post_process = True #self.pre_process = True #self.post_process = True # need to know whether we're using a REALM checkpoint (args.load) # or ICT checkpoint Loading @@ -46,7 +46,7 @@ class IndexBuilder(object): """ Load the necessary attributes: model, dataloader and empty BlockData """ args = get_args() #args = get_args() only_context_model = True if self.biencoder_shared_query_context_model: only_context_model = False Loading Loading @@ -103,12 +103,12 @@ class IndexBuilder(object): while not hasattr(unwrapped_model, 'embed_text'): unwrapped_model = unwrapped_model.module counter = 0 start_time = time.time() cur_time = start_time #counter = 0 #start_time = time.time() #cur_time = start_time while True: #start_time = time.time() t1 = time.time() #t1 = time.time() try: # batch also has query_tokens and query_pad_data row_id, context_tokens, context_mask, context_types, \ Loading @@ -118,7 +118,7 @@ class IndexBuilder(object): break #print_rank_0("get batch time {}".format(cur_time - time.time())) t2 = time.time() #t2 = time.time() # TODO: can we add with torch.no_grad() to reduce memory usage # detach, separate fields and add to BlockData assert context_mask.dtype == torch.bool Loading @@ -129,17 +129,17 @@ class IndexBuilder(object): context_logits = detach(context_logits) row_id = detach(row_id) #print_rank_0("embed text {}".format(cur_time - time.time())) t3 = time.time() #t3 = time.time() self.evidence_embedder_obj.add_block_data(row_id, context_logits) self.track_and_report_progress(batch_size=len(row_id)) #print_rank_0("add block time {}".format(cur_time - time.time())) t4 = time.time() counter += 1 if counter % 1000 == 0: print_rank_0("total time {} 1000 iter time {}".format(time.time() - start_time, time.time() - cur_time)) print_rank_0("breakdown batch {} model {} block {}".format(t2 - t1, t3 - t2, t4 -t3)) cur_time = time.time() #t4 = time.time() #counter += 1 #if counter % 1000 == 0: # print_rank_0("total time {} 1000 iter time {}".format(time.time() - start_time, time.time() - cur_time)) # print_rank_0("breakdown batch {} model {} block {}".format(t2 - t1, t3 - t2, t4 -t3)) # cur_time = time.time() # This process signals to finalize its shard and then synchronize with # the other processes self.evidence_embedder_obj.save_shard() Loading
megatron/model/biencoder_model.py +8 −17 Original line number Diff line number Diff line Loading @@ -33,21 +33,12 @@ def get_model_provider(only_query_model=False, only_context_model=False, return model_provider #def biencoder_model_provider(pre_process=True, # post_process=True): def biencoder_model_provider(only_query_model=False, only_context_model=False, biencoder_shared_query_context_model=False, pre_process=True, post_process=True): """Build the model.""" #args = get_args() #biencoder_shared_query_context_model = args.biencoder_shared_query_context_model #only_context_model = args.only_context_model #only_query_model = args.only_query_model assert mpu.get_tensor_model_parallel_world_size() == 1 and \ mpu.get_pipeline_model_parallel_world_size() == 1, \ Loading
megatron/model/language_model.py +1 −1 Original line number Diff line number Diff line Loading @@ -18,7 +18,7 @@ import torch import torch.nn.functional as F from megatron import get_args, print_rank_0 from megatron import get_args from megatron import mpu from .module import MegatronModule from megatron.model.enums import LayerType, AttnMaskType Loading
tasks/finetune_utils.py +2 −6 Original line number Diff line number Diff line Loading @@ -16,7 +16,6 @@ """Finetune utilities.""" from functools import partial import sys import torch Loading Loading @@ -226,9 +225,6 @@ def _train(model, optimizer, lr_scheduler, forward_step, valid_dataloader, model, iteration, False) #if iteration == 600: # sys.exit() # Checkpointing at the end of each epoch. if args.save: save_checkpoint(iteration, model, optimizer, lr_scheduler) Loading
tasks/orqa/evaluate_orqa.py +8 −27 Original line number Diff line number Diff line Loading @@ -15,18 +15,6 @@ """Main tasks functionality.""" import os import sys #sys.path.append( # os.path.abspath( # os.path.join( # os.path.join(os.path.dirname(__file__), os.path.pardir), # os.path.pardir, # ) # ) #) from megatron import get_args, print_rank_0 from megatron.indexer import IndexBuilder from tasks.orqa.evaluate_utils import ORQAEvaluator Loading @@ -35,30 +23,23 @@ def main(): """ Main program """ #initialize_megatron(extra_args_provider=None, # args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'}) args = get_args() """Create a BlockData data structure by running an IndexBuilder over an ICT Dataset - Include all args needed for initial model specification Other key args: --block-data-path: path to write to --ict-load or --realm-load: path to checkpoint with which to embed --data-path and --titles-data-path: paths for dataset --indexer-log-interval: reporting interval --indexer-batch-size: size specific for indexer jobs Check README.md for example script """ Create a BlockData data structure by running an IndexBuilder over an ICT Dataset and then evaluate on NQ task """ #print_rank_0("Starting index builder!") print_rank_0("Starting index builder!") index_builder = IndexBuilder() index_builder.build_and_save_index() print_rank_0("Build and save indices: done!") print_rank_0("Starting evaluations!") # Set up the model and evaluator evaluator = ORQAEvaluator() Loading