Commit ee2de955 authored by Frank Laub's avatar Frank Laub
Browse files

[MLIR] LLVM dialect: modernize and cleanups

Summary:
Modernize some of the existing custom parsing code in the LLVM dialect.
While this reduces some boilerplate code, it also reduces the precision
of the diagnostic error messges.

Reviewers: ftynse, nicolasvasilache, rriddle

Reviewed By: rriddle

Subscribers: merge_guards_bot, mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, liufengdb, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D72967
parent df7900e2
Loading
Loading
Loading
Loading
+36 −61
Original line number Diff line number Diff line
@@ -55,47 +55,42 @@ template <typename CmpPredicateType>
static ParseResult parseCmpOp(OpAsmParser &parser, OperationState &result) {
  Builder &builder = parser.getBuilder();

  Attribute predicate;
  SmallVector<NamedAttribute, 4> attrs;
  StringAttr predicateAttr;
  OpAsmParser::OperandType lhs, rhs;
  Type type;
  llvm::SMLoc predicateLoc, trailingTypeLoc;
  if (parser.getCurrentLocation(&predicateLoc) ||
      parser.parseAttribute(predicate, "predicate", attrs) ||
      parser.parseAttribute(predicateAttr, "predicate", result.attributes) ||
      parser.parseOperand(lhs) || parser.parseComma() ||
      parser.parseOperand(rhs) || parser.parseOptionalAttrDict(attrs) ||
      parser.parseColon() || parser.getCurrentLocation(&trailingTypeLoc) ||
      parser.parseType(type) ||
      parser.parseOperand(rhs) ||
      parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
      parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type) ||
      parser.resolveOperand(lhs, type, result.operands) ||
      parser.resolveOperand(rhs, type, result.operands))
    return failure();

  // Replace the string attribute `predicate` with an integer attribute.
  auto predicateStr = predicate.dyn_cast<StringAttr>();
  if (!predicateStr)
    return parser.emitError(predicateLoc,
                            "expected 'predicate' attribute of string type");

  int64_t predicateValue = 0;
  if (std::is_same<CmpPredicateType, ICmpPredicate>()) {
    Optional<ICmpPredicate> predicate =
        symbolizeICmpPredicate(predicateStr.getValue());
        symbolizeICmpPredicate(predicateAttr.getValue());
    if (!predicate)
      return parser.emitError(predicateLoc)
             << "'" << predicateStr.getValue()
             << "'" << predicateAttr.getValue()
             << "' is an incorrect value of the 'predicate' attribute";
    predicateValue = static_cast<int64_t>(predicate.getValue());
  } else {
    Optional<FCmpPredicate> predicate =
        symbolizeFCmpPredicate(predicateStr.getValue());
        symbolizeFCmpPredicate(predicateAttr.getValue());
    if (!predicate)
      return parser.emitError(predicateLoc)
             << "'" << predicateStr.getValue()
             << "'" << predicateAttr.getValue()
             << "' is an incorrect value of the 'predicate' attribute";
    predicateValue = static_cast<int64_t>(predicate.getValue());
  }

  attrs[0].second = parser.getBuilder().getI64IntegerAttr(predicateValue);
  result.attributes[0].second =
      parser.getBuilder().getI64IntegerAttr(predicateValue);

  // The result type is either i1 or a vector type <? x i1> if the inputs are
  // vectors.
