Commit 138df298 authored by Markus Böck's avatar Markus Böck
Browse files

[mlir] Revamp `RegionBranchOpInterface` successor mechanism

The `RegionBranchOpInterface` had a few fundamental issues caused by the API design of `getSuccessorRegions`.

It always required passing values for the `operands` parameter. This is problematic as the operands parameter actually changes meaning depending on which predecessor `index` is referring to. If coming from a region, you'd have to find a `RegionBranchTerminatorOpInterface` in that region, get its operand count, and then create a `SmallVector` of that size.
This is not only inconvenient, but also error-prone, which has lead to a bug in the implementation of a previously existing `getSuccessorRegions` overload.

Additionally, this made the method dual-use, trying to serve two different use-cases: 1) Trying to determine possible control flow edges between regions and 2) Trying to determine the region being branched to based on constant operands.

This patch fixes these issues by changing the interface methods and adding new ones:
* The `operands` argument of `getSuccessorRegions` has been removed. The method is now only responsible for returning possible control flow edges between regions.
* An optional `getEntrySuccessorRegions` method has been added. This is used to determine which regions are branched to from the parent op based on constant operands of the parent op. By default, it calls `getSuccessorRegions`. This is analogous to `getSuccessorForOperands` from `BranchOpInterface`.
* Add `getSuccessorRegions` to `RegionBranchTerminatorOpInterface`. This is used to get the possible successors of the terminator based on constant operands. By default, it calls the containing `RegionBranchOpInterface`s `getSuccessorRegions` method.
* `getSuccessorEntryOperands` was renamed to `getEntrySuccessorOperands` for consistency.

Differential Revision: https://reviews.llvm.org/D157506
parent f5ccac71
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -2147,7 +2147,7 @@ def fir_DoLoopOp : region_Op<"do_loop",
}

