Commit a4803d8a authored by Adrian Kuegel's avatar Adrian Kuegel
Browse files

[mlir][Tosa] Fix Clamp verifier to handle quantized types.

parent 7c7896b1
Loading
Loading
Loading
Loading
+8 −0
Original line number Diff line number Diff line
@@ -312,10 +312,18 @@ LogicalResult tosa::AvgPool2dOp::verify() {
LogicalResult tosa::ClampOp::verify() {
  mlir::Type inputETy =
      llvm::cast<ShapedType>(getInput().getType()).getElementType();
  if (auto quantType =
          llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy)) {
    inputETy = quantType.getStorageType();
  }
  mlir::Type maxFpType = getMaxFpAttr().getType();
  mlir::Type minFpType = getMinFpAttr().getType();
  mlir::Type outputETy =
      llvm::cast<ShapedType>(getOutput().getType()).getElementType();
  if (auto quantType =
          llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputETy)) {
    outputETy = quantType.getStorageType();
  }
  unsigned dataTypeBitWidth = inputETy.getIntOrFloatBitWidth();

  if (inputETy != outputETy)
+7 −0
Original line number Diff line number Diff line
@@ -152,6 +152,13 @@ func.func @test_clamp_bf16(%arg0: tensor<13x21x3xbf16>) -> tensor<13x21x3xbf16>
  return %0 : tensor<13x21x3xbf16>
}

// -----
// CHECK-LABEL: clamp_quantized
func.func @test_clamp_quantized(%arg0: tensor<13x21x3x!quant.uniform<i8:f32, 1.000000e-01:-127>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 1.000000e-01:-127>> {
  %0 = tosa.clamp %arg0 {min_fp = 0.0 : f32, max_fp = 1.0: f32, min_int = 0 : i64, max_int = 1 : i64} : (tensor<13x21x3x!quant.uniform<i8:f32, 1.000000e-01:-127>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 1.000000e-01:-127>>
  return %0 : tensor<13x21x3x!quant.uniform<i8:f32, 1.000000e-01:-127>>
}

// -----
// CHECK-LABEL: sigmoid
func.func @test_sigmoid(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {