Commit 25293807 authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

additional cleaning

parent 2eaf6c79
Loading
Loading
Loading
Loading
+0 −6
Original line number Diff line number Diff line
@@ -479,12 +479,6 @@ def _add_learning_rate_args(parser):
    group.add_argument('--min-lr', type=float, default=0.0,
                       help='Minumum value for learning rate. The scheduler'
                       'clip values below this threshold.')
    group.add_argument('--override-lr-new', action='store_true',
                       help='Reset the values of the scheduler (learning rate,'
                       'warmup iterations, minimum learning rate, maximum '
                       'number of iterations, and decay style from input '
                       'arguments and ignore values from checkpoints. Note'
                       'that all the above values will be reset.')
    group.add_argument('--override-lr-scheduler', action='store_true',
                       help='Reset the values of the scheduler (learning rate,'
                       'warmup iterations, minimum learning rate, maximum '
+0 −1
Original line number Diff line number Diff line
@@ -419,7 +419,6 @@ def load_biencoder_checkpoint(model, only_query_model=False,

    assert len(model) == 1
    model[0].load_state_dict(ret_state_dict)

    torch.distributed.barrier()

    if mpu.get_data_parallel_rank() == 0:
+5 −34
Original line number Diff line number Diff line
@@ -26,13 +26,10 @@ class IndexBuilder(object):
        self.evidence_embedder_obj = None
        self.biencoder_shared_query_context_model = \
            args.biencoder_shared_query_context_model
        #self.pre_process = True
        #self.post_process = True

        # need to know whether we're using a REALM checkpoint (args.load)
        # or ICT checkpoint
        assert not (args.load and args.ict_load)
        #self.using_realm_chkpt = args.ict_load is None

        self.log_interval = args.indexer_log_interval
        self.batch_size = args.indexer_batch_size
@@ -46,24 +43,13 @@ class IndexBuilder(object):
        """
        Load the necessary attributes: model, dataloader and empty BlockData
        """
        #args = get_args()
        only_context_model = True
        if self.biencoder_shared_query_context_model:
            only_context_model = False

        #args.only_context_model = only_context_model
        #args.only_query_model = False

        #model = get_model(biencoder_model_provider)

        model = get_model(get_model_provider(only_context_model=only_context_model,
            biencoder_shared_query_context_model=self.biencoder_shared_query_context_model))

        #model = get_model(lambda: biencoder_model_provider(only_context_model \
        #model = get_model(lambda: biencoder_model_provider(only_context_model \
        #    = only_context_model, biencoder_shared_query_context_model = \
        #    self.biencoder_shared_query_context_model,
        #    pre_process=True, post_process=True)
        model = get_model(get_model_provider(only_context_model=\
            only_context_model, biencoder_shared_query_context_model=\
            self.biencoder_shared_query_context_model))

        self.model = load_biencoder_checkpoint(model,
                only_context_model=only_context_model)
@@ -103,12 +89,7 @@ class IndexBuilder(object):
        while not hasattr(unwrapped_model, 'embed_text'):
            unwrapped_model = unwrapped_model.module

        #counter = 0
        #start_time = time.time()
        #cur_time = start_time
        while True:
            #start_time = time.time()
            #t1 = time.time()
            try:
                # batch also has query_tokens and query_pad_data
                row_id, context_tokens, context_mask, context_types, \
@@ -117,8 +98,6 @@ class IndexBuilder(object):
            except (StopIteration, IndexError):
                break

            #print_rank_0("get batch time {}".format(cur_time - time.time()))
            #t2 = time.time()
            # TODO: can we add with torch.no_grad() to reduce memory usage
            # detach, separate fields and add to BlockData
            assert context_mask.dtype == torch.bool
@@ -128,18 +107,10 @@ class IndexBuilder(object):

            context_logits = detach(context_logits)
            row_id = detach(row_id)
            #print_rank_0("embed text {}".format(cur_time - time.time()))
            #t3 = time.time()

            self.evidence_embedder_obj.add_block_data(row_id, context_logits)
            self.track_and_report_progress(batch_size=len(row_id))
            #print_rank_0("add block time {}".format(cur_time - time.time()))
            #t4 = time.time()
            #counter += 1
            #if counter % 1000 == 0:
            #    print_rank_0("total time {} 1000 iter time {}".format(time.time() - start_time, time.time() - cur_time))
            #    print_rank_0("breakdown batch {} model {} block {}".format(t2 - t1, t3 - t2, t4 -t3))
            #    cur_time = time.time()

        # This process signals to finalize its shard and then synchronize with
        # the other processes
        self.evidence_embedder_obj.save_shard()
+0 −16
Original line number Diff line number Diff line
@@ -18,7 +18,6 @@
import math

from megatron import print_rank_0
from megatron import get_args

class AnnealingLR(object):
    """Anneals the learning rate."""
@@ -60,7 +59,6 @@ class AnnealingLR(object):
        """Learning rate decay functions from:
              https://openreview.net/pdf?id=BJYwwY9ll pg. 4"""

        #print_rank_0("self.warmup_steps {} self.num_steps {} self.decay_steps {} self.min_lr {} self.maxlr {}".format(self.warmup_steps, self.num_steps, self.decay_steps, self.min_lr, self.max_lr))
        # Use linear warmup for the initial part.
        if self.warmup_steps > 0 and self.num_steps <= self.warmup_steps:
            return self.max_lr * float(self.num_steps) / \
@@ -90,20 +88,6 @@ class AnnealingLR(object):
            raise Exception('{} decay style is not supported.'.format(
                self.decay_style))

        args = get_args()

        if args.override_lr_new:
            mod_num_steps_ = min(self.num_steps, self.decay_steps - self.warmup_steps)
            mod_num_steps_ = mod_num_steps_ - self.warmup_steps
            use_lr = delta_lr * float(self.decay_steps - mod_num_steps_) / float(self.decay_steps)
            should_use_lr = self.min_lr + coeff * delta_lr
            print_rank_0("num_steps {} decay_steps {} decay_ratio {} coeff {} delta_lr {} use lr {} should_use_lr {} self.warmup_steps {} self.num_steps {} self.decay_steps {}".format(num_steps_, decay_steps_, decay_ratio, coeff, delta_lr, use_lr, should_use_lr, self.warmup_steps, self.num_steps, self.decay_steps))
        else:
            use_lr = self.min_lr + coeff * delta_lr
            print_rank_0("num_steps {} decay_steps {} decay_ratio {} coeff {} delta_lr {} use lr {} self.warmup_steps {} self.num_steps {} self.decay_steps {}".format(num_steps_, decay_steps_, decay_ratio, coeff, delta_lr, use_lr, self.warmup_steps, self.num_steps, self.decay_steps))

        return use_lr

        return self.min_lr + coeff * delta_lr


+2 −1
Original line number Diff line number Diff line
@@ -114,7 +114,8 @@ def _build_infinite_size_dataloader(dataloader):
            iterator = dataloader.__iter__()


def _build_train_valid_dataloaders(train_dataset, valid_dataset, task_collate_fn=None):
def _build_train_valid_dataloaders(train_dataset, valid_dataset, 
    task_collate_fn=None):
    """Traing and validation dataloaders."""
    args = get_args()

Loading