Loading megatron/optimizer/optimizer.py +8 −5 Original line number Diff line number Diff line Loading @@ -107,13 +107,16 @@ def _clip_grad_norm(parameters, max_norm, norm_type=2): torch.distributed.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.SUM, group=mpu.get_model_parallel_group()) total_norm = total_norm_cuda[0].item() ** (1. / norm_type) total_norm = total_norm_cuda[0].item() ** (1.0 / norm_type) # Scale. clip_coef = max_norm / (total_norm + 1.0e-6) if clip_coef < 1.0: for grad in grads: grad.mul_(clip_coef) clip_coeff = max_norm / (total_norm + 1.0e-6) if clip_coeff < 1.0: dummy_overflow_buf = torch.cuda.IntTensor([0]) multi_tensor_applier(amp_C.multi_tensor_scale, dummy_overflow_buf, [grads, grads], clip_coeff) return total_norm Loading Loading
megatron/optimizer/optimizer.py +8 −5 Original line number Diff line number Diff line Loading @@ -107,13 +107,16 @@ def _clip_grad_norm(parameters, max_norm, norm_type=2): torch.distributed.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.SUM, group=mpu.get_model_parallel_group()) total_norm = total_norm_cuda[0].item() ** (1. / norm_type) total_norm = total_norm_cuda[0].item() ** (1.0 / norm_type) # Scale. clip_coef = max_norm / (total_norm + 1.0e-6) if clip_coef < 1.0: for grad in grads: grad.mul_(clip_coef) clip_coeff = max_norm / (total_norm + 1.0e-6) if clip_coeff < 1.0: dummy_overflow_buf = torch.cuda.IntTensor([0]) multi_tensor_applier(amp_C.multi_tensor_scale, dummy_overflow_buf, [grads, grads], clip_coeff) return total_norm Loading