Commit d4319513 authored by George Mitenkov's avatar George Mitenkov
Browse files

[MLIR][SPIRVToLLVM] SPIRV function fix and nits

This patch addresses the comments from https://reviews.llvm.org/D83030 and
https://reviews.llvm.org/D82639. `this->` is removed when not inside the
template. Also, type conversion for `spv.func` takes `convertRegionTypes()`
in order to apply type conversion on all blocks within the function.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D83786
parent 1919c8bf
Loading
Loading
Loading
Loading
+11 −7
Original line number Diff line number Diff line
@@ -83,11 +83,12 @@ static IntegerAttr minusOneIntegerAttribute(Type type, Builder builder) {
/// Creates `llvm.mlir.constant` with all bits set for the given type.
static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType,
                                      PatternRewriter &rewriter) {
  if (srcType.isa<VectorType>())
  if (srcType.isa<VectorType>()) {
    return rewriter.create<LLVM::ConstantOp>(
        loc, dstType,
        SplatElementsAttr::get(srcType.cast<ShapedType>(),
                               minusOneIntegerAttribute(srcType, rewriter)));
  }
  return rewriter.create<LLVM::ConstantOp>(
      loc, dstType, minusOneIntegerAttribute(srcType, rewriter));
}
@@ -239,7 +240,7 @@ public:
  matchAndRewrite(spirv::BitFieldInsertOp op, ArrayRef<Value> operands,
                  ConversionPatternRewriter &rewriter) const override {
    auto srcType = op.getType();
    auto dstType = this->typeConverter.convertType(srcType);
    auto dstType = typeConverter.convertType(srcType);
    if (!dstType)
      return failure();
    Location loc = op.getLoc();
@@ -328,7 +329,7 @@ public:
  matchAndRewrite(spirv::BitFieldSExtractOp op, ArrayRef<Value> operands,
                  ConversionPatternRewriter &rewriter) const override {
    auto srcType = op.getType();
    auto dstType = this->typeConverter.convertType(srcType);
    auto dstType = typeConverter.convertType(srcType);
    if (!dstType)
      return failure();
    Location loc = op.getLoc();
@@ -381,7 +382,7 @@ public:
  matchAndRewrite(spirv::BitFieldUExtractOp op, ArrayRef<Value> operands,
                  ConversionPatternRewriter &rewriter) const override {
    auto srcType = op.getType();
    auto dstType = this->typeConverter.convertType(srcType);
    auto dstType = typeConverter.convertType(srcType);
    if (!dstType)
      return failure();
    Location loc = op.getLoc();
@@ -473,7 +474,7 @@ public:
    }

    // Function returns a single result.
    auto dstType = this->typeConverter.convertType(callOp.getType(0));
    auto dstType = typeConverter.convertType(callOp.getType(0));
    rewriter.replaceOpWithNewOp<LLVM::CallOp>(callOp, dstType, operands,
                                              callOp.getAttrs());
    return success();
@@ -638,7 +639,7 @@ public:
    auto funcType = funcOp.getType();
    TypeConverter::SignatureConversion signatureConverter(
        funcType.getNumInputs());
    auto llvmType = this->typeConverter.convertFunctionSignature(
    auto llvmType = typeConverter.convertFunctionSignature(
        funcOp.getType(), /*isVariadic=*/false, signatureConverter);
    if (!llvmType)
      return failure();
@@ -675,7 +676,10 @@ public:

    rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
                                newFuncOp.end());
    rewriter.applySignatureConversion(&newFuncOp.getBody(), signatureConverter);
    if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter,
                                           &signatureConverter))) {
      return failure();
    }
    rewriter.eraseOp(funcOp);
    return success();
  }