Loading megatron/training.py +15 −6 Original line number Diff line number Diff line Loading @@ -320,6 +320,8 @@ def setup_model_and_optimizer(model_provider_func): 'init_state_dict_from_bert'): print("Initializing ICT from pretrained BERT model", flush=True) unwrapped_model.init_state_dict_from_bert() if args.fp16: optimizer._model_params_to_master_params() return model, optimizer, lr_scheduler Loading Loading @@ -646,6 +648,7 @@ def train_step(forward_step_func, data_iterator, if args.fp16: optimizer.update_master_grads() timers('backward-master-grad').stop() grad_norm_local = None # Clipping gradients helps prevent the exploding gradient. timers('backward-clip-grad').start() Loading @@ -660,16 +663,16 @@ def train_step(forward_step_func, data_iterator, mpu.clip_grad_norm(parameters, args.clip_grad, parameter_names=parameter_names) else: optimizer.clip_master_grads(args.clip_grad) grad_norm_local = optimizer.clip_master_grads(args.clip_grad) timers('backward-clip-grad').stop() #print_rank_0("print-grad_norm_local {}".format(grad_norm_local)) #print_rank_0("after backward") #print_grads(model) print_model(model) print_rank_0(params_global_norm(model)) print_rank_0(params_grad_norm(model)) #print_model(model) #print_rank_0(params_global_norm(model)) #print_rank_0(params_grad_norm(model)) # Update parameters. timers('optimizer').start() Loading @@ -678,10 +681,12 @@ def train_step(forward_step_func, data_iterator, #print_rank_0("after optimizer") #print_model(model) print_rank_0(params_global_norm(model)) #print_rank_0(params_global_norm(model)) #print_rank_0(params_grad_norm(model)) #sys.exit() #print_rank_0("print-optimizer.overflow {}".format(optimizer.overflow)) # Update learning rate. skipped_iter = 0 if not (args.fp16 and optimizer.overflow): Loading Loading @@ -856,6 +861,10 @@ def train(forward_step_func, model, optimizer, lr_scheduler, # Iterations. iteration = args.iteration #print_rank_0("Check betas before iterations") #for group in optimizer.optimizer.param_groups: # print_rank_0("betas {} lr {} weight_decay {} eps {}".format(group['betas'], group['lr'], group['weight_decay'], group['eps'])) timers('interval time').start() print_datetime('before the start of training step') report_memory_flag = True Loading pretrain_ict.py +6 −6 Original line number Diff line number Diff line Loading @@ -109,13 +109,13 @@ def forward_step(data_iterator, model, input_tensor): micro_batch_size = query_logits.shape[0] # recall we assert that tensor_model_parallel_size == 1 #global_batch_size = dist.get_world_size() * micro_batch_size #all_query_logits = AllgatherFromDataParallelRegion.apply(query_logits) #all_context_logits = AllgatherFromDataParallelRegion.apply(context_logits) global_batch_size = dist.get_world_size() * micro_batch_size all_query_logits = AllgatherFromDataParallelRegion.apply(query_logits) all_context_logits = AllgatherFromDataParallelRegion.apply(context_logits) global_batch_size = micro_batch_size all_query_logits = query_logits all_context_logits = context_logits #global_batch_size = micro_batch_size #all_query_logits = query_logits #all_context_logits = context_logits # scores are inner products between query and context embeddings retrieval_scores = torch.matmul(all_query_logits, Loading Loading
megatron/training.py +15 −6 Original line number Diff line number Diff line Loading @@ -320,6 +320,8 @@ def setup_model_and_optimizer(model_provider_func): 'init_state_dict_from_bert'): print("Initializing ICT from pretrained BERT model", flush=True) unwrapped_model.init_state_dict_from_bert() if args.fp16: optimizer._model_params_to_master_params() return model, optimizer, lr_scheduler Loading Loading @@ -646,6 +648,7 @@ def train_step(forward_step_func, data_iterator, if args.fp16: optimizer.update_master_grads() timers('backward-master-grad').stop() grad_norm_local = None # Clipping gradients helps prevent the exploding gradient. timers('backward-clip-grad').start() Loading @@ -660,16 +663,16 @@ def train_step(forward_step_func, data_iterator, mpu.clip_grad_norm(parameters, args.clip_grad, parameter_names=parameter_names) else: optimizer.clip_master_grads(args.clip_grad) grad_norm_local = optimizer.clip_master_grads(args.clip_grad) timers('backward-clip-grad').stop() #print_rank_0("print-grad_norm_local {}".format(grad_norm_local)) #print_rank_0("after backward") #print_grads(model) print_model(model) print_rank_0(params_global_norm(model)) print_rank_0(params_grad_norm(model)) #print_model(model) #print_rank_0(params_global_norm(model)) #print_rank_0(params_grad_norm(model)) # Update parameters. timers('optimizer').start() Loading @@ -678,10 +681,12 @@ def train_step(forward_step_func, data_iterator, #print_rank_0("after optimizer") #print_model(model) print_rank_0(params_global_norm(model)) #print_rank_0(params_global_norm(model)) #print_rank_0(params_grad_norm(model)) #sys.exit() #print_rank_0("print-optimizer.overflow {}".format(optimizer.overflow)) # Update learning rate. skipped_iter = 0 if not (args.fp16 and optimizer.overflow): Loading Loading @@ -856,6 +861,10 @@ def train(forward_step_func, model, optimizer, lr_scheduler, # Iterations. iteration = args.iteration #print_rank_0("Check betas before iterations") #for group in optimizer.optimizer.param_groups: # print_rank_0("betas {} lr {} weight_decay {} eps {}".format(group['betas'], group['lr'], group['weight_decay'], group['eps'])) timers('interval time').start() print_datetime('before the start of training step') report_memory_flag = True Loading
pretrain_ict.py +6 −6 Original line number Diff line number Diff line Loading @@ -109,13 +109,13 @@ def forward_step(data_iterator, model, input_tensor): micro_batch_size = query_logits.shape[0] # recall we assert that tensor_model_parallel_size == 1 #global_batch_size = dist.get_world_size() * micro_batch_size #all_query_logits = AllgatherFromDataParallelRegion.apply(query_logits) #all_context_logits = AllgatherFromDataParallelRegion.apply(context_logits) global_batch_size = dist.get_world_size() * micro_batch_size all_query_logits = AllgatherFromDataParallelRegion.apply(query_logits) all_context_logits = AllgatherFromDataParallelRegion.apply(context_logits) global_batch_size = micro_batch_size all_query_logits = query_logits all_context_logits = context_logits #global_batch_size = micro_batch_size #all_query_logits = query_logits #all_context_logits = context_logits # scores are inner products between query and context embeddings retrieval_scores = torch.matmul(all_query_logits, Loading