Unverified Commit 1d5ccce1 authored by Andrzej Warzyński's avatar Andrzej Warzyński Committed by GitHub
Browse files

[mlir][transform] Update transform.loop.peel (#67482)

This patch updates `transform.loop.peel` so that this Op returns two
rather than one handle:
  * one for the peeled loop, and
  * one for the remainder loop.

Also, following this change this Op will fail if peeling fails. This is
consistent with other similar Ops that also fail if no transformation
takes place.
parent 0bc68ca4
Loading
Loading
Loading
Loading
+9 −10
Original line number Diff line number Diff line
@@ -142,23 +142,22 @@ def LoopPeelOp : Op<Transform_Dialect, "loop.peel",

     This operation ignores non-scf::ForOp ops and drops them in the return.

     This operation always succeeds and returns the scf::ForOp with the
     postcondition: "the loop trip count is divisible by the step".
     This operation may return the same unmodified loop handle when peeling did
     not modify the IR (i.e. the loop trip count was already divisible).
     This operation returns two scf::ForOp, with the first Op satisfying the
     postcondition: "the loop trip count is divisible by the step". The second
     loop Op contains the remaining iteration. Note that even though the
     Payload IR modification may be performed in-place, this operation consumes
     the operand handle and produces a new one.

     Note that even though the Payload IR modification may be performed
     in-place, this operation consumes the operand handle and produces a new
     one.
     #### Return Modes

     TODO: Return both the peeled loop and the remainder loop.
     Produces a definite failure if peeling fails.
  }];

  let arguments =
      (ins Transform_ScfForOp:$target,
           DefaultValuedAttr<BoolAttr, "false">:$fail_if_already_divisible);
  // TODO: Return both the peeled loop and the remainder loop.
  let results = (outs TransformHandleTypeInterface:$transformed);
  let results = (outs TransformHandleTypeInterface:$peeled_loop,
                      TransformHandleTypeInterface:$remainder_loop);

  let assemblyFormat =
    "$target attr-dict `:` functional-type(operands, results)";
+8 −6
Original line number Diff line number Diff line
@@ -226,14 +226,16 @@ transform::LoopPeelOp::applyToOne(transform::TransformRewriter &rewriter,
                                  transform::ApplyToEachResultList &results,
                                  transform::TransformState &state) {
  scf::ForOp result;
  // This helper returns failure when peeling does not occur (i.e. when the IR
  // is not modified). This is not a failure for the op as the postcondition:
  //    "the loop trip count is divisible by the step"
  // is valid.
  LogicalResult status =
      scf::peelForLoopAndSimplifyBounds(rewriter, target, result);
  // TODO: Return both the peeled loop and the remainder loop.
  results.push_back(failed(status) ? target : result);
  if (failed(status)) {
    DiagnosedSilenceableFailure diag = emitSilenceableError()
                                       << "failed to peel";
    return diag;
  }
  results.push_back(target);
  results.push_back(result);

  return DiagnosedSilenceableFailure::success();
}

+1 −1
Original line number Diff line number Diff line
@@ -48,7 +48,7 @@ transform.sequence failures(propagate) {
  %0 = transform.structured.match ops{["linalg.elemwise_binary"]} in %arg1 : (!transform.any_op) -> !transform.any_op
  %1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [32, 32], tile_interchange = [0, 1]}
    : (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op)
  transform.loop.peel %loops#0 : (!transform.op<"scf.for">) -> !transform.any_op
  transform.loop.peel %loops#0 : (!transform.op<"scf.for">) -> (!transform.any_op, !transform.any_op)
}

// -----
+20 −0
Original line number Diff line number Diff line
@@ -59,3 +59,23 @@ transform.sequence failures(propagate) {
  // expected-error @below {{failed to outline}}
  transform.loop.outline %0 {func_name = "foo"} : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
}

// -----

func.func @test_loops_do_not_get_peeled() {
  %lb = arith.constant 0 : index
  %ub = arith.constant 40 : index
  %step = arith.constant 5 : index
  scf.for %i = %lb to %ub step %step {
    arith.addi %i, %i : index
  }
  return
}

transform.sequence failures(propagate) {
^bb1(%arg1: !transform.any_op):
  %0 = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
  %1 = transform.loop.get_parent_for %0 : (!transform.any_op) -> !transform.op<"scf.for">
  // expected-error @below {{failed to peel}}
  transform.loop.peel %1 : (!transform.op<"scf.for">) -> (!transform.any_op, !transform.any_op)
}
+9 −4
Original line number Diff line number Diff line
@@ -87,16 +87,18 @@ transform.sequence failures(propagate) {
// CHECK-LABEL: @loop_peel_op
func.func @loop_peel_op() {
  // CHECK: %[[C0:.+]] = arith.constant 0
  // CHECK: %[[C42:.+]] = arith.constant 42
  // CHECK: %[[C41:.+]] = arith.constant 41
  // CHECK: %[[C5:.+]] = arith.constant 5
  // CHECK: %[[C40:.+]] = arith.constant 40
  // CHECK: scf.for %{{.+}} = %[[C0]] to %[[C40]] step %[[C5]]
  // CHECK:   arith.addi
  // CHECK: scf.for %{{.+}} = %[[C40]] to %[[C42]] step %[[C5]]
  // CHECK: scf.for %{{.+}} = %[[C40]] to %[[C41]] step %[[C5]]
  // CHECK:   arith.addi
  %0 = arith.constant 0 : index
  %1 = arith.constant 42 : index
  %1 = arith.constant 41 : index
  %2 = arith.constant 5 : index
  // expected-remark @below {{main loop}}
  // expected-remark @below {{remainder loop}}
  scf.for %i = %0 to %1 step %2 {
    arith.addi %i, %i : index
  }
@@ -107,7 +109,10 @@ transform.sequence failures(propagate) {
^bb1(%arg1: !transform.any_op):
  %0 = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
  %1 = transform.loop.get_parent_for %0 : (!transform.any_op) -> !transform.op<"scf.for">
  transform.loop.peel %1 : (!transform.op<"scf.for">) -> !transform.any_op
  %main_loop, %remainder = transform.loop.peel %1 : (!transform.op<"scf.for">) -> (!transform.op<"scf.for">, !transform.op<"scf.for">)
  // Make sure 
  transform.test_print_remark_at_operand %main_loop, "main loop" : !transform.op<"scf.for">
  transform.test_print_remark_at_operand %remainder, "remainder loop" : !transform.op<"scf.for">
}

// -----