Commit 03d28809 authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

Merge branch 'l2_grad_clip_fix' into 'master'

Reverting l2 grad optimization

See merge request ADLR/megatron-lm!74
parents 3c709cba d218f9cf
Loading
Loading
Loading
Loading
+15 −10
Original line number Diff line number Diff line
@@ -32,16 +32,21 @@ def l2_grad_clipper(parameters, max_norm):
    """Efficient L2 norm gradient clipping."""

    overflow_buf = torch.zeros(1, dtype=torch.int, device='cuda')
    # Make sure we have an iterable.
    if isinstance(parameters, torch.Tensor):
        parameters = [parameters]
    # Filter parameters with gradients.
    parameters_with_grads = list(filter(
        lambda p: p.grad is not None, parameters))
    # Filter parameters for norm calculations.
    mp_rank_is_zero = (get_model_parallel_rank() == 0)
    parameters = list(filter(lambda p: (p.grad is not None) and
                             (p.model_parallel or mp_rank_is_zero),
                             parameters))
    parameters_for_norm = list(filter(
        lambda p: p.model_parallel or mp_rank_is_zero, parameters_with_grads))
    # Calculate L2 norm.
    norm, _ = multi_tensor_applier(
        amp_C.multi_tensor_l2norm,
        overflow_buf,
        [parameters],
        [parameters_for_norm],
        False # no per-parameter norm
    )
    # Sum across all model parallel GPUs.
@@ -50,10 +55,10 @@ def l2_grad_clipper(parameters, max_norm):
                                 op=torch.distributed.ReduceOp.SUM,
                                 group=get_model_parallel_group())
    total_norm = norm_2.item() ** 0.5

    clip_coef = max_norm / (total_norm + 1e-6)
    grads = [p.grad for p in parameters]
    if clip_coef < 1:
    # Scale to get max_norm.
    clip_coef = float(max_norm) / (total_norm + 1.0e-6)
    grads = [p.grad for p in parameters_with_grads]
    if clip_coef < 1.0:
        multi_tensor_applier(
            amp_C.multi_tensor_scale,
            overflow_buf,
@@ -96,8 +101,8 @@ def clip_grad_norm(parameters, max_norm, norm_type=2):
        if clip_coef < 1:
            for p in parameters:
                p.grad.data.mul_(clip_coef)
    elif norm_type == 2:
        total_norm = l2_grad_clipper(parameters, max_norm)
    #elif norm_type == 2:
    #    total_norm = l2_grad_clipper(parameters, max_norm)

    else:
        total_norm = 0