Commit d242aa24 authored by Shraiysh Vaishay's avatar Shraiysh Vaishay Committed by Alex Zinenko
Browse files

[MLIR] Added llvm.invoke and llvm.landingpad

Summary:
I have tried to implement `llvm.invoke` and `llvm.landingpad`.

  # `llvm.invoke` is similar to `llvm.call` with two successors added, the first one is the normal label and the second one is unwind label.
  # `llvm.launchpad` takes a variable number of args with either `catch` or `filter` associated with them. Catch clauses are not array types and filter clauses are array types. This is same as the criteria used by LLVM (https://github.com/llvm/llvm-project/blob/4f82af81a04d711721300f6ca32f402f2ea6faf4/llvm/include/llvm/IR/Instructions.h#L2866

)

Examples:
LLVM IR
```
define i32 @caller(i32 %a) personality i8* bitcast (i32 (...)* @__gxx_personality_v0 to i8*) {
    invoke i32 @foo(i32 2) to label %success unwind label %fail

  success:
    ret i32 2

  fail:
    landingpad {i8*, i32} catch i8** @_ZTIi catch i8** null catch i8* bitcast (i8** @_ZTIi to i8*) filter [1 x i8] [ i8 1 ]
    ret i32 3
}
```
MLIR LLVM Dialect
```
llvm.func @caller(%arg0: !llvm.i32) -> !llvm.i32 {
  %0 = llvm.mlir.constant(3 : i32) : !llvm.i32
  %1 = llvm.mlir.constant("\01") : !llvm<"[1 x i8]">
  %2 = llvm.mlir.addressof @_ZTIi : !llvm<"i8**">
  %3 = llvm.bitcast %2 : !llvm<"i8**"> to !llvm<"i8*">
  %4 = llvm.mlir.null : !llvm<"i8**">
  %5 = llvm.mlir.addressof @_ZTIi : !llvm<"i8**">
  %6 = llvm.mlir.constant(2 : i32) : !llvm.i32
  %7 = llvm.invoke @foo(%6) to ^bb1 unwind ^bb2 : (!llvm.i32) -> !llvm.i32
^bb1:	// pred: ^bb0
  llvm.return %6 : !llvm.i32
^bb2:	// pred: ^bb0
  %8 = llvm.landingpad (catch %5 : !llvm<"i8**">) (catch %4 : !llvm<"i8**">) (catch %3 : !llvm<"i8*">) (filter %1 : !llvm<"[1 x i8]">) : !llvm<"{ i8*, i32 }">
  llvm.return %0 : !llvm.i32
}
```

Signed-off-by: default avatarShraiysh Vaishay <cs17btech11050@iith.ac.in>

Differential Revision: https://reviews.llvm.org/D72006
parent 06e12893
Loading
Loading
Loading
Loading
+35 −0
Original line number Diff line number Diff line
@@ -315,6 +315,41 @@ def LLVM_FPExtOp : LLVM_CastOp<"fpext", "CreateFPExt">;
def LLVM_FPTruncOp : LLVM_CastOp<"fptrunc", "CreateFPTrunc">;

// Call-related operations.
def LLVM_InvokeOp : LLVM_Op<"invoke", [Terminator]>,
                    Arguments<(ins OptionalAttr<FlatSymbolRefAttr>:$callee,
                               Variadic<LLVM_Type>)>,
                    Results<(outs Variadic<LLVM_Type>)> {
  let builders = [OpBuilder<
    "Builder *b, OperationState &result, ArrayRef<Type> tys, "
    "FlatSymbolRefAttr callee, ValueRange ops, Block* normal, "
    "ValueRange normalOps, Block* unwind, ValueRange unwindOps",
    [{
      result.addAttribute("callee", callee);
      build(b, result, tys, ops, normal, normalOps, unwind, unwindOps);
    }]>,
    OpBuilder<
    "Builder *b, OperationState &result, ArrayRef<Type> tys, "
    "ValueRange ops, Block* normal, "
    "ValueRange normalOps, Block* unwind, ValueRange unwindOps",
    [{
      result.addTypes(tys);
      result.addOperands(ops);
      result.addSuccessor(normal, normalOps);
      result.addSuccessor(unwind, unwindOps);
    }]>];
  let verifier = [{ return ::verify(*this);  }];
  let parser = [{ return parseInvokeOp(parser, result); }];
  let printer = [{ printInvokeOp(p, *this); }];
}

