Commit 8662a2f2 authored by Rob Suderman's avatar Rob Suderman
Browse files

[mlir][tosa] Relax ranked constraint on quantization builder

TosaOp defintion had an artificial constraint that the input/output types
needed to be ranked to invoke the quantization builder. This is correct as an
unranked tensor could still be quantized.

Reviewed By: NatashaKnk

Differential Revision: https://reviews.llvm.org/D109863
parent e03c7e36
Loading
Loading
Loading
Loading
+5 −7
Original line number Diff line number Diff line
@@ -350,8 +350,8 @@ static void buildMatMulOpWithQuantInfo(OpBuilder &builder,
  if (quantAttr) {
    result.addAttribute("quantization_info", quantAttr);

    auto inputType = a.getType().dyn_cast<RankedTensorType>();
    assert(inputType && "Input must be a ranked tensor type!");
    auto inputType = a.getType().dyn_cast<ShapedType>();
    assert(inputType && "Input must be a shaped tensor type!");

    auto inputQType = inputType.getElementType()
                          .dyn_cast<mlir::quant::UniformQuantizedType>();
@@ -359,17 +359,15 @@ static void buildMatMulOpWithQuantInfo(OpBuilder &builder,

    unsigned inputBits = inputQType.getStorageTypeIntegralWidth();

    auto outputShapedType = outputType.dyn_cast<RankedTensorType>();
    assert(outputShapedType && "Output must be a ranked tensor type");

    auto outputShape = outputShapedType.getShape();
    auto outputShapedType = outputType.dyn_cast<ShapedType>();
    assert(outputShapedType && "Output must be a shaped type");

    IntegerType accElementType;
    if (inputBits == 16)
      accElementType = builder.getIntegerType(48);
    else
      accElementType = builder.getI32Type();
    auto accType = RankedTensorType::get(outputShape, accElementType);
    auto accType = outputShapedType.clone(accElementType);
    result.addTypes(accType);
  } else {
    result.addTypes(outputType);
+11 −13
Original line number Diff line number Diff line
@@ -102,8 +102,8 @@ ConvOpQuantizationAttr
mlir::tosa::buildConvOpQuantizationAttr(OpBuilder &builder, Value input,
                                        Value weight) {

  auto inputType = input.getType().dyn_cast<RankedTensorType>();
  auto weightType = weight.getType().dyn_cast<RankedTensorType>();
  auto inputType = input.getType().dyn_cast<ShapedType>();
  auto weightType = weight.getType().dyn_cast<ShapedType>();

  if (!inputType || !weightType)
    return nullptr;
@@ -151,8 +151,8 @@ MatMulOpQuantizationAttr
mlir::tosa::buildMatMulOpQuantizationAttr(OpBuilder &builder, Value a,
                                          Value b) {

  auto aType = a.getType().dyn_cast<RankedTensorType>();
  auto bType = b.getType().dyn_cast<RankedTensorType>();
  auto aType = a.getType().dyn_cast<ShapedType>();
  auto bType = b.getType().dyn_cast<ShapedType>();

  if (!aType || !bType)
    return nullptr;
@@ -187,8 +187,8 @@ UnaryOpQuantizationAttr
mlir::tosa::buildUnaryOpQuantizationAttr(OpBuilder &builder, Value input,
                                         Type outputRawType) {

  auto inputType = input.getType().dyn_cast<RankedTensorType>();
  auto outputType = outputRawType.dyn_cast<RankedTensorType>();
  auto inputType = input.getType().dyn_cast<ShapedType>();
  auto outputType = outputRawType.dyn_cast<ShapedType>();

  if (!inputType || !outputType)
    return nullptr;
@@ -220,7 +220,7 @@ mlir::tosa::buildUnaryOpQuantizationAttr(OpBuilder &builder, Value input,
PadOpQuantizationAttr mlir::tosa::buildPadOpQuantizationAttr(OpBuilder &builder,
                                                             Value input) {

  auto inputType = input.getType().dyn_cast<RankedTensorType>();
  auto inputType = input.getType().dyn_cast<ShapedType>();

  if (!inputType)
    return nullptr;
@@ -245,8 +245,8 @@ PadOpQuantizationAttr mlir::tosa::buildPadOpQuantizationAttr(OpBuilder &builder,
Type mlir::tosa::buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType,
                                           Value input, Value weight) {

  auto inputType = input.getType().dyn_cast<RankedTensorType>();
  auto weightType = weight.getType().dyn_cast<RankedTensorType>();
  auto inputType = input.getType().dyn_cast<ShapedType>();
  auto weightType = weight.getType().dyn_cast<ShapedType>();

  assert(inputType && weightType &&
         "Could not extract input or weight tensors from Conv op");
@@ -260,18 +260,16 @@ Type mlir::tosa::buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType,
  unsigned inputBits = inputQType.getStorageTypeIntegralWidth();
  unsigned weightBits = weightQType.getStorageTypeIntegralWidth();

  auto outputShapedType = outputType.dyn_cast<RankedTensorType>();
  auto outputShapedType = outputType.dyn_cast<ShapedType>();
  assert(outputShapedType &&
         "Could not extract output shape type from Conv op");

  auto outputShape = outputShapedType.getShape();

  IntegerType accElementType;
  if (inputBits == 16 && weightBits == 8)
    accElementType = builder.getIntegerType(48);
  else
    accElementType = builder.getI32Type();
  auto accType = RankedTensorType::get(outputShape, accElementType);
  auto accType = outputShapedType.clone(accElementType);
  return accType;
}