Loading megatron/optimizer/optimizer.py +11 −11 Original line number Diff line number Diff line Loading @@ -68,8 +68,8 @@ def _clip_grad_norm(parameters, max_norm, norm_type=2): # - grad should not be none # - parameter should not be shared # - should not be a replica due to tensor model parallelism params_with_grads = [] params_for_norm = [] grads = [] grads_for_norm = [] for param in parameters: # Make sure the grads are in fp32 assert param.grad.type() == 'torch.cuda.FloatTensor' Loading @@ -77,10 +77,11 @@ def _clip_grad_norm(parameters, max_norm, norm_type=2): is_not_shared = not hasattr(param, 'shared') or not param.shared is_not_tp_duplicate = param.tensor_model_parallel or \ (mpu.get_tensor_model_parallel_rank() == 0) grad = param.grad.detach() if grad_not_none: params_with_grads.append(param) grads.append(grad) if grad_not_none and is_not_shared and is_not_tp_duplicate: params_for_norm.append(param) grads_for_norm.append(grad) # Norm parameters. max_norm = float(max_norm) Loading @@ -89,8 +90,7 @@ def _clip_grad_norm(parameters, max_norm, norm_type=2): # Calculate norm. if norm_type == inf: total_norm = max(param.grad.detach().abs().max() for param in params_for_norm) total_norm = max(grad.abs().max() for grad in grads_for_norm) total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) # Take max across all model-parallel GPUs. torch.distributed.all_reduce(total_norm_cuda, Loading @@ -99,9 +99,9 @@ def _clip_grad_norm(parameters, max_norm, norm_type=2): total_norm = total_norm_cuda[0].item() else: for param in params_for_norm: param_norm = torch.norm(param.grad.detach(), norm_type) total_norm += param_norm.item() ** norm_type for grad in grads_for_norm: grad_norm = torch.norm(grad, norm_type) total_norm += grad_norm.item() ** norm_type # Sum across all model-parallel GPUs. total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) torch.distributed.all_reduce(total_norm_cuda, Loading @@ -112,8 +112,8 @@ def _clip_grad_norm(parameters, max_norm, norm_type=2): # Scale. clip_coef = max_norm / (total_norm + 1.0e-6) if clip_coef < 1.0: for param in params_with_grads: param.grad.detach().mul_(clip_coef) for grad in grads: grad.mul_(clip_coef) return total_norm Loading Loading
megatron/optimizer/optimizer.py +11 −11 Original line number Diff line number Diff line Loading @@ -68,8 +68,8 @@ def _clip_grad_norm(parameters, max_norm, norm_type=2): # - grad should not be none # - parameter should not be shared # - should not be a replica due to tensor model parallelism params_with_grads = [] params_for_norm = [] grads = [] grads_for_norm = [] for param in parameters: # Make sure the grads are in fp32 assert param.grad.type() == 'torch.cuda.FloatTensor' Loading @@ -77,10 +77,11 @@ def _clip_grad_norm(parameters, max_norm, norm_type=2): is_not_shared = not hasattr(param, 'shared') or not param.shared is_not_tp_duplicate = param.tensor_model_parallel or \ (mpu.get_tensor_model_parallel_rank() == 0) grad = param.grad.detach() if grad_not_none: params_with_grads.append(param) grads.append(grad) if grad_not_none and is_not_shared and is_not_tp_duplicate: params_for_norm.append(param) grads_for_norm.append(grad) # Norm parameters. max_norm = float(max_norm) Loading @@ -89,8 +90,7 @@ def _clip_grad_norm(parameters, max_norm, norm_type=2): # Calculate norm. if norm_type == inf: total_norm = max(param.grad.detach().abs().max() for param in params_for_norm) total_norm = max(grad.abs().max() for grad in grads_for_norm) total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) # Take max across all model-parallel GPUs. torch.distributed.all_reduce(total_norm_cuda, Loading @@ -99,9 +99,9 @@ def _clip_grad_norm(parameters, max_norm, norm_type=2): total_norm = total_norm_cuda[0].item() else: for param in params_for_norm: param_norm = torch.norm(param.grad.detach(), norm_type) total_norm += param_norm.item() ** norm_type for grad in grads_for_norm: grad_norm = torch.norm(grad, norm_type) total_norm += grad_norm.item() ** norm_type # Sum across all model-parallel GPUs. total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) torch.distributed.all_reduce(total_norm_cuda, Loading @@ -112,8 +112,8 @@ def _clip_grad_norm(parameters, max_norm, norm_type=2): # Scale. clip_coef = max_norm / (total_norm + 1.0e-6) if clip_coef < 1.0: for param in params_with_grads: param.grad.detach().mul_(clip_coef) for grad in grads: grad.mul_(clip_coef) return total_norm Loading