Commit dca47cfb authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

debugging DPR

parent f64977fd
Loading
Loading
Loading
Loading
+18 −2
Original line number Diff line number Diff line
import sys
import time
import torch
import torch.distributed as dist

@@ -102,7 +103,12 @@ class IndexBuilder(object):
        while not hasattr(unwrapped_model, 'embed_text'):
            unwrapped_model = unwrapped_model.module

        counter = 0
        start_time = time.time()
        cur_time = start_time
        while True:
            #start_time = time.time()
            t1 = time.time()
            try:
                # batch also has query_tokens and query_pad_data
                row_id, context_tokens, context_mask, context_types, \
@@ -111,6 +117,8 @@ class IndexBuilder(object):
            except (StopIteration, IndexError):
                break

            #print_rank_0("get batch time {}".format(cur_time - time.time()))
            t2 = time.time()
            # TODO: can we add with torch.no_grad() to reduce memory usage
            # detach, separate fields and add to BlockData
            assert context_mask.dtype == torch.bool
@@ -120,10 +128,18 @@ class IndexBuilder(object):

            context_logits = detach(context_logits)
            row_id = detach(row_id)
            #print_rank_0("embed text {}".format(cur_time - time.time()))
            t3 = time.time()
 
            self.evidence_embedder_obj.add_block_data(row_id, context_logits)
            self.track_and_report_progress(batch_size=len(row_id))

            #print_rank_0("add block time {}".format(cur_time - time.time()))
            t4 = time.time()
            counter += 1
            if counter % 1000 == 0:
                print_rank_0("total time {} 1000 iter time {}".format(time.time() - start_time, time.time() - cur_time))
                print_rank_0("breakdown batch {} model {} block {}".format(t2 - t1, t3 - t2, t4 -t3))
                cur_time = time.time()
        # This process signals to finalize its shard and then synchronize with
        # the other processes
        self.evidence_embedder_obj.save_shard()