Commit d9b1c681 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'ninja_compilation_fix' into 'main'

refactored the fused kernels build

See merge request ADLR/megatron-lm!250
parents 876096d5 0d5188c1
Loading
Loading
Loading
Loading
+0 −26
Original line number Diff line number Diff line
@@ -19,7 +19,6 @@ import argparse
import os

import torch
from megatron import fused_kernels

def parse_args(extra_args_provider=None, defaults={},
               ignore_unknown_args=False):
@@ -227,31 +226,6 @@ def parse_args(extra_args_provider=None, defaults={},
            'for distribute-checkpointed-activations to work you '\
            'need to enable checkpoint-activations'

    # custom kernel constraints check
    seq_len = args.seq_length
    attn_batch_size = \
        (args.num_attention_heads / args.tensor_model_parallel_size) * \
        args.micro_batch_size

    # constraints on sequence length and attn_batch_size to enable warp based
    # optimization and upper triangular optimization (for causal mask)
    custom_kernel_constraint = seq_len > 16 and seq_len <=2048 and \
        seq_len % 4 == 0 and attn_batch_size % 4 == 0

    if not (args.fp16 and custom_kernel_constraint and args.masked_softmax_fusion):
        print('WARNING: constraints for invoking optimized'
            ' fused softmax kernel are not met. We default back to unfused'
            ' kernel invocations.')

    # Load scaled_masked_softmax_fusion_kernels
    if args.masked_softmax_fusion:
        fused_kernels.load_scaled_upper_triang_masked_softmax_fusion_kernel()
        fused_kernels.load_scaled_masked_softmax_fusion_kernel()

    # Load mixed precision fused layer norm.
    if args.fp32_residual_connection:
        fused_kernels.load_fused_mix_prec_layer_norm_kernel()

    _print_args(args)
    return args

+71 −87
Original line number Diff line number Diff line
@@ -13,114 +13,98 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import pathlib
import subprocess
import os

from torch.utils import cpp_extension

# Setting this param to a list has a problem of generating
# different compilation commands (with diferent order of architectures)
# and leading to recompilation of fused kernels.
# set it to empty string to avoid recompilation
# and assign arch flags explicity in extra_cuda_cflags below
# Setting this param to a list has a problem of generating different
# compilation commands (with diferent order of architectures) and
# leading to recompilation of fused kernels. Set it to empty string
# to avoid recompilation and assign arch flags explicity in
# extra_cuda_cflags below
os.environ["TORCH_CUDA_ARCH_LIST"] = ""

def get_cuda_bare_metal_version(cuda_dir):
    raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"],
                                         universal_newlines=True)
    output = raw_output.split()
    release_idx = output.index("release") + 1
    release = output[release_idx].split(".")
    bare_metal_major = release[0]
    bare_metal_minor = release[1][0]

    return raw_output, bare_metal_major, bare_metal_minor

def create_build_dir(buildpath):
    try:
        os.mkdir(buildpath)
    except OSError:
        if not os.path.isdir(buildpath):
            print(f"Creation of the build directory {buildpath} failed")

def load_scaled_upper_triang_masked_softmax_fusion_kernel():
def load(args):

    # Check, if CUDA11 is installed for compute capability 8.0
    # Check if cuda 11 is installed for compute capability 8.0
    cc_flag = []
    _, bare_metal_major, _ = get_cuda_bare_metal_version(cpp_extension.CUDA_HOME)
    _, bare_metal_major, _ = _get_cuda_bare_metal_version(
        cpp_extension.CUDA_HOME)
    if int(bare_metal_major) >= 11:
        cc_flag.append('-gencode')
        cc_flag.append('arch=compute_80,code=sm_80')

    # Build path
    srcpath = pathlib.Path(__file__).parent.absolute()
    buildpath = srcpath / 'build'
    _create_build_dir(buildpath)

    create_build_dir(buildpath)

    scaled_upper_triang_masked_softmax_cuda = cpp_extension.load(
        name='scaled_upper_triang_masked_softmax_cuda',
        sources=[srcpath / 'scaled_upper_triang_masked_softmax.cpp',
                 srcpath / 'scaled_upper_triang_masked_softmax_cuda.cu'],
    # Helper function to build the kernels.
    def _cpp_extention_load_helper(name, sources, extra_cuda_flags):
        return cpp_extension.load(
            name=name,
            sources=sources,
            build_directory=buildpath,
            extra_cflags=['-O3',],
            extra_cuda_cflags=['-O3',
                               '-gencode', 'arch=compute_70,code=sm_70',
                           '-U__CUDA_NO_HALF_OPERATORS__',
                           '-U__CUDA_NO_HALF_CONVERSIONS__',
                           '--expt-relaxed-constexpr',
                           '--expt-extended-lambda',
                           '--use_fast_math'] + cc_flag)
                               '--use_fast_math'] + extra_cuda_flags + cc_flag,
            verbose=(args.rank == 0)
        )

