Commit 009c053e authored by Quinn Dawkins's avatar Quinn Dawkins
Browse files

[mlir][linalg] Allow outer dims perm and untiled dims in pack/unpack generalization

Extends the pack/unpack generalization patterns to work for any packing
op with only full tiles. This produces a combination of rank-reduced
insert/extract slice ops paired with a transpose on the reduced shape,
similar to what the pattern currently produces for fully tiled
pack/unpacks. Note that only the outer dims are rank-reduced in this
pattern, leaving the shape of the inner tile intact.

Differential Revision: https://reviews.llvm.org/D147555
parent 0d5b51e0
Loading
Loading
Loading
Loading
+125 −44
Original line number Diff line number Diff line
@@ -1241,66 +1241,124 @@ static Value getPackOpSourceOrPaddedSource(OpBuilder &builder,
                                 /*nofold=*/false, loc, builder);
}

// Normalizes a permutation on a higher rank space to its actual size, e.g.
//   perm = [1, 4, 2]
// becomes
//   norm = [0, 2, 1]
static SmallVector<int64_t>
getPackUnpackNormalizedInnerPerm(int rank, ArrayRef<int64_t> innerDimsPos) {
getPackUnpackNormalizedPerm(int rank, ArrayRef<int64_t> perm) {
  constexpr int64_t kNonTiledMarker = -1;
  SmallVector<int64_t> vec(rank, kNonTiledMarker);
  for (auto [index, value] : llvm::enumerate(innerDimsPos))
  for (auto [index, value] : llvm::enumerate(perm))
    vec[value] = index;
  SmallVector<int64_t> perm = llvm::to_vector(llvm::make_filter_range(
  SmallVector<int64_t> normalizedPerm = llvm::to_vector(llvm::make_filter_range(
      vec, [&](int64_t v) { return v != kNonTiledMarker; }));
  // This inverts the permutation in addition to normalizing so invert back.
  return invertPermutationVector(normalizedPerm);
}

// Gets the normalized permutation implied by innerDimsPos and outerDimsPerm
// assuming rank reduction of unit outer dims.
static SmallVector<int64_t>
getPackUnpackRankReducedPerm(ArrayRef<int64_t> shape,
                             ArrayRef<int64_t> innerDimsPos,
                             ArrayRef<int64_t> outerDimsPerm) {
  SmallVector<int64_t> rankReducedOuterDimsPerm;
  SmallVector<int64_t> outerDims;
  SmallVector<int64_t> innerDims;
  int64_t dim = 0;
  int64_t unpackedRank = shape.size();
  for (auto i : llvm::seq<unsigned>(0, unpackedRank)) {
    if (llvm::is_contained(innerDimsPos, i)) {
      innerDims.push_back(dim++);
      continue;
    }
    if (shape[i] == 1)
      continue;
    outerDims.push_back(dim++);
    if (!outerDimsPerm.empty())
      rankReducedOuterDimsPerm.push_back(outerDimsPerm[i]);
  }

  // Get the position of the inner dims after permutation.
  SmallVector<int64_t> innerPerm =
      getPackUnpackNormalizedPerm(unpackedRank, innerDimsPos);
  applyPermutationToVector<int64_t>(innerDims, innerPerm);

  // Ditto for the outer dims.
  SmallVector<int64_t> perm = outerDims;

  rankReducedOuterDimsPerm =
      getPackUnpackNormalizedPerm(unpackedRank, rankReducedOuterDimsPerm);
  if (!rankReducedOuterDimsPerm.empty())
    applyPermutationToVector<int64_t>(perm, rankReducedOuterDimsPerm);

  // The tile always ends up as the inner most dims after packing.
  perm.append(innerDims);

  return perm;
}

LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
    tensor::PackOp packOp, PatternRewriter &rewriter) const {
  // TODO: support the case that outer dimensions are not all 1s A
  // tensor.expand_shape will be generated in this case.
  int64_t srcRank = packOp.getSourceRank();
  if (llvm::any_of(packOp.getDestType().getShape().take_front(srcRank),
                   [](int64_t val) { return val != 1; })) {
    return rewriter.notifyMatchFailure(
        packOp, "require the outer dimension of the result are all 1s");
  }

  if (llvm::any_of(packOp.getMixedTiles(),
                   [](OpFoldResult tile) { return tile.is<Value>(); })) {
    return rewriter.notifyMatchFailure(packOp,
                                       "require inner tile sizes being static");
  }

  // 1. Use rank-reduced tensor.extract_slice op to extract the tile.
  // TODO: support the case that outer dimensions are not all 1s. A
  // tensor.expand_shape will be generated in this case.
  auto innerDimsPos = packOp.getInnerDimsPos();
  int64_t srcRank = packOp.getSourceRank();
  auto destShape = packOp.getDestType().getShape();
  if (llvm::any_of(innerDimsPos, [destShape](int64_t index) {
        return destShape[index] != 1;
      })) {
    return rewriter.notifyMatchFailure(
        packOp, "require the tiled outer dimensions of the result are all 1s");
  }

  // 1. Use rank-reduced tensor.extract_slice op to extract the tile and untiled
  // outer dims.
  Location loc = packOp.getLoc();
  Value input = getPackOpSourceOrPaddedSource(rewriter, packOp);
  auto inputShape = packOp.getSourceType().getShape();
  DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
      packOp.getDimAndTileMapping();
  Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
  Attribute oneIdxAttr = rewriter.getIndexAttr(1);
  SmallVector<OpFoldResult> readOffsets(srcRank, zeroIdxAttr);
  SmallVector<OpFoldResult> readStrides(srcRank, oneIdxAttr);
  SmallVector<OpFoldResult> readSizes;
  SmallVector<int64_t> readShape;
  DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
      packOp.getDimAndTileMapping();
  for (auto i : llvm::seq<unsigned>(0, srcRank)) {
    if (!dimAndTileMapping.count(i)) {
      readSizes.push_back(oneIdxAttr);
      continue;
    }
    readSizes.push_back(dimAndTileMapping[i]);
    if (dimAndTileMapping.count(i)) {
      readShape.push_back(getConstantIntValue(dimAndTileMapping[i])
                              .value_or(ShapedType::kDynamic));
      readSizes.push_back(dimAndTileMapping[i]);
      continue;
    }
    if (ShapedType::isDynamic(inputShape[i])) {
      readSizes.push_back(
          rewriter.create<tensor::DimOp>(loc, input, i).getResult());
    } else {
      readSizes.push_back(rewriter.getIndexAttr(inputShape[i]));
    }
    if (inputShape[i] != 1)
      readShape.push_back(inputShape[i]);
  }

  Type elemType = packOp.getSourceType().getElementType();
  auto readType = RankedTensorType::get(readShape, elemType);

  Value input = getPackOpSourceOrPaddedSource(rewriter, packOp);
  Value tile = rewriter.create<tensor::ExtractSliceOp>(
      loc, readType, input, readOffsets, readSizes, readStrides);

  // 2. Transpose the tile to match the inner tile order.
  SmallVector<int64_t> perm =
      getPackUnpackNormalizedInnerPerm(srcRank, packOp.getInnerDimsPos());
  // The permutation is inverted when normalizing so invert back to match the
  // ordering in the pack op.
  perm = invertPermutationVector(perm);

  SmallVector<int64_t> perm = getPackUnpackRankReducedPerm(
      inputShape, innerDimsPos, packOp.getOuterDimsPerm());

  LLVM_DEBUG(DBGS() << "Pack permutation: " << packOp << "\n";
             llvm::interleaveComma(perm, DBGS() << "perm: "); DBGSNL(););
@@ -1316,9 +1374,8 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
  int64_t destRank = packOp.getDestRank();
  SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr);
  SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr);
  SmallVector<OpFoldResult> writeSizes(srcRank, oneIdxAttr);
  for (auto size : transpShape)
    writeSizes.push_back(rewriter.getIndexAttr(size));
  SmallVector<OpFoldResult> writeSizes =
      tensor::getMixedSizes(rewriter, loc, packOp.getDest());

  auto insert = rewriter.create<tensor::InsertSliceOp>(
      loc, transposedOp.getResult()[0], packOp.getDest(), writeOffsets,
@@ -1333,35 +1390,59 @@ LogicalResult GeneralizeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
  int64_t srcRank = unpackOp.getSourceRank();
  int64_t destRank = unpackOp.getDestRank();
  ArrayRef<int64_t> srcShape = unpackOp.getSourceType().getShape();
  if (llvm::any_of(srcShape.take_front(destRank),
                   [](int64_t val) { return val != 1; })) {
  ArrayRef<int64_t> innerDimsPos = unpackOp.getInnerDimsPos();
  if (llvm::any_of(innerDimsPos, [srcShape](int64_t index) {
        return srcShape[index] != 1;
      })) {
    return rewriter.notifyMatchFailure(
        unpackOp, "require the outer dimension of the result are all 1s");
        unpackOp,
        "require the tiled outer dimensions of the result are all 1s");
  }

  // 1. Use rank-reduced tensor.extract_slice op to extract the tile.
  Location loc = unpackOp.getLoc();
  Value source = unpackOp.getSource();
  DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
      unpackOp.getDimAndTileMapping();
  Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
  Attribute oneIdxAttr = rewriter.getIndexAttr(1);
  SmallVector<OpFoldResult> readOffsets(srcRank, zeroIdxAttr);
  SmallVector<OpFoldResult> readStrides(srcRank, oneIdxAttr);
  SmallVector<OpFoldResult> readSizes;
  SmallVector<int64_t> readShape;
  for (auto i : llvm::seq<unsigned>(0, destRank)) {
    if (dimAndTileMapping.count(i)) {
      readSizes.push_back(oneIdxAttr);
      continue;
    }

    if (ShapedType::isDynamic(srcShape[i])) {
      readSizes.push_back(
          rewriter.create<tensor::DimOp>(loc, source, i).getResult());
    } else {
      readSizes.push_back(rewriter.getIndexAttr(srcShape[i]));
    }
    if (srcShape[i] != 1)
      readShape.push_back(srcShape[i]);
  }
  auto mixedTiles = unpackOp.getMixedTiles();
  SmallVector<OpFoldResult> readSizes(destRank, oneIdxAttr);
  readSizes.append(mixedTiles.begin(), mixedTiles.end());

  // Explicitly create the type for extract_slice op because the inner tile
  // size could be 1. We want to represent the whole inner tile in this case.
  ArrayRef<int64_t> readShape = srcShape.drop_front(destRank);
  auto tileShape = srcShape.drop_front(destRank);
  // Append the inner tile shape to the permuted and rank-reduced outer shape.
  readShape.append(tileShape.begin(), tileShape.end());
  Type elemType = unpackOp.getSourceType().getElementType();
  auto readType = RankedTensorType::get(readShape, elemType);
  Value innerTile = rewriter.create<tensor::ExtractSliceOp>(
      loc, readType, unpackOp.getSource(), readOffsets, readSizes, readStrides);

  // 2. Transpose the tile to match the outer corresponding tile order.
  ArrayRef<int64_t> innerDimsPos = unpackOp.getInnerDimsPos();
  SmallVector<int64_t> perm =
      getPackUnpackNormalizedInnerPerm(srcRank, innerDimsPos);
  SmallVector<int64_t> perm = getPackUnpackRankReducedPerm(
      srcShape.take_front(destRank), innerDimsPos, unpackOp.getOuterDimsPerm());
  // Unpack is a transition out of packed space so we invert the permutation.
  perm = invertPermutationVector(perm);
  SmallVector<int64_t> transpShape(readShape);
  applyPermutationToVector<int64_t>(transpShape, perm);

@@ -1375,11 +1456,13 @@ LogicalResult GeneralizeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
  SmallVector<OpFoldResult> tileStrides(numLoops, oneIdxAttr);
  SmallVector<OpFoldResult> tileOffsets(numLoops, zeroIdxAttr);
  SmallVector<OpFoldResult> tileSizes;
  for (int dim : innerDimsPos)
  ArrayRef<int64_t> destShape = unpackOp.getDestType().getShape();
  for (auto i : llvm::seq<unsigned>(0, destRank)) {
    if (dimAndTileMapping.count(i) || destShape[i] != 1)
      tileSizes.push_back(getAsOpFoldResult(
        rewriter.createOrFold<tensor::DimOp>(loc, unpackOp.getDest(), dim)));
          rewriter.createOrFold<tensor::DimOp>(loc, unpackOp.getDest(), i)));
  }

  applyPermutationToVector<OpFoldResult>(tileSizes, perm);
  auto partialTile = rewriter.create<tensor::ExtractSliceOp>(
      loc, transposedOp.getResult()[0], tileOffsets, tileSizes, tileStrides);

