Unverified Commit 3c07a216 authored by Finn Plummer's avatar Finn Plummer Committed by GitHub
Browse files

[mlir][index][spirv] Add conversion for index to spirv (#68085)

Due to an issue when lowering from scf to spirv as there was no
conversion pass for index to spirv, we are motivated to add a conversion
pass from the Index dialect to the SPIR-V dialect. Furthermore, we add
the new conversion patterns to the scf-to-spirv conversion.

Fixes #63713
parent 1bc42666
Loading
Loading
Loading
Loading
+30 −0
Original line number Diff line number Diff line
//===- IndexToSPIRV.h - Index to SPIRV dialect conversion -------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_CONVERSION_INDEXTOSPIRV_INDEXTOSPIRV_H
#define MLIR_CONVERSION_INDEXTOSPIRV_INDEXTOSPIRV_H

#include "mlir/Pass/Pass.h"
#include <memory>

namespace mlir {
class RewritePatternSet;
class SPIRVTypeConverter;
class Pass;

#define GEN_PASS_DECL_CONVERTINDEXTOSPIRVPASS
#include "mlir/Conversion/Passes.h.inc"

namespace index {
void populateIndexToSPIRVPatterns(SPIRVTypeConverter &converter,
                                  RewritePatternSet &patterns);
std::unique_ptr<OperationPass<>> createConvertIndexToSPIRVPass();
} // namespace index
} // namespace mlir

#endif // MLIR_CONVERSION_INDEXTOSPIRV_INDEXTOSPIRV_H
+1 −0
Original line number Diff line number Diff line
@@ -35,6 +35,7 @@
#include "mlir/Conversion/GPUToSPIRV/GPUToSPIRVPass.h"
#include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h"
#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h"
#include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h"
#include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h"
#include "mlir/Conversion/MathToFuncs/MathToFuncs.h"
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
+22 −0
Original line number Diff line number Diff line
@@ -644,6 +644,28 @@ def ConvertIndexToLLVMPass : Pass<"convert-index-to-llvm"> {
  ];
}

//===----------------------------------------------------------------------===//
// ConvertIndexToSPIRVPass
//===----------------------------------------------------------------------===//

def ConvertIndexToSPIRVPass : Pass<"convert-index-to-spirv"> {
  let summary = "Lower the `index` dialect to the `spirv` dialect.";
  let description = [{
    This pass lowers Index dialect operations to SPIR-V dialect operations.
    Operation conversions are 1-to-1 except for the exotic divides: `ceildivs`,
    `ceildivu`, and `floordivs`. The index bitwidth will be 32 or 64 as
    specified by use-64bit-index.
  }];

  let dependentDialects = ["::mlir::spirv::SPIRVDialect"];

  let options = [
    Option<"use64bitIndex", "use-64bit-index",
           "bool", /*default=*/"false",
           "Use 64-bit integers to convert index types">
  ];
}

//===----------------------------------------------------------------------===//
// LinalgToStandard
//===----------------------------------------------------------------------===//
+8 −3
Original line number Diff line number Diff line
@@ -55,13 +55,13 @@ struct SPIRVConversionOptions {
  /// values will be packed into one 32-bit value to be memory efficient.
  bool emulateLT32BitScalarTypes{true};

  /// Use 64-bit integers to convert index types.
  bool use64bitIndex{false};

  /// Whether to enable fast math mode during conversion. If true, various
  /// patterns would assume no NaN/infinity numbers as inputs, and thus there
  /// will be no special guards emitted to check and handle such cases.
  bool enableFastMathMode{false};

  /// Use 64-bit integers when converting index types.
  bool use64bitIndex{false};
};

/// Type conversion from builtin types to SPIR-V types for shader interface.
@@ -77,6 +77,11 @@ public:
  /// Gets the SPIR-V correspondence for the standard index type.
  Type getIndexType() const;

  /// Gets the bitwidth of the index type when converted to SPIR-V.
  unsigned getIndexTypeBitwidth() const {
    return options.use64bitIndex ? 64 : 32;
  }

  const spirv::TargetEnv &getTargetEnv() const { return targetEnv; }

  /// Returns the options controlling the SPIR-V type converter.
+1 −0
Original line number Diff line number Diff line
@@ -24,6 +24,7 @@ add_subdirectory(GPUToROCDL)
add_subdirectory(GPUToSPIRV)
add_subdirectory(GPUToVulkan)
add_subdirectory(IndexToLLVM)
add_subdirectory(IndexToSPIRV)
add_subdirectory(LinalgToStandard)
add_subdirectory(LLVMCommon)
add_subdirectory(MathToFuncs)
Loading