Commit b77d9062 authored by mohammad's avatar mohammad
Browse files

done with refactoring and checks

parent 9a010310
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -14,6 +14,7 @@
# limitations under the License.

from apex.optimizers import FusedAdam as Adam

from megatron import get_args
from megatron.model import import_layernorm

+78 −131
Original line number Diff line number Diff line
@@ -19,13 +19,15 @@ from abc import ABC
from abc import abstractmethod

import torch
from torch._six import inf

from apex.multi_tensor_apply import multi_tensor_applier
import amp_C

from megatron import get_timers
from megatron import mpu
from megatron import print_rank_0

from .clip_grads import clip_grad_norm_fp32


def _zero_grad_group_helper(group, set_to_none):
@@ -43,95 +45,6 @@ def _zero_grad_group_helper(group, set_to_none):
                param.grad.zero_()


def _clip_grad_norm(parameters, max_norm, norm_type=2):
    """Clips gradient norm of an iterable of parameters.

    This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
    added functionality to handle model parallel parameters. Note that
    the gradients are modified in place.

    Arguments:
        parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
            single Tensor that will have gradients normalized
        max_norm (float or int): max norm of the gradients
        norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
            infinity norm.

    Returns:
        Total norm of the parameters (viewed as a single vector).
    """

    if isinstance(parameters, torch.Tensor):
        parameters = [parameters]

    # Filter parameters based on:
    #   - grad should not be none
    #   - parameter should not be shared
    #   - should not be a replica due to tensor model parallelism
    grads = []
    grads_for_norm = []
    for param in parameters:
        # Make sure the grads are in fp32
        assert param.grad.type() == 'torch.cuda.FloatTensor'
        grad_not_none = param.grad is not None
        is_not_shared = not hasattr(param, 'shared') or not param.shared
        is_not_tp_duplicate = param.tensor_model_parallel or \
                              (mpu.get_tensor_model_parallel_rank() == 0)
        grad = param.grad.detach()
        if grad_not_none:
            grads.append(grad)
        if grad_not_none and is_not_shared and is_not_tp_duplicate:
            grads_for_norm.append(grad)

    # Norm parameters.
    max_norm = float(max_norm)
    norm_type = float(norm_type)
    total_norm = 0.0

    # Calculate norm.
    if norm_type == inf:
        total_norm = max(grad.abs().max() for grad in grads_for_norm)
        total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
        # Take max across all model-parallel GPUs.
        torch.distributed.all_reduce(total_norm_cuda,
                                     op=torch.distributed.ReduceOp.MAX,
                                     group=mpu.get_model_parallel_group())
        total_norm = total_norm_cuda[0].item()

    else:
        if norm_type == 2.0:
            dummy_overflow_buf = torch.cuda.IntTensor([0])
            grad_norm, _ = multi_tensor_applier(
                amp_C.multi_tensor_l2norm,
                dummy_overflow_buf,
                [grads_for_norm],
                False # no per-parameter norm
            )
            total_norm = grad_norm ** norm_type

        else:
            for grad in grads_for_norm:
                grad_norm = torch.norm(grad, norm_type)
                total_norm += grad_norm ** norm_type

        # Sum across all model-parallel GPUs.
        torch.distributed.all_reduce(total_norm,
                                     op=torch.distributed.ReduceOp.SUM,
                                     group=mpu.get_model_parallel_group())
        total_norm = total_norm.item() ** (1.0 / norm_type)

    # Scale.
    clip_coeff = max_norm / (total_norm + 1.0e-6)
    if clip_coeff < 1.0:
        dummy_overflow_buf = torch.cuda.IntTensor([0])
        multi_tensor_applier(amp_C.multi_tensor_scale,
                             dummy_overflow_buf,
                             [grads, grads],
                             clip_coeff)

    return total_norm



class MegatronOptimizer(ABC):

@@ -145,7 +58,7 @@ class MegatronOptimizer(ABC):
        for param_group in self.optimizer.param_groups:
            for param in param_group['params']:
                params.append(param)
        _clip_grad_norm(params, clip_grad)
        clip_grad_norm_fp32(params, clip_grad)

    @abstractmethod
    def zero_grad(self, set_to_none=True):
@@ -283,16 +196,7 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
        return self.grad_scaler.scale


    @torch.no_grad()
    def step(self):

        timers = get_timers()

        # ==================================================
        # Copy gradients from model params to master params.
        # ==================================================

        timers('optimizer-copy-to-master-grad').start()
    def _copy_model_grads_to_master_grads(self):
        # This only needs to be done for the fp16 group.
        model_grads = []
        master_grads = []