def load_scaled_masked_softmax_fusion_kernel():

    # Check, if CUDA11 is installed for compute capability 8.0
    cc_flag = []
    _, bare_metal_major, _ = get_cuda_bare_metal_version(cpp_extension.CUDA_HOME)
    if int(bare_metal_major) >= 11:
        cc_flag.append('-gencode')
        cc_flag.append('arch=compute_80,code=sm_80')
    # ==============
    # Fused softmax.
    # ==============

    srcpath = pathlib.Path(__file__).parent.absolute()
    buildpath = srcpath / 'build'
    if args.masked_softmax_fusion:
        extra_cuda_flags = ['-U__CUDA_NO_HALF_OPERATORS__',
                            '-U__CUDA_NO_HALF_CONVERSIONS__',
                            '--expt-relaxed-constexpr',
                            '--expt-extended-lambda']
        
    create_build_dir(buildpath)
        # Upper triangular softmax.
        sources=[srcpath / 'scaled_upper_triang_masked_softmax.cpp',
                 srcpath / 'scaled_upper_triang_masked_softmax_cuda.cu']
        scaled_upper_triang_masked_softmax_cuda = _cpp_extention_load_helper(
            "scaled_upper_triang_masked_softmax_cuda",
            sources, extra_cuda_flags)

    scaled_upper_triang_masked_softmax_cuda = cpp_extension.load(
        name='scaled_masked_softmax_cuda',
        # Masked softmax.
        sources=[srcpath / 'scaled_masked_softmax.cpp',
                 srcpath / 'scaled_masked_softmax_cuda.cu'],
        build_directory=buildpath,
        extra_cflags=['-O3',],
        extra_cuda_cflags=['-O3',
                           '-gencode', 'arch=compute_70,code=sm_70',
                           '-U__CUDA_NO_HALF_OPERATORS__',
                           '-U__CUDA_NO_HALF_CONVERSIONS__',
                           '--expt-relaxed-constexpr',
                           '--expt-extended-lambda',
                           '--use_fast_math'] + cc_flag)
                 srcpath / 'scaled_masked_softmax_cuda.cu']
        scaled_masked_softmax_cuda = _cpp_extention_load_helper(
            "scaled_masked_softmax_cuda", sources, extra_cuda_flags)

    # =================================
    # Mixed precision fused layer norm.
    # =================================

def load_fused_mix_prec_layer_norm_kernel():
    if args.fp32_residual_connection:
        extra_cuda_flags = ['-maxrregcount=50']
        sources=[srcpath / 'layer_norm_cuda.cpp',
                 srcpath / 'layer_norm_cuda_kernel.cu']
        fused_mix_prec_layer_norm_cuda = _cpp_extention_load_helper(
            "fused_mix_prec_layer_norm_cuda", sources, extra_cuda_flags)

    # Check, if CUDA11 is installed for compute capability 8.0
    cc_flag = []
    _, bare_metal_major, _ = get_cuda_bare_metal_version(cpp_extension.CUDA_HOME)
    if int(bare_metal_major) >= 11:
        cc_flag.append('-gencode')
        cc_flag.append('arch=compute_80,code=sm_80')

    srcpath = pathlib.Path(__file__).parent.absolute()
    buildpath = srcpath / 'build'
def _get_cuda_bare_metal_version(cuda_dir):
    raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"],
                                         universal_newlines=True)
    output = raw_output.split()
    release_idx = output.index("release") + 1
    release = output[release_idx].split(".")
    bare_metal_major = release[0]
    bare_metal_minor = release[1][0]

    create_build_dir(buildpath)
    return raw_output, bare_metal_major, bare_metal_minor

    fused_mix_prec_layer_norm_cuda = cpp_extension.load(
        name='fused_mix_prec_layer_norm_cuda',
        sources=[srcpath / 'layer_norm_cuda.cpp',
                 srcpath / 'layer_norm_cuda_kernel.cu'],
        build_directory=buildpath,
        extra_cflags=['-O3'],
        extra_cuda_cflags=['-O3',
                           '-gencode', 'arch=compute_70,code=sm_70',
                           '-maxrregcount=50',
                           '--use_fast_math'] + cc_flag)

