Commit 91ee60df authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'vijay/softmax_fusion' into 'main'

Various speed optimizations.

See merge request ADLR/megatron-lm!124
parents 12518332 51a2e6b0
Loading
Loading
Loading
Loading
+13 −1
Original line number Diff line number Diff line
@@ -19,7 +19,7 @@ import argparse
import os

import torch

from megatron import fused_kernels

def parse_args(extra_args_provider=None, defaults={},
               ignore_unknown_args=False):
@@ -118,6 +118,10 @@ def parse_args(extra_args_provider=None, defaults={},
            'for distribute-checkpointed-activations to work you '\
            'need to enable checkpoint-activations'

    # load scaled_upper_triang_masked_softmax_fusion kernel
    if args.scaled_upper_triang_masked_softmax_fusion:
        fused_kernels.load_scaled_upper_triang_masked_softmax_fusion_kernel()

    _print_args(args)
    return args

@@ -221,6 +225,14 @@ def _add_training_args(parser):
                       'by this value.')
    group.add_argument('--tensorboard-dir', type=str, default=None,
                       help='Write TensorBoard logs to this directory.')
    group.add_argument('--scaled-upper-triang-masked-softmax-fusion',
                       action='store_true',
                       help='Enable fusion of query_key_value_scaling '
                       'time (upper diagonal) masking, softmax.')
    group.add_argument('--bias-gelu-fusion', action='store_true',
                        help='Enable bias and gelu fusion.')
    group.add_argument('--bias-dropout-fusion', action='store_true',
                       help='Enable bias and dropout fusion.')

    return parser

+53 −0
Original line number Diff line number Diff line
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pathlib
import subprocess
from torch.utils import cpp_extension

def load_scaled_upper_triang_masked_softmax_fusion_kernel():

    def get_cuda_bare_metal_version(cuda_dir):
        raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], 
                                             universal_newlines=True)
        output = raw_output.split()
        release_idx = output.index("release") + 1
        release = output[release_idx].split(".")
        bare_metal_major = release[0]
        bare_metal_minor = release[1][0]

        return raw_output, bare_metal_major, bare_metal_minor

    # Check, if CUDA11 is installed for compute capability 8.0
    cc_flag = []
    _, bare_metal_major, _ = get_cuda_bare_metal_version(cpp_extension.CUDA_HOME)
    if int(bare_metal_major) >= 11:
        cc_flag.append('-gencode')
        cc_flag.append('arch=compute_80,code=sm_80')

    srcpath = pathlib.Path(__file__).parent.absolute()
    scaled_upper_triang_masked_softmax_cuda = cpp_extension.load(
        name='scaled_upper_triang_masked_softmax_cuda', 
        sources=[srcpath / 'scaled_upper_triang_masked_softmax.cpp', 
                 srcpath / 'scaled_upper_triang_masked_softmax_cuda.cu'], 
        extra_cflags=['-O3',],
        extra_cuda_cflags=['-O3',
                           '-gencode', 'arch=compute_70,code=sm_70',
                           '-U__CUDA_NO_HALF_OPERATORS__',
                           '-U__CUDA_NO_HALF_CONVERSIONS__',
                           '--expt-relaxed-constexpr',
                           '--expt-extended-lambda',
                           '--use_fast_math'] + cc_flag,
        verbose=True)
+69 −0
Original line number Diff line number Diff line
/* coding=utf-8
 * Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include <cuda_fp16.h>
#include <torch/extension.h>
#include <vector>

namespace multihead_attn {
namespace fused_softmax {
namespace scaled_upper_triang_masked_softmax {

torch::Tensor fwd_cuda(
    torch::Tensor const& input, 
    float scale_factor);

torch::Tensor bwd_cuda(
    torch::Tensor const& output_grads, 
    torch::Tensor const& softmax_results,
    float scale_factor);

torch::Tensor fwd(torch::Tensor const& input, float scale_factor) {
  AT_ASSERTM(input.dim() == 3, "expected 3D tensor");
  AT_ASSERTM(input.scalar_type() == at::ScalarType::Half, 
      "Only HALF is supported");

  return fwd_cuda(input, scale_factor);
}

torch::Tensor bwd(
    torch::Tensor const& output_grads, 
    torch::Tensor const& softmax_results,
    float scale_factor) {

  AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
  AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");

  AT_ASSERTM(output_grads.scalar_type() == at::ScalarType::Half, 
      "Only HALF is supported");
  AT_ASSERTM(softmax_results.scalar_type() == at::ScalarType::Half, 
      "Only HALF is supported");

  return bwd_cuda(output_grads, softmax_results, scale_factor);
}

} // end namespace scaled_upper_triang_masked_softmax
} // end namespace fused_softmax
} // end namespace multihead_attn

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("forward", 
        &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::fwd, 
	"Self Multihead Attention scaled, time masked softmax -- Forward.");
  m.def("backward", 
        &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::bwd,
	"Self Multihead Attention scaled, time masked softmax -- Backward.");
}
+439 −0

File added.

Preview size limit exceeded, changes collapsed.

+89 −0
Original line number Diff line number Diff line
/* coding=utf-8
 * Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include "THC/THC.h"
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "scaled_upper_triang_masked_softmax.h"

namespace multihead_attn {
namespace fused_softmax {
namespace scaled_upper_triang_masked_softmax {

torch::Tensor fwd_cuda(
    torch::Tensor const& input, 
    float scale_factor)
{
  // input is a 3d tensor with dimensions [attn_batches, seq_len, seq_len]
  const int attn_batches = input.size(0);
  const int seq_len = input.size(1);
  TORCH_INTERNAL_ASSERT(seq_len <= 2048);

  // Output 
  auto act_options = input.options().requires_grad(false);
  torch::Tensor softmax_results = 
      torch::empty({attn_batches, seq_len, seq_len}, act_options);

  // Softmax Intermediate Result Ptr
  void* input_ptr = static_cast<void*>(input.data_ptr());
  void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());

  dispatch_scaled_upper_triang_masked_softmax_forward<half, half, float>(
      reinterpret_cast<half*>(softmax_results_ptr),
      reinterpret_cast<const half*>(input_ptr),
      scale_factor,
      seq_len,
      seq_len,
      attn_batches);
  return softmax_results;
}

torch::Tensor bwd_cuda(
    torch::Tensor const& output_grads_, 
    torch::Tensor const& softmax_results_, 
    float scale_factor)  {
	
  auto output_grads = output_grads_.contiguous();
  auto softmax_results = softmax_results_.contiguous();

  //output grads is a 3d tensor with dimensions [attn_batches, seq_len, seq_len]
  const int attn_batches = output_grads.size(0);
  const int seq_len = output_grads.size(1);
  TORCH_INTERNAL_ASSERT(output_grads.size(1) == output_grads.size(2));

  void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr());

  //Softmax Grad
  dispatch_scaled_upper_triang_masked_softmax_backward<half, half, float>(
      reinterpret_cast<half*>(output_grads_ptr), 
      reinterpret_cast<half*>(output_grads_ptr), 
      reinterpret_cast<half const*>(softmax_results.data_ptr()),
      scale_factor,
      seq_len,
      seq_len,
      attn_batches);
  
  //backward pass is completely in-place
  return output_grads;
}
}
}
}
Loading