Loading megatron/arguments.py +1 −2 Original line number Diff line number Diff line Loading @@ -133,8 +133,7 @@ def parse_args(extra_args_provider=None, defaults={}, if args.bf16: assert not args.fp16 args.params_dtype = torch.bfloat16 # No fusion is support for bfloat for now assert not args.masked_softmax_fusion # Jitting fusion is not supported for bfloat for now assert not args.bias_gelu_fusion assert not args.bias_dropout_fusion Loading megatron/fused_kernels/__init__.py +5 −6 Original line number Diff line number Diff line Loading @@ -82,7 +82,6 @@ def load(args): # Mixed precision fused layer norm. # ================================= if args.fp32_residual_connection: extra_cuda_flags = ['-maxrregcount=50'] sources=[srcpath / 'layer_norm_cuda.cpp', srcpath / 'layer_norm_cuda_kernel.cu'] Loading megatron/fused_kernels/layer_norm_cuda.cpp +31 −50 Original line number Diff line number Diff line Loading @@ -24,12 +24,12 @@ #include "compat.h" namespace { void compute_n1_n2( at::Tensor input, at::IntArrayRef normalized_shape, int& n1, int& n2) { int& n2) { int idiff = input.ndimension() - normalized_shape.size(); n2 = 1; for (int i = 0; i < (int)normalized_shape.size(); ++i) { Loading Loading @@ -118,39 +118,33 @@ void cuda_layer_norm( #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) std::vector<at::Tensor> layer_norm( at::Tensor input, at::IntArrayRef normalized_shape, double epsilon) { CHECK_INPUT(input); int n1,n2; check_args(input,normalized_shape,n1,n2); at::Tensor output = at::empty_like(input); at::Tensor mean = at::empty({n1}, input.options().dtype(input.scalar_type()==at::ScalarType::Half ? at::ScalarType::Float : input.scalar_type())); at::Tensor invvar = at::empty_like(mean); cuda_layer_norm(&output,&mean,&invvar,&input,n1,n2, normalized_shape,NULL,NULL,epsilon); return {output, mean, invvar}; } std::vector<at::Tensor> layer_norm_affine( at::Tensor input, at::IntArrayRef normalized_shape, at::Tensor gamma, at::Tensor beta, double epsilon) { CHECK_INPUT(input); CHECK_INPUT(gamma); CHECK_INPUT(beta); int n1, n2; check_args(input, normalized_shape, gamma, beta, n1, n2); at::Tensor output = at::empty_like(input, input.options().dtype(at::ScalarType::Half)); at::Tensor mean = at::empty({n1}, input.options().dtype(input.scalar_type()==at::ScalarType::Half ? at::ScalarType::Float : input.scalar_type())); at::Tensor output = at::empty_like( input, gamma.options().dtype(gamma.scalar_type())); at::Tensor mean = at::empty( {n1}, input.options().dtype(at::ScalarType::Float)); at::Tensor invvar = at::empty_like(mean); cuda_layer_norm(&output, &mean, &invvar, &input, n1, n2, normalized_shape, &gamma, &beta, epsilon); return {output, mean, invvar}; } void cuda_layer_norm_gradient( at::Tensor* dout, at::Tensor* mean, Loading @@ -167,25 +161,6 @@ void cuda_layer_norm_gradient( at::Tensor* grad_beta ); at::Tensor layer_norm_gradient( at::Tensor dout, at::Tensor mean, at::Tensor invvar, at::Tensor input, at::IntArrayRef normalized_shape, double epsilon) { CHECK_INPUT(dout); CHECK_INPUT(mean); CHECK_INPUT(invvar); CHECK_INPUT(input); int n1,n2; check_args(input,normalized_shape,n1,n2); at::Tensor grad_input = at::empty_like(input); cuda_layer_norm_gradient(&dout,&mean,&invvar,&input,n1,n2, normalized_shape,NULL,NULL,epsilon, &grad_input,NULL,NULL); return grad_input; } std::vector<at::Tensor> layer_norm_gradient_affine( at::Tensor dout, at::Tensor mean, Loading @@ -195,6 +170,7 @@ std::vector<at::Tensor> layer_norm_gradient_affine( at::Tensor gamma, at::Tensor beta, double epsilon) { CHECK_INPUT(dout); CHECK_INPUT(mean); CHECK_INPUT(invvar); Loading @@ -203,18 +179,23 @@ std::vector<at::Tensor> layer_norm_gradient_affine( CHECK_INPUT(beta); int n1, n2; check_args(input, normalized_shape, gamma, beta, n1, n2); at::Tensor grad_input = at::empty_like(input); at::Tensor grad_gamma = at::empty_like(gamma); at::Tensor grad_beta = at::empty_like(beta); cuda_layer_norm_gradient(&dout, &mean, &invvar, &input, n1, n2, normalized_shape, &gamma, &beta, epsilon, &grad_input, &grad_gamma, &grad_beta); return {grad_input, grad_gamma, grad_beta}; } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("forward_affine", &layer_norm_affine, "LayerNorm forward (CUDA)"); m.def("forward", &layer_norm, "LayerNorm forward (CUDA)"); m.def("backward_affine", &layer_norm_gradient_affine, "LayerNorm backward (CUDA)"); m.def("backward", &layer_norm_gradient, "LayerNorm backward (CUDA)"); m.def("forward_affine", &layer_norm_affine, "LayerNorm forward (CUDA)"); m.def("backward_affine", &layer_norm_gradient_affine, "LayerNorm backward (CUDA)"); } megatron/fused_kernels/layer_norm_cuda_kernel.cu +33 −33 Original line number Diff line number Diff line Loading @@ -285,15 +285,6 @@ struct SharedMemory <float> } }; template <> struct SharedMemory <double> { __device__ double *getPointer() { extern __shared__ double s_double[]; return s_double; } }; } template<typename T, typename U, typename V> __global__ Loading Loading @@ -656,6 +647,9 @@ void cuComputeGradInput( } } template<typename T, typename U, typename V> void HostApplyLayerNorm( V* output, Loading @@ -671,7 +665,8 @@ void HostApplyLayerNorm( { auto stream = at::cuda::getCurrentCUDAStream().stream(); const dim3 threads(32,4,1); const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1); int nshared = threads.y > 1 ? Loading @@ -687,6 +682,7 @@ void HostApplyLayerNorm( gamma,beta); } void cuda_layer_norm( at::Tensor* output, at::Tensor* mean, Loading @@ -704,21 +700,21 @@ void cuda_layer_norm( double epsilon) { using namespace at; DISPATCH_DOUBLE_FLOAT_AND_HALF(input->scalar_type(), 0, "layer_norm_cuda_kernel", using accscalar_t = at::acc_type<scalar_t_0, true>; using output_t = at::Half; DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES( input->scalar_type(), output->scalar_type(), "cuda_layer_norm_kernel", HostApplyLayerNorm( output->DATA_PTR<output_t>(), mean->DATA_PTR<accscalar_t>(), invvar->DATA_PTR<accscalar_t>(), input->DATA_PTR<scalar_t_0>(), output->DATA_PTR<scalar_t_out>(), mean->DATA_PTR<float>(), invvar->DATA_PTR<float>(), input->DATA_PTR<scalar_t_in>(), n1,n2, epsilon, gamma != NULL ? gamma->DATA_PTR<output_t>() : NULL, beta != NULL ? beta->DATA_PTR<output_t>() : NULL); gamma != NULL ? gamma->DATA_PTR<scalar_t_out>() : NULL, beta != NULL ? beta->DATA_PTR<scalar_t_out>() : NULL); ) } template<typename T, typename U, typename V> void HostLayerNormGradient( const V* dout, Loading @@ -742,10 +738,12 @@ void HostLayerNormGradient( const int part_size = 16; const dim3 threads2(32,4,1); const dim3 blocks2((n2+threads2.x-1)/threads2.x,part_size,1); const int nshared2_a = 2 * sizeof(U) * threads2.y * threads2.y * (threads2.x + 1); const int nshared2_a = 2 * sizeof(U) * threads2.y * threads2.y * (threads2.x + 1); const int nshared2_b = threads2.x * threads2.y * sizeof(U); const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b; at::Tensor part_grad_gamma = at::empty({part_size,n2}, input->options().dtype(input->scalar_type()==at::ScalarType::Half ? at::ScalarType::Float : input->scalar_type())); at::Tensor part_grad_gamma = at::empty( {part_size,n2}, input->options().dtype(at::ScalarType::Float)); at::Tensor part_grad_beta = at::empty_like(part_grad_gamma); cuComputePartGradGammaBeta<<<blocks2, threads2, nshared2, stream>>>( dout, Loading @@ -770,7 +768,8 @@ void HostLayerNormGradient( } // compute grad_input const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; const dim3 blocks1(1, std::min((uint64_t)n1, maxGridY), 1); const dim3 threads1(32,4,1); int nshared = Loading @@ -788,6 +787,7 @@ void HostLayerNormGradient( grad_input); } void cuda_layer_norm_gradient( at::Tensor* dout, at::Tensor* mean, Loading @@ -808,22 +808,22 @@ void cuda_layer_norm_gradient( at::Tensor* grad_beta) { using namespace at; DISPATCH_FLOAT_AND_HALF(input->scalar_type(), 0, "cuComputeGradInput", using accscalar_t = at::acc_type<scalar_t_0, true>; using output_t = at::Half; DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES( input->scalar_type(), gamma->scalar_type(), "cuda_layer_norm_gradient_kernel", HostLayerNormGradient( dout->DATA_PTR<output_t>(), mean->DATA_PTR<accscalar_t>(), invvar->DATA_PTR<accscalar_t>(), dout->DATA_PTR<scalar_t_out>(), mean->DATA_PTR<float>(), invvar->DATA_PTR<float>(), input, n1,n2, // TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta // if gamma Tensor is NULL on input. gamma != NULL ? gamma->DATA_PTR<output_t>() : NULL, gamma != NULL ? beta->DATA_PTR<output_t>() : NULL, gamma != NULL ? gamma->DATA_PTR<scalar_t_out>() : NULL, gamma != NULL ? beta->DATA_PTR<scalar_t_out>() : NULL, epsilon, grad_input->DATA_PTR<scalar_t_0>(), gamma != NULL ? grad_gamma->DATA_PTR<output_t>() : NULL, gamma != NULL ? grad_beta->DATA_PTR<output_t>() : NULL); grad_input->DATA_PTR<scalar_t_in>(), gamma != NULL ? grad_gamma->DATA_PTR<scalar_t_out>() : NULL, gamma != NULL ? grad_beta->DATA_PTR<scalar_t_out>() : NULL); ) } megatron/fused_kernels/scaled_masked_softmax.cpp +9 −6 Original line number Diff line number Diff line Loading @@ -37,8 +37,9 @@ torch::Tensor fwd( torch::Tensor const& mask, float scale_factor) { AT_ASSERTM(input.dim() == 4, "expected 4D tensor"); AT_ASSERTM(input.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || (input.scalar_type() == at::ScalarType::BFloat16), "Only fp16 and bf16 are supported"); AT_ASSERTM(mask.dim() == 4, "expected 4D tensor"); return fwd_cuda(input, mask, scale_factor); Loading @@ -52,10 +53,12 @@ torch::Tensor bwd( AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor"); AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor"); AT_ASSERTM(output_grads.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(softmax_results.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || (output_grads.scalar_type() == at::ScalarType::BFloat16), "Only fp16 and bf16 are supported"); AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || (softmax_results.scalar_type() == at::ScalarType::BFloat16), "Only fp16 and bf16 are supported"); return bwd_cuda(output_grads, softmax_results, scale_factor); } Loading Loading
megatron/arguments.py +1 −2 Original line number Diff line number Diff line Loading @@ -133,8 +133,7 @@ def parse_args(extra_args_provider=None, defaults={}, if args.bf16: assert not args.fp16 args.params_dtype = torch.bfloat16 # No fusion is support for bfloat for now assert not args.masked_softmax_fusion # Jitting fusion is not supported for bfloat for now assert not args.bias_gelu_fusion assert not args.bias_dropout_fusion Loading
megatron/fused_kernels/__init__.py +5 −6 Original line number Diff line number Diff line Loading @@ -82,7 +82,6 @@ def load(args): # Mixed precision fused layer norm. # ================================= if args.fp32_residual_connection: extra_cuda_flags = ['-maxrregcount=50'] sources=[srcpath / 'layer_norm_cuda.cpp', srcpath / 'layer_norm_cuda_kernel.cu'] Loading
megatron/fused_kernels/layer_norm_cuda.cpp +31 −50 Original line number Diff line number Diff line Loading @@ -24,12 +24,12 @@ #include "compat.h" namespace { void compute_n1_n2( at::Tensor input, at::IntArrayRef normalized_shape, int& n1, int& n2) { int& n2) { int idiff = input.ndimension() - normalized_shape.size(); n2 = 1; for (int i = 0; i < (int)normalized_shape.size(); ++i) { Loading Loading @@ -118,39 +118,33 @@ void cuda_layer_norm( #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) std::vector<at::Tensor> layer_norm( at::Tensor input, at::IntArrayRef normalized_shape, double epsilon) { CHECK_INPUT(input); int n1,n2; check_args(input,normalized_shape,n1,n2); at::Tensor output = at::empty_like(input); at::Tensor mean = at::empty({n1}, input.options().dtype(input.scalar_type()==at::ScalarType::Half ? at::ScalarType::Float : input.scalar_type())); at::Tensor invvar = at::empty_like(mean); cuda_layer_norm(&output,&mean,&invvar,&input,n1,n2, normalized_shape,NULL,NULL,epsilon); return {output, mean, invvar}; } std::vector<at::Tensor> layer_norm_affine( at::Tensor input, at::IntArrayRef normalized_shape, at::Tensor gamma, at::Tensor beta, double epsilon) { CHECK_INPUT(input); CHECK_INPUT(gamma); CHECK_INPUT(beta); int n1, n2; check_args(input, normalized_shape, gamma, beta, n1, n2); at::Tensor output = at::empty_like(input, input.options().dtype(at::ScalarType::Half)); at::Tensor mean = at::empty({n1}, input.options().dtype(input.scalar_type()==at::ScalarType::Half ? at::ScalarType::Float : input.scalar_type())); at::Tensor output = at::empty_like( input, gamma.options().dtype(gamma.scalar_type())); at::Tensor mean = at::empty( {n1}, input.options().dtype(at::ScalarType::Float)); at::Tensor invvar = at::empty_like(mean); cuda_layer_norm(&output, &mean, &invvar, &input, n1, n2, normalized_shape, &gamma, &beta, epsilon); return {output, mean, invvar}; } void cuda_layer_norm_gradient( at::Tensor* dout, at::Tensor* mean, Loading @@ -167,25 +161,6 @@ void cuda_layer_norm_gradient( at::Tensor* grad_beta ); at::Tensor layer_norm_gradient( at::Tensor dout, at::Tensor mean, at::Tensor invvar, at::Tensor input, at::IntArrayRef normalized_shape, double epsilon) { CHECK_INPUT(dout); CHECK_INPUT(mean); CHECK_INPUT(invvar); CHECK_INPUT(input); int n1,n2; check_args(input,normalized_shape,n1,n2); at::Tensor grad_input = at::empty_like(input); cuda_layer_norm_gradient(&dout,&mean,&invvar,&input,n1,n2, normalized_shape,NULL,NULL,epsilon, &grad_input,NULL,NULL); return grad_input; } std::vector<at::Tensor> layer_norm_gradient_affine( at::Tensor dout, at::Tensor mean, Loading @@ -195,6 +170,7 @@ std::vector<at::Tensor> layer_norm_gradient_affine( at::Tensor gamma, at::Tensor beta, double epsilon) { CHECK_INPUT(dout); CHECK_INPUT(mean); CHECK_INPUT(invvar); Loading @@ -203,18 +179,23 @@ std::vector<at::Tensor> layer_norm_gradient_affine( CHECK_INPUT(beta); int n1, n2; check_args(input, normalized_shape, gamma, beta, n1, n2); at::Tensor grad_input = at::empty_like(input); at::Tensor grad_gamma = at::empty_like(gamma); at::Tensor grad_beta = at::empty_like(beta); cuda_layer_norm_gradient(&dout, &mean, &invvar, &input, n1, n2, normalized_shape, &gamma, &beta, epsilon, &grad_input, &grad_gamma, &grad_beta); return {grad_input, grad_gamma, grad_beta}; } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("forward_affine", &layer_norm_affine, "LayerNorm forward (CUDA)"); m.def("forward", &layer_norm, "LayerNorm forward (CUDA)"); m.def("backward_affine", &layer_norm_gradient_affine, "LayerNorm backward (CUDA)"); m.def("backward", &layer_norm_gradient, "LayerNorm backward (CUDA)"); m.def("forward_affine", &layer_norm_affine, "LayerNorm forward (CUDA)"); m.def("backward_affine", &layer_norm_gradient_affine, "LayerNorm backward (CUDA)"); }
megatron/fused_kernels/layer_norm_cuda_kernel.cu +33 −33 Original line number Diff line number Diff line Loading @@ -285,15 +285,6 @@ struct SharedMemory <float> } }; template <> struct SharedMemory <double> { __device__ double *getPointer() { extern __shared__ double s_double[]; return s_double; } }; } template<typename T, typename U, typename V> __global__ Loading Loading @@ -656,6 +647,9 @@ void cuComputeGradInput( } } template<typename T, typename U, typename V> void HostApplyLayerNorm( V* output, Loading @@ -671,7 +665,8 @@ void HostApplyLayerNorm( { auto stream = at::cuda::getCurrentCUDAStream().stream(); const dim3 threads(32,4,1); const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1); int nshared = threads.y > 1 ? Loading @@ -687,6 +682,7 @@ void HostApplyLayerNorm( gamma,beta); } void cuda_layer_norm( at::Tensor* output, at::Tensor* mean, Loading @@ -704,21 +700,21 @@ void cuda_layer_norm( double epsilon) { using namespace at; DISPATCH_DOUBLE_FLOAT_AND_HALF(input->scalar_type(), 0, "layer_norm_cuda_kernel", using accscalar_t = at::acc_type<scalar_t_0, true>; using output_t = at::Half; DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES( input->scalar_type(), output->scalar_type(), "cuda_layer_norm_kernel", HostApplyLayerNorm( output->DATA_PTR<output_t>(), mean->DATA_PTR<accscalar_t>(), invvar->DATA_PTR<accscalar_t>(), input->DATA_PTR<scalar_t_0>(), output->DATA_PTR<scalar_t_out>(), mean->DATA_PTR<float>(), invvar->DATA_PTR<float>(), input->DATA_PTR<scalar_t_in>(), n1,n2, epsilon, gamma != NULL ? gamma->DATA_PTR<output_t>() : NULL, beta != NULL ? beta->DATA_PTR<output_t>() : NULL); gamma != NULL ? gamma->DATA_PTR<scalar_t_out>() : NULL, beta != NULL ? beta->DATA_PTR<scalar_t_out>() : NULL); ) } template<typename T, typename U, typename V> void HostLayerNormGradient( const V* dout, Loading @@ -742,10 +738,12 @@ void HostLayerNormGradient( const int part_size = 16; const dim3 threads2(32,4,1); const dim3 blocks2((n2+threads2.x-1)/threads2.x,part_size,1); const int nshared2_a = 2 * sizeof(U) * threads2.y * threads2.y * (threads2.x + 1); const int nshared2_a = 2 * sizeof(U) * threads2.y * threads2.y * (threads2.x + 1); const int nshared2_b = threads2.x * threads2.y * sizeof(U); const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b; at::Tensor part_grad_gamma = at::empty({part_size,n2}, input->options().dtype(input->scalar_type()==at::ScalarType::Half ? at::ScalarType::Float : input->scalar_type())); at::Tensor part_grad_gamma = at::empty( {part_size,n2}, input->options().dtype(at::ScalarType::Float)); at::Tensor part_grad_beta = at::empty_like(part_grad_gamma); cuComputePartGradGammaBeta<<<blocks2, threads2, nshared2, stream>>>( dout, Loading @@ -770,7 +768,8 @@ void HostLayerNormGradient( } // compute grad_input const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; const dim3 blocks1(1, std::min((uint64_t)n1, maxGridY), 1); const dim3 threads1(32,4,1); int nshared = Loading @@ -788,6 +787,7 @@ void HostLayerNormGradient( grad_input); } void cuda_layer_norm_gradient( at::Tensor* dout, at::Tensor* mean, Loading @@ -808,22 +808,22 @@ void cuda_layer_norm_gradient( at::Tensor* grad_beta) { using namespace at; DISPATCH_FLOAT_AND_HALF(input->scalar_type(), 0, "cuComputeGradInput", using accscalar_t = at::acc_type<scalar_t_0, true>; using output_t = at::Half; DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES( input->scalar_type(), gamma->scalar_type(), "cuda_layer_norm_gradient_kernel", HostLayerNormGradient( dout->DATA_PTR<output_t>(), mean->DATA_PTR<accscalar_t>(), invvar->DATA_PTR<accscalar_t>(), dout->DATA_PTR<scalar_t_out>(), mean->DATA_PTR<float>(), invvar->DATA_PTR<float>(), input, n1,n2, // TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta // if gamma Tensor is NULL on input. gamma != NULL ? gamma->DATA_PTR<output_t>() : NULL, gamma != NULL ? beta->DATA_PTR<output_t>() : NULL, gamma != NULL ? gamma->DATA_PTR<scalar_t_out>() : NULL, gamma != NULL ? beta->DATA_PTR<scalar_t_out>() : NULL, epsilon, grad_input->DATA_PTR<scalar_t_0>(), gamma != NULL ? grad_gamma->DATA_PTR<output_t>() : NULL, gamma != NULL ? grad_beta->DATA_PTR<output_t>() : NULL); grad_input->DATA_PTR<scalar_t_in>(), gamma != NULL ? grad_gamma->DATA_PTR<scalar_t_out>() : NULL, gamma != NULL ? grad_beta->DATA_PTR<scalar_t_out>() : NULL); ) }
megatron/fused_kernels/scaled_masked_softmax.cpp +9 −6 Original line number Diff line number Diff line Loading @@ -37,8 +37,9 @@ torch::Tensor fwd( torch::Tensor const& mask, float scale_factor) { AT_ASSERTM(input.dim() == 4, "expected 4D tensor"); AT_ASSERTM(input.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || (input.scalar_type() == at::ScalarType::BFloat16), "Only fp16 and bf16 are supported"); AT_ASSERTM(mask.dim() == 4, "expected 4D tensor"); return fwd_cuda(input, mask, scale_factor); Loading @@ -52,10 +53,12 @@ torch::Tensor bwd( AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor"); AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor"); AT_ASSERTM(output_grads.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(softmax_results.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || (output_grads.scalar_type() == at::ScalarType::BFloat16), "Only fp16 and bf16 are supported"); AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || (softmax_results.scalar_type() == at::ScalarType::BFloat16), "Only fp16 and bf16 are supported"); return bwd_cuda(output_grads, softmax_results, scale_factor); } Loading