Loading megatron/arguments.py +0 −6 Original line number Diff line number Diff line Loading @@ -479,12 +479,6 @@ def _add_learning_rate_args(parser): group.add_argument('--min-lr', type=float, default=0.0, help='Minumum value for learning rate. The scheduler' 'clip values below this threshold.') group.add_argument('--override-lr-new', action='store_true', help='Reset the values of the scheduler (learning rate,' 'warmup iterations, minimum learning rate, maximum ' 'number of iterations, and decay style from input ' 'arguments and ignore values from checkpoints. Note' 'that all the above values will be reset.') group.add_argument('--override-lr-scheduler', action='store_true', help='Reset the values of the scheduler (learning rate,' 'warmup iterations, minimum learning rate, maximum ' Loading megatron/checkpointing.py +0 −1 Original line number Diff line number Diff line Loading @@ -419,7 +419,6 @@ def load_biencoder_checkpoint(model, only_query_model=False, assert len(model) == 1 model[0].load_state_dict(ret_state_dict) torch.distributed.barrier() if mpu.get_data_parallel_rank() == 0: Loading megatron/indexer.py +5 −34 Original line number Diff line number Diff line Loading @@ -26,13 +26,10 @@ class IndexBuilder(object): self.evidence_embedder_obj = None self.biencoder_shared_query_context_model = \ args.biencoder_shared_query_context_model #self.pre_process = True #self.post_process = True # need to know whether we're using a REALM checkpoint (args.load) # or ICT checkpoint assert not (args.load and args.ict_load) #self.using_realm_chkpt = args.ict_load is None self.log_interval = args.indexer_log_interval self.batch_size = args.indexer_batch_size Loading @@ -46,24 +43,13 @@ class IndexBuilder(object): """ Load the necessary attributes: model, dataloader and empty BlockData """ #args = get_args() only_context_model = True if self.biencoder_shared_query_context_model: only_context_model = False #args.only_context_model = only_context_model #args.only_query_model = False #model = get_model(biencoder_model_provider) model = get_model(get_model_provider(only_context_model=only_context_model, biencoder_shared_query_context_model=self.biencoder_shared_query_context_model)) #model = get_model(lambda: biencoder_model_provider(only_context_model \ #model = get_model(lambda: biencoder_model_provider(only_context_model \ # = only_context_model, biencoder_shared_query_context_model = \ # self.biencoder_shared_query_context_model, # pre_process=True, post_process=True) model = get_model(get_model_provider(only_context_model=\ only_context_model, biencoder_shared_query_context_model=\ self.biencoder_shared_query_context_model)) self.model = load_biencoder_checkpoint(model, only_context_model=only_context_model) Loading Loading @@ -103,12 +89,7 @@ class IndexBuilder(object): while not hasattr(unwrapped_model, 'embed_text'): unwrapped_model = unwrapped_model.module #counter = 0 #start_time = time.time() #cur_time = start_time while True: #start_time = time.time() #t1 = time.time() try: # batch also has query_tokens and query_pad_data row_id, context_tokens, context_mask, context_types, \ Loading @@ -117,8 +98,6 @@ class IndexBuilder(object): except (StopIteration, IndexError): break #print_rank_0("get batch time {}".format(cur_time - time.time())) #t2 = time.time() # TODO: can we add with torch.no_grad() to reduce memory usage # detach, separate fields and add to BlockData assert context_mask.dtype == torch.bool Loading @@ -128,18 +107,10 @@ class IndexBuilder(object): context_logits = detach(context_logits) row_id = detach(row_id) #print_rank_0("embed text {}".format(cur_time - time.time())) #t3 = time.time() self.evidence_embedder_obj.add_block_data(row_id, context_logits) self.track_and_report_progress(batch_size=len(row_id)) #print_rank_0("add block time {}".format(cur_time - time.time())) #t4 = time.time() #counter += 1 #if counter % 1000 == 0: # print_rank_0("total time {} 1000 iter time {}".format(time.time() - start_time, time.time() - cur_time)) # print_rank_0("breakdown batch {} model {} block {}".format(t2 - t1, t3 - t2, t4 -t3)) # cur_time = time.time() # This process signals to finalize its shard and then synchronize with # the other processes self.evidence_embedder_obj.save_shard() Loading megatron/learning_rates.py +0 −16 Original line number Diff line number Diff line Loading @@ -18,7 +18,6 @@ import math from megatron import print_rank_0 from megatron import get_args class AnnealingLR(object): """Anneals the learning rate.""" Loading Loading @@ -60,7 +59,6 @@ class AnnealingLR(object): """Learning rate decay functions from: https://openreview.net/pdf?id=BJYwwY9ll pg. 4""" #print_rank_0("self.warmup_steps {} self.num_steps {} self.decay_steps {} self.min_lr {} self.maxlr {}".format(self.warmup_steps, self.num_steps, self.decay_steps, self.min_lr, self.max_lr)) # Use linear warmup for the initial part. if self.warmup_steps > 0 and self.num_steps <= self.warmup_steps: return self.max_lr * float(self.num_steps) / \ Loading Loading @@ -90,20 +88,6 @@ class AnnealingLR(object): raise Exception('{} decay style is not supported.'.format( self.decay_style)) args = get_args() if args.override_lr_new: mod_num_steps_ = min(self.num_steps, self.decay_steps - self.warmup_steps) mod_num_steps_ = mod_num_steps_ - self.warmup_steps use_lr = delta_lr * float(self.decay_steps - mod_num_steps_) / float(self.decay_steps) should_use_lr = self.min_lr + coeff * delta_lr print_rank_0("num_steps {} decay_steps {} decay_ratio {} coeff {} delta_lr {} use lr {} should_use_lr {} self.warmup_steps {} self.num_steps {} self.decay_steps {}".format(num_steps_, decay_steps_, decay_ratio, coeff, delta_lr, use_lr, should_use_lr, self.warmup_steps, self.num_steps, self.decay_steps)) else: use_lr = self.min_lr + coeff * delta_lr print_rank_0("num_steps {} decay_steps {} decay_ratio {} coeff {} delta_lr {} use lr {} self.warmup_steps {} self.num_steps {} self.decay_steps {}".format(num_steps_, decay_steps_, decay_ratio, coeff, delta_lr, use_lr, self.warmup_steps, self.num_steps, self.decay_steps)) return use_lr return self.min_lr + coeff * delta_lr Loading tasks/finetune_utils.py +2 −1 Original line number Diff line number Diff line Loading @@ -114,7 +114,8 @@ def _build_infinite_size_dataloader(dataloader): iterator = dataloader.__iter__() def _build_train_valid_dataloaders(train_dataset, valid_dataset, task_collate_fn=None): def _build_train_valid_dataloaders(train_dataset, valid_dataset, task_collate_fn=None): """Traing and validation dataloaders.""" args = get_args() Loading Loading
megatron/arguments.py +0 −6 Original line number Diff line number Diff line Loading @@ -479,12 +479,6 @@ def _add_learning_rate_args(parser): group.add_argument('--min-lr', type=float, default=0.0, help='Minumum value for learning rate. The scheduler' 'clip values below this threshold.') group.add_argument('--override-lr-new', action='store_true', help='Reset the values of the scheduler (learning rate,' 'warmup iterations, minimum learning rate, maximum ' 'number of iterations, and decay style from input ' 'arguments and ignore values from checkpoints. Note' 'that all the above values will be reset.') group.add_argument('--override-lr-scheduler', action='store_true', help='Reset the values of the scheduler (learning rate,' 'warmup iterations, minimum learning rate, maximum ' Loading
megatron/checkpointing.py +0 −1 Original line number Diff line number Diff line Loading @@ -419,7 +419,6 @@ def load_biencoder_checkpoint(model, only_query_model=False, assert len(model) == 1 model[0].load_state_dict(ret_state_dict) torch.distributed.barrier() if mpu.get_data_parallel_rank() == 0: Loading
megatron/indexer.py +5 −34 Original line number Diff line number Diff line Loading @@ -26,13 +26,10 @@ class IndexBuilder(object): self.evidence_embedder_obj = None self.biencoder_shared_query_context_model = \ args.biencoder_shared_query_context_model #self.pre_process = True #self.post_process = True # need to know whether we're using a REALM checkpoint (args.load) # or ICT checkpoint assert not (args.load and args.ict_load) #self.using_realm_chkpt = args.ict_load is None self.log_interval = args.indexer_log_interval self.batch_size = args.indexer_batch_size Loading @@ -46,24 +43,13 @@ class IndexBuilder(object): """ Load the necessary attributes: model, dataloader and empty BlockData """ #args = get_args() only_context_model = True if self.biencoder_shared_query_context_model: only_context_model = False #args.only_context_model = only_context_model #args.only_query_model = False #model = get_model(biencoder_model_provider) model = get_model(get_model_provider(only_context_model=only_context_model, biencoder_shared_query_context_model=self.biencoder_shared_query_context_model)) #model = get_model(lambda: biencoder_model_provider(only_context_model \ #model = get_model(lambda: biencoder_model_provider(only_context_model \ # = only_context_model, biencoder_shared_query_context_model = \ # self.biencoder_shared_query_context_model, # pre_process=True, post_process=True) model = get_model(get_model_provider(only_context_model=\ only_context_model, biencoder_shared_query_context_model=\ self.biencoder_shared_query_context_model)) self.model = load_biencoder_checkpoint(model, only_context_model=only_context_model) Loading Loading @@ -103,12 +89,7 @@ class IndexBuilder(object): while not hasattr(unwrapped_model, 'embed_text'): unwrapped_model = unwrapped_model.module #counter = 0 #start_time = time.time() #cur_time = start_time while True: #start_time = time.time() #t1 = time.time() try: # batch also has query_tokens and query_pad_data row_id, context_tokens, context_mask, context_types, \ Loading @@ -117,8 +98,6 @@ class IndexBuilder(object): except (StopIteration, IndexError): break #print_rank_0("get batch time {}".format(cur_time - time.time())) #t2 = time.time() # TODO: can we add with torch.no_grad() to reduce memory usage # detach, separate fields and add to BlockData assert context_mask.dtype == torch.bool Loading @@ -128,18 +107,10 @@ class IndexBuilder(object): context_logits = detach(context_logits) row_id = detach(row_id) #print_rank_0("embed text {}".format(cur_time - time.time())) #t3 = time.time() self.evidence_embedder_obj.add_block_data(row_id, context_logits) self.track_and_report_progress(batch_size=len(row_id)) #print_rank_0("add block time {}".format(cur_time - time.time())) #t4 = time.time() #counter += 1 #if counter % 1000 == 0: # print_rank_0("total time {} 1000 iter time {}".format(time.time() - start_time, time.time() - cur_time)) # print_rank_0("breakdown batch {} model {} block {}".format(t2 - t1, t3 - t2, t4 -t3)) # cur_time = time.time() # This process signals to finalize its shard and then synchronize with # the other processes self.evidence_embedder_obj.save_shard() Loading
megatron/learning_rates.py +0 −16 Original line number Diff line number Diff line Loading @@ -18,7 +18,6 @@ import math from megatron import print_rank_0 from megatron import get_args class AnnealingLR(object): """Anneals the learning rate.""" Loading Loading @@ -60,7 +59,6 @@ class AnnealingLR(object): """Learning rate decay functions from: https://openreview.net/pdf?id=BJYwwY9ll pg. 4""" #print_rank_0("self.warmup_steps {} self.num_steps {} self.decay_steps {} self.min_lr {} self.maxlr {}".format(self.warmup_steps, self.num_steps, self.decay_steps, self.min_lr, self.max_lr)) # Use linear warmup for the initial part. if self.warmup_steps > 0 and self.num_steps <= self.warmup_steps: return self.max_lr * float(self.num_steps) / \ Loading Loading @@ -90,20 +88,6 @@ class AnnealingLR(object): raise Exception('{} decay style is not supported.'.format( self.decay_style)) args = get_args() if args.override_lr_new: mod_num_steps_ = min(self.num_steps, self.decay_steps - self.warmup_steps) mod_num_steps_ = mod_num_steps_ - self.warmup_steps use_lr = delta_lr * float(self.decay_steps - mod_num_steps_) / float(self.decay_steps) should_use_lr = self.min_lr + coeff * delta_lr print_rank_0("num_steps {} decay_steps {} decay_ratio {} coeff {} delta_lr {} use lr {} should_use_lr {} self.warmup_steps {} self.num_steps {} self.decay_steps {}".format(num_steps_, decay_steps_, decay_ratio, coeff, delta_lr, use_lr, should_use_lr, self.warmup_steps, self.num_steps, self.decay_steps)) else: use_lr = self.min_lr + coeff * delta_lr print_rank_0("num_steps {} decay_steps {} decay_ratio {} coeff {} delta_lr {} use lr {} self.warmup_steps {} self.num_steps {} self.decay_steps {}".format(num_steps_, decay_steps_, decay_ratio, coeff, delta_lr, use_lr, self.warmup_steps, self.num_steps, self.decay_steps)) return use_lr return self.min_lr + coeff * delta_lr Loading
tasks/finetune_utils.py +2 −1 Original line number Diff line number Diff line Loading @@ -114,7 +114,8 @@ def _build_infinite_size_dataloader(dataloader): iterator = dataloader.__iter__() def _build_train_valid_dataloaders(train_dataset, valid_dataset, task_collate_fn=None): def _build_train_valid_dataloaders(train_dataset, valid_dataset, task_collate_fn=None): """Traing and validation dataloaders.""" args = get_args() Loading