Commit 3d96729a authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

fixed an issue with ICT

parent 816fb890
Loading
Loading
Loading
Loading
+7 −8
Original line number Diff line number Diff line
@@ -317,12 +317,11 @@ def setup_model_and_optimizer(model_provider_func):
        assert args.DDP_impl == 'local'

    # get model without FP16 and/or TorchDDP wrappers
    model = unwrap_model(model)
    for module in model:
        if args.iteration == 0 and hasattr(module,
    unwrapped_model = unwrap_model(model)
    if args.iteration == 0 and hasattr(unwrapped_model,
                                       'init_state_dict_from_bert'):
            print("Initializing ICT from pretrained BERT model", flush=True)
            module.init_state_dict_from_bert()
        print_rank_0("Initializing ICT from pretrained BERT model")
        unwrapped_model.init_state_dict_from_bert()
        if args.fp16:
            optimizer.reload_model_params()