Commit 84a5997a authored by Boris Fomitchev's avatar Boris Fomitchev
Browse files

Merge remote-tracking branch 'upstream/master' into onnx-erf

parents bb7c638f 07ebf714
Loading
Loading
Loading
Loading
+6 −0
Original line number Diff line number Diff line
@@ -39,6 +39,12 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
        # Make sure cuda is available.
        assert torch.cuda.is_available(), 'Megatron requires CUDA.'

    # This is temporary WAR to make simple case like pytest calling with same args twice
    # Need to implement clean factory init.
    if mpu.model_parallel_is_initialized():
        return
    
    
    # Parse args, build tokenizer, and set adlr-autoresume,
    # tensorboard-writer, and timers.
    set_global_variables(extra_args_provider=extra_args_provider,
+1 −2
Original line number Diff line number Diff line
@@ -16,12 +16,11 @@
"""Transformer."""

import math

import torch
from apex.normalization.fused_layer_norm import FusedLayerNorm as LayerNorm

from megatron import get_args
from megatron import mpu
from megatron.mpu import LayerNorm
from megatron.module import MegatronModule


+1 −0
Original line number Diff line number Diff line
@@ -32,6 +32,7 @@ from .initialize import get_model_parallel_world_size
from .initialize import initialize_model_parallel
from .initialize import model_parallel_is_initialized

from .layers import LayerNorm
from .layers import ColumnParallelLinear
from .layers import ParallelEmbedding
from .layers import RowParallelLinear
+6 −2
Original line number Diff line number Diff line
@@ -21,9 +21,13 @@
import torch
from torch._six import inf

try:
    from apex.multi_tensor_apply import multi_tensor_applier
    import amp_C

except Exception as e:
    print('WARNING: APEX is not installed, multi_tensor_applier will not be available.')

from .initialize import get_model_parallel_group
from .initialize import get_model_parallel_rank

+8 −1
Original line number Diff line number Diff line
@@ -25,7 +25,14 @@ import torch.nn.functional as F
import torch.nn.init as init
from torch.nn.parameter import Parameter

try:
    from apex.normalization.fused_layer_norm import FusedLayerNorm as LayerNorm
    # Try to use FusedLayerNorm from Apex - this will trigger an error.
    _ = LayerNorm(8, eps=1e-5)

except Exception as e:
    print('WARNING: APEX is not installed, using torch.nn.LayerNorm instead of apex.normalization.FusedLayerNorm!')
    from torch.nn import LayerNorm

from .initialize import get_model_parallel_rank
from .initialize import get_model_parallel_world_size