Loading megatron/optimizer/optimizer.py +29 −2 Original line number Diff line number Diff line Loading @@ -76,6 +76,10 @@ class MegatronOptimizer(ABC): def step(self): pass @abstractmethod def reload_model_params(self): pass @abstractmethod def state_dict(self): pass Loading Loading @@ -243,8 +247,7 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer): return found_inf_flag def _copy_master_params_to_model_params(self): # Only needed for the fp16 params. def _get_model_and_master_params_data_fp16(self): model_data = [] master_data = [] for model_group, master_group in zip(self.fp16_groups, Loading @@ -252,6 +255,12 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer): for model_param, master_param in zip(model_group, master_group): model_data.append(model_param.data) master_data.append(master_param.data) return model_data, master_data def _copy_master_params_to_model_params(self): # Only needed for the fp16 params. model_data, master_data = self._get_model_and_master_params_data_fp16() self._dummy_overflow_buf.fill_(0) # Scaling with factor `1.0` is equivalent to copy. multi_tensor_applier(amp_C.multi_tensor_scale, Loading @@ -259,6 +268,20 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer): [master_data, model_data], 1.0) def _copy_model_params_to_master_params(self): # Only needed for the fp16 params. model_data, master_data = self._get_model_and_master_params_data_fp16() self._dummy_overflow_buf.fill_(0) # Scaling with factor `1.0` is equivalent to copy. multi_tensor_applier(amp_C.multi_tensor_scale, self._dummy_overflow_buf, [model_data, master_data], 1.0) def reload_model_params(self): self._copy_model_params_to_master_params() @torch.no_grad() def step(self): Loading Loading @@ -388,6 +411,10 @@ class FP32Optimizer(MegatronOptimizer): return True def reload_model_params(self): pass def state_dict(self): return self.optimizer.state_dict() Loading tasks/finetune_utils.py +1 −2 Original line number Diff line number Diff line Loading @@ -256,8 +256,7 @@ def finetune(train_valid_datasets_provider, model_provider, args.load = original_load # This is critical when only model is loaded. We should make sure # master parameters are also updated. if args.fp16: optimizer._model_params_to_master_params() optimizer.reload_model_params() timers('pretrained checkpoint').stop() # Print setup timing. Loading Loading
megatron/optimizer/optimizer.py +29 −2 Original line number Diff line number Diff line Loading @@ -76,6 +76,10 @@ class MegatronOptimizer(ABC): def step(self): pass @abstractmethod def reload_model_params(self): pass @abstractmethod def state_dict(self): pass Loading Loading @@ -243,8 +247,7 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer): return found_inf_flag def _copy_master_params_to_model_params(self): # Only needed for the fp16 params. def _get_model_and_master_params_data_fp16(self): model_data = [] master_data = [] for model_group, master_group in zip(self.fp16_groups, Loading @@ -252,6 +255,12 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer): for model_param, master_param in zip(model_group, master_group): model_data.append(model_param.data) master_data.append(master_param.data) return model_data, master_data def _copy_master_params_to_model_params(self): # Only needed for the fp16 params. model_data, master_data = self._get_model_and_master_params_data_fp16() self._dummy_overflow_buf.fill_(0) # Scaling with factor `1.0` is equivalent to copy. multi_tensor_applier(amp_C.multi_tensor_scale, Loading @@ -259,6 +268,20 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer): [master_data, model_data], 1.0) def _copy_model_params_to_master_params(self): # Only needed for the fp16 params. model_data, master_data = self._get_model_and_master_params_data_fp16() self._dummy_overflow_buf.fill_(0) # Scaling with factor `1.0` is equivalent to copy. multi_tensor_applier(amp_C.multi_tensor_scale, self._dummy_overflow_buf, [model_data, master_data], 1.0) def reload_model_params(self): self._copy_model_params_to_master_params() @torch.no_grad() def step(self): Loading Loading @@ -388,6 +411,10 @@ class FP32Optimizer(MegatronOptimizer): return True def reload_model_params(self): pass def state_dict(self): return self.optimizer.state_dict() Loading
tasks/finetune_utils.py +1 −2 Original line number Diff line number Diff line Loading @@ -256,8 +256,7 @@ def finetune(train_valid_datasets_provider, model_provider, args.load = original_load # This is critical when only model is loaded. We should make sure # master parameters are also updated. if args.fp16: optimizer._model_params_to_master_params() optimizer.reload_model_params() timers('pretrained checkpoint').stop() # Print setup timing. Loading