Commit de5a81b1 authored by River Riddle's avatar River Riddle
Browse files

[mlir] Update several usages of IntegerType to properly handled unsignedness.

Summary: For example, DenseElementsAttr currently does not properly round-trip unsigned integer values.

Differential Revision: https://reviews.llvm.org/D75374
parent 4167645d
Loading
Loading
Loading
Loading
+1 −2
Original line number Diff line number Diff line
@@ -93,9 +93,8 @@ struct constant_int_op_binder {
      return false;
    auto type = op->getResult(0).getType();

    if (type.isSignlessIntOrIndex()) {
    if (type.isa<IntegerType>() || type.isa<IndexType>())
      return attr_value_binder<IntegerAttr>(bind_value).match(attr);
    }
    if (type.isa<VectorType>() || type.isa<RankedTensorType>()) {
      if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>()) {
        return attr_value_binder<IntegerAttr>(bind_value)
+24 −0
Original line number Diff line number Diff line
@@ -339,6 +339,30 @@ def I16 : I<16>;
def I32 : I<32>;
def I64 : I<64>;

// Unsigned integer types.
// Any unsigned integer type irrespective of its width.
def AnyUnsignedInteger : Type<
  CPred<"$_self.isUnsignedInteger()">, "unsigned integer">;

// Unsigned integer type of a specific width.
class UI<int width>
    : Type<CPred<"$_self.isUnsignedInteger(" # width # ")">,
                  width # "-bit unsigned integer">,
      BuildableType<"$_builder.getIntegerType(" # width #
                    ", /*isSigned=*/false)"> {
  int bitwidth = width;
}

class UnsignedIntOfWidths<list<int> widths> :
    AnyTypeOf<!foreach(w, widths, UI<w>),
              StrJoinInt<widths, "/">.result # "-bit unsigned integer">;

def UI1  : UI<1>;
def UI8  : UI<8>;
def UI16 : UI<16>;
def UI32 : UI<32>;
def UI64 : UI<64>;

// Floating point types.

// Any float type irrespective of its width.
+3 −2
Original line number Diff line number Diff line
@@ -328,8 +328,9 @@ public:
    // Note: Non standard/builtin types are allowed to exist within tensor
    // types. Dialects are expected to verify that tensor types have a valid
    // element type within that dialect.
    return type.isSignlessIntOrFloat() || type.isa<ComplexType>() ||
           type.isa<VectorType>() || type.isa<OpaqueType>() ||
    return type.isa<ComplexType>() || type.isa<FloatType>() ||
           type.isa<IntegerType>() || type.isa<OpaqueType>() ||
           type.isa<VectorType>() ||
           (type.getKind() > Type::Kind::LAST_STANDARD_TYPE);
  }

+3 −0
Original line number Diff line number Diff line
@@ -169,6 +169,9 @@ public:
  /// Return true of this is a signless integer or a float type.
  bool isSignlessIntOrFloat();

  /// Return true of this is an integer(of any signedness) or a float type.
  bool isIntOrFloat();

  /// Print the current type.
  void print(raw_ostream &os);
  void dump();
+2 −2
Original line number Diff line number Diff line
@@ -314,7 +314,7 @@ static unsigned getMemRefEltSizeInBytes(MemRefType memRefType) {
  auto elementType = memRefType.getElementType();

  unsigned sizeInBits;
  if (elementType.isSignlessIntOrFloat()) {
  if (elementType.isIntOrFloat()) {
    sizeInBits = elementType.getIntOrFloatBitWidth();
  } else {
    auto vectorType = elementType.cast<VectorType>();
@@ -358,7 +358,7 @@ Optional<uint64_t> mlir::getMemRefSizeInBytes(MemRefType memRefType) {
  if (!memRefType.hasStaticShape())
    return None;
  auto elementType = memRefType.getElementType();
  if (!elementType.isSignlessIntOrFloat() && !elementType.isa<VectorType>())
  if (!elementType.isIntOrFloat() && !elementType.isa<VectorType>())
    return None;

  uint64_t sizeInBytes = getMemRefEltSizeInBytes(memRefType);
Loading