Commit bfc20ecf authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

fixed isse from Initializing ICT from pretrained BERT model

parent 0295bb89
Loading
Loading
Loading
Loading
+15 −6
Original line number Diff line number Diff line
@@ -320,6 +320,8 @@ def setup_model_and_optimizer(model_provider_func):
                                       'init_state_dict_from_bert'):
        print("Initializing ICT from pretrained BERT model", flush=True)
        unwrapped_model.init_state_dict_from_bert()
        if args.fp16:
            optimizer._model_params_to_master_params()

    return model, optimizer, lr_scheduler

@@ -646,6 +648,7 @@ def train_step(forward_step_func, data_iterator,
    if args.fp16:
        optimizer.update_master_grads()
    timers('backward-master-grad').stop()
    grad_norm_local = None

    # Clipping gradients helps prevent the exploding gradient.
    timers('backward-clip-grad').start()
@@ -660,16 +663,16 @@ def train_step(forward_step_func, data_iterator,
            mpu.clip_grad_norm(parameters, args.clip_grad,
                               parameter_names=parameter_names)
        else:
            optimizer.clip_master_grads(args.clip_grad)
            grad_norm_local = optimizer.clip_master_grads(args.clip_grad)
    timers('backward-clip-grad').stop()

    #print_rank_0("print-grad_norm_local {}".format(grad_norm_local))
    
    #print_rank_0("after backward")
    #print_grads(model)
    print_model(model)
    print_rank_0(params_global_norm(model))
    print_rank_0(params_grad_norm(model))

    #print_model(model)
    #print_rank_0(params_global_norm(model))
    #print_rank_0(params_grad_norm(model))

    # Update parameters.
    timers('optimizer').start()
@@ -678,10 +681,12 @@ def train_step(forward_step_func, data_iterator,

    #print_rank_0("after optimizer")
    #print_model(model)
    print_rank_0(params_global_norm(model))
    #print_rank_0(params_global_norm(model))
    #print_rank_0(params_grad_norm(model))
    #sys.exit()
    
    #print_rank_0("print-optimizer.overflow {}".format(optimizer.overflow))

    # Update learning rate.
    skipped_iter = 0
    if not (args.fp16 and optimizer.overflow):
@@ -856,6 +861,10 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
    # Iterations.
    iteration = args.iteration

    #print_rank_0("Check betas before iterations")
    #for group in optimizer.optimizer.param_groups:
    #    print_rank_0("betas {} lr {} weight_decay {} eps {}".format(group['betas'], group['lr'], group['weight_decay'], group['eps']))

    timers('interval time').start()
    print_datetime('before the start of training step')
    report_memory_flag = True
+6 −6
Original line number Diff line number Diff line
@@ -109,13 +109,13 @@ def forward_step(data_iterator, model, input_tensor):

    micro_batch_size = query_logits.shape[0]
    # recall we assert that tensor_model_parallel_size == 1
    #global_batch_size = dist.get_world_size() * micro_batch_size
    #all_query_logits = AllgatherFromDataParallelRegion.apply(query_logits)
    #all_context_logits = AllgatherFromDataParallelRegion.apply(context_logits)
    global_batch_size = dist.get_world_size() * micro_batch_size
    all_query_logits = AllgatherFromDataParallelRegion.apply(query_logits)
    all_context_logits = AllgatherFromDataParallelRegion.apply(context_logits)
    
    global_batch_size = micro_batch_size
    all_query_logits = query_logits
    all_context_logits = context_logits
    #global_batch_size = micro_batch_size
    #all_query_logits = query_logits
    #all_context_logits = context_logits

    # scores are inner products between query and context embeddings
    retrieval_scores = torch.matmul(all_query_logits,