Loading megatron/optimizer/optimizer.py +17 −6 Original line number Diff line number Diff line Loading @@ -98,16 +98,27 @@ def _clip_grad_norm(parameters, max_norm, norm_type=2): group=mpu.get_model_parallel_group()) total_norm = total_norm_cuda[0].item() else: if norm_type == 2.0: dummy_overflow_buf = torch.cuda.IntTensor([0]) grad_norm, _ = multi_tensor_applier( amp_C.multi_tensor_l2norm, dummy_overflow_buf, [grads_for_norm], False # no per-parameter norm ) total_norm = grad_norm ** norm_type else: for grad in grads_for_norm: grad_norm = torch.norm(grad, norm_type) total_norm += grad_norm.item() ** norm_type total_norm += grad_norm ** norm_type # Sum across all model-parallel GPUs. total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) torch.distributed.all_reduce(total_norm_cuda, torch.distributed.all_reduce(total_norm, op=torch.distributed.ReduceOp.SUM, group=mpu.get_model_parallel_group()) total_norm = total_norm_cuda[0].item() ** (1.0 / norm_type) total_norm = total_norm.item() ** (1.0 / norm_type) # Scale. clip_coeff = max_norm / (total_norm + 1.0e-6) Loading Loading
megatron/optimizer/optimizer.py +17 −6 Original line number Diff line number Diff line Loading @@ -98,16 +98,27 @@ def _clip_grad_norm(parameters, max_norm, norm_type=2): group=mpu.get_model_parallel_group()) total_norm = total_norm_cuda[0].item() else: if norm_type == 2.0: dummy_overflow_buf = torch.cuda.IntTensor([0]) grad_norm, _ = multi_tensor_applier( amp_C.multi_tensor_l2norm, dummy_overflow_buf, [grads_for_norm], False # no per-parameter norm ) total_norm = grad_norm ** norm_type else: for grad in grads_for_norm: grad_norm = torch.norm(grad, norm_type) total_norm += grad_norm.item() ** norm_type total_norm += grad_norm ** norm_type # Sum across all model-parallel GPUs. total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) torch.distributed.all_reduce(total_norm_cuda, torch.distributed.all_reduce(total_norm, op=torch.distributed.ReduceOp.SUM, group=mpu.get_model_parallel_group()) total_norm = total_norm_cuda[0].item() ** (1.0 / norm_type) total_norm = total_norm.item() ** (1.0 / norm_type) # Scale. clip_coeff = max_norm / (total_norm + 1.0e-6) Loading