Commit 7381754e authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'optimize_refactor_reload_model_params' into 'optimize_refactor'

Optimize refactor reload model params

See merge request ADLR/megatron-lm!199
parents 43529f78 9eedf896
Loading
Loading
Loading
Loading
+88 −60
Original line number Diff line number Diff line
@@ -45,6 +45,18 @@ def _zero_grad_group_helper(group, set_to_none):
                param.grad.zero_()


def _multi_tensor_copy_this_to_that(this, that, overflow_buf=None):
    """Use multi-tensor-applier to copy values from one list to another."""
    if overflow_buf:
        overflow_buf.fill_(0)
    else:
        overflow_buf = torch.cuda.IntTensor([0])
    # Scaling with factor `1.0` is equivalent to copy.
    multi_tensor_applier(amp_C.multi_tensor_scale,
                         overflow_buf,
                         [this, that],
                         1.0)


class MegatronOptimizer(ABC):

@@ -76,6 +88,10 @@ class MegatronOptimizer(ABC):
    def step(self):
        pass

    @abstractmethod
    def reload_model_params(self):
        pass

    @abstractmethod
    def state_dict(self):
        pass
@@ -123,7 +139,7 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
        self._dummy_overflow_buf = torch.cuda.IntTensor([0])

        # ======================
        # master parameter stuff
        # main parameter stuff
        # ======================

        # Three groups of parameters:
