Loading megatron/checkpointing.py +3 −5 Original line number Diff line number Diff line Loading @@ -413,11 +413,9 @@ def load_biencoder_checkpoint(model, only_query_model=False, if only_context_model: ret_state_dict.pop('query_model') #print_rank_0(len(model)) #sys.exit() #assert len(model) == 1 #model[0].load_state_dict(ret_state_dict) model.load_state_dict(ret_state_dict) assert len(model) == 1 model[0].load_state_dict(ret_state_dict) torch.distributed.barrier() if mpu.get_data_parallel_rank() == 0: Loading megatron/indexer.py +12 −27 Original line number Diff line number Diff line Loading @@ -45,26 +45,25 @@ class IndexBuilder(object): """ Load the necessary attributes: model, dataloader and empty BlockData """ args = get_args() only_context_model = True if self.biencoder_shared_query_context_model: only_context_model = False model = get_model(lambda: biencoder_model_provider(only_context_model \ = only_context_model, biencoder_shared_query_context_model = \ self.biencoder_shared_query_context_model, \ pre_process=self.pre_process, post_process=self.post_process)) args.only_context_model = only_context_model args.only_query_model = False model = get_model(biencoder_model_provider) #model = biencoder_model_provider(only_context_model \ #model = get_model(lambda: biencoder_model_provider(only_context_model \ # = only_context_model, biencoder_shared_query_context_model = \ # self.biencoder_shared_query_context_model, \ # pre_process=self.pre_process, post_process=self.post_process) # self.biencoder_shared_query_context_model)) self.model = load_biencoder_checkpoint(model, only_context_model=only_context_model) #assert len(self.model) == 1 #self.model[0].eval() self.model.eval() assert len(self.model) == 1 self.model[0].eval() self.dataset = get_open_retrieval_wiki_dataset() self.dataloader = iter(get_one_epoch_dataloader(self.dataset, \ Loading Loading @@ -92,12 +91,11 @@ class IndexBuilder(object): distributed setting will be consolidated by the rank 0 process and saved as a final pickled BlockData. """ #assert len(self.model) == 1 #unwrapped_model = self.model[0] unwrapped_model = self.model assert len(self.model) == 1 unwrapped_model = self.model[0] while not hasattr(unwrapped_model, 'embed_text'): unwrapped_model = unwrapped_model.module print_rank_0("hasattr") while True: try: Loading @@ -108,17 +106,6 @@ class IndexBuilder(object): except (StopIteration, IndexError): break print_rank_0(context_tokens) print_rank_0(context_mask) print_rank_0(context_types) #if torch.cuda.is_available(): # print_rank_0("cuda available") #print_rank_0(torch.cuda.current_device()) #print_rank_0(torch.cuda.get_device_name()) print_rank_0(next(unwrapped_model.parameters()).device) print_rank_0(next(unwrapped_model.context_model.parameters()).device) #print_rank_0("After get_open_retrieval_batch") # TODO: can we add with torch.no_grad() to reduce memory usage # detach, separate fields and add to BlockData assert context_mask.dtype == torch.bool Loading @@ -126,8 +113,6 @@ class IndexBuilder(object): unwrapped_model.context_model, context_tokens, context_mask, context_types) sys.exit() context_logits = detach(context_logits) row_id = detach(row_id) Loading megatron/model/biencoder_model.py +11 −9 Original line number Diff line number Diff line Loading @@ -15,14 +15,21 @@ from megatron.model.utils import init_method_normal from megatron.model.utils import scaled_init_method_normal from .module import MegatronModule def biencoder_model_provider(only_query_model=False, only_context_model=False, biencoder_shared_query_context_model=False, pre_process=True, #def biencoder_model_provider(only_query_model=False, # only_context_model=False, # biencoder_shared_query_context_model=False, # pre_process=True, # post_process=True): def biencoder_model_provider(pre_process=True, post_process=True): """Build the model.""" args = get_args() biencoder_shared_query_context_model = args.biencoder_shared_query_context_model only_context_model = args.only_context_model only_query_model = args.only_query_model assert mpu.get_tensor_model_parallel_world_size() == 1 and \ mpu.get_pipeline_model_parallel_world_size() == 1, \ "Model parallel size > 1 not supported for ICT" Loading Loading @@ -266,11 +273,6 @@ class PretrainedBertModel(MegatronModule): #extended_attention_mask = bert_extended_attention_mask(attention_mask) position_ids = bert_position_ids(input_ids) print_rank_0(input_ids.device) print_rank_0(position_ids.device) print_rank_0(extended_attention_mask.device) print_rank_0(tokentype_ids.device) lm_output = self.language_model(input_ids, position_ids, extended_attention_mask, Loading megatron/model/language_model.py +0 −5 Original line number Diff line number Diff line Loading @@ -338,11 +338,6 @@ class TransformerLanguageModel(MegatronModule): get_key_value=False, pooling_sequence_index=0, enc_hidden_states=None, output_enc_hidden=False): print_rank_0("before self.embedding") print_rank_0(enc_input_ids.device) print_rank_0(enc_position_ids.device) print_rank_0(tokentype_ids.device) # Embeddings. if self.pre_process: embedding_output = self.embedding(enc_input_ids, enc_position_ids, Loading pretrain_ict.py +9 −5 Original line number Diff line number Diff line Loading @@ -33,11 +33,15 @@ from megatron.utils import average_losses_across_data_parallel_group def pretrain_ict_model_provider(): args = get_args() model = biencoder_model_provider( only_context_model=False, only_query_model=False, biencoder_shared_query_context_model=\ args.biencoder_shared_query_context_model) args.only_context_model = False args.only_query_model = False model = biencoder_model_provider() #model = biencoder_model_provider( # only_context_model=False, # only_query_model=False, # biencoder_shared_query_context_model=\ # args.biencoder_shared_query_context_model) return model def get_group_world_size_rank(): Loading Loading
megatron/checkpointing.py +3 −5 Original line number Diff line number Diff line Loading @@ -413,11 +413,9 @@ def load_biencoder_checkpoint(model, only_query_model=False, if only_context_model: ret_state_dict.pop('query_model') #print_rank_0(len(model)) #sys.exit() #assert len(model) == 1 #model[0].load_state_dict(ret_state_dict) model.load_state_dict(ret_state_dict) assert len(model) == 1 model[0].load_state_dict(ret_state_dict) torch.distributed.barrier() if mpu.get_data_parallel_rank() == 0: Loading
megatron/indexer.py +12 −27 Original line number Diff line number Diff line Loading @@ -45,26 +45,25 @@ class IndexBuilder(object): """ Load the necessary attributes: model, dataloader and empty BlockData """ args = get_args() only_context_model = True if self.biencoder_shared_query_context_model: only_context_model = False model = get_model(lambda: biencoder_model_provider(only_context_model \ = only_context_model, biencoder_shared_query_context_model = \ self.biencoder_shared_query_context_model, \ pre_process=self.pre_process, post_process=self.post_process)) args.only_context_model = only_context_model args.only_query_model = False model = get_model(biencoder_model_provider) #model = biencoder_model_provider(only_context_model \ #model = get_model(lambda: biencoder_model_provider(only_context_model \ # = only_context_model, biencoder_shared_query_context_model = \ # self.biencoder_shared_query_context_model, \ # pre_process=self.pre_process, post_process=self.post_process) # self.biencoder_shared_query_context_model)) self.model = load_biencoder_checkpoint(model, only_context_model=only_context_model) #assert len(self.model) == 1 #self.model[0].eval() self.model.eval() assert len(self.model) == 1 self.model[0].eval() self.dataset = get_open_retrieval_wiki_dataset() self.dataloader = iter(get_one_epoch_dataloader(self.dataset, \ Loading Loading @@ -92,12 +91,11 @@ class IndexBuilder(object): distributed setting will be consolidated by the rank 0 process and saved as a final pickled BlockData. """ #assert len(self.model) == 1 #unwrapped_model = self.model[0] unwrapped_model = self.model assert len(self.model) == 1 unwrapped_model = self.model[0] while not hasattr(unwrapped_model, 'embed_text'): unwrapped_model = unwrapped_model.module print_rank_0("hasattr") while True: try: Loading @@ -108,17 +106,6 @@ class IndexBuilder(object): except (StopIteration, IndexError): break print_rank_0(context_tokens) print_rank_0(context_mask) print_rank_0(context_types) #if torch.cuda.is_available(): # print_rank_0("cuda available") #print_rank_0(torch.cuda.current_device()) #print_rank_0(torch.cuda.get_device_name()) print_rank_0(next(unwrapped_model.parameters()).device) print_rank_0(next(unwrapped_model.context_model.parameters()).device) #print_rank_0("After get_open_retrieval_batch") # TODO: can we add with torch.no_grad() to reduce memory usage # detach, separate fields and add to BlockData assert context_mask.dtype == torch.bool Loading @@ -126,8 +113,6 @@ class IndexBuilder(object): unwrapped_model.context_model, context_tokens, context_mask, context_types) sys.exit() context_logits = detach(context_logits) row_id = detach(row_id) Loading
megatron/model/biencoder_model.py +11 −9 Original line number Diff line number Diff line Loading @@ -15,14 +15,21 @@ from megatron.model.utils import init_method_normal from megatron.model.utils import scaled_init_method_normal from .module import MegatronModule def biencoder_model_provider(only_query_model=False, only_context_model=False, biencoder_shared_query_context_model=False, pre_process=True, #def biencoder_model_provider(only_query_model=False, # only_context_model=False, # biencoder_shared_query_context_model=False, # pre_process=True, # post_process=True): def biencoder_model_provider(pre_process=True, post_process=True): """Build the model.""" args = get_args() biencoder_shared_query_context_model = args.biencoder_shared_query_context_model only_context_model = args.only_context_model only_query_model = args.only_query_model assert mpu.get_tensor_model_parallel_world_size() == 1 and \ mpu.get_pipeline_model_parallel_world_size() == 1, \ "Model parallel size > 1 not supported for ICT" Loading Loading @@ -266,11 +273,6 @@ class PretrainedBertModel(MegatronModule): #extended_attention_mask = bert_extended_attention_mask(attention_mask) position_ids = bert_position_ids(input_ids) print_rank_0(input_ids.device) print_rank_0(position_ids.device) print_rank_0(extended_attention_mask.device) print_rank_0(tokentype_ids.device) lm_output = self.language_model(input_ids, position_ids, extended_attention_mask, Loading
megatron/model/language_model.py +0 −5 Original line number Diff line number Diff line Loading @@ -338,11 +338,6 @@ class TransformerLanguageModel(MegatronModule): get_key_value=False, pooling_sequence_index=0, enc_hidden_states=None, output_enc_hidden=False): print_rank_0("before self.embedding") print_rank_0(enc_input_ids.device) print_rank_0(enc_position_ids.device) print_rank_0(tokentype_ids.device) # Embeddings. if self.pre_process: embedding_output = self.embedding(enc_input_ids, enc_position_ids, Loading
pretrain_ict.py +9 −5 Original line number Diff line number Diff line Loading @@ -33,11 +33,15 @@ from megatron.utils import average_losses_across_data_parallel_group def pretrain_ict_model_provider(): args = get_args() model = biencoder_model_provider( only_context_model=False, only_query_model=False, biencoder_shared_query_context_model=\ args.biencoder_shared_query_context_model) args.only_context_model = False args.only_query_model = False model = biencoder_model_provider() #model = biencoder_model_provider( # only_context_model=False, # only_query_model=False, # biencoder_shared_query_context_model=\ # args.biencoder_shared_query_context_model) return model def get_group_world_size_rank(): Loading