Loading megatron/training.py +6 −8 Original line number Diff line number Diff line Loading @@ -317,12 +317,10 @@ def setup_model_and_optimizer(model_provider_func): assert args.DDP_impl == 'local' # get model without FP16 and/or TorchDDP wrappers model = unwrap_model(model) for module in model: if args.iteration == 0 and hasattr(module, 'init_state_dict_from_bert'): print("Initializing ICT from pretrained BERT model", flush=True) module.init_state_dict_from_bert() if args.iteration == 0 and len(unwrapped_model) == 1 \ and hasattr(unwrapped_model[0], 'init_state_dict_from_bert'): print_rank_0("Initializing ICT from pretrained BERT model") unwrapped_model[0].init_state_dict_from_bert() if args.fp16: optimizer.reload_model_params() Loading Loading
megatron/training.py +6 −8 Original line number Diff line number Diff line Loading @@ -317,12 +317,10 @@ def setup_model_and_optimizer(model_provider_func): assert args.DDP_impl == 'local' # get model without FP16 and/or TorchDDP wrappers model = unwrap_model(model) for module in model: if args.iteration == 0 and hasattr(module, 'init_state_dict_from_bert'): print("Initializing ICT from pretrained BERT model", flush=True) module.init_state_dict_from_bert() if args.iteration == 0 and len(unwrapped_model) == 1 \ and hasattr(unwrapped_model[0], 'init_state_dict_from_bert'): print_rank_0("Initializing ICT from pretrained BERT model") unwrapped_model[0].init_state_dict_from_bert() if args.fp16: optimizer.reload_model_params() Loading