Commit fd11cda2 authored by Pierre Oechsel's avatar Pierre Oechsel Committed by Alex Zinenko
Browse files

[mlir] StdToLLVM: Add error when the sourceMemRef of a subview is not a llvm type.

A memref_cast casting to a memref with a non identity map can't be
lowered to llvm. Take the following case:

```

func @invalid_memref_cast(%arg0: memref<?x?xf64>) {
  %c1 = constant 1 : index
  %c0 = constant 0 : index
  %5 = memref_cast %arg0 : memref<?x?xf64> to memref<?x?xf64, #map1>
  %25 = std.subview %5[%c0, %c0][%c1, %c1][] : memref<?x?xf64, #map1> to memref<?x?xf64, #map1>
  return
}
```

When lowering the subview mlir was assuming `%5` to have an llvm type
(which is not the case as mlir failed to lower the memref_cast).

Differential Revision: https://reviews.llvm.org/D74466
parent a19de320
Loading
Loading
Loading
Loading
+4 −1
Original line number Diff line number Diff line
@@ -385,7 +385,8 @@ LLVMOpLowering::LLVMOpLowering(StringRef rootOpName, MLIRContext *context,
/*============================================================================*/
StructBuilder::StructBuilder(Value v) : value(v) {
  assert(value != nullptr && "value cannot be null");
  structType = value.getType().cast<LLVM::LLVMType>();
  structType = value.getType().dyn_cast<LLVM::LLVMType>();
  assert(structType && "expected llvm type");
}

Value StructBuilder::extractPtr(OpBuilder &builder, Location loc,
@@ -2303,6 +2304,8 @@ struct SubViewOpLowering : public LLVMLegalizationPattern<SubViewOp> {
      return matchFailure();

    // Create the descriptor.
    if (!operands.front().getType().isa<LLVM::LLVMType>())
      return matchFailure();
    MemRefDescriptor sourceMemRef(operands.front());
    auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);

+13 −0
Original line number Diff line number Diff line
// RUN: mlir-opt %s -verify-diagnostics -split-input-file

#map1 = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>

func @invalid_memref_cast(%arg0: memref<?x?xf64>) {
  %c1 = constant 1 : index
  %c0 = constant 0 : index
  // expected-error@+1: 'std.memref_cast' op operand #0 must be unranked.memref of any type values or memref of any type values,
  %5 = memref_cast %arg0 : memref<?x?xf64> to memref<?x?xf64, #map1>
  %25 = std.subview %5[%c0, %c0][%c1, %c1][] : memref<?x?xf64, #map1> to memref<?x?xf64, #map1>
  return
}