Commit 3860e995 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'slym/pyt21.08_nvfuser' into 'main'

Use nvfuser at pytorch >= 1.10

See merge request ADLR/megatron-lm!314
parents 89e8d27e dba2506d
Loading
Loading
Loading
Loading
+25 −0
Original line number Diff line number Diff line
@@ -64,6 +64,9 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
            print('> setting random seeds to {} ...'.format(args.seed))
        _set_random_seed(args.seed)

    # Set pytorch JIT layer fusion options.
    _set_jit_fusion_options()

    args = get_args()
    if  args.lazy_mpu_init:
        args.use_cpu_initialization=True
@@ -222,3 +225,25 @@ def write_args_to_tensorboard():
            writer.add_text(arg, str(getattr(args, arg)),
                            global_step=args.iteration)


def _set_jit_fusion_options():
    """Set PyTorch JIT layer fusion options."""
    # flags required to enable jit fusion kernels
    TORCH_MAJOR = int(torch.__version__.split('.')[0])
    TORCH_MINOR = int(torch.__version__.split('.')[1])
    if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10):
        # nvfuser
        torch._C._jit_set_profiling_executor(True)
        torch._C._jit_set_profiling_mode(True)
        torch._C._jit_override_can_fuse_on_cpu(False)
        torch._C._jit_override_can_fuse_on_gpu(False)
        torch._C._jit_set_texpr_fuser_enabled(False)
        torch._C._jit_set_nvfuser_enabled(True)
        torch._C._debug_set_autodiff_subgraph_inlining(False)
    else:
        # legacy pytorch fuser
        torch._C._jit_set_profiling_mode(False)
        torch._C._jit_set_profiling_executor(False)
        torch._C._jit_override_can_fuse_on_cpu(True)
        torch._C._jit_override_can_fuse_on_gpu(True)
+0 −4
Original line number Diff line number Diff line
@@ -15,10 +15,6 @@

import torch

torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False)
torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._jit_override_can_fuse_on_gpu(True)

###### BIAS GELU FUSION/ NO AUTOGRAD ################
# 1/sqrt(2*pi)-> 0.3989423
+0 −5
Original line number Diff line number Diff line
@@ -27,11 +27,6 @@ from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl
from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu

# flags required to enable jit fusion kernels
torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False)
torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._jit_override_can_fuse_on_gpu(True)

""" We use the following notation throughout this file:
     h: hidden size