def _create_build_dir(buildpath):
    try:
        os.mkdir(buildpath)
    except OSError:
        if not os.path.isdir(buildpath):
            print(f"Creation of the build directory {buildpath} failed")
+1 −41
Original line number Diff line number Diff line
@@ -26,11 +26,7 @@
namespace {
void compute_n1_n2(
    at::Tensor input,
    #ifdef VERSION_GE_1_1
    at::IntArrayRef normalized_shape,
    #else
    at::IntList normalized_shape,
    #endif
    int& n1,
    int& n2)
{
@@ -47,11 +43,7 @@ void compute_n1_n2(
}

void check_args(
    #ifdef VERSION_GE_1_1
    at::IntArrayRef normalized_shape,
    #else
    at::IntList normalized_shape,
    #endif
    at::Tensor gamma,
    at::Tensor beta
    )
@@ -62,11 +54,7 @@ void check_args(

void check_args(
    at::Tensor input,
    #ifdef VERSION_GE_1_1
    at::IntArrayRef normalized_shape,
    #else
    at::IntList normalized_shape,
    #endif
    int& n1,
    int& n2
    )
@@ -102,11 +90,7 @@ void check_args(

void check_args(
    at::Tensor input,
    #ifdef VERSION_GE_1_1
    at::IntArrayRef normalized_shape,
    #else
    at::IntList normalized_shape,
    #endif
    at::Tensor gamma,
    at::Tensor beta,
    int& n1,
@@ -125,26 +109,18 @@ void cuda_layer_norm(
    at::Tensor* input,
    int n1,
    int n2,
    #ifdef VERSION_GE_1_1
    at::IntArrayRef normalized_shape,
    #else
    at::IntList normalized_shape,
    #endif
    at::Tensor* gamma,
    at::Tensor* beta,
    double epsilon);

#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)

std::vector<at::Tensor> layer_norm(
    at::Tensor input,
    #ifdef VERSION_GE_1_1
    at::IntArrayRef normalized_shape,
    #else
    at::IntList normalized_shape,
    #endif
    double epsilon) {
  CHECK_INPUT(input);
  int n1,n2;
@@ -158,11 +134,7 @@ std::vector<at::Tensor> layer_norm(
}
std::vector<at::Tensor> layer_norm_affine(
    at::Tensor input,
    #ifdef VERSION_GE_1_1
    at::IntArrayRef normalized_shape,
    #else
    at::IntList normalized_shape,
    #endif
    at::Tensor gamma,
    at::Tensor beta,
    double epsilon) {
@@ -186,11 +158,7 @@ void cuda_layer_norm_gradient(
    at::Tensor* input,
    int n1,
    int n2,
    #ifdef VERSION_GE_1_1
    at::IntArrayRef normalized_shape,
    #else
    at::IntList normalized_shape,
    #endif
    at::Tensor* gamma,
    at::Tensor* beta,
    double epsilon,
@@ -204,11 +172,7 @@ at::Tensor layer_norm_gradient(
    at::Tensor mean,
    at::Tensor invvar,
    at::Tensor input,
    #ifdef VERSION_GE_1_1
    at::IntArrayRef normalized_shape,
    #else
    at::IntList normalized_shape,
    #endif
    double epsilon) {
  CHECK_INPUT(dout);
  CHECK_INPUT(mean);
@@ -227,11 +191,7 @@ std::vector<at::Tensor> layer_norm_gradient_affine(
    at::Tensor mean,
    at::Tensor invvar,
    at::Tensor input,
    #ifdef VERSION_GE_1_1
    at::IntArrayRef normalized_shape,
    #else
    at::IntList normalized_shape,
    #endif
    at::Tensor gamma,
    at::Tensor beta,
    double epsilon) {
+0 −1
Original line number Diff line number Diff line
@@ -19,7 +19,6 @@
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include "THC/THC.h"
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "scaled_masked_softmax.h"
+0 −1
Original line number Diff line number Diff line
@@ -19,7 +19,6 @@
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include "THC/THC.h"
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "scaled_upper_triang_masked_softmax.h"
Loading