Commit 56e16cba authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'vijay/fused_kernel_compiliation' into 'main'

Avoid non deterministic arch order in compilation

See merge request ADLR/megatron-lm!146
parents d80433e1 3dcd7dc0
Loading
Loading
Loading
Loading
+12 −2
Original line number Diff line number Diff line
@@ -15,8 +15,16 @@

import pathlib
import subprocess
import os
from torch.utils import cpp_extension

# Setting this param to a list has a problem of generating
# different compilation commands (with diferent order of architectures)
# and leading to recompilation of fused kernels.
# set it to empty string to avoid recompilation
# and assign arch flags explicity in extra_cuda_cflags below
os.environ["TORCH_CUDA_ARCH_LIST"] = ""

def get_cuda_bare_metal_version(cuda_dir):
    raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], 
                                         universal_newlines=True)
@@ -43,6 +51,7 @@ def load_scaled_upper_triang_masked_softmax_fusion_kernel():
        name='scaled_upper_triang_masked_softmax_cuda', 
        sources=[srcpath / 'scaled_upper_triang_masked_softmax.cpp', 
                 srcpath / 'scaled_upper_triang_masked_softmax_cuda.cu'],
        build_directory=srcpath / 'build',
        extra_cflags=['-O3',],
        extra_cuda_cflags=['-O3',
                           '-gencode', 'arch=compute_70,code=sm_70',
@@ -67,6 +76,7 @@ def load_scaled_masked_softmax_fusion_kernel():
        name='scaled_masked_softmax_cuda', 
        sources=[srcpath / 'scaled_masked_softmax.cpp', 
                 srcpath / 'scaled_masked_softmax_cuda.cu'],
        build_directory=srcpath / 'build',
        extra_cflags=['-O3',],
        extra_cuda_cflags=['-O3',
                           '-gencode', 'arch=compute_70,code=sm_70',