@@ -108,7 +103,6 @@ static ParseResult parseCmpOp(OpAsmParser &parser, OperationState &result) {
    resultType = LLVMType::getVectorTy(
        resultType, argType.getUnderlyingType()->getVectorNumElements());

  result.attributes = attrs;
  result.addTypes({resultType});
  return success();
}
@@ -134,14 +128,13 @@ static void printAllocaOp(OpAsmPrinter &p, AllocaOp &op) {
// <operation> ::= `llvm.alloca` ssa-use `x` type attribute-dict?
//                 `:` type `,` type
static ParseResult parseAllocaOp(OpAsmParser &parser, OperationState &result) {
  SmallVector<NamedAttribute, 4> attrs;
  OpAsmParser::OperandType arraySize;
  Type type, elemType;
  llvm::SMLoc trailingTypeLoc;
  if (parser.parseOperand(arraySize) || parser.parseKeyword("x") ||
      parser.parseType(elemType) || parser.parseOptionalAttrDict(attrs) ||
      parser.parseColon() || parser.getCurrentLocation(&trailingTypeLoc) ||
      parser.parseType(type))
      parser.parseType(elemType) ||
      parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
      parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type))
    return failure();

  // Extract the result type from the trailing function type.
@@ -155,7 +148,6 @@ static ParseResult parseAllocaOp(OpAsmParser &parser, OperationState &result) {
  if (parser.resolveOperand(arraySize, funcType.getInput(0), result.operands))
    return failure();

  result.attributes = attrs;
  result.addTypes({funcType.getResult(0)});
  return success();
}
@@ -177,14 +169,13 @@ static void printGEPOp(OpAsmPrinter &p, GEPOp &op) {
// <operation> ::= `llvm.getelementptr` ssa-use `[` ssa-use-list `]`
//                 attribute-dict? `:` type
static ParseResult parseGEPOp(OpAsmParser &parser, OperationState &result) {
  SmallVector<NamedAttribute, 4> attrs;
  OpAsmParser::OperandType base;
  SmallVector<OpAsmParser::OperandType, 8> indices;
  Type type;
  llvm::SMLoc trailingTypeLoc;
  if (parser.parseOperand(base) ||
      parser.parseOperandList(indices, OpAsmParser::Delimiter::Square) ||
      parser.parseOptionalAttrDict(attrs) || parser.parseColon() ||
      parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
      parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type))
    return failure();

@@ -202,7 +193,6 @@ static ParseResult parseGEPOp(OpAsmParser &parser, OperationState &result) {
                             parser.getNameLoc(), result.operands))
    return failure();

  result.attributes = attrs;
  result.addTypes(funcType.getResults());
  return success();
}
@@ -233,20 +223,18 @@ static Type getLoadStoreElementType(OpAsmParser &parser, Type type,

// <operation> ::= `llvm.load` ssa-use attribute-dict? `:` type
static ParseResult parseLoadOp(OpAsmParser &parser, OperationState &result) {
  SmallVector<NamedAttribute, 4> attrs;
  OpAsmParser::OperandType addr;
  Type type;
  llvm::SMLoc trailingTypeLoc;

  if (parser.parseOperand(addr) || parser.parseOptionalAttrDict(attrs) ||
      parser.parseColon() || parser.getCurrentLocation(&trailingTypeLoc) ||
      parser.parseType(type) ||
  if (parser.parseOperand(addr) ||
      parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
      parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type) ||
      parser.resolveOperand(addr, type, result.operands))
    return failure();

  Type elemTy = getLoadStoreElementType(parser, type, trailingTypeLoc);

  result.attributes = attrs;
  result.addTypes(elemTy);
  return success();
}
@@ -263,15 +251,14 @@ static void printStoreOp(OpAsmPrinter &p, StoreOp &op) {

// <operation> ::= `llvm.store` ssa-use `,` ssa-use attribute-dict? `:` type
static ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) {
  SmallVector<NamedAttribute, 4> attrs;
  OpAsmParser::OperandType addr, value;
  Type type;
  llvm::SMLoc trailingTypeLoc;

  if (parser.parseOperand(value) || parser.parseComma() ||
      parser.parseOperand(addr) || parser.parseOptionalAttrDict(attrs) ||
      parser.parseColon() || parser.getCurrentLocation(&trailingTypeLoc) ||
      parser.parseType(type))
      parser.parseOperand(addr) ||
      parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
      parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type))
    return failure();

  Type elemTy = getLoadStoreElementType(parser, type, trailingTypeLoc);
