Commit 62632d39 authored by mohammad's avatar mohammad Committed by Deepak Narayanan
Browse files

Add the option for fp32 residual connection (fp32 residual connection...

Add the option for fp32 residual connection (fp32 residual connection machinery still needs to be added)
parent 9b174da8
Loading
Loading
Loading
Loading
+5 −0
Original line number Diff line number Diff line
@@ -183,6 +183,9 @@ def parse_args(extra_args_provider=None, defaults={},
    # Mixed precision checks.
    if args.fp16_lm_cross_entropy:
        assert args.fp16, 'lm cross entropy in fp16 only support in fp16 mode.'
    if args.fp32_residual_connection:
        assert args.fp16, \
            'residual connection in fp32 only supports in fp16 mode.'
    # Activation checkpointing.
    if args.distribute_checkpointed_activations:
        assert args.checkpoint_activations, \
@@ -435,6 +438,8 @@ def _add_mixed_precision_args(parser):

    group.add_argument('--fp16', action='store_true',
                       help='Run model in fp16 mode.')
    group.add_argument('--fp32-residual-connection', action='store_true',
                       help='Move residual connections to fp32.')
    group.add_argument('--apply-query-key-layer-scaling', action='store_true',
                       help='Scale Q * K^T by 1 / layer-number. If this flag '
                       'is set, then it will automatically set '
+18 −0
Original line number Diff line number Diff line
@@ -13,9 +13,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.

_LAYER_NORM = None


def import_layernorm(fp32_residual_connection):

    global _LAYER_NORM
    if not _LAYER_NORM:
        if fp32_residual_connection:
            from .fused_layer_norm import MixedFusedLayerNorm as LayerNorm
        else:
            from apex.normalization.fused_layer_norm import FusedLayerNorm as LayerNorm
        _LAYER_NORM = LayerNorm
            
    return _LAYER_NORM


from .distributed import *
from .bert_model import BertModel, BertModelFirstStage, BertModelIntermediateStage, BertModelLastStage
from .realm_model import ICTBertModel
from .gpt2_model import GPT2Model, GPT2ModelFirstStage, GPT2ModelIntermediateStage, GPT2ModelLastStage
from .utils import get_params_for_weight_decay_optimization
from .language_model import get_language_model

+2 −1
Original line number Diff line number Diff line
@@ -21,7 +21,7 @@ from megatron import get_args
from megatron import mpu
from megatron.model.language_model import parallel_lm_logits
from megatron.model.language_model import get_language_model
from megatron.model.transformer import LayerNorm
from megatron.model import import_layernorm
from megatron.model.utils import openai_gelu, erf_gelu
from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal
@@ -83,6 +83,7 @@ class BertLMHead(MegatronModule):
        self.parallel_output = parallel_output

        self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
        LayerNorm = import_layernorm(args.fp32_residual_connection)
        self.layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
        self.gelu = torch.nn.functional.gelu
        if args.openai_gelu:
+9 −2
Original line number Diff line number Diff line
@@ -21,9 +21,9 @@ import torch.nn.functional as F

from megatron import get_args
from megatron import mpu
from megatron.mpu import LayerNorm
from megatron.module import MegatronModule
from megatron.checkpointing import get_checkpoint_version
from megatron.model import import_layernorm
from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl
from megatron.model.utils import openai_gelu, erf_gelu
@@ -404,6 +404,7 @@ class ParallelTransformerLayer(MegatronModule):
            = args.apply_residual_connection_post_layernorm

        # Layernorm on the input data.
        LayerNorm = import_layernorm(args.fp32_residual_connection)
        self.input_layernorm = LayerNorm(
            args.hidden_size,
            eps=args.layernorm_epsilon)
@@ -500,6 +501,8 @@ class ParallelTransformer(MegatronModule):
        super(ParallelTransformer, self).__init__()
        args = get_args()

        self.fp32_residual_connection = args.fp32_residual_connection

        # Store activation checkpoiting flag.
        self.checkpoint_activations = args.checkpoint_activations
        self.checkpoint_num_layers = args.checkpoint_num_layers
@@ -520,6 +523,7 @@ class ParallelTransformer(MegatronModule):

        if mpu.is_pipeline_last_stage():
            # Final layer norm before output.
            LayerNorm = import_layernorm(args.fp32_residual_connection)
            self.final_layernorm = LayerNorm(
                args.hidden_size,
                eps=args.layernorm_epsilon)
@@ -564,6 +568,9 @@ class ParallelTransformer(MegatronModule):

        if mpu.is_pipeline_first_stage():
            # Data format change to avoid explicit tranposes : [b s h] --> [s b h].
            if self.fp32_residual_connection:
                hidden_states = hidden_states.transpose(0, 1).contiguous().float()
            else:
                hidden_states = hidden_states.transpose(0, 1).contiguous()

        if self.checkpoint_activations:
+6 −2
Original line number Diff line number Diff line
@@ -19,8 +19,8 @@ import math

import torch

from .transformer import LayerNorm

from megatron import get_args
from megatron.model import import_layernorm

def init_method_normal(sigma):
    """Init method based on N(0, sigma)."""
@@ -65,6 +65,10 @@ def get_params_for_weight_decay_optimization(module):
    """Divide params into with-weight-decay and without-weight-decay groups.
    Layernorms and baises will have no weight decay but the rest will.
    """

    args = get_args()
    LayerNorm = import_layernorm(args.fp32_residual_connection)
    
    weight_decay_params = {'params': []}
    no_weight_decay_params = {'params': [], 'weight_decay': 0.0}
    for module_ in module.modules():
Loading