Commit 42d21122 authored by Mohammad Shoeybi's avatar Mohammad Shoeybi
Browse files

Merge branch 'build_fix' into 'main'

Create build directory for fused_kernels before building.

See merge request ADLR/megatron-lm!149
parents 54282071 54ded172
Loading
Loading
Loading
Loading
+21 −7
Original line number Diff line number Diff line
@@ -36,6 +36,12 @@ def get_cuda_bare_metal_version(cuda_dir):

    return raw_output, bare_metal_major, bare_metal_minor

def create_build_dir(buildpath):
    try:
        os.mkdir(buildpath)
    except OSError:
        if not os.path.isdir(buildpath):
            print(f"Creation of the build directory {buildpath} failed")

def load_scaled_upper_triang_masked_softmax_fusion_kernel():

@@ -47,11 +53,15 @@ def load_scaled_upper_triang_masked_softmax_fusion_kernel():
        cc_flag.append('arch=compute_80,code=sm_80')

    srcpath = pathlib.Path(__file__).parent.absolute()
    buildpath = srcpath / 'build'

    create_build_dir(buildpath)

    scaled_upper_triang_masked_softmax_cuda = cpp_extension.load(
        name='scaled_upper_triang_masked_softmax_cuda',
        sources=[srcpath / 'scaled_upper_triang_masked_softmax.cpp',
                 srcpath / 'scaled_upper_triang_masked_softmax_cuda.cu'],
        build_directory=srcpath / 'build',
        build_directory=buildpath,
        extra_cflags=['-O3',],
        extra_cuda_cflags=['-O3',
                           '-gencode', 'arch=compute_70,code=sm_70',
@@ -72,11 +82,15 @@ def load_scaled_masked_softmax_fusion_kernel():
        cc_flag.append('arch=compute_80,code=sm_80')

    srcpath = pathlib.Path(__file__).parent.absolute()
    buildpath = srcpath / 'build'

    create_build_dir(buildpath)

    scaled_upper_triang_masked_softmax_cuda = cpp_extension.load(
        name='scaled_masked_softmax_cuda',
        sources=[srcpath / 'scaled_masked_softmax.cpp',
                 srcpath / 'scaled_masked_softmax_cuda.cu'],
        build_directory=srcpath / 'build',
        build_directory=buildpath,
        extra_cflags=['-O3',],
        extra_cuda_cflags=['-O3',
                           '-gencode', 'arch=compute_70,code=sm_70',