@@ -1387,10 +1470,8 @@ LogicalResult GeneralizeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
  SmallVector<OpFoldResult> writeSizes;
  SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr);
  SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr);
  DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
      unpackOp.getDimAndTileMapping();
  for (int i = 0, idx = 0; i < destRank; ++i) {
    if (dimAndTileMapping.count(i))
    if (dimAndTileMapping.count(i) || destShape[i] != 1)
      writeSizes.push_back(tileSizes[idx++]);
    else
      writeSizes.push_back(oneIdxAttr);
+19 −0
Original line number Diff line number Diff line
@@ -76,3 +76,22 @@ func.func @simple_CHW_to_CHWhwc(%arg0: tensor<3x5x7xf32>, %arg1: tensor<1x1x1x5x
// CHECK:         %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]]
// CHECK-SAME:      [0, 0, 0, 0, 0, 0] [1, 1, 1, 5, 7, 3] [1, 1, 1, 1, 1, 1]
// CHECK:         return %[[INSERT]]

// -----

func.func @simple_KCRS_to_KRSCsr(%arg0: tensor<3x1x32x8xf32>, %arg1: tensor<3x1x1x1x8x32xf32>) -> tensor<3x1x1x1x8x32xf32> {
  %0 = tensor.pack %arg0 outer_dims_perm = [0, 2, 3, 1] inner_dims_pos = [3, 2] inner_tiles = [8, 32] into %arg1 : tensor<3x1x32x8xf32> -> tensor<3x1x1x1x8x32xf32>
  return %0 : tensor<3x1x1x1x8x32xf32>
}
// CHECK-LABEL: func.func @simple_KCRS_to_KRSCsr
// CHECK-SAME:    %[[SRC:[a-zA-Z0-9]+]]
// CHECK-SAME:    %[[DEST:[a-zA-Z0-9]+]]
// CHECK:         %[[TILE:.+]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0] [3, 1, 32, 8] [1, 1, 1, 1]
// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<3x8x32xf32>
// CHECK:         %[[TRANSP:.+]] =  linalg.transpose
// CHECK-SAME:      ins(%[[TILE]] : tensor<3x32x8xf32>)
// CHECK-SAME:      outs(%[[EMPTY]] : tensor<3x8x32xf32>)
// CHECK-SAME:      permutation = [0, 2, 1]
// CHECK:         %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]]
// CHECK-SAME:      [0, 0, 0, 0, 0, 0] [3, 1, 1, 1, 8, 32] [1, 1, 1, 1, 1, 1]
// CHECK:         return %[[INSERT]]
+39 −0
Original line number Diff line number Diff line
@@ -55,3 +55,42 @@ func.func @simple_CNnc_to_NC(%arg0: tensor<1x1x32x8xf32>, %arg1: tensor<32x8xf32
//                They have the same type, so the insert_slice op is folded
//                away.
// CHECK:         return %[[TRANSP]]

// -----

func.func @simple_NCHWc_to_NCHW(%arg0: tensor<2x1x16x8x32xf32>, %arg1: tensor<2x32x16x8xf32>) -> tensor<2x32x16x8xf32> {
  %0 = tensor.unpack %arg0 inner_dims_pos = [1] inner_tiles = [32] into %arg1 : tensor<2x1x16x8x32xf32> -> tensor<2x32x16x8xf32>
  return %0 : tensor<2x32x16x8xf32>
}
// CHECK-LABEL: func.func @simple_NCHWc_to_NCHW
// CHECK-SAME:    %[[SRC:[a-zA-Z0-9]+]]
// CHECK-SAME:    %[[DEST:[a-zA-Z0-9]+]]
// CHECK:         %[[TILE:.+]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0, 0] [2, 1, 16, 8, 32] [1, 1, 1, 1, 1]
// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<2x32x16x8xf32>
// CHECK:         %[[TRANSP:.+]] =  linalg.transpose
// CHECK-SAME:      ins(%[[TILE]] : tensor<2x16x8x32xf32>)
// CHECK-SAME:      outs(%[[EMPTY]] : tensor<2x32x16x8xf32>)
// CHECK-SAME:      permutation = [0, 3, 1, 2]
//                They have the same type, so the insert_slice op is folded
//                away.
// CHECK:         return %[[TRANSP]]


// -----

func.func @simple_NHWC_to_NCHW(%arg0: tensor<1x16x8x32xf32>, %arg1: tensor<1x32x16x8xf32>) -> tensor<1x32x16x8xf32> {
  %0 = tensor.unpack %arg0 outer_dims_perm = [0, 2, 3, 1] inner_dims_pos = [] inner_tiles = [] into %arg1 : tensor<1x16x8x32xf32> -> tensor<1x32x16x8xf32>
  return %0 : tensor<1x32x16x8xf32>
}
// CHECK-LABEL: func.func @simple_NHWC_to_NCHW
// CHECK-SAME:    %[[SRC:[a-zA-Z0-9]+]]
// CHECK-SAME:    %[[DEST:[a-zA-Z0-9]+]]
// CHECK:         %[[TILE:.+]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0] [1, 16, 8, 32] [1, 1, 1, 1]
// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<32x16x8xf32>
// CHECK:         %[[TRANSP:.+]] =  linalg.transpose
// CHECK-SAME:      ins(%[[TILE]] : tensor<16x8x32xf32>)
// CHECK-SAME:      outs(%[[EMPTY]] : tensor<32x16x8xf32>)
// CHECK-SAME:      permutation = [2, 0, 1]
// CHECK:         %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]]
// CHECK-SAME:      [0, 0, 0, 0] [1, 32, 16, 8] [1, 1, 1, 1]
// CHECK:         return %[[INSERT]]