Commit fb5b2b36 authored by Mohammad's avatar Mohammad
Browse files

fixed a bug in l2 grad clip

parent eb0a8bf0
Loading
Loading
Loading
Loading
+13 −9
Original line number Diff line number Diff line
@@ -32,16 +32,20 @@ def l2_grad_clipper(parameters, max_norm):
    """Efficient L2 norm gradient clipping."""

    overflow_buf = torch.zeros(1, dtype=torch.int, device='cuda')
    # Make sure we have an iterable.
    if isinstance(parameters, torch.Tensor):
        parameters = [parameters]
    mp_rank_is_zero = (get_model_parallel_rank() == 0)
    parameters = list(filter(lambda p: (p.grad is not None) and
                             (p.model_parallel or mp_rank_is_zero),
                             parameters))
    # Filter parameters with gradients.
    parameters_with_grads = list(filter(
        lambda p: p.grad is not None, parameters))
    # Filter parameters for norm calculations.
    parameters_for_norm = list(filter(
        lambda p: p.model_parallel or mp_rank_is_zero, parameters_with_grads))
    # Calculate L2 norm.
    norm, _ = multi_tensor_applier(
        amp_C.multi_tensor_l2norm,
        overflow_buf,
        [parameters],
        [parameters_for_norm],
        False # no per-parameter norm
    )
    # Sum across all model parallel GPUs.
@@ -50,10 +54,10 @@ def l2_grad_clipper(parameters, max_norm):
                                 op=torch.distributed.ReduceOp.SUM,
                                 group=get_model_parallel_group())
    total_norm = norm_2.item() ** 0.5

    clip_coef = max_norm / (total_norm + 1e-6)
    grads = [p.grad for p in parameters]
    if clip_coef < 1:
    # Scale to get max_norm.
    clip_coef = float(max_norm) / (total_norm + 1.0e-6)
    grads = [p.grad for p in parameters_with_grads]
    if clip_coef < 1.0:
        multi_tensor_applier(
            amp_C.multi_tensor_scale,
            overflow_buf,