Loading megatron/arguments.py +5 −0 Original line number Diff line number Diff line Loading @@ -183,6 +183,9 @@ def parse_args(extra_args_provider=None, defaults={}, # Mixed precision checks. if args.fp16_lm_cross_entropy: assert args.fp16, 'lm cross entropy in fp16 only support in fp16 mode.' if args.fp32_residual_connection: assert args.fp16, \ 'residual connection in fp32 only supports in fp16 mode.' # Activation checkpointing. if args.distribute_checkpointed_activations: assert args.checkpoint_activations, \ Loading Loading @@ -435,6 +438,8 @@ def _add_mixed_precision_args(parser): group.add_argument('--fp16', action='store_true', help='Run model in fp16 mode.') group.add_argument('--fp32-residual-connection', action='store_true', help='Move residual connections to fp32.') group.add_argument('--apply-query-key-layer-scaling', action='store_true', help='Scale Q * K^T by 1 / layer-number. If this flag ' 'is set, then it will automatically set ' Loading megatron/model/__init__.py +18 −0 Original line number Diff line number Diff line Loading @@ -13,9 +13,27 @@ # See the License for the specific language governing permissions and # limitations under the License. _LAYER_NORM = None def import_layernorm(fp32_residual_connection): global _LAYER_NORM if not _LAYER_NORM: if fp32_residual_connection: from .fused_layer_norm import MixedFusedLayerNorm as LayerNorm else: from apex.normalization.fused_layer_norm import FusedLayerNorm as LayerNorm _LAYER_NORM = LayerNorm return _LAYER_NORM from .distributed import * from .bert_model import BertModel, BertModelFirstStage, BertModelIntermediateStage, BertModelLastStage from .realm_model import ICTBertModel from .gpt2_model import GPT2Model, GPT2ModelFirstStage, GPT2ModelIntermediateStage, GPT2ModelLastStage from .utils import get_params_for_weight_decay_optimization from .language_model import get_language_model megatron/model/bert_model.py +2 −1 Original line number Diff line number Diff line Loading @@ -21,7 +21,7 @@ from megatron import get_args from megatron import mpu from megatron.model.language_model import parallel_lm_logits from megatron.model.language_model import get_language_model from megatron.model.transformer import LayerNorm from megatron.model import import_layernorm from megatron.model.utils import openai_gelu, erf_gelu from megatron.model.utils import get_linear_layer from megatron.model.utils import init_method_normal Loading Loading @@ -83,6 +83,7 @@ class BertLMHead(MegatronModule): self.parallel_output = parallel_output self.dense = get_linear_layer(hidden_size, hidden_size, init_method) LayerNorm = import_layernorm(args.fp32_residual_connection) self.layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon) self.gelu = torch.nn.functional.gelu if args.openai_gelu: Loading megatron/model/transformer.py +9 −2 Original line number Diff line number Diff line Loading @@ -21,9 +21,9 @@ import torch.nn.functional as F from megatron import get_args from megatron import mpu from megatron.mpu import LayerNorm from megatron.module import MegatronModule from megatron.checkpointing import get_checkpoint_version from megatron.model import import_layernorm from megatron.model.fused_softmax import FusedScaleMaskSoftmax from megatron.model.fused_bias_gelu import bias_gelu_impl from megatron.model.utils import openai_gelu, erf_gelu Loading Loading @@ -404,6 +404,7 @@ class ParallelTransformerLayer(MegatronModule): = args.apply_residual_connection_post_layernorm # Layernorm on the input data. LayerNorm = import_layernorm(args.fp32_residual_connection) self.input_layernorm = LayerNorm( args.hidden_size, eps=args.layernorm_epsilon) Loading Loading @@ -500,6 +501,8 @@ class ParallelTransformer(MegatronModule): super(ParallelTransformer, self).__init__() args = get_args() self.fp32_residual_connection = args.fp32_residual_connection # Store activation checkpoiting flag. self.checkpoint_activations = args.checkpoint_activations self.checkpoint_num_layers = args.checkpoint_num_layers Loading @@ -520,6 +523,7 @@ class ParallelTransformer(MegatronModule): if mpu.is_pipeline_last_stage(): # Final layer norm before output. LayerNorm = import_layernorm(args.fp32_residual_connection) self.final_layernorm = LayerNorm( args.hidden_size, eps=args.layernorm_epsilon) Loading Loading @@ -564,6 +568,9 @@ class ParallelTransformer(MegatronModule): if mpu.is_pipeline_first_stage(): # Data format change to avoid explicit tranposes : [b s h] --> [s b h]. if self.fp32_residual_connection: hidden_states = hidden_states.transpose(0, 1).contiguous().float() else: hidden_states = hidden_states.transpose(0, 1).contiguous() if self.checkpoint_activations: Loading megatron/model/utils.py +6 −2 Original line number Diff line number Diff line Loading @@ -19,8 +19,8 @@ import math import torch from .transformer import LayerNorm from megatron import get_args from megatron.model import import_layernorm def init_method_normal(sigma): """Init method based on N(0, sigma).""" Loading Loading @@ -65,6 +65,10 @@ def get_params_for_weight_decay_optimization(module): """Divide params into with-weight-decay and without-weight-decay groups. Layernorms and baises will have no weight decay but the rest will. """ args = get_args() LayerNorm = import_layernorm(args.fp32_residual_connection) weight_decay_params = {'params': []} no_weight_decay_params = {'params': [], 'weight_decay': 0.0} for module_ in module.modules(): Loading Loading
megatron/arguments.py +5 −0 Original line number Diff line number Diff line Loading @@ -183,6 +183,9 @@ def parse_args(extra_args_provider=None, defaults={}, # Mixed precision checks. if args.fp16_lm_cross_entropy: assert args.fp16, 'lm cross entropy in fp16 only support in fp16 mode.' if args.fp32_residual_connection: assert args.fp16, \ 'residual connection in fp32 only supports in fp16 mode.' # Activation checkpointing. if args.distribute_checkpointed_activations: assert args.checkpoint_activations, \ Loading Loading @@ -435,6 +438,8 @@ def _add_mixed_precision_args(parser): group.add_argument('--fp16', action='store_true', help='Run model in fp16 mode.') group.add_argument('--fp32-residual-connection', action='store_true', help='Move residual connections to fp32.') group.add_argument('--apply-query-key-layer-scaling', action='store_true', help='Scale Q * K^T by 1 / layer-number. If this flag ' 'is set, then it will automatically set ' Loading
megatron/model/__init__.py +18 −0 Original line number Diff line number Diff line Loading @@ -13,9 +13,27 @@ # See the License for the specific language governing permissions and # limitations under the License. _LAYER_NORM = None def import_layernorm(fp32_residual_connection): global _LAYER_NORM if not _LAYER_NORM: if fp32_residual_connection: from .fused_layer_norm import MixedFusedLayerNorm as LayerNorm else: from apex.normalization.fused_layer_norm import FusedLayerNorm as LayerNorm _LAYER_NORM = LayerNorm return _LAYER_NORM from .distributed import * from .bert_model import BertModel, BertModelFirstStage, BertModelIntermediateStage, BertModelLastStage from .realm_model import ICTBertModel from .gpt2_model import GPT2Model, GPT2ModelFirstStage, GPT2ModelIntermediateStage, GPT2ModelLastStage from .utils import get_params_for_weight_decay_optimization from .language_model import get_language_model
megatron/model/bert_model.py +2 −1 Original line number Diff line number Diff line Loading @@ -21,7 +21,7 @@ from megatron import get_args from megatron import mpu from megatron.model.language_model import parallel_lm_logits from megatron.model.language_model import get_language_model from megatron.model.transformer import LayerNorm from megatron.model import import_layernorm from megatron.model.utils import openai_gelu, erf_gelu from megatron.model.utils import get_linear_layer from megatron.model.utils import init_method_normal Loading Loading @@ -83,6 +83,7 @@ class BertLMHead(MegatronModule): self.parallel_output = parallel_output self.dense = get_linear_layer(hidden_size, hidden_size, init_method) LayerNorm = import_layernorm(args.fp32_residual_connection) self.layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon) self.gelu = torch.nn.functional.gelu if args.openai_gelu: Loading
megatron/model/transformer.py +9 −2 Original line number Diff line number Diff line Loading @@ -21,9 +21,9 @@ import torch.nn.functional as F from megatron import get_args from megatron import mpu from megatron.mpu import LayerNorm from megatron.module import MegatronModule from megatron.checkpointing import get_checkpoint_version from megatron.model import import_layernorm from megatron.model.fused_softmax import FusedScaleMaskSoftmax from megatron.model.fused_bias_gelu import bias_gelu_impl from megatron.model.utils import openai_gelu, erf_gelu Loading Loading @@ -404,6 +404,7 @@ class ParallelTransformerLayer(MegatronModule): = args.apply_residual_connection_post_layernorm # Layernorm on the input data. LayerNorm = import_layernorm(args.fp32_residual_connection) self.input_layernorm = LayerNorm( args.hidden_size, eps=args.layernorm_epsilon) Loading Loading @@ -500,6 +501,8 @@ class ParallelTransformer(MegatronModule): super(ParallelTransformer, self).__init__() args = get_args() self.fp32_residual_connection = args.fp32_residual_connection # Store activation checkpoiting flag. self.checkpoint_activations = args.checkpoint_activations self.checkpoint_num_layers = args.checkpoint_num_layers Loading @@ -520,6 +523,7 @@ class ParallelTransformer(MegatronModule): if mpu.is_pipeline_last_stage(): # Final layer norm before output. LayerNorm = import_layernorm(args.fp32_residual_connection) self.final_layernorm = LayerNorm( args.hidden_size, eps=args.layernorm_epsilon) Loading Loading @@ -564,6 +568,9 @@ class ParallelTransformer(MegatronModule): if mpu.is_pipeline_first_stage(): # Data format change to avoid explicit tranposes : [b s h] --> [s b h]. if self.fp32_residual_connection: hidden_states = hidden_states.transpose(0, 1).contiguous().float() else: hidden_states = hidden_states.transpose(0, 1).contiguous() if self.checkpoint_activations: Loading
megatron/model/utils.py +6 −2 Original line number Diff line number Diff line Loading @@ -19,8 +19,8 @@ import math import torch from .transformer import LayerNorm from megatron import get_args from megatron.model import import_layernorm def init_method_normal(sigma): """Init method based on N(0, sigma).""" Loading Loading @@ -65,6 +65,10 @@ def get_params_for_weight_decay_optimization(module): """Divide params into with-weight-decay and without-weight-decay groups. Layernorms and baises will have no weight decay but the rest will. """ args = get_args() LayerNorm = import_layernorm(args.fp32_residual_connection) weight_decay_params = {'params': []} no_weight_decay_params = {'params': [], 'weight_decay': 0.0} for module_ in module.modules(): Loading