Commit b04885a5 authored by Joonsoo Jeon's avatar Joonsoo Jeon Committed by Lei Zhang
Browse files

[mlir][ods] Added RankedIntElementsAttr class

Defines a tablegen class RankedIntElementsAttr. This is an integer
version of RankedFloatElementsAttr.

Differential Revision: https://reviews.llvm.org/D73764
parent 9ce6dc98
Loading
Loading
Loading
Loading
+20 −0
Original line number Diff line number Diff line
@@ -1040,6 +1040,26 @@ class IntElementsAttr<int width> : ElementsAttrBase<
def I32ElementsAttr : IntElementsAttr<32>;
def I64ElementsAttr : IntElementsAttr<64>;

// A `width`-bit integer elements attribute. The attribute should be ranked and
// has a shape as specified in `dims`.
class RankedIntElementsAttr<int width, list<int> dims> : IntElementsAttr<width> {
  // Check that this has the specified shape.
  let predicate = And<[
    IntElementsAttr<width>.predicate,
    CPred<"$_self.cast<DenseIntElementsAttr>().getType().getShape() == "
        "ArrayRef<int64_t>({" # StrJoinInt<dims>.result # "})">]>;

  let description = width # "-bit int elements attribute of shape [" #
                    StrJoinInt<dims>.result # "]";

  let constBuilderCall = "DenseIntElementsAttr::get("
    "RankedTensorType::get({" # StrJoinInt<dims>.result #
    "}, $_builder.getIntegerType(" # width # ")), makeArrayRef($0))";
}

class RankedI32ElementsAttr<list<int> dims> : RankedIntElementsAttr<32, dims>;
class RankedI64ElementsAttr<list<int> dims> : RankedIntElementsAttr<64, dims>;

class FloatElementsAttr<int width> : ElementsAttrBase<
  CPred<"$_self.isa<DenseFPElementsAttr>() &&"
      "$_self.cast<DenseElementsAttr>().getType()."
+50 −0
Original line number Diff line number Diff line
@@ -243,3 +243,53 @@ func @fn() { return }

// expected-error @+1 {{referencing to a 'FuncOp' symbol}}
"test.symbol_ref_attr"() {symbol = @foo} : () -> ()

// -----

//===----------------------------------------------------------------------===//
// Test IntElementsAttr
//===----------------------------------------------------------------------===//

func @correct_type_pass() {
  "test.int_elements_attr"() {
    // CHECK: matrix_i64_attr = dense<6> : tensor<4x8xi64>
    // CHECK: vector_i32_attr = dense<5> : tensor<2xi32>
    matrix_i64_attr = dense<6> : tensor<4x8xi64>,
    vector_i32_attr = dense<5> : tensor<2xi32>
  } : () -> ()
  return
}

// -----

func @wrong_element_type_fail() {
  // expected-error @+1 {{failed to satisfy constraint: 32-bit int elements attribute of shape [2]}}
  "test.int_elements_attr"() {
    matrix_i64_attr = dense<6> : tensor<4x8xi64>,
    vector_i32_attr = dense<5> : tensor<2xi64>
  } : () -> ()
  return
}

// -----

func @wrong_shape_fail() {
  // expected-error @+1 {{failed to satisfy constraint: 64-bit int elements attribute of shape [4, 8]}}
  "test.int_elements_attr"() {
    matrix_i64_attr = dense<6> : tensor<4xi64>,
    vector_i32_attr = dense<5> : tensor<2xi32>
  } : () -> ()
  return
}

// -----

func @wrong_shape_fail() {
  // expected-error @+1 {{failed to satisfy constraint: 32-bit int elements attribute of shape [2]}}
  "test.int_elements_attr"() {
    matrix_i64_attr = dense<6> : tensor<4x8xi64>,
    vector_i32_attr = dense<5> : tensor<i32>
  } : () -> ()
  return
}
+7 −0
Original line number Diff line number Diff line
@@ -204,6 +204,13 @@ def UpdateFloatElementsAttr : Pat<
    ConstantAttr<RankedF32ElementsAttr<[2]>, "{5.0f, 6.0f}">:$f32attr,
    $f64attr)>;

def IntElementsAttrOp : TEST_Op<"int_elements_attr"> {
  let arguments = (ins
      RankedI32ElementsAttr<[2]>:$vector_i32_attr,
      RankedI64ElementsAttr<[4, 8]>:$matrix_i64_attr
  );
}

//===----------------------------------------------------------------------===//
// Test Attribute Constraints
//===----------------------------------------------------------------------===//