Commit 7c74a250 authored by Matthias Springer's avatar Matthias Springer
Browse files

[mlir][SCF][NFC] Add helper functions to get body of scf.while

Add two new helper functions `getBeforeBody` and `getAfterBody` to be consistent with "scf.for" (`getBody`) and to show in the API that both regions have exactly one block. Also simplify some code that assumed that there can be more than one block in a region.

Differential Revision: https://reviews.llvm.org/D157860
parent 86f3dc83
Loading
Loading
Loading
Loading
+2 −0
Original line number Diff line number Diff line
@@ -1110,6 +1110,8 @@ def WhileOp : SCF_Op<"while",
    YieldOp getYieldOp();
    Block::BlockArgListType getBeforeArguments();
    Block::BlockArgListType getAfterArguments();
    Block *getBeforeBody() { return &getBefore().front(); }
    Block *getAfterBody() { return &getAfter().front(); }
  }];

  let hasCanonicalizer = 1;
+10 −18
Original line number Diff line number Diff line
@@ -542,10 +542,8 @@ LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp,
      rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());

  // Inline both regions.
  Block *after = &whileOp.getAfter().front();
  Block *afterLast = &whileOp.getAfter().back();
  Block *before = &whileOp.getBefore().front();
  Block *beforeLast = &whileOp.getBefore().back();
  Block *after = whileOp.getAfterBody();
  Block *before = whileOp.getBeforeBody();
  rewriter.inlineRegionBefore(whileOp.getAfter(), continuation);
  rewriter.inlineRegionBefore(whileOp.getBefore(), after);

@@ -556,14 +554,14 @@ LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp,
  // Replace terminators with branches. Assuming bodies are SESE, which holds
  // given only the patterns from this file, we only need to look at the last
  // block. This should be reconsidered if we allow break/continue in SCF.
  rewriter.setInsertionPointToEnd(beforeLast);
  auto condOp = cast<ConditionOp>(beforeLast->getTerminator());
  rewriter.setInsertionPointToEnd(before);
  auto condOp = cast<ConditionOp>(before->getTerminator());
  rewriter.replaceOpWithNewOp<cf::CondBranchOp>(condOp, condOp.getCondition(),
                                                after, condOp.getArgs(),
                                                continuation, ValueRange());

  rewriter.setInsertionPointToEnd(afterLast);
  auto yieldOp = cast<scf::YieldOp>(afterLast->getTerminator());
  rewriter.setInsertionPointToEnd(after);
  auto yieldOp = cast<scf::YieldOp>(after->getTerminator());
  rewriter.replaceOpWithNewOp<cf::BranchOp>(yieldOp, before,
                                            yieldOp.getResults());

