Commit 4ff2c963 authored by Mohammad's avatar Mohammad
Browse files

fixed a bug in l2 grad clip

parent fb5b2b36
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -39,6 +39,7 @@ def l2_grad_clipper(parameters, max_norm):
    parameters_with_grads = list(filter(
        lambda p: p.grad is not None, parameters))
    # Filter parameters for norm calculations.
    mp_rank_is_zero = (get_model_parallel_rank() == 0)
    parameters_for_norm = list(filter(
        lambda p: p.model_parallel or mp_rank_is_zero, parameters_with_grads))
    # Calculate L2 norm.