Unverified Commit 98dcd98a authored by Quinn Dawkins's avatar Quinn Dawkins Committed by GitHub
Browse files

[mlir][vector] Hoist uniform scalar loop code after scf.for distribution (#71422)

After propagation of `vector.warp_execute_on_lane_0` through `scf.for`,
uniform operations like those on the loop iterators can now be hoisted
out of the inner warp op.
parent 876bd794
Loading
Loading
Loading
Loading
+3 −0
Original line number Diff line number Diff line
@@ -1560,6 +1560,9 @@ struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
        operand.set(innerWarp.getBodyRegion().getArgument(it->second));
      }
    });

    // Finally, hoist out any now uniform code from the inner warp op.
    mlir::vector::moveScalarUniformCode(innerWarp);
    return success();
  }

+5 −3
Original line number Diff line number Diff line
@@ -322,10 +322,11 @@ func.func @extract_scalar_vector_broadcast(%laneid: index) {
// CHECK-PROP:   %[[INI1:.*]] = "some_def"() : () -> vector<128xf32>
// CHECK-PROP:   vector.yield %[[INI1]] : vector<128xf32>
// CHECK-PROP: }
// CHECK-PROP: %[[F:.*]] = scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[FARG:.*]] = %[[INI]]) -> (vector<4xf32>) {
// CHECK-PROP: %[[F:.*]] = scf.for %[[IT:.+]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[FARG:.*]] = %[[INI]]) -> (vector<4xf32>) {
// CHECK-PROP:   %[[A:.*]] = arith.addi %[[IT]], %{{.*}} : index
// CHECK-PROP:   %[[W:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[32] args(%[[FARG]] : vector<4xf32>) -> (vector<4xf32>) {
// CHECK-PROP:    ^bb0(%[[ARG:.*]]: vector<128xf32>):
// CHECK-PROP:      %[[ACC:.*]] = "some_def"(%[[ARG]]) : (vector<128xf32>) -> vector<128xf32>
// CHECK-PROP:      %[[ACC:.*]] = "some_def"(%[[A]], %[[ARG]]) : (index, vector<128xf32>) -> vector<128xf32>
// CHECK-PROP:      vector.yield %[[ACC]] : vector<128xf32>
// CHECK-PROP:   }
// CHECK-PROP:   scf.yield %[[W]] : vector<4xf32>
@@ -338,7 +339,8 @@ func.func @warp_scf_for(%arg0: index) {
  %0 = vector.warp_execute_on_lane_0(%arg0)[32] -> (vector<4xf32>) {
    %ini = "some_def"() : () -> (vector<128xf32>)
    %3 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %ini) -> (vector<128xf32>) {
      %acc = "some_def"(%arg4) : (vector<128xf32>) -> (vector<128xf32>)
      %add = arith.addi %arg3, %c1 : index
      %acc = "some_def"(%add, %arg4) : (index, vector<128xf32>) -> (vector<128xf32>)
      scf.yield %acc : vector<128xf32>
    }
    vector.yield %3 : vector<128xf32>