Unverified Commit 783b4d91 authored by Sarthak Gupta's avatar Sarthak Gupta Committed by GitHub
Browse files

[mlir][tosa] Check for 0-ranked-tensors during fold (#68512)

Fixes https://github.com/llvm/llvm-project/issues/67761
Trying `getDimSize()` before checking for 0-ranked-tensors throws assert
errors. This PR ensures that it is checked for.
Or should we throw an error if we have a 0-ranked-tensor in a tosa
operation?
parent a4803d8a
Loading
Loading
Loading
Loading
+3 −2
Original line number Diff line number Diff line
@@ -771,7 +771,7 @@ OpFoldResult ConstOp::fold(FoldAdaptor adaptor) { return getValueAttr(); }
    ShapedType inputTy = llvm::cast<ShapedType>(getInput().getType());         \
    if (!inputTy.hasRank())                                                    \
      return {};                                                               \
    if (inputTy.getDimSize(getAxis()) == 1)                                    \
    if (inputTy.getRank() == 0 || inputTy.getDimSize(getAxis()) == 1)          \
      return getInput();                                                       \
    return {};                                                                 \
  }
@@ -874,7 +874,8 @@ OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) {
    return operandAttr;

  // If the dim-length is 1, tosa.reverse is a no-op.
  if (operandTy.hasRank() && operandTy.getDimSize(axis) == 1)
  if (operandTy.hasRank() &&
      (operandTy.getRank() == 0 || operandTy.getDimSize(axis) == 1))
    return operand;

  return {};
+1 −1
Original line number Diff line number Diff line
@@ -1109,7 +1109,7 @@ LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
static LogicalResult ReduceInferReturnTypes(
    ShapeAdaptor operandShape, Type inputType, IntegerAttr axis,
    SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
  if (!operandShape.hasRank()) {
  if (!operandShape.hasRank() || operandShape.getRank() == 0) {
    inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
    return success();
  }
+12 −0
Original line number Diff line number Diff line
@@ -591,3 +591,15 @@ func.func @fold_abs_abs(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
  %1 = tosa.abs %0 : (tensor<?x1xf32>) -> tensor<?x1xf32>
  return %1 : tensor<?x1xf32>
}

// -----

// CHECK-LABEL: @fold_reduce_rank_zero
func.func nested @fold_reduce_rank_zero() {
  // CHECK-NOT: tosa.reduce_min
  // CHECK-NOT: tosa.reverse
  %0 = tensor.empty() : tensor<i32>
  %1 = tosa.reduce_min %0 {axis = 0 : i32} : (tensor<i32>) -> tensor<1x10xi32>
  %2 = tosa.reverse %0 {axis = 0 : i32} : (tensor<i32>) -> tensor<1x10xi32>
  return
}