Commit ebc95c35 authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

Merge branch 'main_retriver_merge' into main_retriver_merge_ict_eval

parents 360fb6b7 a4b628ab
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -27,7 +27,7 @@ python pretrain_ict.py \
        --log-interval 100 \
        --eval-interval 1000 \
        --eval-iters 10 \
        --report-topk-accuracies 1 5 10 20 100 \
        --retriever-report-topk-accuracies 1 5 10 20 100 \
        --retriever-score-scaling \
        --load $CHECKPOINT_PATH \
        --save $CHECKPOINT_PATH \
+3 −7
Original line number Diff line number Diff line
@@ -649,16 +649,12 @@ def _add_biencoder_args(parser):
    group.add_argument('--ict-head-size', type=int, default=None,
                       help='Size of block embeddings to be used in ICT and '
                        'REALM (paper default: 128)')
    group.add_argument('--projection-dim', type=int, default=0,
    group.add_argument('--biencoder-projection-dim', type=int, default=0,
                       help='Size of projection head used in biencoder (paper'
                        ' default: 128)')
    group.add_argument('--shared-query-context-model', action='store_true',
    group.add_argument('--biencoder-shared-query-context-model', action='store_true',
                        help='Whether to share the parameters of the query '
                        'and context models or not')
    group.add_argument('--pool-type', type=str, default='cls-token',
                       choices=['avg', 'cls-token', 'max'],
                       help='different options are: avg | cls-token | max, '
                        'default=cls-token')

    # checkpointing
    group.add_argument('--ict-load', type=str, default=None,
@@ -677,7 +673,7 @@ def _add_biencoder_args(parser):
                       help='Whether to use one sentence documents in ICT')

    # training
    group.add_argument('--report-topk-accuracies', nargs='+', type=int, 
    group.add_argument('--retriever-report-topk-accuracies', nargs='+', type=int,
                        default=[], help="Which top-k accuracies to report "
                        "(e.g. '1 5 20')")
    group.add_argument('--retriever-score-scaling', action='store_true',
+23 −29
Original line number Diff line number Diff line
@@ -17,7 +17,7 @@ from .module import MegatronModule

def biencoder_model_provider(only_query_model=False,
                             only_context_model=False,
                             shared_query_context_model=False):
                             biencoder_shared_query_context_model=False):
    """Build the model."""
    args = get_args()

@@ -31,10 +31,11 @@ def biencoder_model_provider(only_query_model=False,
    # the LM we initialize with has 2 tokentypes
    model = BiEncoderModel(
        num_tokentypes=2,
        parallel_output=True,
        parallel_output=False,
        only_query_model=only_query_model,
        only_context_model=only_context_model,
        shared_query_context_model=shared_query_context_model)
        biencoder_shared_query_context_model=\
            biencoder_shared_query_context_model)

    return model

@@ -47,7 +48,7 @@ class BiEncoderModel(MegatronModule):
                 parallel_output=True,
                 only_query_model=False,
                 only_context_model=False,
                 shared_query_context_model=False):
                 biencoder_shared_query_context_model=False):
        super(BiEncoderModel, self).__init__()
        args = get_args()

@@ -55,13 +56,14 @@ class BiEncoderModel(MegatronModule):
            num_tokentypes=num_tokentypes,
            parallel_output=parallel_output)

        self.shared_query_context_model = shared_query_context_model
        self.biencoder_shared_query_context_model = \
            biencoder_shared_query_context_model
        assert not (only_context_model and only_query_model)
        self.use_context_model = not only_query_model
        self.use_query_model = not only_context_model
        self.projection_dim = args.projection_dim
        self.biencoder_projection_dim = args.biencoder_projection_dim

        if self.shared_query_context_model:
        if self.biencoder_shared_query_context_model:
            self.model = PretrainedBertModel(**bert_kwargs)
            self._model_key = 'shared_model'
            self.query_model, self.context_model = self.model, self.model
