Unverified Commit 3651f377 authored by Sarthak Gupta's avatar Sarthak Gupta Committed by GitHub
Browse files

[mlir][tosa] Check for unranked tensors during validation (#68509)

Fixes https://github.com/llvm/llvm-project/issues/67760
`levelCheckRank` ensures that the tensors for tosa operations are not
unranked

During tosa validation in `levelCheckRank`, we were trying to get the
rank of a tensor without checking if it is ranked or unranked, which
leads to an `assert` error. I see two ways to fix this:

- Only check `type.getRank() > tosa_level.MAX_RANK` if the tensor is
ranked, and then proceed as usual.
(like `if (type.hasRank() && type.getRank() > tosa_level.MAX_RANK)` , OR
- Throw an error for unranked tensors as result.
parent 5b189d6f
Loading
Loading
Loading
Loading
+4 −0
Original line number Diff line number Diff line
@@ -156,6 +156,10 @@ private:
  bool levelCheckRank(Operation *op, const Value &v,
                      const std::string &checkDesc) {
    if (ShapedType type = dyn_cast<ShapedType>(v.getType())) {
      if (!type.hasRank()) {
        op->emitOpError() << "failed level check: unranked tensor";
        return false;
      }
      if (type.getRank() > tosaLevel.MAX_RANK) {
        op->emitOpError() << "failed level check: " << checkDesc;
        return false;
+9 −0
Original line number Diff line number Diff line
@@ -695,4 +695,13 @@ func.func @test_custom(%arg0: tensor<1x1x1x1x1x1x10xi32>) -> tensor<1x1x1x1x1x1x
  return %0 : tensor<1x1x1x1x1x1x10xi32>
}

// -----

// CHECK-LABEL: unranked_tensor
func.func @test_unranked_tensor(%arg0: tensor<*xf32>) {
  // expected-error@+1 {{'tosa.slice' op failed level check: unranked tensor}}
  %0 = "tosa.slice"(%arg0) {start = array<i64>, size = array<i64>} :
          (tensor<*xf32>) -> tensor<*xf32>
  return
}