def LLVM_LandingpadOp : LLVM_OneResultOp<"landingpad">,
                        Arguments<(ins UnitAttr:$cleanup, 
                                       Variadic<LLVM_Type>)> {
  let verifier = [{ return ::verify(*this); }];
  let parser = [{ return parseLandingpadOp(parser, result); }];
  let printer = [{ printLandingpadOp(p, *this); }];
}

def LLVM_CallOp : LLVM_Op<"call">,
                  Arguments<(ins OptionalAttr<FlatSymbolRefAttr>:$callee,
                             Variadic<LLVM_Type>)>,
+225 −0
Original line number Diff line number Diff line
@@ -272,6 +272,231 @@ static ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) {
  return success();
}

///===----------------------------------------------------------------------===//
/// Verifying/Printing/Parsing for LLVM::InvokeOp.
///===----------------------------------------------------------------------===//

static LogicalResult verify(InvokeOp op) {
  if (op.getNumResults() > 1)
    return op.emitOpError("must have 0 or 1 result");
  if (op.getNumSuccessors() != 2)
    return op.emitOpError("must have normal and unwind destinations");

  if (op.getSuccessor(1)->empty())
    return op.emitError(
        "must have at least one operation in unwind destination");

  // In unwind destination, first operation must be LandingpadOp
  if (!isa<LandingpadOp>(op.getSuccessor(1)->front()))
    return op.emitError("first operation in unwind destination should be a "
                        "llvm.landingpad operation");

  return success();
}

static void printInvokeOp(OpAsmPrinter &p, InvokeOp &op) {
  auto callee = op.callee();
  bool isDirect = callee.hasValue();

  p << op.getOperationName() << ' ';

  // Either function name or pointer
  if (isDirect)
    p.printSymbolName(callee.getValue());
  else
    p << op.getOperand(0);

  p << '(' << op.getOperands().drop_front(isDirect ? 0 : 1) << ')';
  p << " to ";
  p.printSuccessorAndUseList(op.getOperation(), 0);
  p << " unwind ";
  p.printSuccessorAndUseList(op.getOperation(), 1);

  p.printOptionalAttrDict(op.getAttrs(), {"callee"});

  SmallVector<Type, 8> argTypes(
      llvm::drop_begin(op.getOperandTypes(), isDirect ? 0 : 1));

  p << " : "
    << FunctionType::get(argTypes, op.getResultTypes(), op.getContext());
}

