Unverified Commit 5c3ed392 authored by Aviad Cohen's avatar Aviad Cohen Committed by GitHub
Browse files

[mlir][linalg] Enable CollapseLinalgDimensions to collapse linalg::CopyOp (#68526)

parent 86bc4867
Loading
Loading
Loading
Loading
+10 −8
Original line number Diff line number Diff line
@@ -1047,15 +1047,17 @@ bool isDimSequencePreserved(AffineMap map, ReassociationIndicesRef dimSequence);
bool areDimSequencesPreserved(ArrayRef<AffineMap> maps,
                              ArrayRef<ReassociationIndices> dimSequences);

/// Collapses dimensions of linalg.generic operation. A precondition to
/// calling this method is that for each list in `foldedIterationDim`, the
/// Collapses dimensions of linalg.generic/linalg.copy operation. A precondition
/// to calling this method is that for each list in `foldedIterationDim`, the
/// sequence of dimensions is contiguous in domains of all `indexing_maps` of
/// the `genericOp`. This can be checked using `areDimSequencePreserved` method.
/// the `linalgOp`. This can be checked using `areDimSequencePreserved` method.
/// When valid, the method also collapses the operands of the op. Returns
/// replacement values of the results of the original `genericOp` by inserting
/// replacement values of the results of the original `linalgOp` by inserting
/// reshapes to get back values of compatible types.
FailureOr<SmallVector<Value>> collapseGenericOpIterationDims(
    GenericOp genericOp, ArrayRef<ReassociationIndices> foldedIterationDims,
template <typename LinalgType>
FailureOr<SmallVector<Value>>
collapseOpIterationDims(LinalgType op,
                        ArrayRef<ReassociationIndices> foldedIterationDims,
                        RewriterBase &rewriter);

struct LowerPackResult {
@@ -1515,7 +1517,7 @@ void populateEraseUnnecessaryInputsPatterns(RewritePatternSet &patterns);
/// to return an array of `ReassociationIndices` representing dimensions that
/// should be merged.
using GetCollapsableDimensionsFn =
    std::function<SmallVector<ReassociationIndices>(linalg::GenericOp)>;
    std::function<SmallVector<ReassociationIndices>(linalg::LinalgOp)>;

/// Pattern to collapse dimensions in a linalg.generic op. This will collapse
/// tensor operands when needed and expand back the result tensors.
+103 −84
Original line number Diff line number Diff line
@@ -1373,16 +1373,17 @@ getOperandReassociation(AffineMap indexingMap,
}

/// Get the new value to use for a given `OpOperand` in the collapsed operation.
static Value getCollapsedOpOperand(Location loc, GenericOp genericOp,
static Value getCollapsedOpOperand(Location loc, LinalgOp op,
                                   OpOperand *opOperand,
                                   const CollapsingInfo &collapsingInfo,
                                   OpBuilder &builder) {
  AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
  AffineMap indexingMap = op.getMatchingIndexingMap(opOperand);
  SmallVector<ReassociationIndices> operandReassociation =
      getOperandReassociation(indexingMap, collapsingInfo);

  // If the number of entries in the reassocation for the operand is same as the
  // number of results of the indexing map, then nothing to do for this operand.
  // If the number of entries in the reassociation for the operand is same as
  // the number of results of the indexing map, then nothing to do for this
  // operand.
  Value operand = opOperand->get();
  if (operandReassociation.size() == indexingMap.getNumResults())
    return operand;
@@ -1439,20 +1440,80 @@ void generateCollapsedIndexingRegion(Location loc, Block *block,
  }
}

template <typename LinalgType>
Operation *createCollapsedOp(LinalgType op,
                             const CollapsingInfo &collapsingInfo,
                             RewriterBase &rewriter) {
  static_assert(llvm::is_one_of<LinalgType, GenericOp, CopyOp>::value,
                "unsupported linalg op type to create");
  Location loc = op->getLoc();

  // Get the input operands.
  SmallVector<Value> inputOperands =
      llvm::map_to_vector(op.getDpsInputOperands(), [&](OpOperand *opOperand) {
        return getCollapsedOpOperand(loc, op, opOperand, collapsingInfo,
                                     rewriter);
      });

  // Get the output operands and result types.
  SmallVector<Type> resultTypes;
  SmallVector<Value> outputOperands;
  resultTypes.reserve(op.getNumDpsInits());
  outputOperands.reserve(op.getNumDpsInits());
  for (OpOperand &output : op.getDpsInitsMutable()) {
    Value newOutput =
        getCollapsedOpOperand(loc, op, &output, collapsingInfo, rewriter);
    outputOperands.push_back(newOutput);
    // If the op has "buffer semantics", then the init operands are ranked
    // memrefs and the op has no results.
    if (!op.hasBufferSemantics())
      resultTypes.push_back(newOutput.getType());
  }

  if (isa<linalg::CopyOp>(op)) {
    return rewriter.create<linalg::CopyOp>(loc, inputOperands[0],
                                           outputOperands[0]);
  }

  // Get the iterator types for the operand.
  SmallVector<utils::IteratorType> iteratorTypes =
      getCollapsedOpIteratorTypes(op.getIteratorTypesArray(), collapsingInfo);

  // Get the indexing maps.
  auto indexingMaps =
      llvm::map_to_vector(op.getIndexingMapsArray(), [&](AffineMap map) {
        return getCollapsedOpIndexingMap(map, collapsingInfo);
      });

  Operation *collapsedOp = rewriter.create<linalg::GenericOp>(
      loc, resultTypes, inputOperands, outputOperands, indexingMaps,
      iteratorTypes, [](OpBuilder &builder, Location loc, ValueRange args) {});
  Block *origOpBlock = &op->getRegion(0).front();
  Block *collapsedOpBlock = &collapsedOp->getRegion(0).front();
  rewriter.mergeBlocks(origOpBlock, collapsedOpBlock,
                       collapsedOpBlock->getArguments());

  return collapsedOp;
}

/// Implementation of fusion with reshape operation by collapsing dimensions.
FailureOr<SmallVector<Value>> mlir::linalg::collapseGenericOpIterationDims(
    GenericOp genericOp, ArrayRef<ReassociationIndices> foldedIterationDims,
template <typename LinalgType>
FailureOr<SmallVector<Value>> mlir::linalg::collapseOpIterationDims(
    LinalgType op, ArrayRef<ReassociationIndices> foldedIterationDims,
    RewriterBase &rewriter) {
  static_assert(llvm::is_one_of<LinalgType, GenericOp, CopyOp>::value,
                "unsupported linalg op type to collapse");

  // Bail on trivial no-op cases.
  if (genericOp.getNumLoops() <= 1 || foldedIterationDims.empty() ||
  if (op.getNumLoops() <= 1 || foldedIterationDims.empty() ||
      llvm::all_of(foldedIterationDims, [](ReassociationIndicesRef foldedDims) {
        return foldedDims.size() <= 1;
      }))
    return failure();

  bool hasBufferSemantics = genericOp.hasBufferSemantics();
  bool hasBufferSemantics = op.hasBufferSemantics();
  if (hasBufferSemantics &&
      !llvm::all_of(genericOp->getOperands(), [&](Value operand) -> bool {
      !llvm::all_of(op->getOperands(), [&](Value operand) -> bool {
        MemRefType memRefToCollapse = dyn_cast<MemRefType>(operand.getType());
        if (!memRefToCollapse)
          return true;
@@ -1460,20 +1521,19 @@ FailureOr<SmallVector<Value>> mlir::linalg::collapseGenericOpIterationDims(
        return memref::CollapseShapeOp::isGuaranteedCollapsible(
            memRefToCollapse, foldedIterationDims);
      }))
    return rewriter.notifyMatchFailure(genericOp,
    return rewriter.notifyMatchFailure(op,
                                       "memref is not guaranteed collapsible");

  CollapsingInfo collapsingInfo;
  if (failed(collapsingInfo.initialize(genericOp.getNumLoops(),
                                       foldedIterationDims))) {
  if (failed(
          collapsingInfo.initialize(op.getNumLoops(), foldedIterationDims))) {
    return rewriter.notifyMatchFailure(
        genericOp, "illegal to collapse specified dimensions");
        op, "illegal to collapse specified dimensions");
  }

  // Bail on non-canonical ranges.
  SmallVector<Range> loopRanges =
      cast<LinalgOp>(genericOp.getOperation())
          .createLoopRanges(rewriter, genericOp.getLoc());
      cast<LinalgOp>(op.getOperation()).createLoopRanges(rewriter, op.getLoc());
  auto opFoldIsConstantValue = [](OpFoldResult ofr, int64_t value) {
    if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr))
      return cast<IntegerAttr>(attr).getInt() == value;
@@ -1486,78 +1546,36 @@ FailureOr<SmallVector<Value>> mlir::linalg::collapseGenericOpIterationDims(
               opFoldIsConstantValue(range.stride, 1);
      })) {
    return rewriter.notifyMatchFailure(
        genericOp,
        "expected all loop ranges to have zero start and unit stride");
        op, "expected all loop ranges to have zero start and unit stride");
  }

  // Get the iterator types for the operand.
  SmallVector<utils::IteratorType> iteratorTypes = getCollapsedOpIteratorTypes(
      genericOp.getIteratorTypesArray(), collapsingInfo);

  // Get the indexing maps.
  auto indexingMaps = llvm::to_vector(
      llvm::map_range(genericOp.getIndexingMapsArray(), [&](AffineMap map) {
        return getCollapsedOpIndexingMap(map, collapsingInfo);
      }));

  Location loc = genericOp->getLoc();

  // Get the input operands.
  auto inputOperands = llvm::to_vector(llvm::map_range(
      genericOp.getDpsInputOperands(), [&](OpOperand *opOperand) {
        return getCollapsedOpOperand(loc, genericOp, opOperand, collapsingInfo,
                                     rewriter);
      }));
  LinalgType collapsedOp = cast<LinalgType>(
      createCollapsedOp<LinalgType>(op, collapsingInfo, rewriter));

  // Get the output operands and result types.
  SmallVector<Type> resultTypes;
  SmallVector<Value> outputOperands;
  resultTypes.reserve(genericOp.getNumDpsInits());
  outputOperands.reserve(genericOp.getNumDpsInits());
  for (OpOperand &output : genericOp.getDpsInitsMutable()) {
    Value newOutput = getCollapsedOpOperand(loc, genericOp, &output,
                                            collapsingInfo, rewriter);
    outputOperands.push_back(newOutput);
    // If the op has "buffer semantics", then the init operands are ranked
    // memrefs and the op has no results.
    if (!hasBufferSemantics)
      resultTypes.push_back(newOutput.getType());
  }

  // Create the generic op.
  auto collapsedGenericOp = rewriter.create<linalg::GenericOp>(
      loc, resultTypes, inputOperands, outputOperands, indexingMaps,
      iteratorTypes, [](OpBuilder &builder, Location loc, ValueRange args) {});
  Block *origOpBlock = &genericOp->getRegion(0).front();
  Block *collapsedOpBlock = &collapsedGenericOp->getRegion(0).front();
  rewriter.mergeBlocks(origOpBlock, collapsedOpBlock,
                       collapsedOpBlock->getArguments());

  if (collapsedGenericOp.hasIndexSemantics()) {
  Location loc = op->getLoc();
  if (collapsedOp.hasIndexSemantics()) {
    // Collect the loop range of the generic op.
    OpBuilder::InsertionGuard g(rewriter);
    rewriter.setInsertionPoint(collapsedGenericOp);
    rewriter.setInsertionPoint(collapsedOp);
    SmallVector<Value> loopBound =
        llvm::to_vector(llvm::map_range(loopRanges, [&](Range range) {
        llvm::map_to_vector(loopRanges, [&](Range range) {
          return getValueOrCreateConstantIndexOp(rewriter, loc, range.size);
        }));
    generateCollapsedIndexingRegion(loc,
                                    &collapsedGenericOp->getRegion(0).front(),
        });
    generateCollapsedIndexingRegion(loc, &collapsedOp->getRegion(0).front(),
                                    collapsingInfo, loopBound, rewriter);
  }

  // Insert expanding reshape for the result to get back the original result
  // type.
  SmallVector<Value> results;
  for (const auto &originalResult : llvm::enumerate(genericOp->getResults())) {
    Value collapsedOpResult =
        collapsedGenericOp->getResult(originalResult.index());
  for (const auto &originalResult : llvm::enumerate(op->getResults())) {
    Value collapsedOpResult = collapsedOp->getResult(originalResult.index());
    auto originalResultType =
        cast<ShapedType>(originalResult.value().getType());
    auto collapsedOpResultType = cast<ShapedType>(collapsedOpResult.getType());
    if (collapsedOpResultType.getRank() != originalResultType.getRank()) {
      AffineMap indexingMap =
          genericOp.getIndexingMapMatchingResult(originalResult.value());
          op.getIndexingMapMatchingResult(originalResult.value());
      SmallVector<ReassociationIndices> reassociation =
          getOperandReassociation(indexingMap, collapsingInfo);
      if (isa<MemRefType>(collapsedOpResult.getType())) {
@@ -1606,8 +1624,8 @@ public:
      }

      std::optional<SmallVector<Value>> replacements =
          collapseGenericOpIterationDims(genericOp, collapsableIterationDims,
                                         rewriter);
          collapseOpIterationDims<linalg::GenericOp>(
              genericOp, collapsableIterationDims, rewriter);
      if (!replacements) {
        return rewriter.notifyMatchFailure(
            genericOp, "failed to do the fusion by collapsing transformation");
@@ -1624,36 +1642,36 @@ private:
};

/// Pattern to collapse dimensions.
class CollapseLinalgDimensions : public OpRewritePattern<GenericOp> {
template <typename LinalgType>
class CollapseLinalgDimensions : public OpRewritePattern<LinalgType> {
public:
  CollapseLinalgDimensions(MLIRContext *context,
                           GetCollapsableDimensionsFn collapseDimensions,
                           PatternBenefit benefit = 1)
      : OpRewritePattern<GenericOp>(context, benefit),
      : OpRewritePattern<LinalgType>(context, benefit),
        controlCollapseDimension(std::move(collapseDimensions)) {}

  LogicalResult matchAndRewrite(GenericOp genericOp,
  LogicalResult matchAndRewrite(LinalgType op,
                                PatternRewriter &rewriter) const override {
    SmallVector<ReassociationIndices> collapsableIterationDims =
        controlCollapseDimension(genericOp);
        controlCollapseDimension(op);
    if (collapsableIterationDims.empty())
      return failure();

    // Check if the specified list of dimensions to collapse is a valid list.
    if (!areDimSequencesPreserved(genericOp.getIndexingMapsArray(),
    if (!areDimSequencesPreserved(op.getIndexingMapsArray(),
                                  collapsableIterationDims)) {
      return rewriter.notifyMatchFailure(
          genericOp, "specified dimensions cannot be collapsed");
          op, "specified dimensions cannot be collapsed");
    }

    std::optional<SmallVector<Value>> replacements =
        collapseGenericOpIterationDims(genericOp, collapsableIterationDims,
        collapseOpIterationDims<LinalgType>(op, collapsableIterationDims,
                                            rewriter);
    if (!replacements) {
      return rewriter.notifyMatchFailure(genericOp,
                                         "failed to collapse dimensions");
      return rewriter.notifyMatchFailure(op, "failed to collapse dimensions");
    }
    rewriter.replaceOp(genericOp, *replacements);
    rewriter.replaceOp(op, *replacements);
    return success();
  }

@@ -1884,8 +1902,9 @@ void mlir::linalg::populateElementwiseOpsFusionPatterns(
void mlir::linalg::populateCollapseDimensions(
    RewritePatternSet &patterns,
    const GetCollapsableDimensionsFn &controlCollapseDimensions) {
  patterns.add<CollapseLinalgDimensions>(patterns.getContext(),
                                         controlCollapseDimensions);
  patterns.add<CollapseLinalgDimensions<linalg::GenericOp>,
               CollapseLinalgDimensions<linalg::CopyOp>>(
      patterns.getContext(), controlCollapseDimensions);
}

//===---------------------------------------------------------------------===//
+37 −0
Original line number Diff line number Diff line
@@ -116,3 +116,40 @@ func.func @uncollapsable_strided_memref(%arg0: memref<2x6x24x48xi32>, %arg1: mem
  }
  return %alloc : memref<2x6x24x48xi32>
}

// -----

// CHECK-LABEL:   func.func @linalg_copy(
// CHECK-SAME:                           %[[VAL_0:.*]]: tensor<1x2x3x4x5xf32, 1 : i64>,
// CHECK-SAME:                           %[[VAL_1:.*]]: tensor<1x2x3x4x5xf32, 3 : i64>) -> tensor<1x2x3x4x5xf32, 3 : i64> {
// CHECK:           %[[VAL_2:.*]] = tensor.collapse_shape %[[VAL_0]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x3x4x5xf32, 1 : i64> into tensor<1x2x12x5xf32>
// CHECK:           %[[VAL_3:.*]] = tensor.collapse_shape %[[VAL_1]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x3x4x5xf32, 3 : i64> into tensor<1x2x12x5xf32>
// CHECK:           %[[VAL_4:.*]] = tensor.collapse_shape %[[VAL_2]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x12x5xf32> into tensor<1x2x60xf32>
// CHECK:           %[[VAL_5:.*]] = tensor.collapse_shape %[[VAL_3]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x12x5xf32> into tensor<1x2x60xf32>
// CHECK:           %[[VAL_6:.*]] = linalg.copy ins(%[[VAL_4]] : tensor<1x2x60xf32>) outs(%[[VAL_5]] : tensor<1x2x60xf32>) -> tensor<1x2x60xf32>
// CHECK:           %[[VAL_7:.*]] = tensor.expand_shape %[[VAL_6]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x60xf32> into tensor<1x2x12x5xf32>
// CHECK:           %[[VAL_8:.*]] = tensor.expand_shape %[[VAL_7]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x12x5xf32> into tensor<1x2x3x4x5xf32, 3 : i64>
// CHECK:           return %[[VAL_8]] : tensor<1x2x3x4x5xf32, 3 : i64>
// CHECK:         }

func.func @linalg_copy(
    %arg0: tensor<1x2x3x4x5xf32, 1>, %arg1: tensor<1x2x3x4x5xf32, 3>) -> tensor<1x2x3x4x5xf32, 3> {
  %0 = linalg.copy ins(%arg0: tensor<1x2x3x4x5xf32, 1>) outs(%arg1: tensor<1x2x3x4x5xf32, 3>) -> tensor<1x2x3x4x5xf32, 3>
  return %0 : tensor<1x2x3x4x5xf32, 3>
}

// -----

// CHECK-LABEL:   func.func private @memref_linalg_copy(
// CHECK-SAME:                                          %[[VAL_0:.*]]: memref<1x24x32x8xf32, 1>,
// CHECK-SAME:                                          %[[VAL_1:.*]]: memref<1x24x32x8xf32, 1>) {
// CHECK:           %[[VAL_2:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0], [1], [2, 3]] : memref<1x24x32x8xf32, 1> into memref<1x24x256xf32, 1>
// CHECK:           %[[VAL_3:.*]] = memref.collapse_shape %[[VAL_1]] {{\[\[}}0], [1], [2, 3]] : memref<1x24x32x8xf32, 1> into memref<1x24x256xf32, 1>
// CHECK:           linalg.copy ins(%[[VAL_2]] : memref<1x24x256xf32, 1>) outs(%[[VAL_3]] : memref<1x24x256xf32, 1>)
// CHECK:           return
// CHECK:         }

func.func private @memref_linalg_copy(%arg0: memref<1x24x32x8xf32, 1>, %arg1: memref<1x24x32x8xf32, 1>) {
  linalg.copy ins(%arg0: memref<1x24x32x8xf32, 1>) outs(%arg1: memref<1x24x32x8xf32, 1>)
  return
}
+1 −1
Original line number Diff line number Diff line
@@ -258,7 +258,7 @@ struct TestLinalgElementwiseFusion
      SmallVector<int64_t, 2> dims(collapseDimensions.begin(),
                                   collapseDimensions.end());
      linalg::GetCollapsableDimensionsFn collapseFn =
          [&dims](linalg::GenericOp op) {
          [&dims](linalg::LinalgOp op) {
            SmallVector<ReassociationIndices> reassociations;
            reassociations.emplace_back(dims);
            return reassociations;