Commit 3f3ba5e7 authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

addressed the comments

parent 3d96729a
Loading
Loading
Loading
Loading
+3 −4
Original line number Diff line number Diff line
@@ -317,11 +317,10 @@ def setup_model_and_optimizer(model_provider_func):
        assert args.DDP_impl == 'local'

    # get model without FP16 and/or TorchDDP wrappers
    unwrapped_model = unwrap_model(model)
    if args.iteration == 0 and hasattr(unwrapped_model,
                                       'init_state_dict_from_bert'):
    if args.iteration == 0 and len(unwrapped_model) == 1 \
        and hasattr(unwrapped_model[0], 'init_state_dict_from_bert'):
        print_rank_0("Initializing ICT from pretrained BERT model")
        unwrapped_model.init_state_dict_from_bert()
        unwrapped_model[0].init_state_dict_from_bert()
        if args.fp16:
            optimizer.reload_model_params()