/// <operation> ::= `llvm.invoke` (function-id | ssa-use) `(` ssa-use-list `)`
///                  `to` bb-id (`[` ssa-use-and-type-list `]`)?
///                  `unwind` bb-id (`[` ssa-use-and-type-list `]`)?
///                  attribute-dict? `:` function-type
static ParseResult parseInvokeOp(OpAsmParser &parser, OperationState &result) {
  SmallVector<OpAsmParser::OperandType, 8> operands;
  FunctionType funcType;
  SymbolRefAttr funcAttr;
  llvm::SMLoc trailingTypeLoc;
  Block *normalDest, *unwindDest;
  SmallVector<Value, 4> normalOperands, unwindOperands;

  // Parse an operand list that will, in practice, contain 0 or 1 operand.  In
  // case of an indirect call, there will be 1 operand before `(`.  In case of a
  // direct call, there will be no operands and the parser will stop at the
  // function identifier without complaining.
  if (parser.parseOperandList(operands))
    return failure();
  bool isDirect = operands.empty();

  // Optionally parse a function identifier.
  if (isDirect && parser.parseAttribute(funcAttr, "callee", result.attributes))
    return failure();

  if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
      parser.parseKeyword("to") ||
      parser.parseSuccessorAndUseList(normalDest, normalOperands) ||
      parser.parseKeyword("unwind") ||
      parser.parseSuccessorAndUseList(unwindDest, unwindOperands) ||
      parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
      parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(funcType))
    return failure();

  if (isDirect) {
    // Make sure types match.
    if (parser.resolveOperands(operands, funcType.getInputs(),
                               parser.getNameLoc(), result.operands))
      return failure();
    result.addTypes(funcType.getResults());
  } else {
    // Construct the LLVM IR Dialect function type that the first operand
    // should match.
    if (funcType.getNumResults() > 1)
      return parser.emitError(trailingTypeLoc,
                              "expected function with 0 or 1 result");

    Builder &builder = parser.getBuilder();
    auto *llvmDialect =
        builder.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
    LLVM::LLVMType llvmResultType;
    if (funcType.getNumResults() == 0) {
      llvmResultType = LLVM::LLVMType::getVoidTy(llvmDialect);
    } else {
      llvmResultType = funcType.getResult(0).dyn_cast<LLVM::LLVMType>();
      if (!llvmResultType)
        return parser.emitError(trailingTypeLoc,
                                "expected result to have LLVM type");
    }

    SmallVector<LLVM::LLVMType, 8> argTypes;
    argTypes.reserve(funcType.getNumInputs());
    for (Type ty : funcType.getInputs()) {
      if (auto argType = ty.dyn_cast<LLVM::LLVMType>())
        argTypes.push_back(argType);
      else
        return parser.emitError(trailingTypeLoc,
                                "expected LLVM types as inputs");
    }

    auto llvmFuncType = LLVM::LLVMType::getFunctionTy(llvmResultType, argTypes,
                                                      /*isVarArg=*/false);
    auto wrappedFuncType = llvmFuncType.getPointerTo();

    auto funcArguments = llvm::makeArrayRef(operands).drop_front();

    // Make sure that the first operand (indirect callee) matches the wrapped
    // LLVM IR function type, and that the types of the other call operands
    // match the types of the function arguments.
    if (parser.resolveOperand(operands[0], wrappedFuncType, result.operands) ||
        parser.resolveOperands(funcArguments, funcType.getInputs(),
                               parser.getNameLoc(), result.operands))
      return failure();

    result.addTypes(llvmResultType);
  }
  result.addSuccessor(normalDest, normalOperands);
  result.addSuccessor(unwindDest, unwindOperands);
  return success();
}

///===----------------------------------------------------------------------===//
/// Verifying/Printing/Parsing for LLVM::LandingpadOp.
///===----------------------------------------------------------------------===//

static LogicalResult verify(LandingpadOp op) {
  Value value;

  if (!op.cleanup() && op.getOperands().empty())
    return op.emitError("landingpad instruction expects at least one clause or "
                        "cleanup attribute");

  for (unsigned idx = 0, ie = op.getNumOperands(); idx < ie; idx++) {
    value = op.getOperand(idx);
    bool isFilter = value.getType().cast<LLVMType>().isArrayTy();
    if (isFilter) {
      // FIXME: Verify filter clauses when arrays are appropriately handled
    } else {
      // catch - global addresses only.
      // Bitcast ops should have global addresses as their args.
      if (auto bcOp = dyn_cast_or_null<BitcastOp>(value.getDefiningOp())) {
        if (auto addrOp =
                dyn_cast_or_null<AddressOfOp>(bcOp.arg().getDefiningOp()))
          continue;
        return op.emitError("constant clauses expected")
                   .attachNote(bcOp.getLoc())
               << "global addresses expected as operand to "
                  "bitcast used in clauses for landingpad";
      }
      // NullOp and AddressOfOp allowed
      if (dyn_cast_or_null<NullOp>(value.getDefiningOp()))
        continue;
      if (dyn_cast_or_null<AddressOfOp>(value.getDefiningOp()))
        continue;
      return op.emitError("clause #")
             << idx << " is not a known constant - null, addressof, bitcast";
    }
  }
  return success();
}

