Loading megatron/arguments.py +4 −0 Original line number Diff line number Diff line Loading @@ -122,6 +122,10 @@ def _add_network_size_args(parser): action='store_true', help='If set, use original BERT residula connection ' 'ordering.') group.add_argument('--openai-gelu', action='store_true', help='Use OpenAIs GeLU implementation. This option' 'should not be used unless for backward compatibility' 'reasons.') return parser Loading megatron/model/bert_model.py +13 −9 Original line number Diff line number Diff line Loading @@ -18,16 +18,15 @@ import torch from megatron import get_args from megatron.model.language_model import parallel_lm_logits from megatron.model.language_model import get_language_model from megatron.model.transformer import LayerNorm from megatron.model.utils import openai_gelu from megatron.model.utils import get_linear_layer from megatron.model.utils import init_method_normal from megatron.model.utils import scaled_init_method_normal from megatron.module import MegatronModule from .language_model import parallel_lm_logits from .language_model import get_language_model from .transformer import LayerNorm from .utils import gelu from .utils import get_linear_layer from .utils import init_method_normal from .utils import scaled_init_method_normal def bert_attention_mask_func(attention_scores, attention_mask): attention_scores = attention_scores + attention_mask Loading Loading @@ -82,6 +81,8 @@ class BertLMHead(MegatronModule): super(BertLMHead, self).__init__() args = get_args() self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size)) self.bias.model_parallel = True self.bias.partition_dim = 0 Loading @@ -90,10 +91,13 @@ class BertLMHead(MegatronModule): self.dense = get_linear_layer(hidden_size, hidden_size, init_method) self.layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon) self.gelu = torch.nn.functional.gelu if args.openai_gelu: self.gelu = openai_gelu def forward(self, hidden_states, word_embeddings_weight): hidden_states = self.dense(hidden_states) hidden_states = gelu(hidden_states) hidden_states = self.gelu(hidden_states) hidden_states = self.layernorm(hidden_states) output = parallel_lm_logits(hidden_states, word_embeddings_weight, Loading megatron/model/language_model.py +7 −2 Original line number Diff line number Diff line Loading @@ -21,9 +21,8 @@ import torch.nn.functional as F from megatron import get_args from megatron import mpu from megatron.module import MegatronModule from megatron.model.transformer import ParallelTransformer from megatron.model.utils import gelu from megatron.model.utils import openai_gelu from megatron.model.utils import get_linear_layer Loading @@ -47,6 +46,12 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, def get_language_model(attention_mask_func, num_tokentypes, add_pooler, init_method, scaled_init_method): """Build language model and return along with the key to save.""" args = get_args() # Use torch gelu unless otherwise forced. gelu = F.gelu if args.openai_gelu: gelu = openai_gelu # Language model. language_model = TransformerLanguageModel( Loading megatron/model/utils.py +1 −3 Original line number Diff line number Diff line Loading @@ -54,9 +54,7 @@ def gelu_impl(x): """OpenAI's gelu implementation.""" return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * (1.0 + 0.044715 * x * x))) def gelu(x): def openai_gelu(x): return gelu_impl(x) Loading megatron/mpu/grads.py +48 −4 Original line number Diff line number Diff line Loading @@ -21,10 +21,47 @@ import torch from torch._six import inf from apex.multi_tensor_apply import multi_tensor_applier import amp_C from .initialize import get_model_parallel_group from .initialize import get_model_parallel_rank def l2_grad_clipper(parameters, max_norm): """Efficient L2 norm gradient clipping.""" overflow_buf = torch.zeros(1, dtype=torch.int, device='cuda') if isinstance(parameters, torch.Tensor): parameters = [parameters] mp_rank_is_zero = (get_model_parallel_rank() == 0) parameters = list(filter(lambda p: (p.grad is not None) and (p.model_parallel or mp_rank_is_zero), parameters)) norm, _ = multi_tensor_applier( amp_C.multi_tensor_l2norm, overflow_buf, [parameters], False # no per-parameter norm ) # Sum across all model parallel GPUs. norm_2 = norm * norm torch.distributed.all_reduce(norm_2, op=torch.distributed.ReduceOp.SUM, group=get_model_parallel_group()) total_norm = norm_2.item() ** 0.5 clip_coef = max_norm / (total_norm + 1e-6) grads = [p.grad for p in parameters] if clip_coef < 1: multi_tensor_applier( amp_C.multi_tensor_scale, overflow_buf, [grads, grads], clip_coef) return total_norm def clip_grad_norm(parameters, max_norm, norm_type=2): """Clips gradient norm of an iterable of parameters. Loading Loading @@ -55,6 +92,13 @@ def clip_grad_norm(parameters, max_norm, norm_type=2): op=torch.distributed.ReduceOp.MAX, group=get_model_parallel_group()) total_norm = total_norm_cuda[0].item() clip_coef = max_norm / (total_norm + 1e-6) if clip_coef < 1: for p in parameters: p.grad.data.mul_(clip_coef) elif norm_type == 2: total_norm = l2_grad_clipper(parameters, max_norm) else: total_norm = 0 for p in parameters: Loading Loading
megatron/arguments.py +4 −0 Original line number Diff line number Diff line Loading @@ -122,6 +122,10 @@ def _add_network_size_args(parser): action='store_true', help='If set, use original BERT residula connection ' 'ordering.') group.add_argument('--openai-gelu', action='store_true', help='Use OpenAIs GeLU implementation. This option' 'should not be used unless for backward compatibility' 'reasons.') return parser Loading
megatron/model/bert_model.py +13 −9 Original line number Diff line number Diff line Loading @@ -18,16 +18,15 @@ import torch from megatron import get_args from megatron.model.language_model import parallel_lm_logits from megatron.model.language_model import get_language_model from megatron.model.transformer import LayerNorm from megatron.model.utils import openai_gelu from megatron.model.utils import get_linear_layer from megatron.model.utils import init_method_normal from megatron.model.utils import scaled_init_method_normal from megatron.module import MegatronModule from .language_model import parallel_lm_logits from .language_model import get_language_model from .transformer import LayerNorm from .utils import gelu from .utils import get_linear_layer from .utils import init_method_normal from .utils import scaled_init_method_normal def bert_attention_mask_func(attention_scores, attention_mask): attention_scores = attention_scores + attention_mask Loading Loading @@ -82,6 +81,8 @@ class BertLMHead(MegatronModule): super(BertLMHead, self).__init__() args = get_args() self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size)) self.bias.model_parallel = True self.bias.partition_dim = 0 Loading @@ -90,10 +91,13 @@ class BertLMHead(MegatronModule): self.dense = get_linear_layer(hidden_size, hidden_size, init_method) self.layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon) self.gelu = torch.nn.functional.gelu if args.openai_gelu: self.gelu = openai_gelu def forward(self, hidden_states, word_embeddings_weight): hidden_states = self.dense(hidden_states) hidden_states = gelu(hidden_states) hidden_states = self.gelu(hidden_states) hidden_states = self.layernorm(hidden_states) output = parallel_lm_logits(hidden_states, word_embeddings_weight, Loading
megatron/model/language_model.py +7 −2 Original line number Diff line number Diff line Loading @@ -21,9 +21,8 @@ import torch.nn.functional as F from megatron import get_args from megatron import mpu from megatron.module import MegatronModule from megatron.model.transformer import ParallelTransformer from megatron.model.utils import gelu from megatron.model.utils import openai_gelu from megatron.model.utils import get_linear_layer Loading @@ -47,6 +46,12 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, def get_language_model(attention_mask_func, num_tokentypes, add_pooler, init_method, scaled_init_method): """Build language model and return along with the key to save.""" args = get_args() # Use torch gelu unless otherwise forced. gelu = F.gelu if args.openai_gelu: gelu = openai_gelu # Language model. language_model = TransformerLanguageModel( Loading
megatron/model/utils.py +1 −3 Original line number Diff line number Diff line Loading @@ -54,9 +54,7 @@ def gelu_impl(x): """OpenAI's gelu implementation.""" return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * (1.0 + 0.044715 * x * x))) def gelu(x): def openai_gelu(x): return gelu_impl(x) Loading
megatron/mpu/grads.py +48 −4 Original line number Diff line number Diff line Loading @@ -21,10 +21,47 @@ import torch from torch._six import inf from apex.multi_tensor_apply import multi_tensor_applier import amp_C from .initialize import get_model_parallel_group from .initialize import get_model_parallel_rank def l2_grad_clipper(parameters, max_norm): """Efficient L2 norm gradient clipping.""" overflow_buf = torch.zeros(1, dtype=torch.int, device='cuda') if isinstance(parameters, torch.Tensor): parameters = [parameters] mp_rank_is_zero = (get_model_parallel_rank() == 0) parameters = list(filter(lambda p: (p.grad is not None) and (p.model_parallel or mp_rank_is_zero), parameters)) norm, _ = multi_tensor_applier( amp_C.multi_tensor_l2norm, overflow_buf, [parameters], False # no per-parameter norm ) # Sum across all model parallel GPUs. norm_2 = norm * norm torch.distributed.all_reduce(norm_2, op=torch.distributed.ReduceOp.SUM, group=get_model_parallel_group()) total_norm = norm_2.item() ** 0.5 clip_coef = max_norm / (total_norm + 1e-6) grads = [p.grad for p in parameters] if clip_coef < 1: multi_tensor_applier( amp_C.multi_tensor_scale, overflow_buf, [grads, grads], clip_coef) return total_norm def clip_grad_norm(parameters, max_norm, norm_type=2): """Clips gradient norm of an iterable of parameters. Loading Loading @@ -55,6 +92,13 @@ def clip_grad_norm(parameters, max_norm, norm_type=2): op=torch.distributed.ReduceOp.MAX, group=get_model_parallel_group()) total_norm = total_norm_cuda[0].item() clip_coef = max_norm / (total_norm + 1e-6) if clip_coef < 1: for p in parameters: p.grad.data.mul_(clip_coef) elif norm_type == 2: total_norm = l2_grad_clipper(parameters, max_norm) else: total_norm = 0 for p in parameters: Loading