@@ -282,7 +269,6 @@ static ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) {
      parser.resolveOperand(addr, type, result.operands))
    return failure();

  result.attributes = attrs;
  return success();
}

@@ -316,7 +302,6 @@ static void printCallOp(OpAsmPrinter &p, CallOp &op) {
// <operation> ::= `llvm.call` (function-id | ssa-use) `(` ssa-use-list `)`
//                 attribute-dict? `:` function-type
static ParseResult parseCallOp(OpAsmParser &parser, OperationState &result) {
  SmallVector<NamedAttribute, 4> attrs;
  SmallVector<OpAsmParser::OperandType, 8> operands;
  Type type;
  SymbolRefAttr funcAttr;
@@ -332,11 +317,11 @@ static ParseResult parseCallOp(OpAsmParser &parser, OperationState &result) {

  // Optionally parse a function identifier.
  if (isDirect)
    if (parser.parseAttribute(funcAttr, "callee", attrs))
    if (parser.parseAttribute(funcAttr, "callee", result.attributes))
      return failure();

  if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
      parser.parseOptionalAttrDict(attrs) || parser.parseColon() ||
      parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
      parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type))
    return failure();

@@ -396,7 +381,6 @@ static ParseResult parseCallOp(OpAsmParser &parser, OperationState &result) {
    result.addTypes(llvmResultType);
  }

  result.attributes = attrs;
  return success();
}

