Commit 121aab84 authored by Krzysztof Drewniak's avatar Krzysztof Drewniak
Browse files

[MLIR][Affine] Simplify nested modulo operations when able

It is the case that, for all positive a and b such that b divides a
(e mod (a * b)) mod b = e mod b. For example, ((d0 mod 35) mod 5) can
be simplified to (d0 mod 5), but ((d0 mod 35) mod 4) cannot be simplified
further (x = 36 is a counterexample).

This change enables more complex simplifications. For example,
((d0 * 72 + d1) mod 144) mod 9 can now simplify to (d0 * 72 + d1) mod 9
and thus to d1 mod 9. Expressions with chained modulus operators are
reasonably common in tensor applications, and this change _should_
improve code generation for such expressions.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D109930
parent 08f0cb77
Loading
Loading
Loading
Loading
+9 −0
Original line number Diff line number Diff line
@@ -829,6 +829,15 @@ static AffineExpr simplifyMod(AffineExpr lhs, AffineExpr rhs) {
      return lBin.getLHS() % rhsConst.getValue();
  }

  // Simplify (e % a) % b to e % b when b evenly divides a
  if (lBin && lBin.getKind() == AffineExprKind::Mod) {
    auto intermediate = lBin.getRHS().dyn_cast<AffineConstantExpr>();
    if (intermediate && intermediate.getValue() >= 1 &&
        mod(intermediate.getValue(), rhsConst.getValue()) == 0) {
      return lBin.getLHS() % rhsConst.getValue();
    }
  }

  return nullptr;
}

+6 −0
Original line number Diff line number Diff line
@@ -189,6 +189,9 @@
// CHECK: #map{{[0-9]+}} = affine_map<(d0, d1) -> (d0 * 3, (d0 + d1) * 2, d0 mod 2)>
#map58 = affine_map<(d0, d1) -> (4*d0 - 2*d0 + d0, (d0 + d1) + (d0 + d1), 2 * (d0 mod 2) - d0 mod 2)>

// CHECK: #map{{[0-9]+}} = affine_map<(d0, d1) -> (d0 mod 5, (d1 mod 35) mod 4)>
#map59 = affine_map<(d0, d1) -> ((d0 mod 35) mod 5, (d1 mod 35) mod 4)>

