Commit 28fe1a4e authored by Tai Ly's avatar Tai Ly Committed by Eric Kunze
Browse files

[mlir] Add trait SameOperandsAndResultRank



This adds a native op trait SameOperandsAndResultRank
and associated verifier that checks that an operator's
operands and result types have same ranks if their ranks
are known.

Signed-off-by: default avatarTai Ly <tai.ly@arm.com>
Change-Id: I2d536f77be10f3710d0c8d84c907ff492a984fda

Reviewed By: rsuderman

Differential Revision: https://reviews.llvm.org/D156369
parent 1eab92bd
Loading
Loading
Loading
Loading
+12 −0
Original line number Diff line number Diff line
@@ -341,6 +341,7 @@ LogicalResult verifySameOperandsAndResultShape(Operation *op);
LogicalResult verifySameOperandsElementType(Operation *op);
LogicalResult verifySameOperandsAndResultElementType(Operation *op);
LogicalResult verifySameOperandsAndResultType(Operation *op);
LogicalResult verifySameOperandsAndResultRank(Operation *op);
LogicalResult verifyResultsAreBoolLike(Operation *op);
LogicalResult verifyResultsAreFloatLike(Operation *op);
LogicalResult verifyResultsAreSignlessIntegerLike(Operation *op);
@@ -1114,6 +1115,17 @@ public:
  }
};

/// This class verifies that op has same ranks for all
/// operands and results types, if known.
template <typename ConcreteType>
class SameOperandsAndResultRank
    : public TraitBase<ConcreteType, SameOperandsAndResultRank> {
public:
  static LogicalResult verifyTrait(Operation *op) {
    return impl::verifySameOperandsAndResultRank(op);
  }
};

/// This class verifies that any results of the specified op have a boolean
/// type, a vector thereof, or a tensor thereof.
template <typename ConcreteType>
+3 −0
Original line number Diff line number Diff line
@@ -369,4 +369,7 @@ def ReifyRankedShapedTypeOpInterface :
// TODO: Change from hard coded to utilizing type inference trait.
def SameOperandsAndResultType : NativeOpTrait<"SameOperandsAndResultType">;

// Op has the same ranks for all operands and results types, if known.
def SameOperandsAndResultRank : NativeOpTrait<"SameOperandsAndResultRank">;

#endif // MLIR_INFERTYPEOPINTERFACE
+45 −0
Original line number Diff line number Diff line
@@ -1082,6 +1082,51 @@ LogicalResult OpTrait::impl::verifySameOperandsAndResultType(Operation *op) {
  return success();
}

LogicalResult OpTrait::impl::verifySameOperandsAndResultRank(Operation *op) {
  if (failed(verifyAtLeastNOperands(op, 1)))
    return failure();

  // delegate function that returns true if type is a shaped type with known
  // rank
  auto hasRank = [](const Type type) {
    if (auto shaped_type = dyn_cast<ShapedType>(type))
      return shaped_type.hasRank();

    return false;
  };

  auto rankedOperandTypes =
      llvm::make_filter_range(op->getOperandTypes(), hasRank);
  auto rankedResultTypes =
      llvm::make_filter_range(op->getResultTypes(), hasRank);

  // If all operands and results are unranked, then no further verification.
  if (rankedOperandTypes.empty() && rankedResultTypes.empty())
    return success();

  // delegate function that returns rank of shaped type with known rank
  auto getRank = [](const Type type) {
    return type.cast<ShapedType>().getRank();
  };

  auto rank = !rankedOperandTypes.empty() ? getRank(*rankedOperandTypes.begin())
                                          : getRank(*rankedResultTypes.begin());

  for (const auto type : rankedOperandTypes) {
    if (rank != getRank(type)) {
      return op->emitOpError("operands don't have matching ranks");
    }
  }

  for (const auto type : rankedResultTypes) {
    if (rank != getRank(type)) {
      return op->emitOpError("result type has different rank than operands");
    }
  }

  return success();
}

LogicalResult OpTrait::impl::verifyIsTerminator(Operation *op) {
  Block *block = op->getBlock();
  // Verify that the operation is at the end of the respective parent block.
+6 −0
Original line number Diff line number Diff line
@@ -692,6 +692,12 @@ def OperandZeroAndResultHaveSameRank :
  let results = (outs AnyShaped:$res);
}

def OperandsAndResultHaveSameRank :
    TEST_Op<"operands_and_result_have_same_rank", [SameOperandsAndResultRank]> {
  let arguments = (ins AnyShaped:$x, AnyShaped:$y);
  let results = (outs AnyShaped:$res);
}

def OperandZeroAndResultHaveSameShape :
    TEST_Op<"operand0_and_result_have_same_shape",
            [AllShapesMatch<["x", "res"]>]> {
+27 −0
Original line number Diff line number Diff line
@@ -377,6 +377,33 @@ func.func @same_rank_failure(%arg0: tensor<1x2xi32>, %arg1: tensor<1x2xf32>) {

// -----

// CHECK-LABEL: same_rank_if_known_success
func.func @same_rank_if_known_success(%t1xi : tensor<1xi32>, %t2xf : tensor<2xf32>, %m3xi : memref<3xi32>, %t1x2xf : tensor<1x2xf32>, %tuxi : tensor<*xi32>) {
  %0 = "test.operands_and_result_have_same_rank"(%t1xi, %t2xf) : (tensor<1xi32>, tensor<2xf32>) -> (tensor<3xf64>)
  %1 = "test.operands_and_result_have_same_rank"(%t1xi, %m3xi) : (tensor<1xi32>, memref<3xi32>) -> (tensor<3xi64>)
  %3 = "test.operands_and_result_have_same_rank"(%tuxi, %t2xf) : (tensor<*xi32>, tensor<2xf32>) -> (tensor<2xf32>)
  %4 = "test.operands_and_result_have_same_rank"(%t1x2xf, %tuxi) : (tensor<1x2xf32>, tensor<*xi32>) -> (tensor<1x2xf64>)
  return
}

// -----

func.func @same_rank_if_known_failure(%arg0: tensor<1xi32>, %arg1: tensor<1x2xf32>) {
  // expected-error@+1 {{operands don't have matching ranks}}
  %0 = "test.operands_and_result_have_same_rank"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xf32>) -> (tensor<*xf32>)
  return
}

// -----

func.func @same_rank_if_known_failure(%arg0: tensor<1xi32>, %arg1: tensor<1x2xf32>) {
  // expected-error@+1 {{result type has different rank than operands}}
  %0 = "test.operands_and_result_have_same_rank"(%arg1, %arg1) : (tensor<1x2xf32>, tensor<1x2xf32>) -> (tensor<1x2x3xf32>)
  return
}

// -----

// CHECK-LABEL: same_shape_success
func.func @same_shape_success(%t2x3: tensor<2x3xi32>, %m2x3: memref<2x3xf32>, %v2x3 : vector<2x3xi32>, %t4x5 : tensor<4x5xi32>) {
  "test.operand0_and_result_have_same_shape"(%t2x3, %t4x5) : (tensor<2x3xi32>, tensor<4x5xi32>) -> (tensor<2x3xf32>)