@@ -147,20 +163,20 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
                    if param.type() == 'torch.cuda.HalfTensor':
                        fp16_params_this_group.append(param)
                        # Create a copy
                        master_param = param.detach().clone().float()
                        main_param = param.detach().clone().float()
                        # Store grads
                        master_param.requires_grad = True
                        main_param.requires_grad = True
                        # Copy tensor model parallel attributes.
                        mpu.copy_tensor_model_parallel_attributes(master_param,
                        mpu.copy_tensor_model_parallel_attributes(main_param,
                                                                  param)
                        if hasattr(param, 'shared'):
                            master_param.shared = param.shared
                            main_param.shared = param.shared
                        # Replace the optimizer params with the new fp32 copy.
                        param_group['params'][i] = master_param
                        fp32_from_fp16_params_this_group.append(master_param)
                        # Reset existing state dict key to the new master param.
                        param_group['params'][i] = main_param
                        fp32_from_fp16_params_this_group.append(main_param)
                        # Reset existing state dict key to the new main param.
                        if param in self.optimizer.state:
                            self.optimizer.state[master_param] \
                            self.optimizer.state[main_param] \
                                = self.optimizer.state.pop(param)

                    # fp32 params.
@@ -196,43 +212,39 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
        return self.grad_scaler.scale


    def _copy_model_grads_to_master_grads(self):
    def _copy_model_grads_to_main_grads(self):
        # This only needs to be done for the fp16 group.
        model_grads = []
        master_grads = []
        for model_group, master_group in zip(self.fp16_groups,
        main_grads = []
        for model_group, main_group in zip(self.fp16_groups,
                                           self.fp32_from_fp16_groups):
            for model_param, master_param in zip(model_group, master_group):
            for model_param, main_param in zip(model_group, main_group):
                if model_param.grad is not None:
                    if master_param.grad is None:
                        master_param.grad = torch.empty_like(master_param)
                    if main_param.grad is None:
                        main_param.grad = torch.empty_like(main_param)
                    model_grads.append(model_param.grad.data)
                    master_grads.append(master_param.grad.data)
        self._dummy_overflow_buf.fill_(0)
        # Scaling with factor `1.0` is equivalent to copy.
        multi_tensor_applier(amp_C.multi_tensor_scale,
                             self._dummy_overflow_buf,
                             [model_grads, master_grads],
                             1.0)
                    main_grads.append(main_param.grad.data)
        _multi_tensor_copy_this_to_that(this=model_grads, that=main_grads,
                                        overflow_buf=self._dummy_overflow_buf)


    def _unscale_master_grads_and_check_for_nan(self):
        master_grads = []
    def _unscale_main_grads_and_check_for_nan(self):
        main_grads = []
        # fp32 params fromm fp16 ones.
        for master_group in self.fp32_from_fp16_groups:
            for master_param in master_group:
                if master_param.grad is not None:
                    master_grads.append(master_param.grad.data)
        for main_group in self.fp32_from_fp16_groups:
            for main_param in main_group:
                if main_param.grad is not None:
                    main_grads.append(main_param.grad.data)
        # Append fp32 parameters.
        for master_group in self.fp32_from_fp32_groups:
            for master_param in master_group:
                if master_param.grad is not None:
                    master_grads.append(master_param.grad.data)
        for main_group in self.fp32_from_fp32_groups:
            for main_param in main_group:
                if main_param.grad is not None:
                    main_grads.append(main_param.grad.data)
        # Reset found inf.
        self.found_inf.fill_(0.0)
        # Unscale and set found inf/nan
        torch._amp_foreach_non_finite_check_and_unscale_(
            master_grads, self.found_inf, self.grad_scaler.inv_scale)
            main_grads, self.found_inf, self.grad_scaler.inv_scale)
        # Update across all model parallel instances.
        torch.distributed.all_reduce(self.found_inf,
                                     op=torch.distributed.ReduceOp.MAX,
@@ -243,21 +255,33 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
        return found_inf_flag


    def _copy_master_params_to_model_params(self):
        # Only needed for the fp16 params.
    def _get_model_and_main_params_data_fp16(self):
        model_data = []
        master_data = []
        for model_group, master_group in zip(self.fp16_groups,
        main_data = []
        for model_group, main_group in zip(self.fp16_groups,
                                           self.fp32_from_fp16_groups):
            for model_param, master_param in zip(model_group, master_group):
            for model_param, main_param in zip(model_group, main_group):
                model_data.append(model_param.data)
                master_data.append(master_param.data)
        self._dummy_overflow_buf.fill_(0)
        # Scaling with factor `1.0` is equivalent to copy.
        multi_tensor_applier(amp_C.multi_tensor_scale,
                             self._dummy_overflow_buf,
                             [master_data, model_data],
                             1.0)
                main_data.append(main_param.data)
        return model_data, main_data


    def _copy_main_params_to_model_params(self):
        # Only needed for the fp16 params.
        model_data, main_data = self._get_model_and_main_params_data_fp16()
        _multi_tensor_copy_this_to_that(this=main_data, that=model_data,
                                        overflow_buf=self._dummy_overflow_buf)


    def _copy_model_params_to_main_params(self):
        # Only needed for the fp16 params.
        model_data, main_data = self._get_model_and_main_params_data_fp16()
        _multi_tensor_copy_this_to_that(this=model_data, that=main_data,
                                        overflow_buf=self._dummy_overflow_buf)


    def reload_model_params(self):
        self._copy_model_params_to_main_params()


    @torch.no_grad()
@@ -266,17 +290,17 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
        timers = get_timers()

        # ==================================================
        # Copy gradients from model params to master params.
        # Copy gradients from model params to main params.
        # ==================================================
        timers('optimizer-copy-to-master-grad').start()
        self._copy_model_grads_to_master_grads()
        timers('optimizer-copy-to-master-grad').stop()
        timers('optimizer-copy-to-main-grad').start()
        self._copy_model_grads_to_main_grads()
        timers('optimizer-copy-to-main-grad').stop()

        # ==============================
        # Unscale and check for inf/nan.
        # ==============================
        timers('optimizer-unscale-and-check-inf').start()
        found_inf_flag = self._unscale_master_grads_and_check_for_nan()
        found_inf_flag = self._unscale_main_grads_and_check_for_nan()
        timers('optimizer-unscale-and-check-inf').stop()

        # ==================================
@@ -292,11 +316,11 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
            return False

        # ==========================
        # Clip the master gradients.
        # Clip the main gradients.
        # ==========================
        timers('optimizer-clip-master-grad').start()
        timers('optimizer-clip-main-grad').start()
        self.clip_grad_norm(self.clip_grad)
        timers('optimizer-clip-master-grad').stop()
        timers('optimizer-clip-main-grad').stop()

        # ===================
        # Step the optimizer.
@@ -304,11 +328,11 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
        self.optimizer.step()

        # =================================
        # Update params from master params.
        # Update params from main params.
        # =================================
        timers('optimizer-copy-master-to-model-params').start()
        self._copy_master_params_to_model_params()
        timers('optimizer-copy-master-to-model-params').stop()
        timers('optimizer-copy-main-to-model-params').start()
        self._copy_main_params_to_model_params()
        timers('optimizer-copy-main-to-model-params').stop()

        # ==================
        # Successful update.
@@ -340,7 +364,7 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
        else:
            self.grad_scaler.load_state_dict(state_dict['grad_scaler'])

        # Copy data for the master params.
        # Copy data for the main params.
        fp32_from_fp16_params_key = 'fp32_from_fp16_params'
        if fp32_from_fp16_params_key not in state_dict:
            fp32_from_fp16_params_key = 'fp32_from_fp16'
@@ -388,6 +412,10 @@ class FP32Optimizer(MegatronOptimizer):
        return True


    def reload_model_params(self):
        pass


    def state_dict(self):
        return self.optimizer.state_dict()

+3 −3
Original line number Diff line number Diff line
@@ -677,10 +677,10 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
    add_to_logging('backward-send-forward-recv')
    add_to_logging('backward-params-all-reduce')
    add_to_logging('backward-embedding-all-reduce')
    add_to_logging('optimizer-copy-to-master-grad')
    add_to_logging('optimizer-copy-to-main-grad')
    add_to_logging('optimizer-unscale-and-check-inf')
    add_to_logging('optimizer-clip-master-grad')
    add_to_logging('optimizer-copy-master-to-model-params')
    add_to_logging('optimizer-clip-main-grad')
    add_to_logging('optimizer-copy-main-to-model-params')
    add_to_logging('optimizer')
    add_to_logging('batch-generator')

+2 −3
Original line number Diff line number Diff line
@@ -255,9 +255,8 @@ def finetune(train_valid_datasets_provider, model_provider,
        _ = load_checkpoint(model, None, None)
        args.load = original_load
        # This is critical when only model is loaded. We should make sure
        # master parameters are also updated.
        if args.fp16:
            optimizer._model_params_to_master_params()
        # main parameters are also updated.
        optimizer.reload_model_params()
    timers('pretrained checkpoint').stop()

    # Print setup timing.