static void printLandingpadOp(OpAsmPrinter &p, LandingpadOp &op) {
  p << op.getOperationName() << (op.cleanup() ? " cleanup " : " ");

  // Clauses
  for (auto value : op.getOperands()) {
    // Similar to llvm - if clause is an array type then it is filter
    // clause else catch clause
    bool isArrayTy = value.getType().cast<LLVMType>().isArrayTy();
    p << '(' << (isArrayTy ? "filter " : "catch ") << value << " : "
      << value.getType() << ") ";
  }

  p.printOptionalAttrDict(op.getAttrs(), {"cleanup"});

  p << ": " << op.getType();
}

/// <operation> ::= `llvm.landingpad` `cleanup`?
///                 ((`catch` | `filter`) operand-type ssa-use)* attribute-dict?
static ParseResult parseLandingpadOp(OpAsmParser &parser,
                                     OperationState &result) {
  // Check for cleanup
  if (succeeded(parser.parseOptionalKeyword("cleanup")))
    result.addAttribute("cleanup", parser.getBuilder().getUnitAttr());

  // Parse clauses with types
  while (succeeded(parser.parseOptionalLParen()) &&
         (succeeded(parser.parseOptionalKeyword("filter")) ||
          succeeded(parser.parseOptionalKeyword("catch")))) {
    OpAsmParser::OperandType operand;
    Type ty;
    if (parser.parseOperand(operand) || parser.parseColon() ||
        parser.parseType(ty) ||
        parser.resolveOperand(operand, ty, result.operands) ||
        parser.parseRParen())
      return failure();
  }

  Type type;
  if (parser.parseColon() || parser.parseType(type))
    return failure();

  result.addTypes(type);
  return success();
}