@@ -302,26 +206,28 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
                if model_param.grad is not None:
                    if master_param.grad is None:
                        master_param.grad = torch.empty_like(master_param)
                    model_grads.append(model_param.grad)
                    master_grads.append(master_param.grad)
                    model_grads.append(model_param.grad.data)
                    master_grads.append(master_param.grad.data)
        self._dummy_overflow_buf.fill_(0)
        # Scaling with factor `1.0` is equivalent to copy.
        multi_tensor_applier(amp_C.multi_tensor_scale,
                             self._dummy_overflow_buf,
                             [model_grads, master_grads],
                             1.0)
        timers('optimizer-copy-to-master-grad').stop()

        # ==============================
        # Unscale and check for inf/nan.
        # ==============================

        timers('optimizer-unscale-and-check-inf').start()
    def _unscale_master_grads_and_check_for_nan(self):
        master_grads = []
        # fp32 params fromm fp16 ones.
        for master_group in self.fp32_from_fp16_groups:
            for master_param in master_group:
                if master_param.grad is not None:
                    master_grads.append(master_param.grad.data)
        # Append fp32 parameters.
        for master_group in self.fp32_from_fp32_groups:
            for master_param in master_group:
                if master_param.grad is not None:
                    master_grads.append(master_param.grad)
                    master_grads.append(master_param.grad.data)
        # Reset found inf.
        self.found_inf.fill_(0.0)
        # Unscale and set found inf/nan
@@ -331,13 +237,52 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
        torch.distributed.all_reduce(self.found_inf,
                                     op=torch.distributed.ReduceOp.MAX,
                                     group=mpu.get_model_parallel_group())

        # Check for nan.
        found_inf_flag = (self.found_inf.item() > 0)
        return found_inf_flag


    def _copy_master_params_to_model_params(self):
        # Only needed for the fp16 params.
        model_data = []
        master_data = []
        for model_group, master_group in zip(self.fp16_groups,
                                             self.fp32_from_fp16_groups):
            for model_param, master_param in zip(model_group, master_group):
                model_data.append(model_param.data)
                master_data.append(master_param.data)
        self._dummy_overflow_buf.fill_(0)
        # Scaling with factor `1.0` is equivalent to copy.
        multi_tensor_applier(amp_C.multi_tensor_scale,
                             self._dummy_overflow_buf,
                             [master_data, model_data],
                             1.0)


    @torch.no_grad()
    def step(self):

        timers = get_timers()

        # ==================================================
        # Copy gradients from model params to master params.
        # ==================================================
        timers('optimizer-copy-to-master-grad').start()
        self._copy_model_grads_to_master_grads()
        timers('optimizer-copy-to-master-grad').stop()

        # ==============================
        # Unscale and check for inf/nan.
        # ==============================
        timers('optimizer-unscale-and-check-inf').start()
        found_inf_flag = self._unscale_master_grads_and_check_for_nan()
        timers('optimizer-unscale-and-check-inf').stop()

        # ==================================
        # We are done with scaling gradients
        # so we can update the loss scale.
        # ==================================
        found_inf_flag = (self.found_inf.item() > 0)
        self.grad_scaler.update(found_inf_flag)

        # =====================================
@@ -349,7 +294,6 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
        # ==========================
        # Clip the master gradients.
        # ==========================

        timers('optimizer-clip-master-grad').start()
        self.clip_grad_norm(self.clip_grad)
        timers('optimizer-clip-master-grad').stop()
@@ -357,30 +301,18 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
        # ===================
        # Step the optimizer.
        # ===================

        self.optimizer.step()

        # =================================
        # Update params from master params.
        # =================================

        timers('optimizer-copy-master-to-model-params').start()
        # Only needed for the fp16 params.
        model_data = []
        master_data = []
        for model_group, master_group in zip(self.fp16_groups,
                                             self.fp32_from_fp16_groups):
            for model_param, master_param in zip(model_group, master_group):
                model_data.append(model_param.data)
                master_data.append(master_param.data)
        self._dummy_overflow_buf.fill_(0)
        # Scaling with factor `1.0` is equivalent to copy.
        multi_tensor_applier(amp_C.multi_tensor_scale,
                             self._dummy_overflow_buf,
                             [master_data, model_data],
                             1.0)
        self._copy_master_params_to_model_params()
        timers('optimizer-copy-master-to-model-params').stop()

        # ==================
        # Successful update.
        # ==================
        return True


@@ -393,13 +325,28 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):


    def load_state_dict(self, state_dict):
        # Defer to the class to load.
        self.optimizer.load_state_dict(state_dict['optimizer'])
        # Optimizer.
        optimizer_key = 'optimizer'
        if optimizer_key not in state_dict:
            optimizer_key = 'optimizer_state_dict'
            print_rank_0('***WARNING*** loading optimizer from '
                         'an old checkpoint ...')
        self.optimizer.load_state_dict(state_dict[optimizer_key])

        # Grad scaler.
        if 'grad_scaler' not in state_dict:
            print_rank_0('***WARNING*** found an old checkpoint, will not '
                         'load grad scaler ...')
        else:
            self.grad_scaler.load_state_dict(state_dict['grad_scaler'])

        # Copy data for the master params.
        fp32_from_fp16_params_key = 'fp32_from_fp16_params'
        if fp32_from_fp16_params_key not in state_dict:
            fp32_from_fp16_params_key = 'fp32_from_fp16'
        for current_group, saved_group in zip(
                self.fp32_from_fp16_groups,
                state_dict['fp32_from_fp16_params']):
                state_dict[fp32_from_fp16_params_key]):
            for current_param, saved_param in zip(current_group, saved_group):
                current_param.data.copy_(saved_param.data)