def fir_IfOp : region_Op<"if", [DeclareOpInterfaceMethods<RegionBranchOpInterface, [
    "getRegionInvocationBounds"]>, RecursiveMemoryEffects,
    "getRegionInvocationBounds", "getEntrySuccessorRegions"]>, RecursiveMemoryEffects,
    NoRegionArguments]> {
  let summary = "if-then-else conditional operation";
  let description = [{
+25 −21
Original line number Diff line number Diff line
@@ -3461,15 +3461,13 @@ void fir::IfOp::build(mlir::OpBuilder &builder, mlir::OperationState &result,
  }
}

// These 2 functions copied from scf.if implementation.
// These 3 functions copied from scf.if implementation.

/// Given the region at `index`, or the parent operation if `index` is None,
/// return the successor regions. These are the regions that may be selected
/// during the flow of control. `operands` is a set of optional attributes that
/// correspond to a constant value for each operand, or null if that operand is
/// not a constant.
/// during the flow of control.
void fir::IfOp::getSuccessorRegions(
    std::optional<unsigned> index, llvm::ArrayRef<mlir::Attribute> operands,
    std::optional<unsigned> index,
    llvm::SmallVectorImpl<mlir::RegionSuccessor> &regions) {
  // The `then` and the `else` region branch back to the parent operation.
  if (index) {
@@ -3477,27 +3475,33 @@ void fir::IfOp::getSuccessorRegions(
    return;
  }

  // Don't consider the else region if it is empty.
  regions.push_back(mlir::RegionSuccessor(&getThenRegion()));

  // Don't consider the else region if it is empty.
  mlir::Region *elseRegion = &this->getElseRegion();
  if (elseRegion->empty())
    elseRegion = nullptr;

  // Otherwise, the successor is dependent on the condition.
  bool condition;
  if (auto condAttr = operands.front().dyn_cast_or_null<mlir::IntegerAttr>()) {
    condition = condAttr.getValue().isOne();
  } else {
    // If the condition isn't constant, both regions may be executed.
    regions.push_back(mlir::RegionSuccessor(&getThenRegion()));
    // If the else region does not exist, it is not a viable successor.
    if (elseRegion)
    regions.push_back(mlir::RegionSuccessor());
  else
    regions.push_back(mlir::RegionSuccessor(elseRegion));
    return;
}

  // Add the successor regions using the condition.
  regions.push_back(
      mlir::RegionSuccessor(condition ? &getThenRegion() : elseRegion));
void fir::IfOp::getEntrySuccessorRegions(
    llvm::ArrayRef<mlir::Attribute> operands,
    llvm::SmallVectorImpl<mlir::RegionSuccessor> &regions) {
  FoldAdaptor adaptor(operands);
  auto boolAttr =
      mlir::dyn_cast_or_null<mlir::BoolAttr>(adaptor.getCondition());
  if (!boolAttr || boolAttr.getValue())
    regions.emplace_back(&getThenRegion());

  // If the else region is empty, execution continues after the parent op.
  if (!boolAttr || !boolAttr.getValue()) {
    if (!getElseRegion().empty())
      regions.emplace_back(&getElseRegion());
    else
      regions.emplace_back(getResults());
  }
}

void fir::IfOp::getRegionInvocationBounds(
+3 −1
Original line number Diff line number Diff line
@@ -26,6 +26,7 @@ class CallOpInterface;
class CallableOpInterface;
class BranchOpInterface;
class RegionBranchOpInterface;
class RegionBranchTerminatorOpInterface;

namespace dataflow {

@@ -207,7 +208,8 @@ private:
  /// Visit the given terminator operation that exits a region under an
  /// operation with control-flow semantics. These are terminators with no CFG
  /// successors.
  void visitRegionTerminator(Operation *op, RegionBranchOpInterface branch);
  void visitRegionTerminator(RegionBranchTerminatorOpInterface op,
                             RegionBranchOpInterface branch);

  /// Visit the given terminator operation that exits a callable region. These
  /// are terminators with no CFG successors.
+1 −1
Original line number Diff line number Diff line
@@ -123,7 +123,7 @@ def AffineForOp : Affine_Op<"for",
     ["getSingleInductionVar", "getSingleLowerBound", "getSingleStep",
      "getSingleUpperBound"]>,
     DeclareOpInterfaceMethods<RegionBranchOpInterface,
     ["getSuccessorEntryOperands"]>]> {
     ["getEntrySuccessorOperands"]>]> {
  let summary = "for operation";
  let description = [{
    Syntax:
+2 −4
Original line number Diff line number Diff line
@@ -35,7 +35,7 @@ class Async_Op<string mnemonic, list<Trait> traits = []> :
def Async_ExecuteOp :
  Async_Op<"execute", [SingleBlockImplicitTerminator<"YieldOp">,
                       DeclareOpInterfaceMethods<RegionBranchOpInterface,
                                                 ["getSuccessorEntryOperands",
                                                 ["getEntrySuccessorOperands",
                                                  "areTypesCompatible"]>,
                       AttrSizedOperandSegments,
                       AutomaticAllocationScope]> {
@@ -312,8 +312,7 @@ def Async_ReturnOp : Async_Op<"return",

def Async_YieldOp :
    Async_Op<"yield", [
      HasParent<"ExecuteOp">, Pure, Terminator,
      DeclareOpInterfaceMethods<RegionBranchTerminatorOpInterface>]> {
      HasParent<"ExecuteOp">, Pure, Terminator, ReturnLike]> {
  let summary = "terminator for Async execute operation";
  let description = [{
    The `async.yield` is a special terminator operation for the block inside
@@ -322,7 +321,6 @@ def Async_YieldOp :

  let arguments = (ins Variadic<AnyType>:$operands);
  let assemblyFormat = "($operands^ `:` type($operands))? attr-dict";
  let hasVerifier = 1;
}

def Async_AwaitOp : Async_Op<"await"> {
Loading