Commit 0fa7175f authored by Mohammad Shoeybi's avatar Mohammad Shoeybi Committed by Jared Casper
Browse files

Bfloat fused softmax + fused layer norm

parent d9b1c681
Loading
Loading
Loading
Loading
+1 −2
Original line number Diff line number Diff line
@@ -133,8 +133,7 @@ def parse_args(extra_args_provider=None, defaults={},
    if args.bf16:
        assert not args.fp16
        args.params_dtype = torch.bfloat16
        # No fusion is support for bfloat for now
        assert not args.masked_softmax_fusion
        # Jitting fusion is not supported for bfloat for now
        assert not args.bias_gelu_fusion
        assert not args.bias_dropout_fusion

+5 −6
Original line number Diff line number Diff line
@@ -82,7 +82,6 @@ def load(args):
    # Mixed precision fused layer norm.
    # =================================

    if args.fp32_residual_connection:
    extra_cuda_flags = ['-maxrregcount=50']
    sources=[srcpath / 'layer_norm_cuda.cpp',
             srcpath / 'layer_norm_cuda_kernel.cu']
+31 −50
Original line number Diff line number Diff line
@@ -24,12 +24,12 @@
#include "compat.h"

namespace {

void compute_n1_n2(
    at::Tensor input,
    at::IntArrayRef normalized_shape,
    int& n1,
    int& n2)
{
    int& n2) {
    int idiff = input.ndimension() - normalized_shape.size();
    n2 = 1;
    for (int i = 0;  i < (int)normalized_shape.size();  ++i) {
@@ -118,39 +118,33 @@ void cuda_layer_norm(
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)

std::vector<at::Tensor> layer_norm(
    at::Tensor input,
    at::IntArrayRef normalized_shape,
    double epsilon) {
  CHECK_INPUT(input);
  int n1,n2;
  check_args(input,normalized_shape,n1,n2);
  at::Tensor output = at::empty_like(input);
  at::Tensor mean = at::empty({n1}, input.options().dtype(input.scalar_type()==at::ScalarType::Half ? at::ScalarType::Float : input.scalar_type()));
  at::Tensor invvar = at::empty_like(mean);
  cuda_layer_norm(&output,&mean,&invvar,&input,n1,n2,
      normalized_shape,NULL,NULL,epsilon);
  return {output, mean, invvar};
}
std::vector<at::Tensor> layer_norm_affine(
    at::Tensor input,
    at::IntArrayRef normalized_shape,
    at::Tensor gamma,
    at::Tensor beta,
    double epsilon) {
  
  CHECK_INPUT(input);
  CHECK_INPUT(gamma);
  CHECK_INPUT(beta);
  int n1, n2;
  check_args(input, normalized_shape, gamma, beta, n1, n2);
  at::Tensor output = at::empty_like(input, input.options().dtype(at::ScalarType::Half));
  at::Tensor mean = at::empty({n1}, input.options().dtype(input.scalar_type()==at::ScalarType::Half ? at::ScalarType::Float : input.scalar_type()));

  at::Tensor output = at::empty_like(
      input, gamma.options().dtype(gamma.scalar_type()));
  at::Tensor mean = at::empty(
      {n1}, input.options().dtype(at::ScalarType::Float));
  at::Tensor invvar = at::empty_like(mean);

  cuda_layer_norm(&output, &mean, &invvar, &input, n1, n2,
      normalized_shape, &gamma, &beta, epsilon);

  return {output, mean, invvar};

}


void cuda_layer_norm_gradient(
    at::Tensor* dout,
    at::Tensor* mean,
@@ -167,25 +161,6 @@ void cuda_layer_norm_gradient(
    at::Tensor* grad_beta
    );

at::Tensor layer_norm_gradient(
    at::Tensor dout,
    at::Tensor mean,
    at::Tensor invvar,
    at::Tensor input,
    at::IntArrayRef normalized_shape,
    double epsilon) {
  CHECK_INPUT(dout);
  CHECK_INPUT(mean);
  CHECK_INPUT(invvar);
  CHECK_INPUT(input);
  int n1,n2;
  check_args(input,normalized_shape,n1,n2);
  at::Tensor grad_input = at::empty_like(input);
  cuda_layer_norm_gradient(&dout,&mean,&invvar,&input,n1,n2,
      normalized_shape,NULL,NULL,epsilon,
      &grad_input,NULL,NULL);
  return grad_input;
}
std::vector<at::Tensor> layer_norm_gradient_affine(
    at::Tensor dout,
    at::Tensor mean,
@@ -195,6 +170,7 @@ std::vector<at::Tensor> layer_norm_gradient_affine(
    at::Tensor gamma,
    at::Tensor beta,
    double epsilon) {

  CHECK_INPUT(dout);
  CHECK_INPUT(mean);
  CHECK_INPUT(invvar);
@@ -203,18 +179,23 @@ std::vector<at::Tensor> layer_norm_gradient_affine(
  CHECK_INPUT(beta);
  int n1, n2;
  check_args(input, normalized_shape, gamma, beta, n1, n2);

  at::Tensor grad_input = at::empty_like(input);
  at::Tensor grad_gamma = at::empty_like(gamma);
  at::Tensor grad_beta = at::empty_like(beta);

  cuda_layer_norm_gradient(&dout, &mean, &invvar, &input, n1, n2,
      normalized_shape, &gamma, &beta, epsilon,
      &grad_input, &grad_gamma, &grad_beta);

  return {grad_input, grad_gamma, grad_beta};

}


PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("forward_affine", &layer_norm_affine, "LayerNorm forward (CUDA)");
  m.def("forward", &layer_norm, "LayerNorm forward (CUDA)");
  m.def("backward_affine", &layer_norm_gradient_affine, "LayerNorm backward (CUDA)");
  m.def("backward", &layer_norm_gradient, "LayerNorm backward (CUDA)");
  m.def("forward_affine", &layer_norm_affine,
	"LayerNorm forward (CUDA)");
  m.def("backward_affine", &layer_norm_gradient_affine,
	"LayerNorm backward (CUDA)");
}
+33 −33
Original line number Diff line number Diff line
@@ -285,15 +285,6 @@ struct SharedMemory <float>
    }
};

template <>
struct SharedMemory <double>
{
    __device__ double *getPointer()
    {
        extern __shared__ double s_double[];
        return s_double;
    }
};
}

template<typename T, typename U, typename V> __global__
@@ -656,6 +647,9 @@ void cuComputeGradInput(
  }
}




template<typename T, typename U, typename V> 
void HostApplyLayerNorm(
    V* output,
@@ -671,7 +665,8 @@ void HostApplyLayerNorm(
{
    auto stream = at::cuda::getCurrentCUDAStream().stream();
    const dim3 threads(32,4,1);
    const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
    const uint64_t maxGridY =
      at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
    const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1);
    int nshared = 
        threads.y > 1 ? 
@@ -687,6 +682,7 @@ void HostApplyLayerNorm(
            gamma,beta);
}


void cuda_layer_norm(
    at::Tensor* output,
    at::Tensor* mean,
@@ -704,21 +700,21 @@ void cuda_layer_norm(
    double epsilon)
{
    using namespace at;
    DISPATCH_DOUBLE_FLOAT_AND_HALF(input->scalar_type(), 0, "layer_norm_cuda_kernel",
        using accscalar_t = at::acc_type<scalar_t_0, true>;
        using output_t = at::Half;
    DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(
        input->scalar_type(), output->scalar_type(), "cuda_layer_norm_kernel",
        HostApplyLayerNorm(
        output->DATA_PTR<output_t>(),
	    mean->DATA_PTR<accscalar_t>(),
	    invvar->DATA_PTR<accscalar_t>(),
	    input->DATA_PTR<scalar_t_0>(),
	    output->DATA_PTR<scalar_t_out>(),
	    mean->DATA_PTR<float>(),
	    invvar->DATA_PTR<float>(),
	    input->DATA_PTR<scalar_t_in>(),
	    n1,n2,
	    epsilon,
	    gamma != NULL ? gamma->DATA_PTR<output_t>() : NULL,
	    beta != NULL ? beta->DATA_PTR<output_t>() : NULL);
	    gamma != NULL ? gamma->DATA_PTR<scalar_t_out>() : NULL,
	    beta != NULL ? beta->DATA_PTR<scalar_t_out>() : NULL);
      )
}


template<typename T, typename U, typename V>
void HostLayerNormGradient(
    const V* dout,
@@ -742,10 +738,12 @@ void HostLayerNormGradient(
      const int part_size = 16;
      const dim3 threads2(32,4,1);
      const dim3 blocks2((n2+threads2.x-1)/threads2.x,part_size,1);
      const int nshared2_a = 2 * sizeof(U) * threads2.y * threads2.y * (threads2.x + 1);
      const int nshared2_a = 2 * sizeof(U) * threads2.y * threads2.y *
	(threads2.x + 1);
      const int nshared2_b = threads2.x * threads2.y * sizeof(U);
      const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b;
      at::Tensor part_grad_gamma = at::empty({part_size,n2}, input->options().dtype(input->scalar_type()==at::ScalarType::Half ? at::ScalarType::Float : input->scalar_type()));
      at::Tensor part_grad_gamma = at::empty(
	  {part_size,n2}, input->options().dtype(at::ScalarType::Float));
      at::Tensor part_grad_beta = at::empty_like(part_grad_gamma);
      cuComputePartGradGammaBeta<<<blocks2, threads2, nshared2, stream>>>(
		      dout,
@@ -770,7 +768,8 @@ void HostLayerNormGradient(
    }

    // compute grad_input
    const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
    const uint64_t maxGridY =
      at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
    const dim3 blocks1(1, std::min((uint64_t)n1, maxGridY), 1);
    const dim3 threads1(32,4,1);
    int nshared =
@@ -788,6 +787,7 @@ void HostLayerNormGradient(
            grad_input);
}


void cuda_layer_norm_gradient(
    at::Tensor* dout,
    at::Tensor* mean,
@@ -808,22 +808,22 @@ void cuda_layer_norm_gradient(
    at::Tensor* grad_beta)
{
    using namespace at;
    DISPATCH_FLOAT_AND_HALF(input->scalar_type(), 0, "cuComputeGradInput",
        using accscalar_t = at::acc_type<scalar_t_0, true>;
        using output_t = at::Half;
    DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(
        input->scalar_type(), gamma->scalar_type(),
	"cuda_layer_norm_gradient_kernel",
        HostLayerNormGradient(
	    dout->DATA_PTR<output_t>(),
	    mean->DATA_PTR<accscalar_t>(),
	    invvar->DATA_PTR<accscalar_t>(),
	    dout->DATA_PTR<scalar_t_out>(),
	    mean->DATA_PTR<float>(),
	    invvar->DATA_PTR<float>(),
	    input,
	    n1,n2,
            // TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta
            // if gamma Tensor is NULL on input.
	    gamma != NULL ? gamma->DATA_PTR<output_t>() : NULL,
	    gamma != NULL ? beta->DATA_PTR<output_t>() : NULL,
	    gamma != NULL ? gamma->DATA_PTR<scalar_t_out>() : NULL,
	    gamma != NULL ? beta->DATA_PTR<scalar_t_out>() : NULL,
	    epsilon,
	    grad_input->DATA_PTR<scalar_t_0>(),
	    gamma != NULL ? grad_gamma->DATA_PTR<output_t>() : NULL,
	    gamma != NULL ? grad_beta->DATA_PTR<output_t>() : NULL);
	    grad_input->DATA_PTR<scalar_t_in>(),
	    gamma != NULL ? grad_gamma->DATA_PTR<scalar_t_out>() : NULL,
	    gamma != NULL ? grad_beta->DATA_PTR<scalar_t_out>() : NULL);
      )
}
+9 −6
Original line number Diff line number Diff line
@@ -37,8 +37,9 @@ torch::Tensor fwd(
    torch::Tensor const& mask,
    float scale_factor) {
  AT_ASSERTM(input.dim() == 4, "expected 4D tensor");
  AT_ASSERTM(input.scalar_type() == at::ScalarType::Half, 
      "Only HALF is supported");
  AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
	     (input.scalar_type() == at::ScalarType::BFloat16), 
      "Only fp16 and bf16 are supported");
  AT_ASSERTM(mask.dim() == 4, "expected 4D tensor");

  return fwd_cuda(input, mask, scale_factor);
@@ -52,10 +53,12 @@ torch::Tensor bwd(
  AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor");
  AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor");

  AT_ASSERTM(output_grads.scalar_type() == at::ScalarType::Half, 
      "Only HALF is supported");
  AT_ASSERTM(softmax_results.scalar_type() == at::ScalarType::Half, 
      "Only HALF is supported");
  AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) ||
	     (output_grads.scalar_type() == at::ScalarType::BFloat16), 
      "Only fp16 and bf16 are supported");
  AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) ||
	     (softmax_results.scalar_type() == at::ScalarType::BFloat16), 
      "Only fp16 and bf16 are supported");

  return bwd_cuda(output_grads, softmax_results, scale_factor);
}
Loading