Commit fdcecefe authored by Stephan Herhut's avatar Stephan Herhut
Browse files

Add lowering for loop.parallel to cfg.

Summary:
This also removes the explicit pattern for loop.terminator to ensure
that the terminator is only erased if the parent op is rewritten.

Reductions are not yet supported.

Reviewers: nicolasvasilache

Subscribers: mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, aartbik, liufengdb, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D73348
parent 8ed47b74
Loading
Loading
Loading
Loading
+7 −0
Original line number Diff line number Diff line
@@ -177,6 +177,13 @@ def ParallelOp : Loop_Op<"parallel",
                       Variadic<Index>:$step);
  let results = (outs Variadic<AnyType>:$results);
  let regions = (region SizedRegion<1>:$body);

  let extraClassDeclaration = [{
    iterator_range<Block::args_iterator> getInductionVars() {
      Block &block = body().front();
      return {block.args_begin(), block.args_end()};
    }
  }];
}

def ReduceOp : Loop_Op<"reduce", [HasParent<"ParallelOp">]> {
+42 −8
Original line number Diff line number Diff line
@@ -14,6 +14,7 @@
#include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h"
#include "mlir/Dialect/LoopOps/LoopOps.h"
#include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
@@ -142,14 +143,11 @@ struct IfLowering : public OpRewritePattern<IfOp> {
                                     PatternRewriter &rewriter) const override;
};

struct TerminatorLowering : public OpRewritePattern<TerminatorOp> {
  using OpRewritePattern<TerminatorOp>::OpRewritePattern;
struct ParallelLowering : public OpRewritePattern<mlir::loop::ParallelOp> {
  using OpRewritePattern<mlir::loop::ParallelOp>::OpRewritePattern;

  PatternMatchResult matchAndRewrite(TerminatorOp op,
                                     PatternRewriter &rewriter) const override {
    rewriter.eraseOp(op);
    return matchSuccess();
  }
  PatternMatchResult matchAndRewrite(mlir::loop::ParallelOp parallelOp,
                                     PatternRewriter &rewriter) const override;
};
} // namespace

@@ -178,6 +176,7 @@ ForLowering::matchAndRewrite(ForOp forOp, PatternRewriter &rewriter) const {
  // Append the induction variable stepping logic to the last body block and
  // branch back to the condition block.  Construct an expression f :
  // (x -> x+step) and apply this expression to the induction variable.
  rewriter.eraseOp(lastBodyBlock->getTerminator());
  rewriter.setInsertionPointToEnd(lastBodyBlock);
  auto step = forOp.step();
  auto stepped = rewriter.create<AddIOp>(loc, iv, step).getResult();
@@ -220,6 +219,7 @@ IfLowering::matchAndRewrite(IfOp ifOp, PatternRewriter &rewriter) const {
  // place it before the continuation block, and branch to it.
  auto &thenRegion = ifOp.thenRegion();
  auto *thenBlock = &thenRegion.front();
  rewriter.eraseOp(thenRegion.back().getTerminator());
  rewriter.setInsertionPointToEnd(&thenRegion.back());
  rewriter.create<BranchOp>(loc, continueBlock);
  rewriter.inlineRegionBefore(thenRegion, continueBlock);
@@ -231,6 +231,7 @@ IfLowering::matchAndRewrite(IfOp ifOp, PatternRewriter &rewriter) const {
  auto &elseRegion = ifOp.elseRegion();
  if (!elseRegion.empty()) {
    elseBlock = &elseRegion.front();
    rewriter.eraseOp(elseRegion.back().getTerminator());
    rewriter.setInsertionPointToEnd(&elseRegion.back());
    rewriter.create<BranchOp>(loc, continueBlock);
    rewriter.inlineRegionBefore(elseRegion, continueBlock);
@@ -246,9 +247,42 @@ IfLowering::matchAndRewrite(IfOp ifOp, PatternRewriter &rewriter) const {
  return matchSuccess();
}

PatternMatchResult
ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
                                  PatternRewriter &rewriter) const {
  Location loc = parallelOp.getLoc();
  BlockAndValueMapping mapping;

  if (parallelOp.getNumResults() != 0) {
    // TODO: Implement lowering of parallelOp with reductions.
    return matchFailure();
  }

  // For a parallel loop, we essentially need to create an n-dimensional loop
  // nest. We do this by translating to loop.for ops and have those lowered in
  // a further rewrite.
  for (auto loop_operands :
       llvm::zip(parallelOp.getInductionVars(), parallelOp.lowerBound(),
                 parallelOp.upperBound(), parallelOp.step())) {
    Value iv, lower, upper, step;
    std::tie(iv, lower, upper, step) = loop_operands;
    ForOp forOp = rewriter.create<ForOp>(loc, lower, upper, step);
    mapping.map(iv, forOp.getInductionVar());
    rewriter.setInsertionPointToStart(forOp.getBody());
  }

  // Now copy over the contents of the body.
  for (auto &op : parallelOp.body().front().without_terminator())
    rewriter.clone(op, mapping);

  rewriter.eraseOp(parallelOp);

  return matchSuccess();
}

void mlir::populateLoopToStdConversionPatterns(
    OwningRewritePatternList &patterns, MLIRContext *ctx) {
  patterns.insert<ForLowering, IfLowering, TerminatorLowering>(ctx);
  patterns.insert<ForLowering, IfLowering, ParallelLowering>(ctx);
}

void LoopToStandardPass::runOnOperation() {
+33 −0
Original line number Diff line number Diff line
@@ -147,3 +147,36 @@ func @simple_std_for_loop_with_2_ifs(%arg0 : index, %arg1 : index, %arg2 : index
  }
  return
}

// CHECK-LABEL:   func @parallel_loop(
// CHECK-SAME:                        [[VAL_0:%.*]]: index, [[VAL_1:%.*]]: index, [[VAL_2:%.*]]: index, [[VAL_3:%.*]]: index, [[VAL_4:%.*]]: index) {
// CHECK:           [[VAL_5:%.*]] = constant 1 : index
// CHECK:           br ^bb1([[VAL_0]] : index)
// CHECK:         ^bb1([[VAL_6:%.*]]: index):
// CHECK:           [[VAL_7:%.*]] = cmpi "slt", [[VAL_6]], [[VAL_2]] : index
// CHECK:           cond_br [[VAL_7]], ^bb2, ^bb6
// CHECK:         ^bb2:
// CHECK:           br ^bb3([[VAL_1]] : index)
// CHECK:         ^bb3([[VAL_8:%.*]]: index):
// CHECK:           [[VAL_9:%.*]] = cmpi "slt", [[VAL_8]], [[VAL_3]] : index
// CHECK:           cond_br [[VAL_9]], ^bb4, ^bb5
// CHECK:         ^bb4:
// CHECK:           [[VAL_10:%.*]] = constant 1 : index
// CHECK:           [[VAL_11:%.*]] = addi [[VAL_8]], [[VAL_5]] : index
// CHECK:           br ^bb3([[VAL_11]] : index)
// CHECK:         ^bb5:
// CHECK:           [[VAL_12:%.*]] = addi [[VAL_6]], [[VAL_4]] : index
// CHECK:           br ^bb1([[VAL_12]] : index)
// CHECK:         ^bb6:
// CHECK:           return
// CHECK:         }

func @parallel_loop(%arg0 : index, %arg1 : index, %arg2 : index,
                        %arg3 : index, %arg4 : index) {
  %step = constant 1 : index
  loop.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
                                          step (%arg4, %step) {
    %c1 = constant 1 : index
  }
  return
}