Loading megatron/arguments.py +6 −0 Original line number Diff line number Diff line Loading @@ -478,6 +478,12 @@ def _add_learning_rate_args(parser): group.add_argument('--min-lr', type=float, default=0.0, help='Minumum value for learning rate. The scheduler' 'clip values below this threshold.') group.add_argument('--override-lr-new', action='store_true', help='Reset the values of the scheduler (learning rate,' 'warmup iterations, minimum learning rate, maximum ' 'number of iterations, and decay style from input ' 'arguments and ignore values from checkpoints. Note' 'that all the above values will be reset.') group.add_argument('--override-lr-scheduler', action='store_true', help='Reset the values of the scheduler (learning rate,' 'warmup iterations, minimum learning rate, maximum ' Loading megatron/checkpointing.py +5 −2 Original line number Diff line number Diff line Loading @@ -413,8 +413,11 @@ def load_biencoder_checkpoint(model, only_query_model=False, if only_context_model: ret_state_dict.pop('query_model') assert len(model) == 1 model[0].load_state_dict(ret_state_dict) #print_rank_0(len(model)) #sys.exit() #assert len(model) == 1 #model[0].load_state_dict(ret_state_dict) model.load_state_dict(ret_state_dict) torch.distributed.barrier() if mpu.get_data_parallel_rank() == 0: Loading megatron/indexer.py +32 −7 Original line number Diff line number Diff line Loading @@ -2,7 +2,7 @@ import sys import torch import torch.distributed as dist from megatron import get_args from megatron import get_args, print_rank_0 from megatron import mpu from megatron.checkpointing import load_biencoder_checkpoint from megatron.data.orqa_wiki_dataset import get_open_retrieval_wiki_dataset Loading @@ -25,6 +25,8 @@ class IndexBuilder(object): self.evidence_embedder_obj = None self.biencoder_shared_query_context_model = \ args.biencoder_shared_query_context_model self.pre_process = True self.post_process = True # need to know whether we're using a REALM checkpoint (args.load) # or ICT checkpoint Loading @@ -47,15 +49,22 @@ class IndexBuilder(object): if self.biencoder_shared_query_context_model: only_context_model = False model = get_model(lambda: biencoder_model_provider(only_context_model \ #model = get_model(lambda: biencoder_model_provider(only_context_model \ # = only_context_model, biencoder_shared_query_context_model = \ # self.biencoder_shared_query_context_model, \ # pre_process=self.pre_process, post_process=self.post_process)) model = biencoder_model_provider(only_context_model \ = only_context_model, biencoder_shared_query_context_model = \ self.biencoder_shared_query_context_model)) self.biencoder_shared_query_context_model, \ pre_process=self.pre_process, post_process=self.post_process) self.model = load_biencoder_checkpoint(model, only_context_model=only_context_model) assert len(self.model) == 1 self.model[0].eval() #assert len(self.model) == 1 #self.model[0].eval() self.model.eval() self.dataset = get_open_retrieval_wiki_dataset() self.dataloader = iter(get_one_epoch_dataloader(self.dataset, \ Loading Loading @@ -83,10 +92,12 @@ class IndexBuilder(object): distributed setting will be consolidated by the rank 0 process and saved as a final pickled BlockData. """ assert len(self.model) == 1 unwrapped_model = self.model[0] #assert len(self.model) == 1 #unwrapped_model = self.model[0] unwrapped_model = self.model while not hasattr(unwrapped_model, 'embed_text'): unwrapped_model = unwrapped_model.module print_rank_0("hasattr") while True: try: Loading @@ -97,12 +108,26 @@ class IndexBuilder(object): except (StopIteration, IndexError): break print_rank_0(context_tokens) print_rank_0(context_mask) print_rank_0(context_types) #if torch.cuda.is_available(): # print_rank_0("cuda available") #print_rank_0(torch.cuda.current_device()) #print_rank_0(torch.cuda.get_device_name()) print_rank_0(next(unwrapped_model.parameters()).device) print_rank_0(next(unwrapped_model.context_model.parameters()).device) #print_rank_0("After get_open_retrieval_batch") # TODO: can we add with torch.no_grad() to reduce memory usage # detach, separate fields and add to BlockData assert context_mask.dtype == torch.bool context_logits = unwrapped_model.embed_text( unwrapped_model.context_model, context_tokens, context_mask, context_types) sys.exit() context_logits = detach(context_logits) row_id = detach(row_id) Loading megatron/learning_rates.py +17 −1 Original line number Diff line number Diff line Loading @@ -18,6 +18,7 @@ import math from megatron import print_rank_0 from megatron import get_args class AnnealingLR(object): """Anneals the learning rate.""" Loading Loading @@ -59,6 +60,7 @@ class AnnealingLR(object): """Learning rate decay functions from: https://openreview.net/pdf?id=BJYwwY9ll pg. 4""" #print_rank_0("self.warmup_steps {} self.num_steps {} self.decay_steps {} self.min_lr {} self.maxlr {}".format(self.warmup_steps, self.num_steps, self.decay_steps, self.min_lr, self.max_lr)) # Use linear warmup for the initial part. if self.warmup_steps > 0 and self.num_steps <= self.warmup_steps: return self.max_lr * float(self.num_steps) / \ Loading Loading @@ -88,6 +90,20 @@ class AnnealingLR(object): raise Exception('{} decay style is not supported.'.format( self.decay_style)) args = get_args() if args.override_lr_new: mod_num_steps_ = min(self.num_steps, self.decay_steps - self.warmup_steps) mod_num_steps_ = mod_num_steps_ - self.warmup_steps use_lr = delta_lr * float(self.decay_steps - mod_num_steps_) / float(self.decay_steps) should_use_lr = self.min_lr + coeff * delta_lr print_rank_0("num_steps {} decay_steps {} decay_ratio {} coeff {} delta_lr {} use lr {} should_use_lr {} self.warmup_steps {} self.num_steps {} self.decay_steps {}".format(num_steps_, decay_steps_, decay_ratio, coeff, delta_lr, use_lr, should_use_lr, self.warmup_steps, self.num_steps, self.decay_steps)) else: use_lr = self.min_lr + coeff * delta_lr print_rank_0("num_steps {} decay_steps {} decay_ratio {} coeff {} delta_lr {} use lr {} self.warmup_steps {} self.num_steps {} self.decay_steps {}".format(num_steps_, decay_steps_, decay_ratio, coeff, delta_lr, use_lr, self.warmup_steps, self.num_steps, self.decay_steps)) return use_lr return self.min_lr + coeff * delta_lr Loading megatron/model/biencoder_model.py +4 −0 Original line number Diff line number Diff line Loading @@ -266,6 +266,10 @@ class PretrainedBertModel(MegatronModule): #extended_attention_mask = bert_extended_attention_mask(attention_mask) position_ids = bert_position_ids(input_ids) print_rank_0(input_ids.device) print_rank_0(position_ids.device) print_rank_0(extended_attention_mask.device) print_rank_0(tokentype_ids.device) lm_output = self.language_model(input_ids, position_ids, Loading Loading
megatron/arguments.py +6 −0 Original line number Diff line number Diff line Loading @@ -478,6 +478,12 @@ def _add_learning_rate_args(parser): group.add_argument('--min-lr', type=float, default=0.0, help='Minumum value for learning rate. The scheduler' 'clip values below this threshold.') group.add_argument('--override-lr-new', action='store_true', help='Reset the values of the scheduler (learning rate,' 'warmup iterations, minimum learning rate, maximum ' 'number of iterations, and decay style from input ' 'arguments and ignore values from checkpoints. Note' 'that all the above values will be reset.') group.add_argument('--override-lr-scheduler', action='store_true', help='Reset the values of the scheduler (learning rate,' 'warmup iterations, minimum learning rate, maximum ' Loading
megatron/checkpointing.py +5 −2 Original line number Diff line number Diff line Loading @@ -413,8 +413,11 @@ def load_biencoder_checkpoint(model, only_query_model=False, if only_context_model: ret_state_dict.pop('query_model') assert len(model) == 1 model[0].load_state_dict(ret_state_dict) #print_rank_0(len(model)) #sys.exit() #assert len(model) == 1 #model[0].load_state_dict(ret_state_dict) model.load_state_dict(ret_state_dict) torch.distributed.barrier() if mpu.get_data_parallel_rank() == 0: Loading
megatron/indexer.py +32 −7 Original line number Diff line number Diff line Loading @@ -2,7 +2,7 @@ import sys import torch import torch.distributed as dist from megatron import get_args from megatron import get_args, print_rank_0 from megatron import mpu from megatron.checkpointing import load_biencoder_checkpoint from megatron.data.orqa_wiki_dataset import get_open_retrieval_wiki_dataset Loading @@ -25,6 +25,8 @@ class IndexBuilder(object): self.evidence_embedder_obj = None self.biencoder_shared_query_context_model = \ args.biencoder_shared_query_context_model self.pre_process = True self.post_process = True # need to know whether we're using a REALM checkpoint (args.load) # or ICT checkpoint Loading @@ -47,15 +49,22 @@ class IndexBuilder(object): if self.biencoder_shared_query_context_model: only_context_model = False model = get_model(lambda: biencoder_model_provider(only_context_model \ #model = get_model(lambda: biencoder_model_provider(only_context_model \ # = only_context_model, biencoder_shared_query_context_model = \ # self.biencoder_shared_query_context_model, \ # pre_process=self.pre_process, post_process=self.post_process)) model = biencoder_model_provider(only_context_model \ = only_context_model, biencoder_shared_query_context_model = \ self.biencoder_shared_query_context_model)) self.biencoder_shared_query_context_model, \ pre_process=self.pre_process, post_process=self.post_process) self.model = load_biencoder_checkpoint(model, only_context_model=only_context_model) assert len(self.model) == 1 self.model[0].eval() #assert len(self.model) == 1 #self.model[0].eval() self.model.eval() self.dataset = get_open_retrieval_wiki_dataset() self.dataloader = iter(get_one_epoch_dataloader(self.dataset, \ Loading Loading @@ -83,10 +92,12 @@ class IndexBuilder(object): distributed setting will be consolidated by the rank 0 process and saved as a final pickled BlockData. """ assert len(self.model) == 1 unwrapped_model = self.model[0] #assert len(self.model) == 1 #unwrapped_model = self.model[0] unwrapped_model = self.model while not hasattr(unwrapped_model, 'embed_text'): unwrapped_model = unwrapped_model.module print_rank_0("hasattr") while True: try: Loading @@ -97,12 +108,26 @@ class IndexBuilder(object): except (StopIteration, IndexError): break print_rank_0(context_tokens) print_rank_0(context_mask) print_rank_0(context_types) #if torch.cuda.is_available(): # print_rank_0("cuda available") #print_rank_0(torch.cuda.current_device()) #print_rank_0(torch.cuda.get_device_name()) print_rank_0(next(unwrapped_model.parameters()).device) print_rank_0(next(unwrapped_model.context_model.parameters()).device) #print_rank_0("After get_open_retrieval_batch") # TODO: can we add with torch.no_grad() to reduce memory usage # detach, separate fields and add to BlockData assert context_mask.dtype == torch.bool context_logits = unwrapped_model.embed_text( unwrapped_model.context_model, context_tokens, context_mask, context_types) sys.exit() context_logits = detach(context_logits) row_id = detach(row_id) Loading
megatron/learning_rates.py +17 −1 Original line number Diff line number Diff line Loading @@ -18,6 +18,7 @@ import math from megatron import print_rank_0 from megatron import get_args class AnnealingLR(object): """Anneals the learning rate.""" Loading Loading @@ -59,6 +60,7 @@ class AnnealingLR(object): """Learning rate decay functions from: https://openreview.net/pdf?id=BJYwwY9ll pg. 4""" #print_rank_0("self.warmup_steps {} self.num_steps {} self.decay_steps {} self.min_lr {} self.maxlr {}".format(self.warmup_steps, self.num_steps, self.decay_steps, self.min_lr, self.max_lr)) # Use linear warmup for the initial part. if self.warmup_steps > 0 and self.num_steps <= self.warmup_steps: return self.max_lr * float(self.num_steps) / \ Loading Loading @@ -88,6 +90,20 @@ class AnnealingLR(object): raise Exception('{} decay style is not supported.'.format( self.decay_style)) args = get_args() if args.override_lr_new: mod_num_steps_ = min(self.num_steps, self.decay_steps - self.warmup_steps) mod_num_steps_ = mod_num_steps_ - self.warmup_steps use_lr = delta_lr * float(self.decay_steps - mod_num_steps_) / float(self.decay_steps) should_use_lr = self.min_lr + coeff * delta_lr print_rank_0("num_steps {} decay_steps {} decay_ratio {} coeff {} delta_lr {} use lr {} should_use_lr {} self.warmup_steps {} self.num_steps {} self.decay_steps {}".format(num_steps_, decay_steps_, decay_ratio, coeff, delta_lr, use_lr, should_use_lr, self.warmup_steps, self.num_steps, self.decay_steps)) else: use_lr = self.min_lr + coeff * delta_lr print_rank_0("num_steps {} decay_steps {} decay_ratio {} coeff {} delta_lr {} use lr {} self.warmup_steps {} self.num_steps {} self.decay_steps {}".format(num_steps_, decay_steps_, decay_ratio, coeff, delta_lr, use_lr, self.warmup_steps, self.num_steps, self.decay_steps)) return use_lr return self.min_lr + coeff * delta_lr Loading
megatron/model/biencoder_model.py +4 −0 Original line number Diff line number Diff line Loading @@ -266,6 +266,10 @@ class PretrainedBertModel(MegatronModule): #extended_attention_mask = bert_extended_attention_mask(attention_mask) position_ids = bert_position_ids(input_ids) print_rank_0(input_ids.device) print_rank_0(position_ids.device) print_rank_0(extended_attention_mask.device) print_rank_0(tokentype_ids.device) lm_output = self.language_model(input_ids, position_ids, Loading