@@ -461,23 +445,18 @@ static void printExtractValueOp(OpAsmPrinter &p, ExtractValueOp &op) {
// resulting type wrapped in MLIR, or nullptr on error.
static LLVM::LLVMType getInsertExtractValueElementType(OpAsmParser &parser,
                                                       Type containerType,
                                                       Attribute positionAttr,
                                                       ArrayAttr positionAttr,
                                                       llvm::SMLoc attributeLoc,
                                                       llvm::SMLoc typeLoc) {
  auto wrappedContainerType = containerType.dyn_cast<LLVM::LLVMType>();
  if (!wrappedContainerType)
    return parser.emitError(typeLoc, "expected LLVM IR Dialect type"), nullptr;

  auto positionArrayAttr = positionAttr.dyn_cast<ArrayAttr>();
  if (!positionArrayAttr)
    return parser.emitError(attributeLoc, "expected an array attribute"),
           nullptr;

  // Infer the element type from the structure type: iteratively step inside the
  // type by taking the element type, indexed by the position attribute for
  // structures.  Check the position index before accessing, it is supposed to
  // be in bounds.
  for (Attribute subAttr : positionArrayAttr) {
  for (Attribute subAttr : positionAttr) {
    auto positionElementAttr = subAttr.dyn_cast<IntegerAttr>();
    if (!positionElementAttr)
      return parser.emitError(attributeLoc,
@@ -512,16 +491,15 @@ static LLVM::LLVMType getInsertExtractValueElementType(OpAsmParser &parser,
//                 attribute-dict? `:` type
static ParseResult parseExtractValueOp(OpAsmParser &parser,
                                       OperationState &result) {
  SmallVector<NamedAttribute, 4> attrs;
  OpAsmParser::OperandType container;
  Type containerType;
  Attribute positionAttr;
  ArrayAttr positionAttr;
  llvm::SMLoc attributeLoc, trailingTypeLoc;

  if (parser.parseOperand(container) ||
      parser.getCurrentLocation(&attributeLoc) ||
      parser.parseAttribute(positionAttr, "position", attrs) ||
      parser.parseOptionalAttrDict(attrs) || parser.parseColon() ||
      parser.parseAttribute(positionAttr, "position", result.attributes) ||
      parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
      parser.getCurrentLocation(&trailingTypeLoc) ||
      parser.parseType(containerType) ||
      parser.resolveOperand(container, containerType, result.operands))
@@ -532,7 +510,6 @@ static ParseResult parseExtractValueOp(OpAsmParser &parser,
  if (!elementType)
    return failure();

  result.attributes = attrs;
  result.addTypes(elementType);
  return success();
}
@@ -599,7 +576,7 @@ static ParseResult parseInsertValueOp(OpAsmParser &parser,
                                      OperationState &result) {
  OpAsmParser::OperandType container, value;
  Type containerType;
  Attribute positionAttr;
  ArrayAttr positionAttr;
  llvm::SMLoc attributeLoc, trailingTypeLoc;

  if (parser.parseOperand(value) || parser.parseComma() ||
@@ -1080,15 +1057,15 @@ static void printShuffleVectorOp(OpAsmPrinter &p, ShuffleVectorOp &op) {
static ParseResult parseShuffleVectorOp(OpAsmParser &parser,
                                        OperationState &result) {
  llvm::SMLoc loc;
  SmallVector<NamedAttribute, 4> attrs;
  OpAsmParser::OperandType v1, v2;
  Attribute maskAttr;
  ArrayAttr maskAttr;
  Type typeV1, typeV2;
  if (parser.getCurrentLocation(&loc) || parser.parseOperand(v1) ||
      parser.parseComma() || parser.parseOperand(v2) ||
      parser.parseAttribute(maskAttr, "mask", attrs) ||
      parser.parseOptionalAttrDict(attrs) || parser.parseColonType(typeV1) ||
      parser.parseComma() || parser.parseType(typeV2) ||
      parser.parseAttribute(maskAttr, "mask", result.attributes) ||
      parser.parseOptionalAttrDict(result.attributes) ||
      parser.parseColonType(typeV1) || parser.parseComma() ||
      parser.parseType(typeV2) ||
      parser.resolveOperand(v1, typeV1, result.operands) ||
      parser.resolveOperand(v2, typeV2, result.operands))
    return failure();
@@ -1097,10 +1074,8 @@ static ParseResult parseShuffleVectorOp(OpAsmParser &parser,
      !wrappedContainerType1.getUnderlyingType()->isVectorTy())
    return parser.emitError(
        loc, "expected LLVM IR dialect vector type for operand #1");
  auto vType =
      LLVMType::getVectorTy(wrappedContainerType1.getVectorElementType(),
                            maskAttr.cast<ArrayAttr>().size());
  result.attributes = attrs;
  auto vType = LLVMType::getVectorTy(
      wrappedContainerType1.getVectorElementType(), maskAttr.size());
  result.addTypes(vType);
  return success();
}
+3 −3
Original line number Diff line number Diff line
@@ -12,7 +12,7 @@ func @invalid_noalias(%arg0: !llvm.i32 {llvm.noalias = 3}) {
// -----

func @icmp_non_string(%arg0 : !llvm.i32, %arg1 : !llvm<"i16">) {
  // expected-error@+1 {{expected 'predicate' attribute of string type}}
  // expected-error@+1 {{invalid kind of attribute specified}}
  llvm.icmp 42 %arg0, %arg0 : !llvm.i32
  return
}
@@ -156,7 +156,7 @@ func @insertvalue_non_llvm_type(%a : i32, %b : i32) {
func @insertvalue_non_array_position() {
  // Note the double-type, otherwise attribute parsing consumes the trailing
  // type of the op as the (wrong) attribute type.
  // expected-error@+1 {{expected an array attribute}}
  // expected-error@+1 {{invalid kind of attribute specified}}
  llvm.insertvalue %a, %b 0 : i32 : !llvm<"{i32}">
}

@@ -200,7 +200,7 @@ func @extractvalue_non_llvm_type(%a : i32, %b : i32) {
func @extractvalue_non_array_position() {
  // Note the double-type, otherwise attribute parsing consumes the trailing
  // type of the op as the (wrong) attribute type.
  // expected-error@+1 {{expected an array attribute}}
  // expected-error@+1 {{invalid kind of attribute specified}}
  llvm.extractvalue %b 0 : i32 : !llvm<"{i32}">
}