Commit 03a2903f authored by Sangkug Lym's avatar Sangkug Lym
Browse files

add a kernel import guard for persistent layer norm

parent 4a3213f1
Loading
Loading
Loading
Loading
+7 −2
Original line number Diff line number Diff line
@@ -23,7 +23,11 @@ from torch.nn.parameter import Parameter
from torch.nn import init
import importlib

try:
    from apex.contrib.layer_norm.layer_norm import FastLayerNormFN
    HAVE_PERSIST_LAYER_NORM = True
except:
    HAVE_PERSIST_LAYER_NORM = False

global fused_mix_prec_layer_norm_cuda
fused_mix_prec_layer_norm_cuda = None
@@ -76,7 +80,8 @@ class MixedFusedLayerNorm(torch.nn.Module):
        persist_ln_hidden_sizes = [1024, 1536, 2048, 2304, 3072, 3840, 4096,
            5120, 6144, 8192, 10240, 12288, 12800, 15360, 16384, 18432, 20480,
            24576, 25600, 30720, 32768, 40960, 49152, 65536]
        if normalized_shape not in persist_ln_hidden_sizes:
        if normalized_shape not in persist_ln_hidden_sizes or \
                not HAVE_PERSIST_LAYER_NORM:
            no_persist_layer_norm = True

        if isinstance(normalized_shape, numbers.Integral):