Commit dfd428cb authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'main_retriver_merge' into 'main'

ICT Retriever

See merge request ADLR/megatron-lm!235
parents 68233932 3f3ba5e7
Loading
Loading
Loading
Loading
+6 −8
Original line number Diff line number Diff line
@@ -317,12 +317,10 @@ def setup_model_and_optimizer(model_provider_func):
        assert args.DDP_impl == 'local'

    # get model without FP16 and/or TorchDDP wrappers
    model = unwrap_model(model)
    for module in model:
        if args.iteration == 0 and hasattr(module,
                                           'init_state_dict_from_bert'):
            print("Initializing ICT from pretrained BERT model", flush=True)
            module.init_state_dict_from_bert()
    if args.iteration == 0 and len(unwrapped_model) == 1 \
        and hasattr(unwrapped_model[0], 'init_state_dict_from_bert'):
        print_rank_0("Initializing ICT from pretrained BERT model")
        unwrapped_model[0].init_state_dict_from_bert()
        if args.fp16:
            optimizer.reload_model_params()