Commit 220637f9 authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

DPR evaluation debugging

parent a8d172b3
Loading
Loading
Loading
Loading
+6 −0
Original line number Diff line number Diff line
@@ -478,6 +478,12 @@ def _add_learning_rate_args(parser):
    group.add_argument('--min-lr', type=float, default=0.0,
                       help='Minumum value for learning rate. The scheduler'
                       'clip values below this threshold.')
    group.add_argument('--override-lr-new', action='store_true',
                       help='Reset the values of the scheduler (learning rate,'
                       'warmup iterations, minimum learning rate, maximum '
                       'number of iterations, and decay style from input '
                       'arguments and ignore values from checkpoints. Note'
                       'that all the above values will be reset.')
    group.add_argument('--override-lr-scheduler', action='store_true',
                       help='Reset the values of the scheduler (learning rate,'
                       'warmup iterations, minimum learning rate, maximum '
+5 −2
Original line number Diff line number Diff line
@@ -413,8 +413,11 @@ def load_biencoder_checkpoint(model, only_query_model=False,
    if only_context_model:
        ret_state_dict.pop('query_model')

    assert len(model) == 1
    model[0].load_state_dict(ret_state_dict)
    #print_rank_0(len(model))
    #sys.exit()
    #assert len(model) == 1
    #model[0].load_state_dict(ret_state_dict)
    model.load_state_dict(ret_state_dict)
    torch.distributed.barrier()

    if mpu.get_data_parallel_rank() == 0:
+32 −7
Original line number Diff line number Diff line
@@ -2,7 +2,7 @@ import sys
import torch
import torch.distributed as dist

from megatron import get_args
from megatron import get_args, print_rank_0
from megatron import mpu
from megatron.checkpointing import load_biencoder_checkpoint
from megatron.data.orqa_wiki_dataset import get_open_retrieval_wiki_dataset
@@ -25,6 +25,8 @@ class IndexBuilder(object):
        self.evidence_embedder_obj = None
        self.biencoder_shared_query_context_model = \
            args.biencoder_shared_query_context_model
        self.pre_process = True
        self.post_process = True

        # need to know whether we're using a REALM checkpoint (args.load)
        # or ICT checkpoint
@@ -47,15 +49,22 @@ class IndexBuilder(object):
        if self.biencoder_shared_query_context_model:
            only_context_model = False

        model = get_model(lambda: biencoder_model_provider(only_context_model \
        #model = get_model(lambda: biencoder_model_provider(only_context_model \
        #    = only_context_model, biencoder_shared_query_context_model = \
        #    self.biencoder_shared_query_context_model, \
        #    pre_process=self.pre_process, post_process=self.post_process))

        model = biencoder_model_provider(only_context_model \
            = only_context_model, biencoder_shared_query_context_model = \
            self.biencoder_shared_query_context_model))
            self.biencoder_shared_query_context_model, \
            pre_process=self.pre_process, post_process=self.post_process)

        self.model = load_biencoder_checkpoint(model,
                only_context_model=only_context_model)

        assert len(self.model) == 1
        self.model[0].eval()
        #assert len(self.model) == 1
        #self.model[0].eval()
        self.model.eval()

        self.dataset = get_open_retrieval_wiki_dataset()
        self.dataloader = iter(get_one_epoch_dataloader(self.dataset, \
@@ -83,10 +92,12 @@ class IndexBuilder(object):
        distributed setting will be consolidated by the rank 0 process
        and saved as a final pickled BlockData.
        """
        assert len(self.model) == 1
        unwrapped_model = self.model[0]
        #assert len(self.model) == 1
        #unwrapped_model = self.model[0]
        unwrapped_model = self.model
        while not hasattr(unwrapped_model, 'embed_text'):
            unwrapped_model = unwrapped_model.module
            print_rank_0("hasattr")

        while True:
            try:
@@ -97,12 +108,26 @@ class IndexBuilder(object):
            except (StopIteration, IndexError):
                break

            print_rank_0(context_tokens)
            print_rank_0(context_mask)
            print_rank_0(context_types)
            #if torch.cuda.is_available():
            #    print_rank_0("cuda available")
            #print_rank_0(torch.cuda.current_device())
            #print_rank_0(torch.cuda.get_device_name())
            print_rank_0(next(unwrapped_model.parameters()).device)
            print_rank_0(next(unwrapped_model.context_model.parameters()).device)
            #print_rank_0("After get_open_retrieval_batch")

            # TODO: can we add with torch.no_grad() to reduce memory usage
            # detach, separate fields and add to BlockData
            assert context_mask.dtype == torch.bool
            context_logits = unwrapped_model.embed_text(
                unwrapped_model.context_model, context_tokens, context_mask,
                context_types)

            sys.exit()

            context_logits = detach(context_logits)
            row_id = detach(row_id)

+17 −1
Original line number Diff line number Diff line
@@ -18,6 +18,7 @@
import math

from megatron import print_rank_0
from megatron import get_args

class AnnealingLR(object):
    """Anneals the learning rate."""
@@ -59,6 +60,7 @@ class AnnealingLR(object):
        """Learning rate decay functions from:
              https://openreview.net/pdf?id=BJYwwY9ll pg. 4"""

        #print_rank_0("self.warmup_steps {} self.num_steps {} self.decay_steps {} self.min_lr {} self.maxlr {}".format(self.warmup_steps, self.num_steps, self.decay_steps, self.min_lr, self.max_lr))
        # Use linear warmup for the initial part.
        if self.warmup_steps > 0 and self.num_steps <= self.warmup_steps:
            return self.max_lr * float(self.num_steps) / \
@@ -88,6 +90,20 @@ class AnnealingLR(object):
            raise Exception('{} decay style is not supported.'.format(
                self.decay_style))

        args = get_args()

        if args.override_lr_new:
            mod_num_steps_ = min(self.num_steps, self.decay_steps - self.warmup_steps)
            mod_num_steps_ = mod_num_steps_ - self.warmup_steps
            use_lr = delta_lr * float(self.decay_steps - mod_num_steps_) / float(self.decay_steps)
            should_use_lr = self.min_lr + coeff * delta_lr
            print_rank_0("num_steps {} decay_steps {} decay_ratio {} coeff {} delta_lr {} use lr {} should_use_lr {} self.warmup_steps {} self.num_steps {} self.decay_steps {}".format(num_steps_, decay_steps_, decay_ratio, coeff, delta_lr, use_lr, should_use_lr, self.warmup_steps, self.num_steps, self.decay_steps))
        else:
            use_lr = self.min_lr + coeff * delta_lr
            print_rank_0("num_steps {} decay_steps {} decay_ratio {} coeff {} delta_lr {} use lr {} self.warmup_steps {} self.num_steps {} self.decay_steps {}".format(num_steps_, decay_steps_, decay_ratio, coeff, delta_lr, use_lr, self.warmup_steps, self.num_steps, self.decay_steps))

        return use_lr

        return self.min_lr + coeff * delta_lr


+4 −0
Original line number Diff line number Diff line
@@ -266,6 +266,10 @@ class PretrainedBertModel(MegatronModule):
        #extended_attention_mask = bert_extended_attention_mask(attention_mask)
        position_ids = bert_position_ids(input_ids)

        print_rank_0(input_ids.device)
        print_rank_0(position_ids.device)
        print_rank_0(extended_attention_mask.device)
        print_rank_0(tokentype_ids.device)

        lm_output = self.language_model(input_ids,
                                        position_ids,
Loading