@@ -577,12 +575,7 @@ LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp,
LogicalResult
DoWhileLowering::matchAndRewrite(WhileOp whileOp,
                                 PatternRewriter &rewriter) const {
  if (!llvm::hasSingleElement(whileOp.getAfter()))
    return rewriter.notifyMatchFailure(whileOp,
                                       "do-while simplification applicable to "
                                       "single-block 'after' region only");

  Block &afterBlock = whileOp.getAfter().front();
  Block &afterBlock = *whileOp.getAfterBody();
  if (!llvm::hasSingleElement(afterBlock))
    return rewriter.notifyMatchFailure(whileOp,
                                       "do-while simplification applicable "
@@ -601,8 +594,7 @@ DoWhileLowering::matchAndRewrite(WhileOp whileOp,
      rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());

  // Only the "before" region should be inlined.
  Block *before = &whileOp.getBefore().front();
  Block *beforeLast = &whileOp.getBefore().back();
  Block *before = whileOp.getBeforeBody();
  rewriter.inlineRegionBefore(whileOp.getBefore(), continuation);

  // Branch to the "before" region.
@@ -610,8 +602,8 @@ DoWhileLowering::matchAndRewrite(WhileOp whileOp,
  rewriter.create<cf::BranchOp>(whileOp.getLoc(), before, whileOp.getInits());

  // Loop around the "before" region based on condition.
  rewriter.setInsertionPointToEnd(beforeLast);
  auto condOp = cast<ConditionOp>(beforeLast->getTerminator());
  rewriter.setInsertionPointToEnd(before);
  auto condOp = cast<ConditionOp>(before->getTerminator());
  rewriter.replaceOpWithNewOp<cf::CondBranchOp>(condOp, condOp.getCondition(),
                                                before, condOp.getArgs(),
                                                continuation, ValueRange());
+19 −21
Original line number Diff line number Diff line
@@ -3177,19 +3177,19 @@ OperandRange WhileOp::getEntrySuccessorOperands(std::optional<unsigned> index) {
}

ConditionOp WhileOp::getConditionOp() {
  return cast<ConditionOp>(getBefore().front().getTerminator());
  return cast<ConditionOp>(getBeforeBody()->getTerminator());
}

YieldOp WhileOp::getYieldOp() {
  return cast<YieldOp>(getAfter().front().getTerminator());
  return cast<YieldOp>(getAfterBody()->getTerminator());
}

Block::BlockArgListType WhileOp::getBeforeArguments() {
  return getBefore().front().getArguments();
  return getBeforeBody()->getArguments();
}

Block::BlockArgListType WhileOp::getAfterArguments() {
  return getAfter().front().getArguments();
  return getAfterBody()->getArguments();
}

void WhileOp::getSuccessorRegions(std::optional<unsigned> index,
@@ -3260,8 +3260,7 @@ ParseResult scf::WhileOp::parse(OpAsmParser &parser, OperationState &result) {

/// Prints a `while` op.
void scf::WhileOp::print(OpAsmPrinter &p) {
  printInitializationList(p, getBefore().front().getArguments(), getInits(),
                          " ");
  printInitializationList(p, getBeforeArguments(), getInits(), " ");
  p << " : ";
  p.printFunctionalType(getInits().getTypes(), getResults().getTypes());
  p << ' ';
@@ -3411,7 +3410,7 @@ struct RemoveLoopInvariantArgsFromBeforeBlock

  LogicalResult matchAndRewrite(WhileOp op,
                                PatternRewriter &rewriter) const override {
    Block &afterBlock = op.getAfter().front();
    Block &afterBlock = *op.getAfterBody();
    Block::BlockArgListType beforeBlockArgs = op.getBeforeArguments();
    ConditionOp condOp = op.getConditionOp();
    OperandRange condOpArgs = condOp.getArgs();
@@ -3493,7 +3492,7 @@ struct RemoveLoopInvariantArgsFromBeforeBlock
        &newWhile.getBefore(), /*insertPt*/ {},
        ValueRange(newYieldOpArgs).getTypes(), newBeforeBlockArgLocs);

    Block &beforeBlock = op.getBefore().front();
    Block &beforeBlock = *op.getBeforeBody();
    SmallVector<Value> newBeforeBlockArgs(beforeBlock.getNumArguments());
    // For each i-th before block argument we find it's replacement value as :-
    //   1. If i-th before block argument is a loop invariant, we fetch it's
@@ -3563,7 +3562,7 @@ struct RemoveLoopInvariantValueYielded : public OpRewritePattern<WhileOp> {

  LogicalResult matchAndRewrite(WhileOp op,
                                PatternRewriter &rewriter) const override {
    Block &beforeBlock = op.getBefore().front();
    Block &beforeBlock = *op.getBeforeBody();
    ConditionOp condOp = op.getConditionOp();
    OperandRange condOpArgs = condOp.getArgs();

@@ -3616,7 +3615,7 @@ struct RemoveLoopInvariantValueYielded : public OpRewritePattern<WhileOp> {
        *rewriter.createBlock(&newWhile.getAfter(), /*insertPt*/ {},
                              newAfterBlockType, newAfterBlockArgLocs);

    Block &afterBlock = op.getAfter().front();
    Block &afterBlock = *op.getAfterBody();
    // Since a new scf.condition op was created, we need to fetch the new
    // `after` block arguments which will be used while replacing operations of
    // previous scf.while's `after` blocks. We'd also be fetching new result
@@ -3733,7 +3732,7 @@ struct WhileUnusedResult : public OpRewritePattern<WhileOp> {
    rewriter.inlineRegionBefore(op.getBefore(), newWhile.getBefore(),
                                newWhile.getBefore().begin());

    Block &afterBlock = op.getAfter().front();
    Block &afterBlock = *op.getAfterBody();
    rewriter.mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);

    rewriter.replaceOp(op, newResults);
@@ -3774,8 +3773,7 @@ struct WhileCmpCond : public OpRewritePattern<scf::WhileOp> {
    if (!cmp)
      return failure();
    bool changed = false;
    for (auto tup :
         llvm::zip(cond.getArgs(), op.getAfter().front().getArguments())) {
    for (auto tup : llvm::zip(cond.getArgs(), op.getAfterArguments())) {
      for (size_t opIdx = 0; opIdx < 2; opIdx++) {
        if (std::get<0>(tup) != cmp.getOperand(opIdx))
          continue;
@@ -3839,8 +3837,8 @@ struct WhileRemoveUnusedArgs : public OpRewritePattern<WhileOp> {
      }
    }

    Block &beforeBlock = op.getBefore().front();
    Block &afterBlock = op.getAfter().front();
    Block &beforeBlock = *op.getBeforeBody();
    Block &afterBlock = *op.getAfterBody();

    beforeBlock.eraseArguments(argsToErase);

@@ -3848,8 +3846,8 @@ struct WhileRemoveUnusedArgs : public OpRewritePattern<WhileOp> {
    auto newWhileOp =
        rewriter.create<WhileOp>(loc, op.getResultTypes(), newInits,
                                 /*beforeBody*/ nullptr, /*afterBody*/ nullptr);
    Block &newBeforeBlock = newWhileOp.getBefore().front();
    Block &newAfterBlock = newWhileOp.getAfter().front();
    Block &newBeforeBlock = *newWhileOp.getBeforeBody();
    Block &newAfterBlock = *newWhileOp.getAfterBody();

    OpBuilder::InsertionGuard g(rewriter);
    rewriter.setInsertionPoint(yield);
@@ -3899,8 +3897,8 @@ struct WhileRemoveDuplicatedResults : public OpRewritePattern<WhileOp> {
    auto newWhileOp = rewriter.create<scf::WhileOp>(
        loc, argsRange.getTypes(), op.getInits(), /*beforeBody*/ nullptr,
        /*afterBody*/ nullptr);
    Block &newBeforeBlock = newWhileOp.getBefore().front();
    Block &newAfterBlock = newWhileOp.getAfter().front();
    Block &newBeforeBlock = *newWhileOp.getBeforeBody();
    Block &newAfterBlock = *newWhileOp.getAfterBody();

    SmallVector<Value> afterArgsMapping;
    SmallVector<Value> resultsMapping;
@@ -3917,8 +3915,8 @@ struct WhileRemoveDuplicatedResults : public OpRewritePattern<WhileOp> {
    rewriter.replaceOpWithNewOp<ConditionOp>(condOp, condOp.getCondition(),
                                             argsRange);

    Block &beforeBlock = op.getBefore().front();
    Block &afterBlock = op.getAfter().front();
    Block &beforeBlock = *op.getBeforeBody();
    Block &afterBlock = *op.getAfterBody();

    rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock,
                         newBeforeBlock.getArguments());
+2 −9
Original line number Diff line number Diff line
@@ -760,13 +760,6 @@ struct WhileOpInterface
                          const BufferizationOptions &options) const {
    auto whileOp = cast<scf::WhileOp>(op);

    assert(whileOp.getBefore().getBlocks().size() == 1 &&
           "regions with multiple blocks not supported");
    Block *beforeBody = &whileOp.getBefore().front();
    assert(whileOp.getAfter().getBlocks().size() == 1 &&
           "regions with multiple blocks not supported");
    Block *afterBody = &whileOp.getAfter().front();

    // Indices of all bbArgs that have tensor type. These are the ones that
    // are bufferized. The "before" and "after" regions may have different args.
    DenseSet<int64_t> indicesBefore = getTensorIndices(whileOp.getInits());
@@ -827,7 +820,7 @@ struct WhileOpInterface
    rewriter.setInsertionPointToStart(newBeforeBody);
    SmallVector<Value> newBeforeArgs = getBbArgReplacements(
        rewriter, newWhileOp.getBeforeArguments(), indicesBefore);
    rewriter.mergeBlocks(beforeBody, newBeforeBody, newBeforeArgs);
    rewriter.mergeBlocks(whileOp.getBeforeBody(), newBeforeBody, newBeforeArgs);

    // Set up new iter_args and move the loop body block to the new op.
    // The old block uses tensors, so wrap the (memref) bbArgs of the new block
@@ -835,7 +828,7 @@ struct WhileOpInterface
    rewriter.setInsertionPointToStart(newAfterBody);
    SmallVector<Value> newAfterArgs = getBbArgReplacements(
        rewriter, newWhileOp.getAfterArguments(), indicesAfter);
    rewriter.mergeBlocks(afterBody, newAfterBody, newAfterArgs);
    rewriter.mergeBlocks(whileOp.getAfterBody(), newAfterBody, newAfterArgs);

    // Replace loop results.
    replaceOpWithBufferizedValues(rewriter, op, newWhileOp->getResults());
+1 −1
Original line number Diff line number Diff line
@@ -57,7 +57,7 @@ struct ForLoopLoweringPattern : public OpRewritePattern<ForOp> {
    // arguments to the 'after' region.
    auto *beforeBlock = rewriter.createBlock(
        &whileOp.getBefore(), whileOp.getBefore().begin(), lcvTypes, lcvLocs);
    rewriter.setInsertionPointToStart(&whileOp.getBefore().front());
    rewriter.setInsertionPointToStart(whileOp.getBeforeBody());
    auto cmpOp = rewriter.create<arith::CmpIOp>(
        whileOp.getLoc(), arith::CmpIPredicate::slt,
        beforeBlock->getArgument(0), forOp.getUpperBound());