Loading megatron/module.py +1 −0 Original line number Diff line number Diff line Loading @@ -79,6 +79,7 @@ class PipelinedMegatronModule(MegatronModule): args.padded_vocab_size, args.hidden_size, init_method=init_method_normal(args.init_method_std)) self.word_embeddings.weight.data.fill_(0) self.word_embeddings.weight.shared = True # Ensure that first and last stages have the same initial parameter values. if mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage(): torch.distributed.all_reduce(self.word_embeddings_weight().data, Loading megatron/mpu/grads.py +35 −31 Original line number Diff line number Diff line Loading @@ -72,7 +72,7 @@ def l2_grad_clipper(parameters, max_norm): return total_norm def clip_grad_norm(parameters, max_norm, norm_type=2, parameter_names=None): def clip_grad_norm(parameters, max_norm, norm_type=2): """Clips gradient norm of an iterable of parameters. This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and Loading @@ -89,42 +89,43 @@ def clip_grad_norm(parameters, max_norm, norm_type=2, parameter_names=None): Returns: Total norm of the parameters (viewed as a single vector). """ if isinstance(parameters, torch.Tensor): parameters = [parameters] if parameter_names is not None: # Filter parameters based on: # - grad should not be none # - parameter should not be shared # - should not be a replica due to tensor model parallelism filtered_parameters = [] assert len(parameters) == len(parameter_names), \ 'length of parameters and parameter_names should be the same' for p, n in zip(parameters, parameter_names): if p.grad is not None: # TODO: Bit hacky; is there a cleaner way to do this? # Count embedding layer only once (in first stage). # Don't count the weights a second time in the last stage. if "embedding" not in n or \ is_pipeline_first_stage(): filtered_parameters.append(p) for param in parameters: grad_not_none = param.grad is not None is_not_shared = not hasattr(param, 'shared') or not param.shared is_not_tp_duplicate = param.tensor_model_parallel or \ (get_tensor_model_parallel_rank() == 0) if grad_not_none and is_not_shared and is_not_tp_duplicate: filtered_parameters.append(param) parameters = filtered_parameters else: parameters = list(filter(lambda p: p.grad is not None, parameters)) # Norm parameters. max_norm = float(max_norm) norm_type = float(norm_type) total_norm = 0 # Calculate norm. if norm_type == inf: total_norm = max(p.grad.data.abs().max() for p in parameters) total_norm = max(param.grad.detach().abs().max() for param in parameters) total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) # Take max across all model-parallel GPUs. torch.distributed.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=get_model_parallel_group()) total_norm = total_norm_cuda[0].item() clip_coef = max_norm / (total_norm + 1e-6) if clip_coef < 1: for p in parameters: p.grad.data.mul_(clip_coef) else: total_norm = 0 for p in parameters: if p.tensor_model_parallel or (get_tensor_model_parallel_rank() == 0): param_norm = torch.linalg.norm(p.grad.data.flatten(), norm_type) for param in parameters: param_norm = torch.norm(param.grad.detach(), norm_type) total_norm += param_norm.item() ** norm_type # Sum across all model-parallel GPUs. total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) Loading @@ -132,8 +133,11 @@ def clip_grad_norm(parameters, max_norm, norm_type=2, parameter_names=None): op=torch.distributed.ReduceOp.SUM, group=get_model_parallel_group()) total_norm = total_norm_cuda[0].item() ** (1. / norm_type) # Scale. clip_coef = max_norm / (total_norm + 1e-6) if clip_coef < 1: for p in parameters: p.grad.data.mul_(clip_coef) for param in parameters: param.grad.detach().mul_(clip_coef) return total_norm megatron/optimizer/optimizer.py +83 −12 Original line number Diff line number Diff line Loading @@ -19,6 +19,7 @@ from abc import ABC from abc import abstractmethod import torch from torch._six import inf from apex.multi_tensor_apply import multi_tensor_applier from apex.optimizers import FusedAdam as Adam Loading Loading @@ -195,6 +196,77 @@ def _zero_grad_group_helper(group, set_to_none): param.grad.zero_() def _clip_grad_norm(parameters, max_norm, norm_type=2): """Clips gradient norm of an iterable of parameters. This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and added functionality to handle model parallel parameters. Note that the gradients are modified in place. Arguments: parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a single Tensor that will have gradients normalized max_norm (float or int): max norm of the gradients norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for infinity norm. Returns: Total norm of the parameters (viewed as a single vector). """ if isinstance(parameters, torch.Tensor): parameters = [parameters] # Filter parameters based on: # - grad should not be none # - parameter should not be shared # - should not be a replica due to tensor model parallelism filtered_parameters = [] for param in parameters: grad_not_none = param.grad is not None is_not_shared = not hasattr(param, 'shared') or not param.shared is_not_tp_duplicate = param.tensor_model_parallel or \ (mpu.get_tensor_model_parallel_rank() == 0) if grad_not_none and is_not_shared and is_not_tp_duplicate: filtered_parameters.append(param) parameters = filtered_parameters # Norm parameters. max_norm = float(max_norm) norm_type = float(norm_type) total_norm = 0.0 # Calculate norm. if norm_type == inf: total_norm = max(param.grad.detach().abs().max() for param in parameters) total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) # Take max across all model-parallel GPUs. torch.distributed.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=mpu.get_model_parallel_group()) total_norm = total_norm_cuda[0].item() else: for param in parameters: param_norm = torch.norm(param.grad.detach(), norm_type) total_norm += param_norm.item() ** norm_type # Sum across all model-parallel GPUs. total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) torch.distributed.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.SUM, group=mpu.get_model_parallel_group()) total_norm = total_norm_cuda[0].item() ** (1. / norm_type) # Scale. clip_coef = max_norm / (total_norm + 1e-6) if clip_coef < 1: for param in parameters: param.grad.detach().mul_(clip_coef) return total_norm class MegatronOptimizer(ABC): Loading @@ -203,6 +275,13 @@ class MegatronOptimizer(ABC): self.optimizer = optimizer assert self.optimizer, 'no optimizer is provided.' def clip_grad_norm(self, clip_grad): params = [] for param_group in self.optimizer.param_groups: for param in param_group['params']: params.append(param) _clip_grad_norm(params, clip_grad) @abstractmethod def zero_grad(self, set_to_none=True): pass Loading Loading @@ -299,6 +378,8 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer): # Copy tensor model parallel attributes. mpu.copy_tensor_model_parallel_attributes(master_param, param) if hasattr(param, 'shared'): master_param.shared = param.shared # Replace the optimizer params with the new fp32 copy. param_group['params'][i] = master_param fp32_from_fp16_params_this_group.append(master_param) Loading Loading @@ -408,11 +489,7 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer): # ========================== timers('optimizer-clip-master-grad').start() fp32_params = [] for param_group in self.optimizer.param_groups: for param in param_group['params']: fp32_params.append(param) mpu.clip_grad_norm(fp32_params, self.clip_grad) self.clip_grad_norm(self.clip_grad) timers('optimizer-clip-master-grad').stop() # =================== Loading Loading @@ -473,13 +550,7 @@ class FP32Optimizer(MegatronOptimizer): # Clip gradients. if self.clip_grad > 0.0: parameters = [] parameter_names = [] for parameter_name, parameter in self.model.named_parameters(): parameters.append(parameter) parameter_names.append(parameter_name) mpu.clip_grad_norm(parameters, self.clip_grad, parameter_names=parameter_names) self.clip_grad_norm(self.clip_grad) # Update parameters. self.optimizer.step() Loading Loading
megatron/module.py +1 −0 Original line number Diff line number Diff line Loading @@ -79,6 +79,7 @@ class PipelinedMegatronModule(MegatronModule): args.padded_vocab_size, args.hidden_size, init_method=init_method_normal(args.init_method_std)) self.word_embeddings.weight.data.fill_(0) self.word_embeddings.weight.shared = True # Ensure that first and last stages have the same initial parameter values. if mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage(): torch.distributed.all_reduce(self.word_embeddings_weight().data, Loading
megatron/mpu/grads.py +35 −31 Original line number Diff line number Diff line Loading @@ -72,7 +72,7 @@ def l2_grad_clipper(parameters, max_norm): return total_norm def clip_grad_norm(parameters, max_norm, norm_type=2, parameter_names=None): def clip_grad_norm(parameters, max_norm, norm_type=2): """Clips gradient norm of an iterable of parameters. This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and Loading @@ -89,42 +89,43 @@ def clip_grad_norm(parameters, max_norm, norm_type=2, parameter_names=None): Returns: Total norm of the parameters (viewed as a single vector). """ if isinstance(parameters, torch.Tensor): parameters = [parameters] if parameter_names is not None: # Filter parameters based on: # - grad should not be none # - parameter should not be shared # - should not be a replica due to tensor model parallelism filtered_parameters = [] assert len(parameters) == len(parameter_names), \ 'length of parameters and parameter_names should be the same' for p, n in zip(parameters, parameter_names): if p.grad is not None: # TODO: Bit hacky; is there a cleaner way to do this? # Count embedding layer only once (in first stage). # Don't count the weights a second time in the last stage. if "embedding" not in n or \ is_pipeline_first_stage(): filtered_parameters.append(p) for param in parameters: grad_not_none = param.grad is not None is_not_shared = not hasattr(param, 'shared') or not param.shared is_not_tp_duplicate = param.tensor_model_parallel or \ (get_tensor_model_parallel_rank() == 0) if grad_not_none and is_not_shared and is_not_tp_duplicate: filtered_parameters.append(param) parameters = filtered_parameters else: parameters = list(filter(lambda p: p.grad is not None, parameters)) # Norm parameters. max_norm = float(max_norm) norm_type = float(norm_type) total_norm = 0 # Calculate norm. if norm_type == inf: total_norm = max(p.grad.data.abs().max() for p in parameters) total_norm = max(param.grad.detach().abs().max() for param in parameters) total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) # Take max across all model-parallel GPUs. torch.distributed.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=get_model_parallel_group()) total_norm = total_norm_cuda[0].item() clip_coef = max_norm / (total_norm + 1e-6) if clip_coef < 1: for p in parameters: p.grad.data.mul_(clip_coef) else: total_norm = 0 for p in parameters: if p.tensor_model_parallel or (get_tensor_model_parallel_rank() == 0): param_norm = torch.linalg.norm(p.grad.data.flatten(), norm_type) for param in parameters: param_norm = torch.norm(param.grad.detach(), norm_type) total_norm += param_norm.item() ** norm_type # Sum across all model-parallel GPUs. total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) Loading @@ -132,8 +133,11 @@ def clip_grad_norm(parameters, max_norm, norm_type=2, parameter_names=None): op=torch.distributed.ReduceOp.SUM, group=get_model_parallel_group()) total_norm = total_norm_cuda[0].item() ** (1. / norm_type) # Scale. clip_coef = max_norm / (total_norm + 1e-6) if clip_coef < 1: for p in parameters: p.grad.data.mul_(clip_coef) for param in parameters: param.grad.detach().mul_(clip_coef) return total_norm
megatron/optimizer/optimizer.py +83 −12 Original line number Diff line number Diff line Loading @@ -19,6 +19,7 @@ from abc import ABC from abc import abstractmethod import torch from torch._six import inf from apex.multi_tensor_apply import multi_tensor_applier from apex.optimizers import FusedAdam as Adam Loading Loading @@ -195,6 +196,77 @@ def _zero_grad_group_helper(group, set_to_none): param.grad.zero_() def _clip_grad_norm(parameters, max_norm, norm_type=2): """Clips gradient norm of an iterable of parameters. This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and added functionality to handle model parallel parameters. Note that the gradients are modified in place. Arguments: parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a single Tensor that will have gradients normalized max_norm (float or int): max norm of the gradients norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for infinity norm. Returns: Total norm of the parameters (viewed as a single vector). """ if isinstance(parameters, torch.Tensor): parameters = [parameters] # Filter parameters based on: # - grad should not be none # - parameter should not be shared # - should not be a replica due to tensor model parallelism filtered_parameters = [] for param in parameters: grad_not_none = param.grad is not None is_not_shared = not hasattr(param, 'shared') or not param.shared is_not_tp_duplicate = param.tensor_model_parallel or \ (mpu.get_tensor_model_parallel_rank() == 0) if grad_not_none and is_not_shared and is_not_tp_duplicate: filtered_parameters.append(param) parameters = filtered_parameters # Norm parameters. max_norm = float(max_norm) norm_type = float(norm_type) total_norm = 0.0 # Calculate norm. if norm_type == inf: total_norm = max(param.grad.detach().abs().max() for param in parameters) total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) # Take max across all model-parallel GPUs. torch.distributed.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=mpu.get_model_parallel_group()) total_norm = total_norm_cuda[0].item() else: for param in parameters: param_norm = torch.norm(param.grad.detach(), norm_type) total_norm += param_norm.item() ** norm_type # Sum across all model-parallel GPUs. total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) torch.distributed.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.SUM, group=mpu.get_model_parallel_group()) total_norm = total_norm_cuda[0].item() ** (1. / norm_type) # Scale. clip_coef = max_norm / (total_norm + 1e-6) if clip_coef < 1: for param in parameters: param.grad.detach().mul_(clip_coef) return total_norm class MegatronOptimizer(ABC): Loading @@ -203,6 +275,13 @@ class MegatronOptimizer(ABC): self.optimizer = optimizer assert self.optimizer, 'no optimizer is provided.' def clip_grad_norm(self, clip_grad): params = [] for param_group in self.optimizer.param_groups: for param in param_group['params']: params.append(param) _clip_grad_norm(params, clip_grad) @abstractmethod def zero_grad(self, set_to_none=True): pass Loading Loading @@ -299,6 +378,8 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer): # Copy tensor model parallel attributes. mpu.copy_tensor_model_parallel_attributes(master_param, param) if hasattr(param, 'shared'): master_param.shared = param.shared # Replace the optimizer params with the new fp32 copy. param_group['params'][i] = master_param fp32_from_fp16_params_this_group.append(master_param) Loading Loading @@ -408,11 +489,7 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer): # ========================== timers('optimizer-clip-master-grad').start() fp32_params = [] for param_group in self.optimizer.param_groups: for param in param_group['params']: fp32_params.append(param) mpu.clip_grad_norm(fp32_params, self.clip_grad) self.clip_grad_norm(self.clip_grad) timers('optimizer-clip-master-grad').stop() # =================== Loading Loading @@ -473,13 +550,7 @@ class FP32Optimizer(MegatronOptimizer): # Clip gradients. if self.clip_grad > 0.0: parameters = [] parameter_names = [] for parameter_name, parameter in self.model.named_parameters(): parameters.append(parameter) parameter_names.append(parameter_name) mpu.clip_grad_norm(parameters, self.clip_grad, parameter_names=parameter_names) self.clip_grad_norm(self.clip_grad) # Update parameters. self.optimizer.step() Loading