Commit 345f5d0d authored by mohammad's avatar mohammad
Browse files

moved grad scaling to multi-tensor-apply

parent e68ea218
Loading
Loading
Loading
Loading
+8 −5
Original line number Diff line number Diff line
@@ -107,13 +107,16 @@ def _clip_grad_norm(parameters, max_norm, norm_type=2):
        torch.distributed.all_reduce(total_norm_cuda,
                                     op=torch.distributed.ReduceOp.SUM,
                                     group=mpu.get_model_parallel_group())
        total_norm = total_norm_cuda[0].item() ** (1. / norm_type)
        total_norm = total_norm_cuda[0].item() ** (1.0 / norm_type)

    # Scale.
    clip_coef = max_norm / (total_norm + 1.0e-6)
    if clip_coef < 1.0:
        for grad in grads:
            grad.mul_(clip_coef)
    clip_coeff = max_norm / (total_norm + 1.0e-6)
    if clip_coeff < 1.0:
        dummy_overflow_buf = torch.cuda.IntTensor([0])
        multi_tensor_applier(amp_C.multi_tensor_scale,
                             dummy_overflow_buf,
                             [grads, grads],
                             clip_coeff)

    return total_norm