Commit 2eaf6c79 authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

cleaning the code

parent 7a0710ec
Loading
Loading
Loading
Loading
+16 −16
Original line number Diff line number Diff line
@@ -26,8 +26,8 @@ class IndexBuilder(object):
        self.evidence_embedder_obj = None
        self.biencoder_shared_query_context_model = \
            args.biencoder_shared_query_context_model
        self.pre_process = True
        self.post_process = True
        #self.pre_process = True
        #self.post_process = True

        # need to know whether we're using a REALM checkpoint (args.load)
        # or ICT checkpoint
@@ -46,7 +46,7 @@ class IndexBuilder(object):
        """
        Load the necessary attributes: model, dataloader and empty BlockData
        """
        args = get_args()
        #args = get_args()
        only_context_model = True
        if self.biencoder_shared_query_context_model:
            only_context_model = False
@@ -103,12 +103,12 @@ class IndexBuilder(object):
        while not hasattr(unwrapped_model, 'embed_text'):
            unwrapped_model = unwrapped_model.module

        counter = 0
        start_time = time.time()
        cur_time = start_time
        #counter = 0
        #start_time = time.time()
        #cur_time = start_time
        while True:
            #start_time = time.time()
            t1 = time.time()
            #t1 = time.time()
            try:
                # batch also has query_tokens and query_pad_data
                row_id, context_tokens, context_mask, context_types, \
@@ -118,7 +118,7 @@ class IndexBuilder(object):
                break

            #print_rank_0("get batch time {}".format(cur_time - time.time()))
            t2 = time.time()
            #t2 = time.time()
            # TODO: can we add with torch.no_grad() to reduce memory usage
            # detach, separate fields and add to BlockData
            assert context_mask.dtype == torch.bool
@@ -129,17 +129,17 @@ class IndexBuilder(object):
            context_logits = detach(context_logits)
            row_id = detach(row_id)
            #print_rank_0("embed text {}".format(cur_time - time.time()))
            t3 = time.time()
            #t3 = time.time()
 
            self.evidence_embedder_obj.add_block_data(row_id, context_logits)
            self.track_and_report_progress(batch_size=len(row_id))
            #print_rank_0("add block time {}".format(cur_time - time.time()))
            t4 = time.time()
            counter += 1
            if counter % 1000 == 0:
                print_rank_0("total time {} 1000 iter time {}".format(time.time() - start_time, time.time() - cur_time))
                print_rank_0("breakdown batch {} model {} block {}".format(t2 - t1, t3 - t2, t4 -t3))
                cur_time = time.time()
            #t4 = time.time()
            #counter += 1
            #if counter % 1000 == 0:
            #    print_rank_0("total time {} 1000 iter time {}".format(time.time() - start_time, time.time() - cur_time))
            #    print_rank_0("breakdown batch {} model {} block {}".format(t2 - t1, t3 - t2, t4 -t3))
            #    cur_time = time.time()
        # This process signals to finalize its shard and then synchronize with
        # the other processes
        self.evidence_embedder_obj.save_shard()
+8 −17
Original line number Diff line number Diff line
@@ -33,21 +33,12 @@ def get_model_provider(only_query_model=False, only_context_model=False,
    return model_provider



#def biencoder_model_provider(pre_process=True, 
#                             post_process=True):
 
def biencoder_model_provider(only_query_model=False,
                             only_context_model=False,
                             biencoder_shared_query_context_model=False,
                             pre_process=True,
                             post_process=True):
    """Build the model."""
    #args = get_args()

    #biencoder_shared_query_context_model = args.biencoder_shared_query_context_model
    #only_context_model = args.only_context_model
    #only_query_model = args.only_query_model

    assert mpu.get_tensor_model_parallel_world_size() == 1 and \
        mpu.get_pipeline_model_parallel_world_size() == 1, \
+1 −1
Original line number Diff line number Diff line
@@ -18,7 +18,7 @@
import torch
import torch.nn.functional as F

from megatron import get_args, print_rank_0
from megatron import get_args
from megatron import mpu
from .module import MegatronModule
from megatron.model.enums import LayerType, AttnMaskType
+2 −6
Original line number Diff line number Diff line
@@ -16,7 +16,6 @@
"""Finetune utilities."""

from functools import partial
import sys

import torch

@@ -226,9 +225,6 @@ def _train(model, optimizer, lr_scheduler, forward_step,
                                           valid_dataloader, model,
                                           iteration, False)

            #if iteration == 600:
            #    sys.exit()

        # Checkpointing at the end of each epoch.
        if args.save:
            save_checkpoint(iteration, model, optimizer, lr_scheduler)
+8 −27
Original line number Diff line number Diff line
@@ -15,18 +15,6 @@

"""Main tasks functionality."""

import os
import sys

#sys.path.append(
#    os.path.abspath(
#        os.path.join(
#            os.path.join(os.path.dirname(__file__), os.path.pardir),
#            os.path.pardir,
#        )
#    )
#)

from megatron import get_args, print_rank_0
from megatron.indexer import IndexBuilder
from tasks.orqa.evaluate_utils import ORQAEvaluator
@@ -35,30 +23,23 @@ def main():
    """
    Main program
    """
    #initialize_megatron(extra_args_provider=None,
    #                    args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})

    args = get_args()

    """Create a BlockData data structure by running an IndexBuilder over an ICT Dataset
    - Include all args needed for initial model specification

    Other key args:
        --block-data-path: path to write to
        --ict-load or --realm-load: path to checkpoint with which to embed
        --data-path and --titles-data-path: paths for dataset
        --indexer-log-interval: reporting interval
        --indexer-batch-size: size specific for indexer jobs

    Check README.md for example script
    """
    Create a BlockData data structure by running an IndexBuilder over an
    ICT Dataset and then evaluate on NQ task
    """

    #print_rank_0("Starting index builder!")
    print_rank_0("Starting index builder!")

    index_builder = IndexBuilder()
    index_builder.build_and_save_index()
    print_rank_0("Build and save indices: done!")


    print_rank_0("Starting evaluations!")

    # Set up the model and evaluator
    evaluator = ORQAEvaluator()

Loading