//===----------------------------------------------------------------------===//
// Printing/parsing for LLVM::CallOp.
//===----------------------------------------------------------------------===//
+53 −15
Original line number Diff line number Diff line
@@ -76,7 +76,7 @@ private:
  /// `br` branches to `target`. Append the block arguments to attach to the
  /// generated branch op to `blockArguments`. These should be in the same order
  /// as the PHIs in `target`.
  LogicalResult processBranchArgs(llvm::BranchInst *br,
  LogicalResult processBranchArgs(llvm::Instruction *br,
                                  llvm::BasicBlock *target,
                                  SmallVectorImpl<Value> &blockArguments);
  /// Returns the standard type equivalent to be used in attributes for the
@@ -422,21 +422,26 @@ GlobalOp Importer::processGlobal(llvm::GlobalVariable *GV) {
}

Value Importer::processConstant(llvm::Constant *c) {
  OpBuilder bEntry(currentEntryBlock, currentEntryBlock->begin());
  if (Attribute attr = getConstantAsAttr(c)) {
    // These constants can be represented as attributes.
    OpBuilder b(currentEntryBlock, currentEntryBlock->begin());
    LLVMType type = processType(c->getType());
    if (!type)
      return nullptr;
    return instMap[c] = b.create<ConstantOp>(unknownLoc, type, attr);
    return instMap[c] = bEntry.create<ConstantOp>(unknownLoc, type, attr);
  }
  if (auto *cn = dyn_cast<llvm::ConstantPointerNull>(c)) {
    OpBuilder b(currentEntryBlock, currentEntryBlock->begin());
    LLVMType type = processType(cn->getType());
    if (!type)
      return nullptr;
    return instMap[c] = b.create<NullOp>(unknownLoc, type);
    return instMap[c] = bEntry.create<NullOp>(unknownLoc, type);
  }
  if (auto *GV = dyn_cast<llvm::GlobalVariable>(c))
    return bEntry.create<AddressOfOp>(UnknownLoc::get(context),
                                      processGlobal(GV),
                                      ArrayRef<NamedAttribute>());

  if (auto *ce = dyn_cast<llvm::ConstantExpr>(c)) {
    llvm::Instruction *i = ce->getAsInstruction();
    OpBuilder::InsertionGuard guard(b);
@@ -471,16 +476,6 @@ Value Importer::processValue(llvm::Value *value) {
    return unknownInstMap[value]->getResult(0);
  }

  if (auto *GV = dyn_cast<llvm::GlobalVariable>(value)) {
    auto global = processGlobal(GV);
    if (!global)
      return nullptr;
    return b.create<AddressOfOp>(UnknownLoc::get(context), global,
                                 ArrayRef<NamedAttribute>());
  }

  // Note, constant global variables are both GlobalVariables and Constants,
  // so we handle GlobalVariables first above.
  if (auto *c = dyn_cast<llvm::Constant>(value))
    return processConstant(c);

@@ -570,7 +565,7 @@ static ICmpPredicate getICmpPredicate(llvm::CmpInst::Predicate p) {
// `br` branches to `target`. Return the branch arguments to `br`, in the
// same order of the PHIs in `target`.
LogicalResult
Importer::processBranchArgs(llvm::BranchInst *br, llvm::BasicBlock *target,
Importer::processBranchArgs(llvm::Instruction *br, llvm::BasicBlock *target,
                            SmallVectorImpl<Value> &blockArguments) {
  for (auto inst = target->begin(); isa<llvm::PHINode>(inst); ++inst) {
    auto *PN = cast<llvm::PHINode>(&*inst);
@@ -719,6 +714,49 @@ LogicalResult Importer::processInstruction(llvm::Instruction *inst) {
      v = op->getResult(0);
    return success();
  }
  case llvm::Instruction::LandingPad: {
    llvm::LandingPadInst *lpi = cast<llvm::LandingPadInst>(inst);
    SmallVector<Value, 4> ops;

    for (unsigned i = 0, ie = lpi->getNumClauses(); i < ie; i++)
      ops.push_back(processConstant(lpi->getClause(i)));

    b.create<LandingpadOp>(loc, processType(lpi->getType()), lpi->isCleanup(),
                           ops);
    return success();
  }
  case llvm::Instruction::Invoke: {
    llvm::InvokeInst *ii = cast<llvm::InvokeInst>(inst);

    SmallVector<Type, 2> tys;
    if (!ii->getType()->isVoidTy())
      tys.push_back(processType(inst->getType()));

    SmallVector<Value, 4> ops;
    ops.reserve(inst->getNumOperands() + 1);
    for (auto &op : ii->arg_operands())
      ops.push_back(processValue(op.get()));

    SmallVector<Value, 4> normalArgs, unwindArgs;
    processBranchArgs(ii, ii->getNormalDest(), normalArgs);
    processBranchArgs(ii, ii->getUnwindDest(), unwindArgs);

    Operation *op;
    if (llvm::Function *callee = ii->getCalledFunction()) {
      op = b.create<InvokeOp>(loc, tys, b.getSymbolRefAttr(callee->getName()),
                              ops, blocks[ii->getNormalDest()], normalArgs,
                              blocks[ii->getUnwindDest()], unwindArgs);
    } else {
      ops.insert(ops.begin(), processValue(ii->getCalledValue()));
      op = b.create<InvokeOp>(loc, tys, ops, blocks[ii->getNormalDest()],
                              normalArgs, blocks[ii->getUnwindDest()],
                              unwindArgs);
    }

    if (!ii->getType()->isVoidTy())
      v = op->getResult(0);
    return success();
  }
  case llvm::Instruction::GetElementPtr: {
    // FIXME: Support inbounds GEPs.
    llvm::GetElementPtrInst *gep = cast<llvm::GetElementPtrInst>(inst);
+28 −0
Original line number Diff line number Diff line
@@ -307,6 +307,34 @@ LogicalResult ModuleTranslation::convertOperation(Operation &opInst,
    return success(result->getType()->isVoidTy());
  }

  if (auto invOp = dyn_cast<LLVM::InvokeOp>(opInst)) {
    auto operands = lookupValues(opInst.getOperands());
    ArrayRef<llvm::Value *> operandsRef(operands);
    if (auto attr = opInst.getAttrOfType<FlatSymbolRefAttr>("callee"))
      builder.CreateInvoke(functionMapping.lookup(attr.getValue()),
                           blockMapping[invOp.getSuccessor(0)],
                           blockMapping[invOp.getSuccessor(1)], operandsRef);
    else
      builder.CreateInvoke(
          operandsRef.front(), blockMapping[invOp.getSuccessor(0)],
          blockMapping[invOp.getSuccessor(1)], operandsRef.drop_front());
    return success();
  }

  if (auto lpOp = dyn_cast<LLVM::LandingpadOp>(opInst)) {
    llvm::Type *ty = lpOp.getType().dyn_cast<LLVMType>().getUnderlyingType();
    llvm::LandingPadInst *lpi =
        builder.CreateLandingPad(ty, lpOp.getNumOperands());

    // Add clauses
    for (auto operand : lookupValues(lpOp.getOperands())) {
      // All operands should be constant - checked by verifier
      if (auto constOperand = dyn_cast<llvm::Constant>(operand))
        lpi->addClause(constOperand);
    }
    return success();
  }

  // Emit branches.  We need to look up the remapped blocks and ignore the block
  // arguments that were transformed into PHI nodes.
  if (auto brOp = dyn_cast<LLVM::BrOp>(opInst)) {
+52 −0
Original line number Diff line number Diff line
@@ -509,3 +509,55 @@ func @cmpxchg_failure_acq_rel(%i32_ptr : !llvm<"i32*">, %i32 : !llvm.i32) {
  %0 = llvm.cmpxchg %i32_ptr, %i32, %i32 acq_rel acq_rel : !llvm.i32
  llvm.return
}

// -----

llvm.func @foo(!llvm.i32) -> !llvm.i32
llvm.func @__gxx_personality_v0(...) -> !llvm.i32

llvm.func @bad_landingpad(%arg0: !llvm<"i8**">) {
  %0 = llvm.mlir.constant(3 : i32) : !llvm.i32
  %1 = llvm.mlir.constant(2 : i32) : !llvm.i32
  %2 = llvm.invoke @foo(%1) to ^bb1 unwind ^bb2 : (!llvm.i32) -> !llvm.i32
^bb1:  // pred: ^bb0
  llvm.return %1 : !llvm.i32
^bb2:  // pred: ^bb0
  // expected-error@+1 {{clause #0 is not a known constant - null, addressof, bitcast}}
  %3 = llvm.landingpad cleanup (catch %1 : !llvm.i32) (catch %arg0 : !llvm<"i8**">) : !llvm<"{ i8*, i32 }">
  llvm.return %0 : !llvm.i32
}

// -----

llvm.func @foo(!llvm.i32) -> !llvm.i32
llvm.func @__gxx_personality_v0(...) -> !llvm.i32

llvm.func @caller(%arg0: !llvm.i32) -> !llvm.i32 {
  %0 = llvm.mlir.constant(1 : i32) : !llvm.i32
  %1 = llvm.alloca %0 x !llvm<"i8*"> : (!llvm.i32) -> !llvm<"i8**">
  // expected-note@+1 {{global addresses expected as operand to bitcast used in clauses for landingpad}}
  %2 = llvm.bitcast %1 : !llvm<"i8**"> to !llvm<"i8*">
  %3 = llvm.invoke @foo(%0) to ^bb1 unwind ^bb2 : (!llvm.i32) -> !llvm.i32
^bb1: // pred: ^bb0
  llvm.return %0 : !llvm.i32
^bb2: // pred: ^bb0
  // expected-error@+1 {{constant clauses expected}}
  %5 = llvm.landingpad (catch %2 : !llvm<"i8*">) : !llvm<"{ i8*, i32 }">
  llvm.return %0 : !llvm.i32
}

// -----

llvm.func @foo(!llvm.i32) -> !llvm.i32
llvm.func @__gxx_personality_v0(...) -> !llvm.i32

llvm.func @caller(%arg0: !llvm.i32) -> !llvm.i32 {
  %0 = llvm.mlir.constant(1 : i32) : !llvm.i32
  %1 = llvm.invoke @foo(%0) to ^bb1 unwind ^bb2 : (!llvm.i32) -> !llvm.i32
^bb1: // pred: ^bb0
  llvm.return %0 : !llvm.i32
^bb2: // pred: ^bb0
  // expected-error@+1 {{landingpad instruction expects at least one clause or cleanup attribute}}
  %2 = llvm.landingpad : !llvm<"{ i8*, i32 }">
  llvm.return %0 : !llvm.i32
}
Loading