// Single identity maps are removed.
// CHECK: @f0(memref<2x4xi8, 1>)
func private @f0(memref<2x4xi8, #map0, 1>)
@@ -373,3 +376,6 @@ func private @f56(memref<1x1xi8, #map56>)

// CHECK: "f58"() {map = #map{{[0-9]+}}} : () -> ()
"f58"() {map = #map58} : () -> ()

// CHECK: "f59"() {map = #map{{[0-9]+}}} : () -> ()
"f59"() {map = #map59} : () -> ()
+4 −4
Original line number Diff line number Diff line
@@ -576,9 +576,9 @@ func @fuse_across_varying_dims_complex(%arg0: f32) {
}
// MAXIMAL-DAG: [[$MAP0:#map[0-9]+]] = affine_map<(d0, d1) -> ((d0 * 72 + d1) floordiv 2304)>
// MAXIMAL-DAG: [[$MAP1:#map[0-9]+]] = affine_map<(d0, d1) -> (((d0 * 72 + d1) mod 2304) floordiv 1152)>
// MAXIMAL-DAG: [[$MAP2:#map[0-9]+]] = affine_map<(d0, d1) -> (((((d0 * 72 + d1) mod 2304) mod 1152) floordiv 9) floordiv 8)>
// MAXIMAL-DAG: [[$MAP3:#map[0-9]+]] = affine_map<(d0, d1) -> (((((d0 * 72 + d1) mod 2304) mod 1152) mod 9) floordiv 3)>
// MAXIMAL-DAG: [[$MAP4:#map[0-9]+]] = affine_map<(d0, d1) -> (((((d0 * 72 + d1) mod 2304) mod 1152) mod 9) mod 3)>
// MAXIMAL-DAG: [[$MAP2:#map[0-9]+]] = affine_map<(d0, d1) -> ((((d0 * 72 + d1) mod 1152) floordiv 9) floordiv 8)>
// MAXIMAL-DAG: [[$MAP3:#map[0-9]+]] = affine_map<(d0, d1) -> ((d1 mod 9) floordiv 3)>
// MAXIMAL-DAG: [[$MAP4:#map[0-9]+]] = affine_map<(d0, d1) -> (d1 mod 3)>
// MAXIMAL-DAG: [[$MAP7:#map[0-9]+]] = affine_map<(d0, d1) -> (d0 * 16 + d1)>
// MAXIMAL-DAG: [[$MAP8:#map[0-9]+]] = affine_map<(d0, d1) -> (d0 * 16 - d1 + 15)>
// MAXIMAL-LABEL: func @fuse_across_varying_dims_complex
+9 −9
Original line number Diff line number Diff line
@@ -737,15 +737,15 @@ func @R6_to_R2_reshape_square() -> memref<64x9xi32> {
//
// CHECK-DAG: [[$MAP0:#map[0-9]+]] = affine_map<(d0, d1) -> ((d0 * 9 + d1) floordiv 288)>
// CHECK-DAG: [[$MAP1:#map[0-9]+]] = affine_map<(d0, d1) -> (((d0 * 9 + d1) mod 288) floordiv 144)>
// CHECK-DAG: [[$MAP2:#map[0-9]+]] = affine_map<(d0, d1) -> ((((d0 * 9 + d1) mod 288) mod 144) floordiv 48)>
// CHECK-DAG: [[$MAP3:#map[0-9]+]] = affine_map<(d0, d1) -> (((((d0 * 9 + d1) mod 288) mod 144) mod 48) floordiv 16)>
// CHECK-DAG: [[$MAP4:#map[0-9]+]] = affine_map<(d0, d1) -> (((((d0 * 9 + d1) mod 288) mod 144) mod 48) mod 16)>
// CHECK-DAG: [[$MAP2:#map[0-9]+]] = affine_map<(d0, d1) -> (((d0 * 9 + d1) mod 144) floordiv 48)>
// CHECK-DAG: [[$MAP3:#map[0-9]+]] = affine_map<(d0, d1) -> (((d0 * 9 + d1) mod 48) floordiv 16)>
// CHECK-DAG: [[$MAP4:#map[0-9]+]] = affine_map<(d0, d1) -> ((d0 * 9 + d1) mod 16)>
// CHECK-DAG: [[$MAP11:#map[0-9]+]] = affine_map<(d0, d1) -> (d0 * 9 + d1)>
// CHECK-DAG: [[$MAP12:#map[0-9]+]] = affine_map<(d0) -> (d0 floordiv 288)>
// CHECK-DAG: [[$MAP13:#map[0-9]+]] = affine_map<(d0) -> ((d0 mod 288) floordiv 144)>
// CHECK-DAG: [[$MAP14:#map[0-9]+]] = affine_map<(d0) -> (((d0 mod 288) mod 144) floordiv 48)>
// CHECK-DAG: [[$MAP15:#map[0-9]+]] = affine_map<(d0) -> ((((d0 mod 288) mod 144) mod 48) floordiv 16)>
// CHECK-DAG: [[$MAP16:#map[0-9]+]] = affine_map<(d0) -> ((((d0 mod 288) mod 144) mod 48) mod 16)>
// CHECK-DAG: [[$MAP14:#map[0-9]+]] = affine_map<(d0) -> ((d0 mod 144) floordiv 48)>
// CHECK-DAG: [[$MAP15:#map[0-9]+]] = affine_map<(d0) -> ((d0 mod 48) floordiv 16)>
// CHECK-DAG: [[$MAP16:#map[0-9]+]] = affine_map<(d0) -> (d0 mod 16)>
// CHECK-DAG: [[$MAP17:#map[0-9]+]] = affine_map<(d0) -> (0)>

//
@@ -761,7 +761,7 @@ func @R6_to_R2_reshape_square() -> memref<64x9xi32> {
// CHECK-NEXT:      affine.apply [[$MAP3]](%{{.*}}, %{{.*}})
// CHECK-NEXT:      affine.apply [[$MAP4]](%{{.*}}, %{{.*}})
// CHECK-NEXT:      "foo"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (index, index, index, index, index, index) -> i32
// CHECK-NEXT:      affine.store %{{.*}}, %{{.*}}[0, ((%{{.*}} * 9 + %{{.*}}) mod 288) floordiv 144, (((%{{.*}} * 9 + %{{.*}}) mod 288) mod 144) floordiv 48, ((((%{{.*}} * 9 + %{{.*}}) mod 288) mod 144) mod 48) floordiv 16, ((((%{{.*}} * 9 + %{{.*}}) mod 288) mod 144) mod 48) mod 16, 0] : memref<1x2x3x3x16x1xi32>
// CHECK-NEXT:      affine.store %{{.*}}, %{{.*}}[0, ((%{{.*}} * 9 + %{{.*}}) mod 288) floordiv 144, ((%{{.*}} * 9 + %{{.*}}) mod 144) floordiv 48, ((%{{.*}} * 9 + %{{.*}}) mod 48) floordiv 16, (%{{.*}} * 9 + %{{.*}}) mod 16, 0] : memref<1x2x3x3x16x1xi32>
// CHECK-NEXT:      affine.apply [[$MAP11]](%{{.*}}, %{{.*}})
// CHECK-NEXT:      affine.apply [[$MAP12]](%{{.*}})
// CHECK-NEXT:      affine.apply [[$MAP13]](%{{.*}})
@@ -769,7 +769,7 @@ func @R6_to_R2_reshape_square() -> memref<64x9xi32> {
// CHECK-NEXT:      affine.apply [[$MAP15]](%{{.*}})
// CHECK-NEXT:      affine.apply [[$MAP16]](%{{.*}})
// CHECK-NEXT:      affine.apply [[$MAP17]](%{{.*}})
// CHECK-NEXT:      affine.load %{{.*}}[0, ((%{{.*}} * 9 + %{{.*}}) mod 288) floordiv 144, (((%{{.*}} * 9 + %{{.*}}) mod 288) mod 144) floordiv 48, ((((%{{.*}} * 9 + %{{.*}}) mod 288) mod 144) mod 48) floordiv 16, ((((%{{.*}} * 9 + %{{.*}}) mod 288) mod 144) mod 48) mod 16, 0] : memref<1x2x3x3x16x1xi32>
// CHECK-NEXT:      affine.load %{{.*}}[0, ((%{{.*}} * 9 + %{{.*}}) mod 288) floordiv 144, ((%{{.*}} * 9 + %{{.*}}) mod 144) floordiv 48, ((%{{.*}} * 9 + %{{.*}}) mod 48) floordiv 16, (%{{.*}} * 9 + %{{.*}}) mod 16, 0] : memref<1x2x3x3x16x1xi32>
// CHECK-NEXT:      affine.store %{{.*}}, %{{.*}}[0, 0] : memref<1x1xi32>
// CHECK-NEXT:      affine.load %{{.*}}[0, 0] : memref<1x1xi32>
// CHECK-NEXT:      muli %{{.*}}, %{{.*}} : i32