Commit 6fe77b10 authored by Nicolas Vasilache's avatar Nicolas Vasilache
Browse files

[mlir][Linalg] Fail comprehensive bufferization if a memref is returned.

Summary:

Reviewers:

Subscribers:

Differential revision: https://reviews.llvm.org/D109824
parent 446e11fa
Loading
Loading
Loading
Loading
+4 −1
Original line number Diff line number Diff line
@@ -37,7 +37,10 @@ def LinalgComprehensiveModuleBufferize :
  let options = [
    Option<"testAnalysisOnly", "test-analysis-only", "bool",
            /*default=*/"false",
           "Only runs inplaceability analysis (for testing purposes only)">
           "Only runs inplaceability analysis (for testing purposes only)">,
    Option<"allowReturnMemref", "allow-return-memref", "bool",
            /*default=*/"false",
           "Allows the return of memrefs (for testing purposes only)">
  ];
  let constructor = "mlir::createLinalgComprehensiveModuleBufferizePass()";
}
+8 −0
Original line number Diff line number Diff line
@@ -2914,6 +2914,14 @@ void LinalgComprehensiveModuleBufferize::runOnOperation() {
      signalPassFailure();
      return;
    }
    if (!allowReturnMemref &&
        llvm::any_of(funcOp.getType().getResults(), [](Type t) {
          return t.isa<MemRefType, UnrankedMemRefType>();
        })) {
      funcOp->emitError("memref return type is unsupported");
      signalPassFailure();
      return;
    }
  }

  // Perform a post-processing pass of layout modification at function boundary
+10 −0
Original line number Diff line number Diff line
@@ -130,3 +130,13 @@ func @unknown_op(%A : tensor<4xf32>) -> tensor<4xf32>
  %r = "marklar"(%A) : (tensor<4xf32>) -> (tensor<4xf32>)
  return %r: tensor<4xf32>
}

// -----

// expected-error @+1 {{memref return type is unsupported}}
func @mini_test_case1() -> tensor<10x20xf32> {
  %f0 = constant 0.0 : f32
  %t = linalg.init_tensor [10, 20] : tensor<10x20xf32>
  %r = linalg.fill(%f0, %t) : f32, tensor<10x20xf32> -> tensor<10x20xf32>
  return %r : tensor<10x20xf32>
}
+1 −1
Original line number Diff line number Diff line
// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize -split-input-file | FileCheck %s
// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize=allow-return-memref -split-input-file | FileCheck %s

// CHECK-LABEL: func @transfer_read(%{{.*}}: memref<?xf32, #map>) -> vector<4xf32> {
func @transfer_read(%A : tensor<?xf32>) -> (vector<4xf32>) {