Loading examples/pretrain_ict.sh +1 −1 Original line number Diff line number Diff line Loading @@ -27,7 +27,7 @@ python pretrain_ict.py \ --log-interval 100 \ --eval-interval 1000 \ --eval-iters 10 \ --report-topk-accuracies 1 5 10 20 100 \ --retriever-report-topk-accuracies 1 5 10 20 100 \ --retriever-score-scaling \ --load $CHECKPOINT_PATH \ --save $CHECKPOINT_PATH \ Loading megatron/arguments.py +3 −7 Original line number Diff line number Diff line Loading @@ -649,16 +649,12 @@ def _add_biencoder_args(parser): group.add_argument('--ict-head-size', type=int, default=None, help='Size of block embeddings to be used in ICT and ' 'REALM (paper default: 128)') group.add_argument('--projection-dim', type=int, default=0, group.add_argument('--biencoder-projection-dim', type=int, default=0, help='Size of projection head used in biencoder (paper' ' default: 128)') group.add_argument('--shared-query-context-model', action='store_true', group.add_argument('--biencoder-shared-query-context-model', action='store_true', help='Whether to share the parameters of the query ' 'and context models or not') group.add_argument('--pool-type', type=str, default='cls-token', choices=['avg', 'cls-token', 'max'], help='different options are: avg | cls-token | max, ' 'default=cls-token') # checkpointing group.add_argument('--ict-load', type=str, default=None, Loading @@ -677,7 +673,7 @@ def _add_biencoder_args(parser): help='Whether to use one sentence documents in ICT') # training group.add_argument('--report-topk-accuracies', nargs='+', type=int, group.add_argument('--retriever-report-topk-accuracies', nargs='+', type=int, default=[], help="Which top-k accuracies to report " "(e.g. '1 5 20')") group.add_argument('--retriever-score-scaling', action='store_true', Loading megatron/model/biencoder_model.py +23 −29 Original line number Diff line number Diff line Loading @@ -17,7 +17,7 @@ from .module import MegatronModule def biencoder_model_provider(only_query_model=False, only_context_model=False, shared_query_context_model=False): biencoder_shared_query_context_model=False): """Build the model.""" args = get_args() Loading @@ -31,10 +31,11 @@ def biencoder_model_provider(only_query_model=False, # the LM we initialize with has 2 tokentypes model = BiEncoderModel( num_tokentypes=2, parallel_output=True, parallel_output=False, only_query_model=only_query_model, only_context_model=only_context_model, shared_query_context_model=shared_query_context_model) biencoder_shared_query_context_model=\ biencoder_shared_query_context_model) return model Loading @@ -47,7 +48,7 @@ class BiEncoderModel(MegatronModule): parallel_output=True, only_query_model=False, only_context_model=False, shared_query_context_model=False): biencoder_shared_query_context_model=False): super(BiEncoderModel, self).__init__() args = get_args() Loading @@ -55,13 +56,14 @@ class BiEncoderModel(MegatronModule): num_tokentypes=num_tokentypes, parallel_output=parallel_output) self.shared_query_context_model = shared_query_context_model self.biencoder_shared_query_context_model = \ biencoder_shared_query_context_model assert not (only_context_model and only_query_model) self.use_context_model = not only_query_model self.use_query_model = not only_context_model self.projection_dim = args.projection_dim self.biencoder_projection_dim = args.biencoder_projection_dim if self.shared_query_context_model: if self.biencoder_shared_query_context_model: self.model = PretrainedBertModel(**bert_kwargs) self._model_key = 'shared_model' self.query_model, self.context_model = self.model, self.model Loading Loading @@ -109,7 +111,7 @@ class BiEncoderModel(MegatronModule): prefix='', keep_vars=False): """Save dict with state dicts of each of the models.""" state_dict_ = {} if self.shared_query_context_model: if self.biencoder_shared_query_context_model: state_dict_[self._model_key] = \ self.model.state_dict_for_save_checkpoint(destination, prefix, Loading @@ -129,7 +131,7 @@ class BiEncoderModel(MegatronModule): def load_state_dict(self, state_dict, strict=True): """Load the state dicts of each of the models""" if self.shared_query_context_model: if self.biencoder_shared_query_context_model: print_rank_0("Loading shared query-context model") self.model.load_state_dict(state_dict[self._model_key], \ strict=strict) Loading Loading @@ -188,14 +190,14 @@ class BiEncoderModel(MegatronModule): # load the LM state dict into each model model_dict = state_dict['model']['language_model'] if self.shared_query_context_model: if self.biencoder_shared_query_context_model: self.model.language_model.load_state_dict(model_dict) fix_query_key_value_ordering(self.model, checkpoint_version) else: if self.use_query_model: self.query_model.language_model.load_state_dict(model_dict) # give each model the same ict_head to begin with as well if self.projection_dim > 0: if self.biencoder_projection_dim > 0: query_proj_state_dict = \ self.state_dict_for_save_checkpoint()\ [self._query_key]['projection_enc'] Loading @@ -203,7 +205,8 @@ class BiEncoderModel(MegatronModule): if self.use_context_model: self.context_model.language_model.load_state_dict(model_dict) if self.query_model is not None and self.projection_dim > 0: if self.query_model is not None and \ self.biencoder_projection_dim > 0: self.context_model.projection_enc.load_state_dict\ (query_proj_state_dict) fix_query_key_value_ordering(self.context_model, checkpoint_version) Loading @@ -220,8 +223,7 @@ class PretrainedBertModel(MegatronModule): args = get_args() tokenizer = get_tokenizer() self.pad_id = tokenizer.pad self.pool_type = args.pool_type self.projection_dim = args.projection_dim self.biencoder_projection_dim = args.biencoder_projection_dim self.parallel_output = parallel_output init_method = init_method_normal(args.init_method_std) scaled_init_method = scaled_init_method_normal( Loading @@ -234,9 +236,9 @@ class PretrainedBertModel(MegatronModule): init_method=init_method, scaled_init_method=scaled_init_method) if args.projection_dim > 0: if args.biencoder_projection_dim > 0: self.projection_enc = get_linear_layer(args.hidden_size, args.projection_dim, args.biencoder_projection_dim, init_method) self._projection_enc_key = 'projection_enc' Loading @@ -254,21 +256,13 @@ class PretrainedBertModel(MegatronModule): pool_mask = (input_ids == self.pad_id).unsqueeze(2) # Taking the representation of the [CLS] token of BERT if self.pool_type == "cls-token": pooled_output = lm_output[:, 0, :] elif self.pool_type == "avg": # Average Pooling pooled_output = lm_output.masked_fill(pool_mask, 0) pooled_output = pooled_output.sum(1) / (pool_mask.size(1) \ - pool_mask.float().sum(1)) elif self.pool_type == "max": # Max-Pooling pooled_output = lm_output.masked_fill(pool_mask, -1000) pooled_output = torch.max(pooled_output, 1)[0] # Converting to float16 dtype pooled_output = pooled_output.to(lm_output.dtype) # Output. if self.projection_dim: if self.biencoder_projection_dim: pooled_output = self.projection_enc(pooled_output) return pooled_output Loading @@ -283,7 +277,7 @@ class PretrainedBertModel(MegatronModule): = self.language_model.state_dict_for_save_checkpoint( destination, prefix, keep_vars) if self.projection_dim > 0: if self.biencoder_projection_dim > 0: state_dict_[self._projection_enc_key] = \ self.projection_enc.state_dict(destination, prefix, keep_vars) Loading @@ -295,7 +289,7 @@ class PretrainedBertModel(MegatronModule): self.language_model.load_state_dict( state_dict[self._language_model_key], strict=strict) if self.projection_dim > 0: if self.biencoder_projection_dim > 0: print_rank_0("loading projection head weights") self.projection_enc.load_state_dict( state_dict[self._projection_enc_key], strict=strict) pretrain_ict.py +4 −3 Original line number Diff line number Diff line Loading @@ -36,7 +36,8 @@ def pretrain_ict_model_provider(): model = biencoder_model_provider( only_context_model=False, only_query_model=False, shared_query_context_model=args.shared_query_context_model) biencoder_shared_query_context_model=\ args.biencoder_shared_query_context_model) return model def get_group_world_size_rank(): Loading Loading @@ -120,7 +121,7 @@ def forward_step(data_iterator, model, input_tensor): return torch.cuda.FloatTensor([sum([int(i in sorted_indices[i, :k]) \ for i in range(global_batch_size)]) / global_batch_size]) topk_accs = [topk_accuracy(int(k)) for k in args.report_topk_accuracies] topk_accs = [topk_accuracy(int(k)) for k in args.retriever_report_topk_accuracies] labels = torch.arange(global_batch_size).long().cuda() loss = F.nll_loss(softmax_scores, labels, reduction='mean') Loading @@ -131,7 +132,7 @@ def forward_step(data_iterator, model, input_tensor): # create stats_dict with retrieval loss and all specified top-k accuracies topk_acc_dict = {'top{}_acc'.format(k): v * 100 for k, v in \ zip(args.report_topk_accuracies, reduced_losses[1:])} zip(args.retriever_report_topk_accuracies, reduced_losses[1:])} stats_dict = dict(loss=reduced_losses[0], **topk_acc_dict) return loss, stats_dict Loading Loading
examples/pretrain_ict.sh +1 −1 Original line number Diff line number Diff line Loading @@ -27,7 +27,7 @@ python pretrain_ict.py \ --log-interval 100 \ --eval-interval 1000 \ --eval-iters 10 \ --report-topk-accuracies 1 5 10 20 100 \ --retriever-report-topk-accuracies 1 5 10 20 100 \ --retriever-score-scaling \ --load $CHECKPOINT_PATH \ --save $CHECKPOINT_PATH \ Loading
megatron/arguments.py +3 −7 Original line number Diff line number Diff line Loading @@ -649,16 +649,12 @@ def _add_biencoder_args(parser): group.add_argument('--ict-head-size', type=int, default=None, help='Size of block embeddings to be used in ICT and ' 'REALM (paper default: 128)') group.add_argument('--projection-dim', type=int, default=0, group.add_argument('--biencoder-projection-dim', type=int, default=0, help='Size of projection head used in biencoder (paper' ' default: 128)') group.add_argument('--shared-query-context-model', action='store_true', group.add_argument('--biencoder-shared-query-context-model', action='store_true', help='Whether to share the parameters of the query ' 'and context models or not') group.add_argument('--pool-type', type=str, default='cls-token', choices=['avg', 'cls-token', 'max'], help='different options are: avg | cls-token | max, ' 'default=cls-token') # checkpointing group.add_argument('--ict-load', type=str, default=None, Loading @@ -677,7 +673,7 @@ def _add_biencoder_args(parser): help='Whether to use one sentence documents in ICT') # training group.add_argument('--report-topk-accuracies', nargs='+', type=int, group.add_argument('--retriever-report-topk-accuracies', nargs='+', type=int, default=[], help="Which top-k accuracies to report " "(e.g. '1 5 20')") group.add_argument('--retriever-score-scaling', action='store_true', Loading
megatron/model/biencoder_model.py +23 −29 Original line number Diff line number Diff line Loading @@ -17,7 +17,7 @@ from .module import MegatronModule def biencoder_model_provider(only_query_model=False, only_context_model=False, shared_query_context_model=False): biencoder_shared_query_context_model=False): """Build the model.""" args = get_args() Loading @@ -31,10 +31,11 @@ def biencoder_model_provider(only_query_model=False, # the LM we initialize with has 2 tokentypes model = BiEncoderModel( num_tokentypes=2, parallel_output=True, parallel_output=False, only_query_model=only_query_model, only_context_model=only_context_model, shared_query_context_model=shared_query_context_model) biencoder_shared_query_context_model=\ biencoder_shared_query_context_model) return model Loading @@ -47,7 +48,7 @@ class BiEncoderModel(MegatronModule): parallel_output=True, only_query_model=False, only_context_model=False, shared_query_context_model=False): biencoder_shared_query_context_model=False): super(BiEncoderModel, self).__init__() args = get_args() Loading @@ -55,13 +56,14 @@ class BiEncoderModel(MegatronModule): num_tokentypes=num_tokentypes, parallel_output=parallel_output) self.shared_query_context_model = shared_query_context_model self.biencoder_shared_query_context_model = \ biencoder_shared_query_context_model assert not (only_context_model and only_query_model) self.use_context_model = not only_query_model self.use_query_model = not only_context_model self.projection_dim = args.projection_dim self.biencoder_projection_dim = args.biencoder_projection_dim if self.shared_query_context_model: if self.biencoder_shared_query_context_model: self.model = PretrainedBertModel(**bert_kwargs) self._model_key = 'shared_model' self.query_model, self.context_model = self.model, self.model Loading Loading @@ -109,7 +111,7 @@ class BiEncoderModel(MegatronModule): prefix='', keep_vars=False): """Save dict with state dicts of each of the models.""" state_dict_ = {} if self.shared_query_context_model: if self.biencoder_shared_query_context_model: state_dict_[self._model_key] = \ self.model.state_dict_for_save_checkpoint(destination, prefix, Loading @@ -129,7 +131,7 @@ class BiEncoderModel(MegatronModule): def load_state_dict(self, state_dict, strict=True): """Load the state dicts of each of the models""" if self.shared_query_context_model: if self.biencoder_shared_query_context_model: print_rank_0("Loading shared query-context model") self.model.load_state_dict(state_dict[self._model_key], \ strict=strict) Loading Loading @@ -188,14 +190,14 @@ class BiEncoderModel(MegatronModule): # load the LM state dict into each model model_dict = state_dict['model']['language_model'] if self.shared_query_context_model: if self.biencoder_shared_query_context_model: self.model.language_model.load_state_dict(model_dict) fix_query_key_value_ordering(self.model, checkpoint_version) else: if self.use_query_model: self.query_model.language_model.load_state_dict(model_dict) # give each model the same ict_head to begin with as well if self.projection_dim > 0: if self.biencoder_projection_dim > 0: query_proj_state_dict = \ self.state_dict_for_save_checkpoint()\ [self._query_key]['projection_enc'] Loading @@ -203,7 +205,8 @@ class BiEncoderModel(MegatronModule): if self.use_context_model: self.context_model.language_model.load_state_dict(model_dict) if self.query_model is not None and self.projection_dim > 0: if self.query_model is not None and \ self.biencoder_projection_dim > 0: self.context_model.projection_enc.load_state_dict\ (query_proj_state_dict) fix_query_key_value_ordering(self.context_model, checkpoint_version) Loading @@ -220,8 +223,7 @@ class PretrainedBertModel(MegatronModule): args = get_args() tokenizer = get_tokenizer() self.pad_id = tokenizer.pad self.pool_type = args.pool_type self.projection_dim = args.projection_dim self.biencoder_projection_dim = args.biencoder_projection_dim self.parallel_output = parallel_output init_method = init_method_normal(args.init_method_std) scaled_init_method = scaled_init_method_normal( Loading @@ -234,9 +236,9 @@ class PretrainedBertModel(MegatronModule): init_method=init_method, scaled_init_method=scaled_init_method) if args.projection_dim > 0: if args.biencoder_projection_dim > 0: self.projection_enc = get_linear_layer(args.hidden_size, args.projection_dim, args.biencoder_projection_dim, init_method) self._projection_enc_key = 'projection_enc' Loading @@ -254,21 +256,13 @@ class PretrainedBertModel(MegatronModule): pool_mask = (input_ids == self.pad_id).unsqueeze(2) # Taking the representation of the [CLS] token of BERT if self.pool_type == "cls-token": pooled_output = lm_output[:, 0, :] elif self.pool_type == "avg": # Average Pooling pooled_output = lm_output.masked_fill(pool_mask, 0) pooled_output = pooled_output.sum(1) / (pool_mask.size(1) \ - pool_mask.float().sum(1)) elif self.pool_type == "max": # Max-Pooling pooled_output = lm_output.masked_fill(pool_mask, -1000) pooled_output = torch.max(pooled_output, 1)[0] # Converting to float16 dtype pooled_output = pooled_output.to(lm_output.dtype) # Output. if self.projection_dim: if self.biencoder_projection_dim: pooled_output = self.projection_enc(pooled_output) return pooled_output Loading @@ -283,7 +277,7 @@ class PretrainedBertModel(MegatronModule): = self.language_model.state_dict_for_save_checkpoint( destination, prefix, keep_vars) if self.projection_dim > 0: if self.biencoder_projection_dim > 0: state_dict_[self._projection_enc_key] = \ self.projection_enc.state_dict(destination, prefix, keep_vars) Loading @@ -295,7 +289,7 @@ class PretrainedBertModel(MegatronModule): self.language_model.load_state_dict( state_dict[self._language_model_key], strict=strict) if self.projection_dim > 0: if self.biencoder_projection_dim > 0: print_rank_0("loading projection head weights") self.projection_enc.load_state_dict( state_dict[self._projection_enc_key], strict=strict)
pretrain_ict.py +4 −3 Original line number Diff line number Diff line Loading @@ -36,7 +36,8 @@ def pretrain_ict_model_provider(): model = biencoder_model_provider( only_context_model=False, only_query_model=False, shared_query_context_model=args.shared_query_context_model) biencoder_shared_query_context_model=\ args.biencoder_shared_query_context_model) return model def get_group_world_size_rank(): Loading Loading @@ -120,7 +121,7 @@ def forward_step(data_iterator, model, input_tensor): return torch.cuda.FloatTensor([sum([int(i in sorted_indices[i, :k]) \ for i in range(global_batch_size)]) / global_batch_size]) topk_accs = [topk_accuracy(int(k)) for k in args.report_topk_accuracies] topk_accs = [topk_accuracy(int(k)) for k in args.retriever_report_topk_accuracies] labels = torch.arange(global_batch_size).long().cuda() loss = F.nll_loss(softmax_scores, labels, reduction='mean') Loading @@ -131,7 +132,7 @@ def forward_step(data_iterator, model, input_tensor): # create stats_dict with retrieval loss and all specified top-k accuracies topk_acc_dict = {'top{}_acc'.format(k): v * 100 for k, v in \ zip(args.report_topk_accuracies, reduced_losses[1:])} zip(args.retriever_report_topk_accuracies, reduced_losses[1:])} stats_dict = dict(loss=reduced_losses[0], **topk_acc_dict) return loss, stats_dict Loading