Loading megatron/model/fused_layer_norm.py +7 −2 Original line number Diff line number Diff line Loading @@ -23,7 +23,11 @@ from torch.nn.parameter import Parameter from torch.nn import init import importlib try: from apex.contrib.layer_norm.layer_norm import FastLayerNormFN HAVE_PERSIST_LAYER_NORM = True except: HAVE_PERSIST_LAYER_NORM = False global fused_mix_prec_layer_norm_cuda fused_mix_prec_layer_norm_cuda = None Loading Loading @@ -76,7 +80,8 @@ class MixedFusedLayerNorm(torch.nn.Module): persist_ln_hidden_sizes = [1024, 1536, 2048, 2304, 3072, 3840, 4096, 5120, 6144, 8192, 10240, 12288, 12800, 15360, 16384, 18432, 20480, 24576, 25600, 30720, 32768, 40960, 49152, 65536] if normalized_shape not in persist_ln_hidden_sizes: if normalized_shape not in persist_ln_hidden_sizes or \ not HAVE_PERSIST_LAYER_NORM: no_persist_layer_norm = True if isinstance(normalized_shape, numbers.Integral): Loading Loading
megatron/model/fused_layer_norm.py +7 −2 Original line number Diff line number Diff line Loading @@ -23,7 +23,11 @@ from torch.nn.parameter import Parameter from torch.nn import init import importlib try: from apex.contrib.layer_norm.layer_norm import FastLayerNormFN HAVE_PERSIST_LAYER_NORM = True except: HAVE_PERSIST_LAYER_NORM = False global fused_mix_prec_layer_norm_cuda fused_mix_prec_layer_norm_cuda = None Loading Loading @@ -76,7 +80,8 @@ class MixedFusedLayerNorm(torch.nn.Module): persist_ln_hidden_sizes = [1024, 1536, 2048, 2304, 3072, 3840, 4096, 5120, 6144, 8192, 10240, 12288, 12800, 15360, 16384, 18432, 20480, 24576, 25600, 30720, 32768, 40960, 49152, 65536] if normalized_shape not in persist_ln_hidden_sizes: if normalized_shape not in persist_ln_hidden_sizes or \ not HAVE_PERSIST_LAYER_NORM: no_persist_layer_norm = True if isinstance(normalized_shape, numbers.Integral): Loading