Loading megatron/indexer.py +18 −2 Original line number Diff line number Diff line import sys import time import torch import torch.distributed as dist Loading Loading @@ -102,7 +103,12 @@ class IndexBuilder(object): while not hasattr(unwrapped_model, 'embed_text'): unwrapped_model = unwrapped_model.module counter = 0 start_time = time.time() cur_time = start_time while True: #start_time = time.time() t1 = time.time() try: # batch also has query_tokens and query_pad_data row_id, context_tokens, context_mask, context_types, \ Loading @@ -111,6 +117,8 @@ class IndexBuilder(object): except (StopIteration, IndexError): break #print_rank_0("get batch time {}".format(cur_time - time.time())) t2 = time.time() # TODO: can we add with torch.no_grad() to reduce memory usage # detach, separate fields and add to BlockData assert context_mask.dtype == torch.bool Loading @@ -120,10 +128,18 @@ class IndexBuilder(object): context_logits = detach(context_logits) row_id = detach(row_id) #print_rank_0("embed text {}".format(cur_time - time.time())) t3 = time.time() self.evidence_embedder_obj.add_block_data(row_id, context_logits) self.track_and_report_progress(batch_size=len(row_id)) #print_rank_0("add block time {}".format(cur_time - time.time())) t4 = time.time() counter += 1 if counter % 1000 == 0: print_rank_0("total time {} 1000 iter time {}".format(time.time() - start_time, time.time() - cur_time)) print_rank_0("breakdown batch {} model {} block {}".format(t2 - t1, t3 - t2, t4 -t3)) cur_time = time.time() # This process signals to finalize its shard and then synchronize with # the other processes self.evidence_embedder_obj.save_shard() Loading Loading
megatron/indexer.py +18 −2 Original line number Diff line number Diff line import sys import time import torch import torch.distributed as dist Loading Loading @@ -102,7 +103,12 @@ class IndexBuilder(object): while not hasattr(unwrapped_model, 'embed_text'): unwrapped_model = unwrapped_model.module counter = 0 start_time = time.time() cur_time = start_time while True: #start_time = time.time() t1 = time.time() try: # batch also has query_tokens and query_pad_data row_id, context_tokens, context_mask, context_types, \ Loading @@ -111,6 +117,8 @@ class IndexBuilder(object): except (StopIteration, IndexError): break #print_rank_0("get batch time {}".format(cur_time - time.time())) t2 = time.time() # TODO: can we add with torch.no_grad() to reduce memory usage # detach, separate fields and add to BlockData assert context_mask.dtype == torch.bool Loading @@ -120,10 +128,18 @@ class IndexBuilder(object): context_logits = detach(context_logits) row_id = detach(row_id) #print_rank_0("embed text {}".format(cur_time - time.time())) t3 = time.time() self.evidence_embedder_obj.add_block_data(row_id, context_logits) self.track_and_report_progress(batch_size=len(row_id)) #print_rank_0("add block time {}".format(cur_time - time.time())) t4 = time.time() counter += 1 if counter % 1000 == 0: print_rank_0("total time {} 1000 iter time {}".format(time.time() - start_time, time.time() - cur_time)) print_rank_0("breakdown batch {} model {} block {}".format(t2 - t1, t3 - t2, t4 -t3)) cur_time = time.time() # This process signals to finalize its shard and then synchronize with # the other processes self.evidence_embedder_obj.save_shard() Loading