Commit 143edeca authored by Rob Suderman's avatar Rob Suderman
Browse files

[mlir][tosa] Shape inference for a few remaining easy cases:

Handles shape inference for identity, cast, and rescale. These were missed
during the initialy elementwise work. This includes resize shape propagation
which includes both attribute and input type based propagation.

Reviewed By: jpienaar

Differential Revision: https://reviews.llvm.org/D105845
parent 2d9759c7
Loading
Loading
Loading
Loading
+13 −4
Original line number Diff line number Diff line
@@ -1582,7 +1582,10 @@ def Tosa_ScatterOp : Tosa_Op<"scatter", [
//===----------------------------------------------------------------------===//
// Operator: resize
//===----------------------------------------------------------------------===//
def Tosa_ResizeOp : Tosa_Op<"resize", [NoSideEffect]> {
def Tosa_ResizeOp : Tosa_Op<"resize", [
      DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
                              ["inferReturnTypeComponents"]>,
      NoSideEffect]> {

  let summary = "Resize operation, supports various resize/upsample modes";

@@ -1617,7 +1620,9 @@ def Tosa_ResizeOp : Tosa_Op<"resize", [NoSideEffect]> {
//===----------------------------------------------------------------------===//
// Operator: cast
//===----------------------------------------------------------------------===//
def Tosa_CastOp: Tosa_Op<"cast", [NoSideEffect]> {
def Tosa_CastOp: Tosa_Op<"cast", [NoSideEffect,
      DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
                              ["inferReturnTypeComponents"]>]> {

  let summary = "Cast operation";

@@ -1655,7 +1660,9 @@ def Tosa_CastOp: Tosa_Op<"cast", [NoSideEffect]> {
//===----------------------------------------------------------------------===//
// Operator: rescale
//===----------------------------------------------------------------------===//
def Tosa_RescaleOp: Tosa_Op<"rescale", [NoSideEffect]> {
def Tosa_RescaleOp: Tosa_Op<"rescale", [NoSideEffect, 
      DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
                              ["inferReturnTypeComponents"]>]> {
  let summary = "Tosa rescale operator";

  let description = [{
@@ -1723,7 +1730,9 @@ def Tosa_ConstOp : Tosa_Op<"const", [ConstantLike, NoSideEffect,
//===----------------------------------------------------------------------===//
// Operator: identity
//===----------------------------------------------------------------------===//
def Tosa_IdentityOp: Tosa_Op<"identity", [NoSideEffect]> {
def Tosa_IdentityOp: Tosa_Op<"identity", [NoSideEffect,
      DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
                              ["inferReturnTypeComponents"]>]> {
  let summary = "Identity operator";
  let description = [{
    Returns a tensor with the same shape, size, type
+104 −23
Original line number Diff line number Diff line
@@ -345,6 +345,12 @@ static void getI64Values(ArrayAttr arrayAttr, SmallVector<int64_t> &values) {
  }
}

static void getF64Values(ArrayAttr arrayAttr, SmallVector<double> &values) {
  for (auto it : arrayAttr) {
    values.push_back(it.cast<FloatAttr>().getValueAsDouble());
  }
}

LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(
    MLIRContext *context, ::llvm::Optional<Location> location,
    ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
@@ -386,13 +392,13 @@ LogicalResult tosa::ConcatOp::inferReturnTypeComponents(

    // Copy the Operand's rank.
    if (!hasRankedInput)
      outputShape.resize(operandTy.getRank(), -1);
      outputShape.resize(operandTy.getRank(), ShapedType::kDynamicSize);

    // Copy shapes until the dim is non-dynamic.
    for (int i = 0, s = operandTy.getRank(); i < s; i++) {
      if (i == axis || operandTy.isDynamicDim(i))
        continue;
      if (outputShape[i] == -1)
      if (outputShape[i] == ShapedType::kDynamicSize)
        outputShape[i] = operandTy.getDimSize(i);
      if (outputShape[i] != operandTy.getDimSize(i))
        return failure();
@@ -414,7 +420,7 @@ LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
    // We need to know the length of the concatenation axis of all inputs to
    // determine the dimension size of the output shape.
    if (!operandTy.hasRank() || operandTy.isDynamicDim(axis)) {
      concatDimSize = -1;
      concatDimSize = ShapedType::kDynamicSize;
      break;
    }

@@ -437,7 +443,7 @@ LogicalResult tosa::FullyConnectedOp::inferReturnTypeComponents(

  // All shapes are dynamic.
  SmallVector<int64_t> outShape;
  outShape.resize(2, -1);
  outShape.resize(2, ShapedType::kDynamicSize);

  if (inputTy.hasRank()) {
    outShape[0] = inputTy.getDimSize(0);
@@ -448,7 +454,8 @@ LogicalResult tosa::FullyConnectedOp::inferReturnTypeComponents(
  }

  if (biasTy.hasRank()) {
    outShape[1] = outShape[1] == -1 ? biasTy.getDimSize(0) : outShape[1];
    outShape[1] = outShape[1] == ShapedType::kDynamicSize ? biasTy.getDimSize(0)
                                                          : outShape[1];
  }

  inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
@@ -464,7 +471,7 @@ LogicalResult tosa::MatMulOp::inferReturnTypeComponents(

  // All shapes are dynamic.
  SmallVector<int64_t> outShape;
  outShape.resize(3, -1);
  outShape.resize(3, ShapedType::kDynamicSize);

  if (lhsTy.hasRank()) {
    outShape[0] = lhsTy.getDimSize(0);
@@ -472,7 +479,8 @@ LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
  }

  if (rhsTy.hasRank()) {
    outShape[0] = outShape[0] == -1 ? rhsTy.getDimSize(0) : outShape[0];
    outShape[0] = outShape[0] == ShapedType::kDynamicSize ? rhsTy.getDimSize(0)
                                                          : outShape[0];
    outShape[2] = rhsTy.getDimSize(2);
  }

@@ -503,7 +511,7 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents(
      return success();
    }

    outputShape.resize(paddingTy.getDimSize(0), -1);
    outputShape.resize(paddingTy.getDimSize(0), ShapedType::kDynamicSize);
    inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
    return success();
  }
@@ -511,7 +519,7 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents(
  DenseIntElementsAttr paddings;
  // If the paddings value is not a constant, all dimensions must be dynamic.
  if (!matchPattern(operands[1], m_Constant(&paddings))) {
    outputShape.resize(inputTy.getRank(), -1);
    outputShape.resize(inputTy.getRank(), ShapedType::kDynamicSize);
    inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
    return success();
  }
@@ -524,7 +532,7 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents(
  outputShape.reserve(inputTy.getRank());
  for (int i = 0, s = inputTy.getRank(); i < s; i++) {
    if (inputTy.isDynamicDim(i)) {
      outputShape.push_back(-1);
      outputShape.push_back(ShapedType::kDynamicSize);
      continue;
    }

@@ -574,7 +582,7 @@ LogicalResult tosa::TileOp::inferReturnTypeComponents(
  ShapedType inputTy = operands[0].getType().cast<ShapedType>();
  SmallVector<int64_t> outputShape;
  if (!inputTy.hasRank()) {
    outputShape.resize(multiples.size(), -1);
    outputShape.resize(multiples.size(), ShapedType::kDynamicSize);
    inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
    return success();
  }
@@ -590,7 +598,7 @@ LogicalResult tosa::TileOp::inferReturnTypeComponents(
  outputShape.reserve(multiples.size());
  for (int i = 0, s = inputTy.getRank(); i < s; i++) {
    int dim = inputTy.getDimSize(i);
    if (dim != -1)
    if (dim != ShapedType::kDynamicSize)
      dim *= multipleValues[i];
    outputShape.push_back(dim);
  }
@@ -622,14 +630,14 @@ LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
  int64_t numElements = type.getNumElements();
  int64_t staticMul = 1;
  for (auto val : newShapeValue) {
    if (val != -1) {
    if (val != ShapedType::kDynamicSize) {
      staticMul *= val;
    }
  }

  // Determine the length of the dynamic dimension.
  for (auto &val : newShapeValue) {
    if (val == -1)
    if (val == ShapedType::kDynamicSize)
      val = numElements / staticMul;
  }

@@ -655,7 +663,7 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
  // can determine the output rank.
  SmallVector<int64_t> outputShape;
  if (!inputTy.hasRank()) {
    outputShape.resize(permsTy.getDimSize(0), -1);
    outputShape.resize(permsTy.getDimSize(0), ShapedType::kDynamicSize);
    inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
    return success();
  }
@@ -684,7 +692,7 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
  }

  DenseIntElementsAttr perms;
  outputShape.resize(inputTy.getRank(), -1);
  outputShape.resize(inputTy.getRank(), ShapedType::kDynamicSize);
  // If the permuations are a constant we can directly determine the output
  // shape.
  if (matchPattern(operands[1], m_Constant(&perms))) {
@@ -708,7 +716,7 @@ LogicalResult tosa::GatherOp::inferReturnTypeComponents(
    ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
    SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
  llvm::SmallVector<int64_t> outputShape;
  outputShape.resize(3, -1);
  outputShape.resize(3, ShapedType::kDynamicSize);

  if (auto ty = operands[0].getType().dyn_cast<RankedTensorType>()) {
    outputShape[0] = ty.getDimSize(0);
@@ -716,9 +724,9 @@ LogicalResult tosa::GatherOp::inferReturnTypeComponents(
  }

  if (auto ty = operands[1].getType().dyn_cast<RankedTensorType>()) {
    if (outputShape[0] == -1)
    if (outputShape[0] == ShapedType::kDynamicSize)
      outputShape[0] = ty.getDimSize(0);
    if (outputShape[1] == -1)
    if (outputShape[1] == ShapedType::kDynamicSize)
      outputShape[1] = ty.getDimSize(1);
  }

@@ -726,12 +734,82 @@ LogicalResult tosa::GatherOp::inferReturnTypeComponents(
  return success();
}

LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
    MLIRContext *context, ::llvm::Optional<Location> location,
    ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
    SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
  llvm::SmallVector<int64_t, 4> outputShape;
  outputShape.resize(4, ShapedType::kDynamicSize);

  int32_t inHeight = ShapedType::kDynamicSize;
  int32_t inWidth = ShapedType::kDynamicSize;

  if (auto ty = operands[0].getType().dyn_cast<RankedTensorType>()) {
    outputShape[0] = ty.getDimSize(0);
    outputShape[3] = ty.getDimSize(3);

    inHeight = ty.getDimSize(1);
    inWidth = ty.getDimSize(2);
  }

  int32_t shift =
      attributes.get("shift").cast<IntegerAttr>().getValue().getSExtValue();
  llvm::SmallVector<int64_t> newShape;
  getI64Values(attributes.get("output_size").cast<ArrayAttr>(), newShape);
  outputShape[1] = newShape[0];
  outputShape[2] = newShape[1];

  llvm::SmallVector<int64_t> strideInt;
  llvm::SmallVector<int64_t> offsetInt;
  llvm::SmallVector<double> strideFp;
  llvm::SmallVector<double> offsetFp;
  getI64Values(attributes.get("offset").cast<ArrayAttr>(), offsetInt);
  getF64Values(attributes.get("offset_fp").cast<ArrayAttr>(), offsetFp);
  getI64Values(attributes.get("stride").cast<ArrayAttr>(), strideInt);
  getF64Values(attributes.get("stride_fp").cast<ArrayAttr>(), strideFp);

  // If we have a 0 zero in integers we know that the resize indexing needs to
  // be performed in floating point. Use the floating point varient to compute
  // the resize shape.
  bool fpMode = strideInt[0] == 0;

  // We can compute the output shape if attribute specifies unknown dimensions
  // based on the offset and stride. If we perfectly line up to the last index
  // we need to round up the size to include it.
  if (outputShape[1] == ShapedType::kDynamicSize && inHeight >= 0 && fpMode) {
    float sizeFp = (inHeight - offsetFp[0] - 1) / strideFp[0];
    float round = std::floor(sizeFp) == sizeFp ? 1 : 0;
    outputShape[1] = std::ceil(sizeFp) + round;
  }

  if (outputShape[2] == ShapedType::kDynamicSize && inWidth >= 0 && fpMode) {
    float sizeFp = (inWidth - offsetFp[1] - 1) / strideFp[1];
    float round = std::floor(sizeFp) == sizeFp ? 1 : 0;
    outputShape[2] = std::ceil(sizeFp) + round;
  }

  if (outputShape[1] == ShapedType::kDynamicSize && inHeight >= 0 && !fpMode) {
    int64_t size = (inHeight - 1);
    size = ((size << shift) - offsetInt[0]) / strideInt[0];
    outputShape[1] = size + 1;
  }

  if (outputShape[2] == ShapedType::kDynamicSize && inWidth >= 0 && !fpMode) {
    int64_t size = (inWidth - 1);
    size = ((size << shift) - offsetInt[1]) / strideInt[1];
    outputShape[2] = size + 1;
  }

  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
  return success();
}

LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
    MLIRContext *context, ::llvm::Optional<Location> location,
    ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
    SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
  llvm::SmallVector<int64_t> outputShape;
  outputShape.resize(3, -1);
  outputShape.resize(3, ShapedType::kDynamicSize);

  if (auto ty = operands[0].getType().dyn_cast<RankedTensorType>()) {
    outputShape[0] = ty.getDimSize(0);
@@ -740,14 +818,14 @@ LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
  }

  if (auto ty = operands[1].getType().dyn_cast<RankedTensorType>()) {
    if (outputShape[0] == -1)
    if (outputShape[0] == ShapedType::kDynamicSize)
      outputShape[0] = ty.getDimSize(0);
  }

  if (auto ty = operands[2].getType().dyn_cast<RankedTensorType>()) {
    if (outputShape[0] == -1)
    if (outputShape[0] == ShapedType::kDynamicSize)
      outputShape[0] = ty.getDimSize(0);
    if (outputShape[2] == -1)
    if (outputShape[2] == ShapedType::kDynamicSize)
      outputShape[2] = ty.getDimSize(2);
  }

@@ -859,6 +937,7 @@ NARY_SHAPE_INFER(tosa::BitwiseAndOp)
NARY_SHAPE_INFER(tosa::BitwiseOrOp)
NARY_SHAPE_INFER(tosa::BitwiseXorOp)
NARY_SHAPE_INFER(tosa::BitwiseNotOp)
NARY_SHAPE_INFER(tosa::CastOp)
NARY_SHAPE_INFER(tosa::CeilOp)
NARY_SHAPE_INFER(tosa::ClampOp)
NARY_SHAPE_INFER(tosa::ClzOp)
@@ -868,6 +947,7 @@ NARY_SHAPE_INFER(tosa::ExpOp)
NARY_SHAPE_INFER(tosa::FloorOp)
NARY_SHAPE_INFER(tosa::GreaterEqualOp)
NARY_SHAPE_INFER(tosa::GreaterOp)
NARY_SHAPE_INFER(tosa::IdentityOp)
NARY_SHAPE_INFER(tosa::LogOp)
NARY_SHAPE_INFER(tosa::LogicalAndOp)
NARY_SHAPE_INFER(tosa::LogicalLeftShiftOp)
@@ -882,6 +962,7 @@ NARY_SHAPE_INFER(tosa::NegateOp)
NARY_SHAPE_INFER(tosa::PowOp)
NARY_SHAPE_INFER(tosa::ReciprocalOp)
NARY_SHAPE_INFER(tosa::ReluNOp)
NARY_SHAPE_INFER(tosa::RescaleOp)
NARY_SHAPE_INFER(tosa::ReverseOp)
NARY_SHAPE_INFER(tosa::RsqrtOp)
NARY_SHAPE_INFER(tosa::SelectOp)
+71 −0
Original line number Diff line number Diff line
@@ -65,6 +65,9 @@ func @test_unary_f32(%arg0 : tensor<4xf32>) -> () {

  // CHECK: "tosa.sigmoid"(%arg0) : (tensor<4xf32>) -> tensor<4xf32>
  %12 = "tosa.sigmoid"(%arg0) : (tensor<4xf32>) -> tensor<*xf32>

  // CHECK: "tosa.cast"(%arg0) : (tensor<4xf32>) -> tensor<4xi32>
  %13 = "tosa.cast"(%arg0) : (tensor<4xf32>) -> tensor<*xi32>
  return
}

@@ -92,6 +95,12 @@ func @test_unary_i32(%arg0 : tensor<4xi32>) -> () {

  // CHECK: "tosa.reverse"(%arg0) {axis = 0 : i64} : (tensor<4xi32>) -> tensor<4xi32>
  %6 = "tosa.reverse"(%arg0) { axis = 0 : i64 } : (tensor<4xi32>) -> tensor<?xi32>

  // CHECK: "tosa.rescale"(%arg0) {{.+}} : (tensor<4xi32>) -> tensor<4xi16>
  %7 = "tosa.rescale"(%arg0) {input_zp = 243 : i32, output_zp = 252 : i32, multiplier = [42 : i32, 43 : i32], shift = [14 : i32, 15 : i32], scale32 = false, double_round = false, per_channel = false} : (tensor<4xi32>)  -> (tensor<*xi16>)

  // CHECK: "tosa.identity"(%arg0) : (tensor<4xi32>) -> tensor<4xi32>
  %8 = "tosa.identity"(%arg0) : (tensor<4xi32>) -> tensor<?xi32>
  return
}

@@ -971,3 +980,65 @@ func @transpose_conv2d_strided(%arg0: tensor<1x5x7x1xf32>, %arg1: tensor<1x1x1x1
  return
}

// -----

// CHECK-LABEL: @resize_output_size
func @resize_output_size(%arg0: tensor<2x?x?x3xi32>) {
  // CHECK: -> tensor<2x4x5x3xi32>
  %0 = "tosa.resize"(%arg0) {mode = "NEAREST_NEIGHBOR", offset = [0, 1], offset_fp = [0.000000e+00 : f32, 0.000000e+00 : f32], output_size = [4, 5], shift = 8 : i32, stride = [1, 1], stride_fp = [0.000000e+00 : f32, 0.000000e+00 : f32]} : (tensor<2x?x?x3xi32>) -> tensor<?x?x?x?xi32>
  return
}

// -----

// CHECK-LABEL: @resize_int_horizontal
func @resize_int_horizontal(%arg0: tensor<1x2x4x1xi32>) {
  // CHECK: -> tensor<1x2x7x1xi32>
  %0 = "tosa.resize"(%arg0) {mode = "NEAREST_NEIGHBOR", offset = [0, 0], offset_fp = [0.000000e+00 : f32, 0.000000e+00 : f32], output_size = [-1, -1], shift = 8 : i32, stride = [256, 128], stride_fp = [0.000000e+00 : f32, 0.000000e+00 : f32]} : (tensor<1x2x4x1xi32>) -> tensor<?x?x?x?xi32>
  return
}

// -----

// CHECK-LABEL: @resize_int_vertical
func @resize_int_vertical(%arg0: tensor<1x2x4x1xi32>) {
  // CHECK: -> tensor<1x3x4x1xi32>
  %0 = "tosa.resize"(%arg0) {mode = "NEAREST_NEIGHBOR", offset = [0, 0], offset_fp = [0.000000e+00 : f32, 0.000000e+00 : f32], output_size = [-1, -1], shift = 8 : i32, stride = [128, 256], stride_fp = [0.000000e+00 : f32, 0.000000e+00 : f32]} : (tensor<1x2x4x1xi32>) -> tensor<?x?x?x?xi32>
  return
}

// -----

// CHECK-LABEL: @resize_int_offsetted
func @resize_int_offsetted(%arg0: tensor<1x2x4x1xi32>) {
  // CHECK: -> tensor<1x4x6x1xi32>
  %0 = "tosa.resize"(%arg0) {mode = "NEAREST_NEIGHBOR", offset = [64, 64], offset_fp = [0.000000e+00 : f32, 0.000000e+00 : f32], output_size = [-1, -1], shift = 8 : i32, stride = [64, 128], stride_fp = [0.000000e+00 : f32, 0.000000e+00 : f32]} : (tensor<1x2x4x1xi32>) -> tensor<?x?x?x?xi32>
  return
}

// -----

// CHECK-LABEL: @resize_fp_horizontal
func @resize_fp_horizontal(%arg0: tensor<1x2x4x1xi32>) {
  // CHECK: -> tensor<1x2x7x1xi32>
  %0 = "tosa.resize"(%arg0) {mode = "NEAREST_NEIGHBOR", offset = [0, 0], offset_fp = [0.000000e+00 : f32, 0.000000e+00 : f32], output_size = [-1, -1], shift = 0 : i32, stride = [0, 0], stride_fp = [1.000000e+00 : f32, 5.000000e-01 : f32]} : (tensor<1x2x4x1xi32>) -> tensor<?x?x?x?xi32>
  return
}

// -----

// CHECK-LABEL: @resize_fp_vertical
func @resize_fp_vertical(%arg0: tensor<1x2x4x1xi32>) {
  // CHECK: -> tensor<1x3x4x1xi32>
  %0 = "tosa.resize"(%arg0) {mode = "NEAREST_NEIGHBOR", offset = [0, 0], offset_fp = [0.000000e+00 : f32, 0.000000e+00 : f32], output_size = [-1, -1], shift = 0 : i32, stride = [0, 0], stride_fp = [5.000000e-01 : f32, 1.000000e+00 : f32]} : (tensor<1x2x4x1xi32>) -> tensor<?x?x?x?xi32>
  return
}

// -----

// CHECK-LABEL: @resize_fp_offsetted
func @resize_fp_offsetted(%arg0: tensor<1x2x4x1xi32>) {
  // CHECK: -> tensor<1x4x6x1xi32>
  %0 = "tosa.resize"(%arg0) {mode = "NEAREST_NEIGHBOR", offset = [0, 0], offset_fp = [2.500000e-01 : f32, 2.500000e-01 : f32], output_size = [-1, -1], shift = 0 : i32, stride = [0, 0], stride_fp = [2.500000e-01 : f32, 5.000000e-01 : f32]} : (tensor<1x2x4x1xi32>) -> tensor<?x?x?x?xi32>
  return
}