Commit c95a7246 authored by Matthias Springer's avatar Matthias Springer
Browse files

[mlir][linalg] Tiling: Use loop ub in extract_slice size computation if possible

When tiling a LinalgOp, extract_slice/insert_slice pairs are inserted. To avoid going out-of-bounds when the tile size does not divide the shape size evenly (at the boundary), AffineMin ops are inserted. Some ops have assumptions regarding the dimensions of inputs/outputs. E.g., in a `A * B` matmul, `dim(A, 1) == dim(B, 0)`. However, loop bounds use either `dim(A, 1)` or `dim(B, 0)`.

With this change, AffineMin ops are expressed in terms of loop bounds instead of tensor sizes. (Both have the same runtime value.) This simplifies canonicalizations.

Differential Revision: https://reviews.llvm.org/D109267
parent d96e0c53
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -91,7 +91,7 @@ SmallVector<Value> computeTileSizes(OpBuilder &b, Location loc, ValueRange ivs,
/// at offsets `lbs` and with sizes `subShapeSizes`.
Value makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile,
                     ValueRange tileSizes, AffineMap map, ValueRange lbs,
                     ValueRange subShapeSizes);
                     ValueRange ubs, ValueRange subShapeSizes);

/// Creates extract_slice/subview ops for all `valuesToTile` of the given
/// `linalgOp` with `builder`, assuming `linalgOp` is being fused into a loop
+3 −4
Original line number Diff line number Diff line
@@ -177,19 +177,18 @@ static LinalgOp fuse(OpBuilder &b, LinalgOp producer,
  auto one = b.create<ConstantIndexOp>(loc, 1);

  for (unsigned i = 0, e = producer.getNumLoops(); i < e; ++i) {
    auto shapeDim = getShapeDefiningLoopRange(producer, i);
    Value dim = createOrFoldDimOp(b, loc, shapeDim.shape, shapeDim.dimension);
    sizeBounds.push_back(dim);
    auto it = fusedLoopsAndRanges.find(i);
    if (it != fusedLoopsAndRanges.end()) {
      ivs.push_back(it->second.offset);
      tileSizes.push_back(it->second.size);
      sizeBounds.push_back(nullptr);
      loopRanges.push_back(it->second);
      LLVM_DEBUG(llvm::dbgs() << "tiled loop#" << i << " with LoopRange "
                              << loopRanges.back() << "\n");
    } else {
      auto shapeDim = getShapeDefiningLoopRange(producer, i);
      Value dim = createOrFoldDimOp(b, loc, shapeDim.shape, shapeDim.dimension);
      tileSizes.push_back(zero);
      sizeBounds.push_back(dim);
      loopRanges.push_back(Range{zero, dim, one});
      LLVM_DEBUG(llvm::dbgs() << "full loop#" << i << " with LoopRange "
                              << loopRanges.back() << "\n");
+6 −4
Original line number Diff line number Diff line
@@ -370,8 +370,9 @@ static LogicalResult tilePadTensorOp(OpBuilder &builder, PadTensorOp op,
  assert(static_cast<int64_t>(tileSizes.size()) == rank);
  // Compute lower and upper bounds of the loop nest.
  SmallVector<Range> ranges = op.getLoopBounds(builder);
  SmallVector<Value> lbs, dims, steps;
  SmallVector<Value> lbs, dims, allDims, steps;
  for (int64_t i = 0; i < rank; ++i) {
    allDims.push_back(ranges[i].size);
    if (!isZero(tileSizes[i])) {
      lbs.push_back(ranges[i].offset);
      dims.push_back(ranges[i].size);
@@ -388,13 +389,14 @@ static LogicalResult tilePadTensorOp(OpBuilder &builder, PadTensorOp op,
        SmallVector<Value> offsets =
            computeTileOffsets(b, loc, localIvs, tileSizes);
        SmallVector<Value> sizes =
            computeTileSizes(b, loc, localIvs, tileSizes, dims);
            computeTileSizes(b, loc, localIvs, tileSizes, allDims);
        // Create ExtractSliceOp: Extract a tile from the PadTensorOp.
        // Note: The PadTensorOp is located outside of the loop nest. It is
        // later moved inside by ExtractSliceOfPadTensorSwapPattern.
        auto map = AffineMap::getMultiDimIdentityMap(rank, b.getContext());
        Value tiledOutput = makeTiledShape(b, loc, newPadOp->getResult(0),
                                           tileSizes, map, offsets, sizes);
        Value tiledOutput =
            makeTiledShape(b, loc, newPadOp->getResult(0), tileSizes, map,
                           offsets, allDims, sizes);
        auto sliceOp = tiledOutput.getDefiningOp<tensor::ExtractSliceOp>();
        assert(sliceOp && "expected ExtractSliceOp");
        // Insert the tile into the output tensor.
+4 −4
Original line number Diff line number Diff line
@@ -514,7 +514,7 @@ void GenerateLoopNest<scf::ParallelOp>::doit(

Value makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile,
                     ValueRange tileSizes, AffineMap map, ValueRange lbs,
                     ValueRange subShapeSizes) {
                     ValueRange ubs, ValueRange subShapeSizes) {
  auto shapedType = valueToTile.getType().dyn_cast<ShapedType>();
  assert(shapedType && "only shaped types can be tiled");
  ArrayRef<int64_t> shape = shapedType.getShape();
@@ -567,7 +567,7 @@ Value makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile,
          AffineMap::inferFromExprList(
              ArrayRef<ArrayRef<AffineExpr>>{{dim0, dim1 - dim2}})
              .front();
      Value d = createOrFoldDimOp(builder, loc, valueToTile, r);
      Value d = applyMapToValues(builder, loc, m, ubs).front();
      SmallVector<Value, 4> operands{size, d, offset};
      fullyComposeAffineMapAndOperands(&minMap, &operands);
      size = builder.create<AffineMinOp>(loc, builder.getIndexType(), minMap,
@@ -656,8 +656,8 @@ SmallVector<Value, 4> makeTiledShapes(OpBuilder &b, Location loc,
    }
    LLVM_DEBUG(llvm::dbgs() << ": tiled: figure out subshape...\n");

    tiledShapes.push_back(
        makeTiledShape(b, loc, shapedOp, tileSizes, map, lbs, subShapeSizes));
    tiledShapes.push_back(makeTiledShape(b, loc, shapedOp, tileSizes, map, lbs,
                                         sizeBounds, subShapeSizes));
  }

  return tiledShapes;
+17 −36
Original line number Diff line number Diff line
@@ -43,12 +43,10 @@ module {
//      CHECK:     %[[TILE_N:.+]] = affine.min #[[MAP2]](%[[IV1]])[%[[N]]]
//      CHECK:     %[[SV2:.+]] = memref.subview %[[ARG1]][0, %[[IV1]]]
// CHECK-SAME:       %[[K_2]], %[[TILE_N]]
//      CHECK:     %[[SV3:.+]] = memref.subview %[[ARG2]][%[[IV0]], %[[IV1]]]
// CHECK-SAME:       [%[[TILE_M]], %[[TILE_N]]]
//      CHECK:     %[[M_2:.+]] = memref.dim %[[ARG2]], %[[C0]]
//      CHECK:     %[[TILE_M_2:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M_2]]]
//      CHECK:     %[[N_2:.+]] = memref.dim %[[ARG2]], %[[C1]]
//      CHECK:     %[[TILE_N_2:.+]] = affine.min #[[MAP2]](%[[IV1]])[%[[N_2]]]
//      CHECK:     %[[SV3:.+]] = memref.subview %[[ARG2]][%[[IV0]], %[[IV1]]]
// CHECK-SAME:       [%[[TILE_M_2]], %[[TILE_N_2]]]
//      CHECK:     %[[TILE_M_3:.+]] = affine.min #[[MAP4]](%[[IV0]])[%[[M_2]], %[[M]]]
//      CHECK:     %[[TILE_N_3:.+]] = affine.min #[[MAP5]](%[[IV1]])[%[[N_2]], %[[N]]]
//      CHECK:     %[[SV3_2:.+]] = memref.subview %[[ARG2]][%[[IV0]], %[[IV1]]]
@@ -59,9 +57,8 @@ module {
//      CHECK:       %[[TILE_K:.+]] = affine.min #[[MAP3]](%[[IV2]])[%[[K]]]
//      CHECK:       %[[SV4:.+]] = memref.subview %[[SV1]][0, %[[IV2]]]
// CHECK-SAME:         [%[[TILE_M]], %[[TILE_K]]]
//      CHECK:       %[[TILE_K_2:.+]] = affine.min #[[MAP3]](%[[IV2]])[%[[K_2]]]
//      CHECK:       %[[SV5:.+]] = memref.subview %[[SV2]][%[[IV2]], 0]
// CHECK-SAME:         [%[[TILE_K_2]], %[[TILE_N]]]
// CHECK-SAME:         [%[[TILE_K]], %[[TILE_N]]]
//      CHECK:       linalg.matmul
// CHECK-SAME:         __internal_linalg_transform__ = "after_basic_fusion"
// CHECK-SAME:         ins(%[[SV4]], %[[SV5]]
@@ -112,18 +109,15 @@ module {
//      CHECK:     %[[SV1:.+]] = memref.subview %[[ARG2]][0, %[[IV0]]]
// CHECK-SAME:       [%[[K]], %[[TILE_N]]]
//      CHECK:     %[[M:.+]] = memref.dim %[[ARG3]], %[[C0]]
//      CHECK:     %[[N_2:.+]] = memref.dim %[[ARG3]], %[[C1]]
//      CHECK:     %[[TILE_N_2:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[N_2]]]
//      CHECK:     %[[SV2:.+]] = memref.subview %[[ARG3]][0, %[[IV0]]]
// CHECK-SAME:       [%[[M]], %[[TILE_N_2]]]
//      CHECK:     %[[K_2:.+]] = memref.dim %[[ARG1]], %[[C0]]
// CHECK-SAME:       [%[[M]], %[[TILE_N]]
//      CHECK:     %[[N_3:.+]] = memref.dim %[[ARG1]], %[[C1]]
//      CHECK:     %[[K_2:.+]] = memref.dim %[[ARG1]], %[[C0]]
//      CHECK:     %[[TILE_N_3:.+]] = affine.min #[[MAP4]](%[[IV0]])[%[[N_3]], %[[N]]]
//      CHECK:     %[[SV3:.+]] = memref.subview %[[ARG1]][0, %[[IV0]]]
// CHECK-SAME:       [%[[K_2]], %[[TILE_N_3]]]
//      CHECK:     %[[TILE_N_4:.+]] = affine.min #[[MAP4]](%[[IV0]])[%[[N]], %[[N]]]
//      CHECK:     %[[SV3_2:.+]] = memref.subview %[[ARG2]][0, %[[IV0]]]
// CHECK-SAME:       [%[[K]], %[[TILE_N_4]]]
// CHECK-SAME:       [%[[K]], %[[TILE_N_3]]]
//      CHECK:     linalg.copy(%[[SV3]], %[[SV3_2]])
// CHECK-SAME:       __internal_linalg_transform__ = "after_rhs_fusion_producer"
//  CHECK-NOT:     linalg.fill
@@ -136,12 +130,10 @@ module {
//      CHECK:         %[[TILE_K:.+]] = affine.min #[[MAP3]](%[[IV2]])[%[[K_2]]]
//      CHECK:         %[[SV4:.+]] = memref.subview %[[ARG0]][%[[IV1]], %[[IV2]]]
// CHECK-SAME:           [%[[TILE_M]], %[[TILE_K]]]
//      CHECK:         %[[TILE_K_2:.+]] = affine.min #[[MAP3]](%[[IV2]])[%[[K]]]
//      CHECK:         %[[SV5:.+]] = memref.subview %[[SV1]][%[[IV2]], 0]
// CHECK-SAME:           [%[[TILE_K_2]], %[[TILE_N]]]
//      CHECK:         %[[TILE_M_2:.+]] = affine.min #[[MAP2]](%[[IV1]])[%[[M]]]
// CHECK-SAME:           [%[[TILE_K]], %[[TILE_N]]]
//      CHECK:         %[[SV6:.+]] = memref.subview %[[SV2]][%[[IV1]], 0]
// CHECK-SAME:           [%[[TILE_M_2]], %[[TILE_N_2]]]
// CHECK-SAME:           [%[[TILE_M]], %[[TILE_N]]]
//      CHECK:         linalg.matmul
// CHECK-SAME:           __internal_linalg_transform__ = "after_rhs_fusion"
// CHECK-SAME:           ins(%[[SV4]], %[[SV5]]
@@ -195,11 +187,10 @@ module {
//      CHECK:     %[[K:.+]] = memref.dim %[[ARG1]], %[[C1]]
//      CHECK:     %[[SV1:.+]] = memref.subview %[[ARG1]][%[[IV0]], 0]
// CHECK-SAME:       [%[[TILE_M]], %[[K]]]
//      CHECK:     %[[M_2:.+]] = memref.dim %[[ARG3]], %[[C0]]
//      CHECK:     %[[TILE_M_2:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M_2]]]
//      CHECK:     %[[N:.+]] = memref.dim %[[ARG3]], %[[C1]]
//      CHECK:     %[[SV2:.+]] = memref.subview %[[ARG3]][%[[IV0]], 0]
// CHECK-SAME:       [%[[TILE_M_2]], %[[N]]]
// CHECK-SAME:       [%[[TILE_M]], %[[N]]]
//      CHECK:     %[[M_2:.+]] = memref.dim %[[ARG3]], %[[C0]]
//      CHECK:     %[[TILE_M_3:.+]] = affine.min #[[MAP4]](%[[IV0]])[%[[M_2]], %[[M]]]
//      CHECK:     %[[SV2_2:.+]] = memref.subview %[[ARG3]][%[[IV0]], 0]
// CHECK-SAME:       [%[[TILE_M_3]], %[[N]]]
@@ -208,9 +199,8 @@ module {
//      CHECK:     %[[K_3:.+]] = memref.dim %[[ARG0]], %[[C1]]
//      CHECK:     %[[SV3:.+]] = memref.subview %[[ARG0]][%[[IV0]], 0]
// CHECK-SAME:       [%[[TILE_M_4]], %[[K_3]]]
//      CHECK:     %[[TILE_M_5:.+]] = affine.min #[[MAP4]](%[[IV0]])[%[[M]], %[[M]]]
//      CHECK:     %[[SV3_2:.+]] = memref.subview %[[ARG1]][%[[IV0]], 0]
// CHECK-SAME:       [%[[TILE_M_5]], %[[K]]]
// CHECK-SAME:       [%[[TILE_M_4]], %[[K]]]
//      CHECK:     linalg.copy(%[[SV3]], %[[SV3_2]])
// CHECK-SAME:       __internal_linalg_transform__ = "after_two_operand_fusion_producer"
//      CHECK:     linalg.fill(%[[CST]], %[[SV2_2]])
@@ -222,14 +212,11 @@ module {
//      CHECK:         %[[TILE_K:.+]] = affine.min #[[MAP2]](%[[IV2]])[%[[K]]]
//      CHECK:         %[[SV4:.+]] = memref.subview %[[SV1]][0, %[[IV2]]]
// CHECK-SAME:           [%[[TILE_M]], %[[TILE_K]]]
//      CHECK:         %[[K_2:.+]] = memref.dim %[[ARG2]], %[[C0]]
//      CHECK:         %[[TILE_K_2:.+]] = affine.min #[[MAP2]](%[[IV2]])[%[[K_2]]]
//      CHECK:         %[[TILE_N:.+]] = affine.min #[[MAP3]](%[[IV1]])[%[[N_2]]]
//      CHECK:         %[[SV5:.+]] = memref.subview %[[ARG2]][%[[IV2]], %[[IV1]]]
// CHECK-SAME:           [%[[TILE_K_2]], %[[TILE_N]]]
//      CHECK:         %[[TILE_N_2:.+]] = affine.min #[[MAP3]](%[[IV1]])[%[[N]]]
// CHECK-SAME:           [%[[TILE_K]], %[[TILE_N]]]
//      CHECK:         %[[SV6:.+]] = memref.subview %[[SV2]][0, %[[IV1]]]
// CHECK-SAME:           [%[[TILE_M_2]], %[[TILE_N_2]]]
// CHECK-SAME:           [%[[TILE_M]], %[[TILE_N]]]
//      CHECK:         linalg.matmul
// CHECK-SAME:           __internal_linalg_transform__ = "after_two_operand_fusion"
// CHECK-SAME:           ins(%[[SV4]], %[[SV5]]
@@ -280,19 +267,16 @@ module {
//      CHECK:     %[[K2:.+]] = memref.dim %[[ARG2]], %[[C1]]
//      CHECK:     %[[SV1:.+]] = memref.subview %[[ARG2]][%[[IV0]], 0]
// CHECK-SAME:       [%[[TILE_M]], %[[K2]]]
//      CHECK:     %[[M_2:.+]] = memref.dim %[[ARG4]], %[[C0]]
//      CHECK:     %[[TILE_M_2:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M_2]]]
//      CHECK:     %[[N:.+]] = memref.dim %[[ARG4]], %[[C1]]
//      CHECK:     %[[SV2:.+]] = memref.subview %[[ARG4]][%[[IV0]], 0]
// CHECK-SAME:       [%[[TILE_M_2]], %[[N]]]
// CHECK-SAME:       [%[[TILE_M]], %[[N]]]
//      CHECK:     %[[M_3:.+]] = memref.dim %[[ARG0]], %[[C0]]
//      CHECK:     %[[TILE_M_3:.+]] = affine.min #[[MAP4]](%[[IV0]])[%[[M_3]], %[[M]]]
//      CHECK:     %[[K1:.+]] = memref.dim %[[ARG0]], %[[C1]]
//      CHECK:     %[[SV3:.+]] = memref.subview %[[ARG0]][%[[IV0]], 0]
// CHECK-SAME:       [%[[TILE_M_3]], %[[K1]]]
//      CHECK:     %[[TILE_M_4:.+]] = affine.min #[[MAP4]](%[[IV0]])[%[[M]], %[[M]]]
//      CHECK:     %[[SV1_2:.+]] = memref.subview %[[ARG2]][%[[IV0]], 0]
// CHECK-SAME:       [%[[TILE_M_4]], %[[K2]]]
// CHECK-SAME:       [%[[TILE_M_3]], %[[K2]]]
//      CHECK:     linalg.matmul
// CHECK-SAME:         __internal_linalg_transform__ = "after_lhs_fusion_producer"
// CHECK-SAME:         ins(%[[SV3]], %[[ARG1]]
@@ -305,14 +289,11 @@ module {
//      CHECK:         %[[TILE_K:.+]] = affine.min #[[MAP2]](%[[IV2]])[%[[K2]]]
//      CHECK:         %[[SV6:.+]] = memref.subview %[[SV1]][0, %[[IV2]]]
// CHECK-SAME:           [%[[TILE_M]], %[[TILE_K]]]
//      CHECK:         %[[K_2:.+]] = memref.dim %[[ARG3]], %[[C0]]
//      CHECK:         %[[TILE_K_2:.+]] = affine.min #[[MAP2]](%[[IV2]])[%[[K_2]]]
//      CHECK:         %[[TILE_N:.+]] = affine.min #[[MAP3]](%[[IV1]])[%[[N_2]]]
//      CHECK:         %[[SV7:.+]] = memref.subview %[[ARG3]][%[[IV2]], %[[IV1]]]
// CHECK-SAME:           [%[[TILE_K_2]], %[[TILE_N]]]
//      CHECK:         %[[TILE_N_2:.+]] = affine.min #[[MAP3]](%[[IV1]])[%[[N]]]
// CHECK-SAME:           [%[[TILE_K]], %[[TILE_N]]]
//      CHECK:         %[[SV8:.+]] = memref.subview %[[SV2]][0, %[[IV1]]]
// CHECK-SAME:           [%[[TILE_M_2]], %[[TILE_N_2]]]
// CHECK-SAME:           [%[[TILE_M]], %[[TILE_N]]]
//      CHECK:         linalg.matmul
// CHECK-SAME:           __internal_linalg_transform__ = "after_lhs_fusion"
// CHECK-SAME:           ins(%[[SV6]], %[[SV7]]
Loading