Commit 160ba680 authored by mohammad's avatar mohammad
Browse files

added reload model params for finetuning

parent 43529f78
Loading
Loading
Loading
Loading
+29 −2
Original line number Diff line number Diff line
@@ -76,6 +76,10 @@ class MegatronOptimizer(ABC):
    def step(self):
        pass

    @abstractmethod
    def reload_model_params(self):
        pass

    @abstractmethod
    def state_dict(self):
        pass
@@ -243,8 +247,7 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
        return found_inf_flag


    def _copy_master_params_to_model_params(self):
        # Only needed for the fp16 params.
    def _get_model_and_master_params_data_fp16(self):
        model_data = []
        master_data = []
        for model_group, master_group in zip(self.fp16_groups,
@@ -252,6 +255,12 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
            for model_param, master_param in zip(model_group, master_group):
                model_data.append(model_param.data)
                master_data.append(master_param.data)
        return model_data, master_data


    def _copy_master_params_to_model_params(self):
        # Only needed for the fp16 params.
        model_data, master_data = self._get_model_and_master_params_data_fp16()
        self._dummy_overflow_buf.fill_(0)
        # Scaling with factor `1.0` is equivalent to copy.
        multi_tensor_applier(amp_C.multi_tensor_scale,
@@ -259,6 +268,20 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
                             [master_data, model_data],
                             1.0)

    def _copy_model_params_to_master_params(self):
        # Only needed for the fp16 params.
        model_data, master_data = self._get_model_and_master_params_data_fp16()
        self._dummy_overflow_buf.fill_(0)
        # Scaling with factor `1.0` is equivalent to copy.
        multi_tensor_applier(amp_C.multi_tensor_scale,
                             self._dummy_overflow_buf,
                             [model_data, master_data],
                             1.0)


    def reload_model_params(self):
        self._copy_model_params_to_master_params()
                

    @torch.no_grad()
    def step(self):
@@ -388,6 +411,10 @@ class FP32Optimizer(MegatronOptimizer):
        return True


    def reload_model_params(self):
        pass


    def state_dict(self):
        return self.optimizer.state_dict()

+1 −2
Original line number Diff line number Diff line
@@ -256,8 +256,7 @@ def finetune(train_valid_datasets_provider, model_provider,
        args.load = original_load
        # This is critical when only model is loaded. We should make sure
        # master parameters are also updated.
        if args.fp16:
            optimizer._model_params_to_master_params()
        optimizer.reload_model_params()
    timers('pretrained checkpoint').stop()

    # Print setup timing.