@@ -109,7 +111,7 @@ class BiEncoderModel(MegatronModule):
        prefix='', keep_vars=False):
        """Save dict with state dicts of each of the models."""
        state_dict_ = {}
        if self.shared_query_context_model:
        if self.biencoder_shared_query_context_model:
            state_dict_[self._model_key] = \
                self.model.state_dict_for_save_checkpoint(destination,
                                                          prefix,
@@ -129,7 +131,7 @@ class BiEncoderModel(MegatronModule):

    def load_state_dict(self, state_dict, strict=True):
        """Load the state dicts of each of the models"""
        if self.shared_query_context_model:
        if self.biencoder_shared_query_context_model:
            print_rank_0("Loading shared query-context model")
            self.model.load_state_dict(state_dict[self._model_key], \
                strict=strict)
@@ -188,14 +190,14 @@ class BiEncoderModel(MegatronModule):
        # load the LM state dict into each model
        model_dict = state_dict['model']['language_model']

        if self.shared_query_context_model:
        if self.biencoder_shared_query_context_model:
            self.model.language_model.load_state_dict(model_dict)
            fix_query_key_value_ordering(self.model, checkpoint_version)
        else:
            if self.use_query_model:
                self.query_model.language_model.load_state_dict(model_dict)
                # give each model the same ict_head to begin with as well
                if self.projection_dim > 0:
                if self.biencoder_projection_dim > 0:
                    query_proj_state_dict = \
                        self.state_dict_for_save_checkpoint()\
                        [self._query_key]['projection_enc']
@@ -203,7 +205,8 @@ class BiEncoderModel(MegatronModule):

            if self.use_context_model:
                self.context_model.language_model.load_state_dict(model_dict)
                if self.query_model is not None and self.projection_dim > 0:
                if self.query_model is not None and \
                    self.biencoder_projection_dim > 0:
                    self.context_model.projection_enc.load_state_dict\
                        (query_proj_state_dict)
                fix_query_key_value_ordering(self.context_model, checkpoint_version)
@@ -220,8 +223,7 @@ class PretrainedBertModel(MegatronModule):
        args = get_args()
        tokenizer = get_tokenizer()
        self.pad_id = tokenizer.pad
        self.pool_type = args.pool_type
        self.projection_dim = args.projection_dim
        self.biencoder_projection_dim = args.biencoder_projection_dim
        self.parallel_output = parallel_output
        init_method = init_method_normal(args.init_method_std)
        scaled_init_method = scaled_init_method_normal(
@@ -234,9 +236,9 @@ class PretrainedBertModel(MegatronModule):
            init_method=init_method,
            scaled_init_method=scaled_init_method)

        if args.projection_dim > 0:
        if args.biencoder_projection_dim > 0:
            self.projection_enc = get_linear_layer(args.hidden_size,
                                                   args.projection_dim,
                                                   args.biencoder_projection_dim,
                                                   init_method)
            self._projection_enc_key = 'projection_enc'

@@ -254,21 +256,13 @@ class PretrainedBertModel(MegatronModule):
        pool_mask = (input_ids == self.pad_id).unsqueeze(2)

        # Taking the representation of the [CLS] token of BERT
        if self.pool_type == "cls-token":
        pooled_output = lm_output[:, 0, :]
        elif self.pool_type == "avg":    # Average Pooling
            pooled_output = lm_output.masked_fill(pool_mask, 0)
            pooled_output = pooled_output.sum(1) / (pool_mask.size(1) \
                - pool_mask.float().sum(1))
        elif self.pool_type == "max":    # Max-Pooling
            pooled_output = lm_output.masked_fill(pool_mask, -1000)
            pooled_output = torch.max(pooled_output, 1)[0]

        # Converting to float16 dtype
        pooled_output = pooled_output.to(lm_output.dtype)

        # Output.
        if self.projection_dim:
        if self.biencoder_projection_dim:
            pooled_output = self.projection_enc(pooled_output)

        return pooled_output
@@ -283,7 +277,7 @@ class PretrainedBertModel(MegatronModule):
            = self.language_model.state_dict_for_save_checkpoint(
            destination, prefix, keep_vars)

        if self.projection_dim > 0:
        if self.biencoder_projection_dim > 0:
            state_dict_[self._projection_enc_key] = \
                self.projection_enc.state_dict(destination, prefix, keep_vars)

@@ -295,7 +289,7 @@ class PretrainedBertModel(MegatronModule):
        self.language_model.load_state_dict(
            state_dict[self._language_model_key], strict=strict)

        if self.projection_dim > 0:
        if self.biencoder_projection_dim > 0:
            print_rank_0("loading projection head weights")
            self.projection_enc.load_state_dict(
                state_dict[self._projection_enc_key], strict=strict)
+4 −3
Original line number Diff line number Diff line
@@ -36,7 +36,8 @@ def pretrain_ict_model_provider():
    model = biencoder_model_provider(
                only_context_model=False,
                only_query_model=False,
                shared_query_context_model=args.shared_query_context_model)
                biencoder_shared_query_context_model=\
                    args.biencoder_shared_query_context_model)
    return model

def get_group_world_size_rank():
@@ -120,7 +121,7 @@ def forward_step(data_iterator, model, input_tensor):
        return torch.cuda.FloatTensor([sum([int(i in sorted_indices[i, :k]) \
            for i in range(global_batch_size)]) / global_batch_size])

    topk_accs = [topk_accuracy(int(k)) for k in args.report_topk_accuracies]
    topk_accs = [topk_accuracy(int(k)) for k in args.retriever_report_topk_accuracies]

    labels = torch.arange(global_batch_size).long().cuda()
    loss = F.nll_loss(softmax_scores, labels, reduction='mean')
@@ -131,7 +132,7 @@ def forward_step(data_iterator, model, input_tensor):

    # create stats_dict with retrieval loss and all specified top-k accuracies
    topk_acc_dict = {'top{}_acc'.format(k): v * 100 for k, v in \
                        zip(args.report_topk_accuracies, reduced_losses[1:])}
                        zip(args.retriever_report_topk_accuracies, reduced_losses[1:])}
    stats_dict = dict(loss=reduced_losses[0], **topk_acc_dict)
    return loss, stats_dict