Commit 6191ff59 authored by Mohammad's avatar Mohammad
Browse files

found a bug in l2 norm calculation

parent b84d7a90
Loading
Loading
Loading
Loading
+3 −4
Original line number Diff line number Diff line
@@ -76,7 +76,6 @@ def _clip_grad_norm(parameters, max_norm, norm_type=2):
                              (mpu.get_tensor_model_parallel_rank() == 0)
        if grad_not_none and is_not_shared and is_not_tp_duplicate:
            filtered_parameters.append(param)
    parameters = filtered_parameters

    # Norm parameters.
    max_norm = float(max_norm)
@@ -86,7 +85,7 @@ def _clip_grad_norm(parameters, max_norm, norm_type=2):
    # Calculate norm.
    if norm_type == inf:
        total_norm = max(param.grad.detach().abs().max()
                         for param in parameters)
                         for param in filtered_parameters)
        total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
        # Take max across all model-parallel GPUs.
        torch.distributed.all_reduce(total_norm_cuda,
@@ -95,7 +94,7 @@ def _clip_grad_norm(parameters, max_norm, norm_type=2):
        total_norm = total_norm_cuda[0].item()

    else:
        for param in parameters:
        for param in filtered_parameters:
            param_norm = torch.norm(param.grad.detach(), norm_type)
            total_norm += param_norm.item() ** norm_type
        # Sum across all model-parallel GPUs.
@@ -107,7 +106,7 @@ def _clip_grad_norm(parameters, max_norm, norm_type=2):

    # Scale.
    clip_coef = max_norm / (total_norm + 1e-6)
    if clip_coef < 1:
    if clip_coef < 1.0:
        for param in parameters:
            param.grad.detach().mul_(clip_coef)