Commit 47c6ab2b authored by Lei Zhang's avatar Lei Zhang
Browse files

[mlir][spirv] Properly support SPIR-V conversion target

This commit defines a new SPIR-V dialect attribute for specifying
a SPIR-V target environment. It is a dictionary attribute containing
the SPIR-V version, supported extension list, and allowed capability
list. A SPIRVConversionTarget subclass is created to take in the
target environment and sets proper dynmaically legal ops by querying
the op availability interface of SPIR-V ops to make sure they are
available in the specified target environment. All existing conversions
targeting SPIR-V is changed to use this SPIRVConversionTarget. It
probes whether the input IR has a `spv.target_env` attribute,
otherwise, it uses the default target environment: SPIR-V 1.0 with
Shader capability and no extra extensions.

Differential Revision: https://reviews.llvm.org/D72256
parent 60d39479
Loading
Loading
Loading
Loading
+36 −8
Original line number Diff line number Diff line
@@ -725,6 +725,28 @@ func @foo() -> () {
}
```

## Target environment

SPIR-V aims to support multiple execution environments as specified by client
APIs. These execution environments affect the availability of certain SPIR-V
features. For example, a [Vulkan 1.1][VulkanSpirv] implementation must support
the 1.0, 1.1, 1.2, and 1.3 versions of SPIR-V and the 1.0 version of the SPIR-V
extended instructions for GLSL. Further Vulkan extensions may enable more SPIR-V
instructions.

SPIR-V compilation should also take into consideration of the execution
environment, so we generate SPIR-V modules valid for the target environment.
This is conveyed by the `spv.target_env` attribute. It is a triple of

*   `version`: a 32-bit integer indicating the target SPIR-V version.
*   `extensions`: a string array attribute containing allowed extensions.
*   `capabilities`: a 32-bit integer array attribute containing allowed
    capabilities.

Dialect conversion framework will utilize the information in `spv.target_env`
to properly filter out patterns and ops not available in the target execution
environment.

## Shader interface (ABI)

SPIR-V itself is just expressing computation happening on GPU device. SPIR-V
@@ -852,12 +874,18 @@ classes are provided.
additional rules are imposed by [Vulkan execution environment][VulkanSpirv]. The
lowering described below implements both these requirements.)

### `SPIRVConversionTarget`

The `mlir::spirv::SPIRVConversionTarget` class derives from the
`mlir::ConversionTarget` class and serves as a utility to define a conversion
target satisfying a given [`spv.target_env`](#target-environment). It registers
proper hooks to check the dynamic legality of SPIR-V ops. Users can further
register other legality constraints into the returned `SPIRVConversionTarget`.

### SPIRVTypeConverter
### `SPIRVTypeConverter`

The `mlir::spirv::SPIRVTypeConverter` derives from
`mlir::TypeConverter` and provides type conversion for standard
types to SPIR-V types:
The `mlir::SPIRVTypeConverter` derives from `mlir::TypeConverter` and provides
type conversion for standard types to SPIR-V types:

*   [Standard Integer][MlirIntegerType] -> Standard Integer
*   [Standard Float][MlirFloatType] -> Standard Float
@@ -874,11 +902,11 @@ supported in SPIR-V. Currently the `index` type is converted to `i32`.
(TODO: Allow for configuring the integer width to use for `index` types in the
SPIR-V dialect)

### SPIRVOpLowering
### `SPIRVOpLowering`

`mlir::spirv::SPIRVOpLowering` is a base class that can be used to define the
patterns used for implementing the lowering. For now this only provides derived
classes access to an instance of `mlir::spirv::SPIRVTypeLowering` class.
`mlir::SPIRVOpLowering` is a base class that can be used to define the patterns
used for implementing the lowering. For now this only provides derived classes
access to an instance of `mlir::SPIRVTypeLowering` class.

### Utility functions for lowering

+26 −1
Original line number Diff line number Diff line
@@ -13,8 +13,10 @@
#ifndef MLIR_DIALECT_SPIRV_SPIRVLOWERING_H
#define MLIR_DIALECT_SPIRV_SPIRVLOWERING_H

#include "mlir/Dialect/SPIRV/SPIRVTypes.h"
#include "mlir/Dialect/SPIRV/TargetAndABI.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallSet.h"

namespace mlir {

@@ -48,7 +50,30 @@ protected:
};

namespace spirv {
enum class BuiltIn : uint32_t;
class SPIRVConversionTarget : public ConversionTarget {
public:
  /// Creates a SPIR-V conversion target for the given target environment.
  static std::unique_ptr<SPIRVConversionTarget> get(TargetEnvAttr targetEnv,
                                                    MLIRContext *context);

private:
  SPIRVConversionTarget(TargetEnvAttr targetEnv, MLIRContext *context);

  // Be explicit that instance of this class cannot be copied or moved: there
  // are lambdas capturing fields of the instance.
  SPIRVConversionTarget(const SPIRVConversionTarget &) = delete;
  SPIRVConversionTarget(SPIRVConversionTarget &&) = delete;
  SPIRVConversionTarget &operator=(const SPIRVConversionTarget &) = delete;
  SPIRVConversionTarget &operator=(SPIRVConversionTarget &&) = delete;

  /// Returns true if the given `op` is legal to use under the current target
  /// environment.
  bool isLegalOp(Operation *op);

  Version givenVersion;                            /// SPIR-V version to target
  llvm::SmallSet<Extension, 4> givenExtensions;    /// Allowed extensions
  llvm::SmallSet<Capability, 8> givenCapabilities; /// Allowed capabilities
};

/// Returns a value that represents a builtin variable value within the SPIR-V
/// module.
+16 −4
Original line number Diff line number Diff line
@@ -27,21 +27,33 @@ class Value;
namespace spirv {
enum class StorageClass : uint32_t;

/// Attribute name for specifying argument ABI information.
/// Returns the attribute name for specifying argument ABI information.
StringRef getInterfaceVarABIAttrName();

/// Get the InterfaceVarABIAttr given its fields.
/// Gets the InterfaceVarABIAttr given its fields.
InterfaceVarABIAttr getInterfaceVarABIAttr(unsigned descriptorSet,
                                           unsigned binding,
                                           StorageClass storageClass,
                                           MLIRContext *context);

/// Attribute name for specifying entry point information.
/// Returns the attribute name for specifying entry point information.
StringRef getEntryPointABIAttrName();

/// Get the EntryPointABIAttr given its fields.
/// Gets the EntryPointABIAttr given its fields.
EntryPointABIAttr getEntryPointABIAttr(ArrayRef<int32_t> localSize,
                                       MLIRContext *context);

/// Returns the attribute name for specifying SPIR-V target environment.
StringRef getTargetEnvAttrName();

/// Returns the default target environment: SPIR-V 1.0 with Shader capability
/// and no extra extensions.
TargetEnvAttr getDefaultTargetEnv(MLIRContext *context);

/// Queries the target environment from the given `op` or returns the default
/// target environment (SPIR-V 1.0 with Shader capability and no extra
/// extensions) if not provided.
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op);
} // namespace spirv
} // namespace mlir

+35 −22
Original line number Diff line number Diff line
//===- SPIRVBase.td - MLIR SPIR-V Op Definitions Base file -*- tablegen -*-===//
//===- TargetAndABI.td - SPIR-V Target and ABI definitions -*- tablegen -*-===//
//
// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,21 +6,20 @@
//
//===----------------------------------------------------------------------===//
//
// This is the base file for supporting lowering to SPIR-V dialect. This
// file defines SPIR-V attributes used for specifying the shader
// interface or ABI. This is because SPIR-V module is expected to work in
// an execution environment as specified by a client API. A SPIR-V module
// needs to "link" correctly with the execution environment regarding the
// resources that are used in the SPIR-V module and get populated with
// data via the client API. The shader interface (or ABI) is passed into
// SPIR-V lowering path via attributes defined in this file. A
// compilation flow targeting SPIR-V is expected to attach such
// This is the base file for supporting lowering to SPIR-V dialect. This file
// defines SPIR-V attributes used for specifying the shader interface or ABI.
// This is because SPIR-V module is expected to work in an execution environment
// as specified by a client API. A SPIR-V module needs to "link" correctly with
// the execution environment regarding the resources that are used in the SPIR-V
// module and get populated with data via the client API. The shader interface
// (or ABI) is passed into SPIR-V lowering path via attributes defined in this
// file. A compilation flow targeting SPIR-V is expected to attach such
// attributes to resources and other suitable places.
//
//===----------------------------------------------------------------------===//

#ifndef SPIRV_LOWERING
#define SPIRV_LOWERING
#ifndef SPIRV_TARGET_AND_ABI
#define SPIRV_TARGET_AND_ABI

include "mlir/Dialect/SPIRV/SPIRVBase.td"

@@ -30,17 +29,31 @@ include "mlir/Dialect/SPIRV/SPIRVBase.td"
// 1) Descriptor Set.
// 2) Binding number.
// 3) Storage class.
def SPV_InterfaceVarABIAttr:
    StructAttr<"InterfaceVarABIAttr", SPV_Dialect,
               [StructFieldAttr<"descriptor_set", I32Attr>,
def SPV_InterfaceVarABIAttr : StructAttr<"InterfaceVarABIAttr", SPV_Dialect, [
    StructFieldAttr<"descriptor_set", I32Attr>,
    StructFieldAttr<"binding", I32Attr>,
                StructFieldAttr<"storage_class", SPV_StorageClassAttr>]>;
    StructFieldAttr<"storage_class", SPV_StorageClassAttr>
]>;

// For entry functions, this attribute specifies information related to entry
// points in the generated SPIR-V module:
// 1) WorkGroup Size.
def SPV_EntryPointABIAttr:
    StructAttr<"EntryPointABIAttr", SPV_Dialect,
               [StructFieldAttr<"local_size", I32ElementsAttr>]>;
def SPV_EntryPointABIAttr : StructAttr<"EntryPointABIAttr", SPV_Dialect, [
    StructFieldAttr<"local_size", I32ElementsAttr>
]>;

#endif // SPIRV_LOWERING
def SPV_ExtensionArrayAttr : TypedArrayAttrBase<
    SPV_ExtensionAttr, "SPIR-V extension array attribute">;

def SPV_CapabilityArrayAttr : TypedArrayAttrBase<
    SPV_CapabilityAttr, "SPIR-V capability array attribute">;

// For the generated SPIR-V module, this attribute specifies the target version,
// allowed extensions and capabilities.
def SPV_TargetEnvAttr : StructAttr<"TargetEnvAttr", SPV_Dialect, [
    StructFieldAttr<"version", SPV_VersionAttr>,
    StructFieldAttr<"extensions", SPV_ExtensionArrayAttr>,
    StructFieldAttr<"capabilities", SPV_CapabilityArrayAttr>
]>;

#endif // SPIRV_TARGET_AND_ABI
+6 −6
Original line number Diff line number Diff line
@@ -55,8 +55,8 @@ private:
} // namespace

void GPUToSPIRVPass::runOnModule() {
  auto context = &getContext();
  auto module = getModule();
  MLIRContext *context = &getContext();
  ModuleOp module = getModule();

  SmallVector<Operation *, 1> kernelModules;
  OpBuilder builder(context);
@@ -73,12 +73,12 @@ void GPUToSPIRVPass::runOnModule() {
  populateGPUToSPIRVPatterns(context, typeConverter, patterns, workGroupSize);
  populateStandardToSPIRVPatterns(context, typeConverter, patterns);

  ConversionTarget target(*context);
  target.addLegalDialect<spirv::SPIRVDialect>();
  target.addDynamicallyLegalOp<FuncOp>(
  std::unique_ptr<ConversionTarget> target = spirv::SPIRVConversionTarget::get(
      spirv::lookupTargetEnvOrDefault(module), context);
  target->addDynamicallyLegalOp<FuncOp>(
      [&](FuncOp op) { return typeConverter.isSignatureLegal(op.getType()); });

  if (failed(applyFullConversion(kernelModules, target, patterns,
  if (failed(applyFullConversion(kernelModules, *target, patterns,
                                 &typeConverter))) {
    return signalPassFailure();
  }
Loading