Commit e68ea218 authored by mohammad's avatar mohammad
Browse files

further refactor, matching old results

parent b0a3fdfe
Loading
Loading
Loading
Loading
+11 −11
Original line number Diff line number Diff line
@@ -68,8 +68,8 @@ def _clip_grad_norm(parameters, max_norm, norm_type=2):
    #   - grad should not be none
    #   - parameter should not be shared
    #   - should not be a replica due to tensor model parallelism
    params_with_grads = []
    params_for_norm = []
    grads = []
    grads_for_norm = []
    for param in parameters:
        # Make sure the grads are in fp32
        assert param.grad.type() == 'torch.cuda.FloatTensor'
@@ -77,10 +77,11 @@ def _clip_grad_norm(parameters, max_norm, norm_type=2):
        is_not_shared = not hasattr(param, 'shared') or not param.shared
        is_not_tp_duplicate = param.tensor_model_parallel or \
                              (mpu.get_tensor_model_parallel_rank() == 0)
        grad = param.grad.detach()
        if grad_not_none:
            params_with_grads.append(param)
            grads.append(grad)
        if grad_not_none and is_not_shared and is_not_tp_duplicate:
            params_for_norm.append(param)
            grads_for_norm.append(grad)

    # Norm parameters.
    max_norm = float(max_norm)
@@ -89,8 +90,7 @@ def _clip_grad_norm(parameters, max_norm, norm_type=2):

    # Calculate norm.
    if norm_type == inf:
        total_norm = max(param.grad.detach().abs().max()
                         for param in params_for_norm)
        total_norm = max(grad.abs().max() for grad in grads_for_norm)
        total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
        # Take max across all model-parallel GPUs.
        torch.distributed.all_reduce(total_norm_cuda,
@@ -99,9 +99,9 @@ def _clip_grad_norm(parameters, max_norm, norm_type=2):
        total_norm = total_norm_cuda[0].item()

    else:
        for param in params_for_norm:
            param_norm = torch.norm(param.grad.detach(), norm_type)
            total_norm += param_norm.item() ** norm_type
        for grad in grads_for_norm:
            grad_norm = torch.norm(grad, norm_type)
            total_norm += grad_norm.item() ** norm_type
        # Sum across all model-parallel GPUs.
        total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
        torch.distributed.all_reduce(total_norm_cuda,
@@ -112,8 +112,8 @@ def _clip_grad_norm(parameters, max_norm, norm_type=2):
    # Scale.
    clip_coef = max_norm / (total_norm + 1.0e-6)
    if clip_coef < 1.0:
        for param in params_with_grads:
            param.grad.detach().mul_(clip_coef)
        for grad in grads:
            grad.mul_(clip_coef)

    return total_norm