Loading mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +10 −8 Original line number Diff line number Diff line Loading @@ -1047,15 +1047,17 @@ bool isDimSequencePreserved(AffineMap map, ReassociationIndicesRef dimSequence); bool areDimSequencesPreserved(ArrayRef<AffineMap> maps, ArrayRef<ReassociationIndices> dimSequences); /// Collapses dimensions of linalg.generic operation. A precondition to /// calling this method is that for each list in `foldedIterationDim`, the /// Collapses dimensions of linalg.generic/linalg.copy operation. A precondition /// to calling this method is that for each list in `foldedIterationDim`, the /// sequence of dimensions is contiguous in domains of all `indexing_maps` of /// the `genericOp`. This can be checked using `areDimSequencePreserved` method. /// the `linalgOp`. This can be checked using `areDimSequencePreserved` method. /// When valid, the method also collapses the operands of the op. Returns /// replacement values of the results of the original `genericOp` by inserting /// replacement values of the results of the original `linalgOp` by inserting /// reshapes to get back values of compatible types. FailureOr<SmallVector<Value>> collapseGenericOpIterationDims( GenericOp genericOp, ArrayRef<ReassociationIndices> foldedIterationDims, template <typename LinalgType> FailureOr<SmallVector<Value>> collapseOpIterationDims(LinalgType op, ArrayRef<ReassociationIndices> foldedIterationDims, RewriterBase &rewriter); struct LowerPackResult { Loading Loading @@ -1515,7 +1517,7 @@ void populateEraseUnnecessaryInputsPatterns(RewritePatternSet &patterns); /// to return an array of `ReassociationIndices` representing dimensions that /// should be merged. using GetCollapsableDimensionsFn = std::function<SmallVector<ReassociationIndices>(linalg::GenericOp)>; std::function<SmallVector<ReassociationIndices>(linalg::LinalgOp)>; /// Pattern to collapse dimensions in a linalg.generic op. This will collapse /// tensor operands when needed and expand back the result tensors. Loading mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +103 −84 Original line number Diff line number Diff line Loading @@ -1373,16 +1373,17 @@ getOperandReassociation(AffineMap indexingMap, } /// Get the new value to use for a given `OpOperand` in the collapsed operation. static Value getCollapsedOpOperand(Location loc, GenericOp genericOp, static Value getCollapsedOpOperand(Location loc, LinalgOp op, OpOperand *opOperand, const CollapsingInfo &collapsingInfo, OpBuilder &builder) { AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand); AffineMap indexingMap = op.getMatchingIndexingMap(opOperand); SmallVector<ReassociationIndices> operandReassociation = getOperandReassociation(indexingMap, collapsingInfo); // If the number of entries in the reassocation for the operand is same as the // number of results of the indexing map, then nothing to do for this operand. // If the number of entries in the reassociation for the operand is same as // the number of results of the indexing map, then nothing to do for this // operand. Value operand = opOperand->get(); if (operandReassociation.size() == indexingMap.getNumResults()) return operand; Loading Loading @@ -1439,20 +1440,80 @@ void generateCollapsedIndexingRegion(Location loc, Block *block, } } template <typename LinalgType> Operation *createCollapsedOp(LinalgType op, const CollapsingInfo &collapsingInfo, RewriterBase &rewriter) { static_assert(llvm::is_one_of<LinalgType, GenericOp, CopyOp>::value, "unsupported linalg op type to create"); Location loc = op->getLoc(); // Get the input operands. SmallVector<Value> inputOperands = llvm::map_to_vector(op.getDpsInputOperands(), [&](OpOperand *opOperand) { return getCollapsedOpOperand(loc, op, opOperand, collapsingInfo, rewriter); }); // Get the output operands and result types. SmallVector<Type> resultTypes; SmallVector<Value> outputOperands; resultTypes.reserve(op.getNumDpsInits()); outputOperands.reserve(op.getNumDpsInits()); for (OpOperand &output : op.getDpsInitsMutable()) { Value newOutput = getCollapsedOpOperand(loc, op, &output, collapsingInfo, rewriter); outputOperands.push_back(newOutput); // If the op has "buffer semantics", then the init operands are ranked // memrefs and the op has no results. if (!op.hasBufferSemantics()) resultTypes.push_back(newOutput.getType()); } if (isa<linalg::CopyOp>(op)) { return rewriter.create<linalg::CopyOp>(loc, inputOperands[0], outputOperands[0]); } // Get the iterator types for the operand. SmallVector<utils::IteratorType> iteratorTypes = getCollapsedOpIteratorTypes(op.getIteratorTypesArray(), collapsingInfo); // Get the indexing maps. auto indexingMaps = llvm::map_to_vector(op.getIndexingMapsArray(), [&](AffineMap map) { return getCollapsedOpIndexingMap(map, collapsingInfo); }); Operation *collapsedOp = rewriter.create<linalg::GenericOp>( loc, resultTypes, inputOperands, outputOperands, indexingMaps, iteratorTypes, [](OpBuilder &builder, Location loc, ValueRange args) {}); Block *origOpBlock = &op->getRegion(0).front(); Block *collapsedOpBlock = &collapsedOp->getRegion(0).front(); rewriter.mergeBlocks(origOpBlock, collapsedOpBlock, collapsedOpBlock->getArguments()); return collapsedOp; } /// Implementation of fusion with reshape operation by collapsing dimensions. FailureOr<SmallVector<Value>> mlir::linalg::collapseGenericOpIterationDims( GenericOp genericOp, ArrayRef<ReassociationIndices> foldedIterationDims, template <typename LinalgType> FailureOr<SmallVector<Value>> mlir::linalg::collapseOpIterationDims( LinalgType op, ArrayRef<ReassociationIndices> foldedIterationDims, RewriterBase &rewriter) { static_assert(llvm::is_one_of<LinalgType, GenericOp, CopyOp>::value, "unsupported linalg op type to collapse"); // Bail on trivial no-op cases. if (genericOp.getNumLoops() <= 1 || foldedIterationDims.empty() || if (op.getNumLoops() <= 1 || foldedIterationDims.empty() || llvm::all_of(foldedIterationDims, [](ReassociationIndicesRef foldedDims) { return foldedDims.size() <= 1; })) return failure(); bool hasBufferSemantics = genericOp.hasBufferSemantics(); bool hasBufferSemantics = op.hasBufferSemantics(); if (hasBufferSemantics && !llvm::all_of(genericOp->getOperands(), [&](Value operand) -> bool { !llvm::all_of(op->getOperands(), [&](Value operand) -> bool { MemRefType memRefToCollapse = dyn_cast<MemRefType>(operand.getType()); if (!memRefToCollapse) return true; Loading @@ -1460,20 +1521,19 @@ FailureOr<SmallVector<Value>> mlir::linalg::collapseGenericOpIterationDims( return memref::CollapseShapeOp::isGuaranteedCollapsible( memRefToCollapse, foldedIterationDims); })) return rewriter.notifyMatchFailure(genericOp, return rewriter.notifyMatchFailure(op, "memref is not guaranteed collapsible"); CollapsingInfo collapsingInfo; if (failed(collapsingInfo.initialize(genericOp.getNumLoops(), foldedIterationDims))) { if (failed( collapsingInfo.initialize(op.getNumLoops(), foldedIterationDims))) { return rewriter.notifyMatchFailure( genericOp, "illegal to collapse specified dimensions"); op, "illegal to collapse specified dimensions"); } // Bail on non-canonical ranges. SmallVector<Range> loopRanges = cast<LinalgOp>(genericOp.getOperation()) .createLoopRanges(rewriter, genericOp.getLoc()); cast<LinalgOp>(op.getOperation()).createLoopRanges(rewriter, op.getLoc()); auto opFoldIsConstantValue = [](OpFoldResult ofr, int64_t value) { if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr)) return cast<IntegerAttr>(attr).getInt() == value; Loading @@ -1486,78 +1546,36 @@ FailureOr<SmallVector<Value>> mlir::linalg::collapseGenericOpIterationDims( opFoldIsConstantValue(range.stride, 1); })) { return rewriter.notifyMatchFailure( genericOp, "expected all loop ranges to have zero start and unit stride"); op, "expected all loop ranges to have zero start and unit stride"); } // Get the iterator types for the operand. SmallVector<utils::IteratorType> iteratorTypes = getCollapsedOpIteratorTypes( genericOp.getIteratorTypesArray(), collapsingInfo); // Get the indexing maps. auto indexingMaps = llvm::to_vector( llvm::map_range(genericOp.getIndexingMapsArray(), [&](AffineMap map) { return getCollapsedOpIndexingMap(map, collapsingInfo); })); Location loc = genericOp->getLoc(); // Get the input operands. auto inputOperands = llvm::to_vector(llvm::map_range( genericOp.getDpsInputOperands(), [&](OpOperand *opOperand) { return getCollapsedOpOperand(loc, genericOp, opOperand, collapsingInfo, rewriter); })); LinalgType collapsedOp = cast<LinalgType>( createCollapsedOp<LinalgType>(op, collapsingInfo, rewriter)); // Get the output operands and result types. SmallVector<Type> resultTypes; SmallVector<Value> outputOperands; resultTypes.reserve(genericOp.getNumDpsInits()); outputOperands.reserve(genericOp.getNumDpsInits()); for (OpOperand &output : genericOp.getDpsInitsMutable()) { Value newOutput = getCollapsedOpOperand(loc, genericOp, &output, collapsingInfo, rewriter); outputOperands.push_back(newOutput); // If the op has "buffer semantics", then the init operands are ranked // memrefs and the op has no results. if (!hasBufferSemantics) resultTypes.push_back(newOutput.getType()); } // Create the generic op. auto collapsedGenericOp = rewriter.create<linalg::GenericOp>( loc, resultTypes, inputOperands, outputOperands, indexingMaps, iteratorTypes, [](OpBuilder &builder, Location loc, ValueRange args) {}); Block *origOpBlock = &genericOp->getRegion(0).front(); Block *collapsedOpBlock = &collapsedGenericOp->getRegion(0).front(); rewriter.mergeBlocks(origOpBlock, collapsedOpBlock, collapsedOpBlock->getArguments()); if (collapsedGenericOp.hasIndexSemantics()) { Location loc = op->getLoc(); if (collapsedOp.hasIndexSemantics()) { // Collect the loop range of the generic op. OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(collapsedGenericOp); rewriter.setInsertionPoint(collapsedOp); SmallVector<Value> loopBound = llvm::to_vector(llvm::map_range(loopRanges, [&](Range range) { llvm::map_to_vector(loopRanges, [&](Range range) { return getValueOrCreateConstantIndexOp(rewriter, loc, range.size); })); generateCollapsedIndexingRegion(loc, &collapsedGenericOp->getRegion(0).front(), }); generateCollapsedIndexingRegion(loc, &collapsedOp->getRegion(0).front(), collapsingInfo, loopBound, rewriter); } // Insert expanding reshape for the result to get back the original result // type. SmallVector<Value> results; for (const auto &originalResult : llvm::enumerate(genericOp->getResults())) { Value collapsedOpResult = collapsedGenericOp->getResult(originalResult.index()); for (const auto &originalResult : llvm::enumerate(op->getResults())) { Value collapsedOpResult = collapsedOp->getResult(originalResult.index()); auto originalResultType = cast<ShapedType>(originalResult.value().getType()); auto collapsedOpResultType = cast<ShapedType>(collapsedOpResult.getType()); if (collapsedOpResultType.getRank() != originalResultType.getRank()) { AffineMap indexingMap = genericOp.getIndexingMapMatchingResult(originalResult.value()); op.getIndexingMapMatchingResult(originalResult.value()); SmallVector<ReassociationIndices> reassociation = getOperandReassociation(indexingMap, collapsingInfo); if (isa<MemRefType>(collapsedOpResult.getType())) { Loading Loading @@ -1606,8 +1624,8 @@ public: } std::optional<SmallVector<Value>> replacements = collapseGenericOpIterationDims(genericOp, collapsableIterationDims, rewriter); collapseOpIterationDims<linalg::GenericOp>( genericOp, collapsableIterationDims, rewriter); if (!replacements) { return rewriter.notifyMatchFailure( genericOp, "failed to do the fusion by collapsing transformation"); Loading @@ -1624,36 +1642,36 @@ private: }; /// Pattern to collapse dimensions. class CollapseLinalgDimensions : public OpRewritePattern<GenericOp> { template <typename LinalgType> class CollapseLinalgDimensions : public OpRewritePattern<LinalgType> { public: CollapseLinalgDimensions(MLIRContext *context, GetCollapsableDimensionsFn collapseDimensions, PatternBenefit benefit = 1) : OpRewritePattern<GenericOp>(context, benefit), : OpRewritePattern<LinalgType>(context, benefit), controlCollapseDimension(std::move(collapseDimensions)) {} LogicalResult matchAndRewrite(GenericOp genericOp, LogicalResult matchAndRewrite(LinalgType op, PatternRewriter &rewriter) const override { SmallVector<ReassociationIndices> collapsableIterationDims = controlCollapseDimension(genericOp); controlCollapseDimension(op); if (collapsableIterationDims.empty()) return failure(); // Check if the specified list of dimensions to collapse is a valid list. if (!areDimSequencesPreserved(genericOp.getIndexingMapsArray(), if (!areDimSequencesPreserved(op.getIndexingMapsArray(), collapsableIterationDims)) { return rewriter.notifyMatchFailure( genericOp, "specified dimensions cannot be collapsed"); op, "specified dimensions cannot be collapsed"); } std::optional<SmallVector<Value>> replacements = collapseGenericOpIterationDims(genericOp, collapsableIterationDims, collapseOpIterationDims<LinalgType>(op, collapsableIterationDims, rewriter); if (!replacements) { return rewriter.notifyMatchFailure(genericOp, "failed to collapse dimensions"); return rewriter.notifyMatchFailure(op, "failed to collapse dimensions"); } rewriter.replaceOp(genericOp, *replacements); rewriter.replaceOp(op, *replacements); return success(); } Loading Loading @@ -1884,8 +1902,9 @@ void mlir::linalg::populateElementwiseOpsFusionPatterns( void mlir::linalg::populateCollapseDimensions( RewritePatternSet &patterns, const GetCollapsableDimensionsFn &controlCollapseDimensions) { patterns.add<CollapseLinalgDimensions>(patterns.getContext(), controlCollapseDimensions); patterns.add<CollapseLinalgDimensions<linalg::GenericOp>, CollapseLinalgDimensions<linalg::CopyOp>>( patterns.getContext(), controlCollapseDimensions); } //===---------------------------------------------------------------------===// Loading mlir/test/Dialect/Linalg/collapse-dim.mlir +37 −0 Original line number Diff line number Diff line Loading @@ -116,3 +116,40 @@ func.func @uncollapsable_strided_memref(%arg0: memref<2x6x24x48xi32>, %arg1: mem } return %alloc : memref<2x6x24x48xi32> } // ----- // CHECK-LABEL: func.func @linalg_copy( // CHECK-SAME: %[[VAL_0:.*]]: tensor<1x2x3x4x5xf32, 1 : i64>, // CHECK-SAME: %[[VAL_1:.*]]: tensor<1x2x3x4x5xf32, 3 : i64>) -> tensor<1x2x3x4x5xf32, 3 : i64> { // CHECK: %[[VAL_2:.*]] = tensor.collapse_shape %[[VAL_0]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x3x4x5xf32, 1 : i64> into tensor<1x2x12x5xf32> // CHECK: %[[VAL_3:.*]] = tensor.collapse_shape %[[VAL_1]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x3x4x5xf32, 3 : i64> into tensor<1x2x12x5xf32> // CHECK: %[[VAL_4:.*]] = tensor.collapse_shape %[[VAL_2]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x12x5xf32> into tensor<1x2x60xf32> // CHECK: %[[VAL_5:.*]] = tensor.collapse_shape %[[VAL_3]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x12x5xf32> into tensor<1x2x60xf32> // CHECK: %[[VAL_6:.*]] = linalg.copy ins(%[[VAL_4]] : tensor<1x2x60xf32>) outs(%[[VAL_5]] : tensor<1x2x60xf32>) -> tensor<1x2x60xf32> // CHECK: %[[VAL_7:.*]] = tensor.expand_shape %[[VAL_6]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x60xf32> into tensor<1x2x12x5xf32> // CHECK: %[[VAL_8:.*]] = tensor.expand_shape %[[VAL_7]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x12x5xf32> into tensor<1x2x3x4x5xf32, 3 : i64> // CHECK: return %[[VAL_8]] : tensor<1x2x3x4x5xf32, 3 : i64> // CHECK: } func.func @linalg_copy( %arg0: tensor<1x2x3x4x5xf32, 1>, %arg1: tensor<1x2x3x4x5xf32, 3>) -> tensor<1x2x3x4x5xf32, 3> { %0 = linalg.copy ins(%arg0: tensor<1x2x3x4x5xf32, 1>) outs(%arg1: tensor<1x2x3x4x5xf32, 3>) -> tensor<1x2x3x4x5xf32, 3> return %0 : tensor<1x2x3x4x5xf32, 3> } // ----- // CHECK-LABEL: func.func private @memref_linalg_copy( // CHECK-SAME: %[[VAL_0:.*]]: memref<1x24x32x8xf32, 1>, // CHECK-SAME: %[[VAL_1:.*]]: memref<1x24x32x8xf32, 1>) { // CHECK: %[[VAL_2:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0], [1], [2, 3]] : memref<1x24x32x8xf32, 1> into memref<1x24x256xf32, 1> // CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[VAL_1]] {{\[\[}}0], [1], [2, 3]] : memref<1x24x32x8xf32, 1> into memref<1x24x256xf32, 1> // CHECK: linalg.copy ins(%[[VAL_2]] : memref<1x24x256xf32, 1>) outs(%[[VAL_3]] : memref<1x24x256xf32, 1>) // CHECK: return // CHECK: } func.func private @memref_linalg_copy(%arg0: memref<1x24x32x8xf32, 1>, %arg1: memref<1x24x32x8xf32, 1>) { linalg.copy ins(%arg0: memref<1x24x32x8xf32, 1>) outs(%arg1: memref<1x24x32x8xf32, 1>) return } mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp +1 −1 Original line number Diff line number Diff line Loading @@ -258,7 +258,7 @@ struct TestLinalgElementwiseFusion SmallVector<int64_t, 2> dims(collapseDimensions.begin(), collapseDimensions.end()); linalg::GetCollapsableDimensionsFn collapseFn = [&dims](linalg::GenericOp op) { [&dims](linalg::LinalgOp op) { SmallVector<ReassociationIndices> reassociations; reassociations.emplace_back(dims); return reassociations; Loading Loading
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +10 −8 Original line number Diff line number Diff line Loading @@ -1047,15 +1047,17 @@ bool isDimSequencePreserved(AffineMap map, ReassociationIndicesRef dimSequence); bool areDimSequencesPreserved(ArrayRef<AffineMap> maps, ArrayRef<ReassociationIndices> dimSequences); /// Collapses dimensions of linalg.generic operation. A precondition to /// calling this method is that for each list in `foldedIterationDim`, the /// Collapses dimensions of linalg.generic/linalg.copy operation. A precondition /// to calling this method is that for each list in `foldedIterationDim`, the /// sequence of dimensions is contiguous in domains of all `indexing_maps` of /// the `genericOp`. This can be checked using `areDimSequencePreserved` method. /// the `linalgOp`. This can be checked using `areDimSequencePreserved` method. /// When valid, the method also collapses the operands of the op. Returns /// replacement values of the results of the original `genericOp` by inserting /// replacement values of the results of the original `linalgOp` by inserting /// reshapes to get back values of compatible types. FailureOr<SmallVector<Value>> collapseGenericOpIterationDims( GenericOp genericOp, ArrayRef<ReassociationIndices> foldedIterationDims, template <typename LinalgType> FailureOr<SmallVector<Value>> collapseOpIterationDims(LinalgType op, ArrayRef<ReassociationIndices> foldedIterationDims, RewriterBase &rewriter); struct LowerPackResult { Loading Loading @@ -1515,7 +1517,7 @@ void populateEraseUnnecessaryInputsPatterns(RewritePatternSet &patterns); /// to return an array of `ReassociationIndices` representing dimensions that /// should be merged. using GetCollapsableDimensionsFn = std::function<SmallVector<ReassociationIndices>(linalg::GenericOp)>; std::function<SmallVector<ReassociationIndices>(linalg::LinalgOp)>; /// Pattern to collapse dimensions in a linalg.generic op. This will collapse /// tensor operands when needed and expand back the result tensors. Loading
mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +103 −84 Original line number Diff line number Diff line Loading @@ -1373,16 +1373,17 @@ getOperandReassociation(AffineMap indexingMap, } /// Get the new value to use for a given `OpOperand` in the collapsed operation. static Value getCollapsedOpOperand(Location loc, GenericOp genericOp, static Value getCollapsedOpOperand(Location loc, LinalgOp op, OpOperand *opOperand, const CollapsingInfo &collapsingInfo, OpBuilder &builder) { AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand); AffineMap indexingMap = op.getMatchingIndexingMap(opOperand); SmallVector<ReassociationIndices> operandReassociation = getOperandReassociation(indexingMap, collapsingInfo); // If the number of entries in the reassocation for the operand is same as the // number of results of the indexing map, then nothing to do for this operand. // If the number of entries in the reassociation for the operand is same as // the number of results of the indexing map, then nothing to do for this // operand. Value operand = opOperand->get(); if (operandReassociation.size() == indexingMap.getNumResults()) return operand; Loading Loading @@ -1439,20 +1440,80 @@ void generateCollapsedIndexingRegion(Location loc, Block *block, } } template <typename LinalgType> Operation *createCollapsedOp(LinalgType op, const CollapsingInfo &collapsingInfo, RewriterBase &rewriter) { static_assert(llvm::is_one_of<LinalgType, GenericOp, CopyOp>::value, "unsupported linalg op type to create"); Location loc = op->getLoc(); // Get the input operands. SmallVector<Value> inputOperands = llvm::map_to_vector(op.getDpsInputOperands(), [&](OpOperand *opOperand) { return getCollapsedOpOperand(loc, op, opOperand, collapsingInfo, rewriter); }); // Get the output operands and result types. SmallVector<Type> resultTypes; SmallVector<Value> outputOperands; resultTypes.reserve(op.getNumDpsInits()); outputOperands.reserve(op.getNumDpsInits()); for (OpOperand &output : op.getDpsInitsMutable()) { Value newOutput = getCollapsedOpOperand(loc, op, &output, collapsingInfo, rewriter); outputOperands.push_back(newOutput); // If the op has "buffer semantics", then the init operands are ranked // memrefs and the op has no results. if (!op.hasBufferSemantics()) resultTypes.push_back(newOutput.getType()); } if (isa<linalg::CopyOp>(op)) { return rewriter.create<linalg::CopyOp>(loc, inputOperands[0], outputOperands[0]); } // Get the iterator types for the operand. SmallVector<utils::IteratorType> iteratorTypes = getCollapsedOpIteratorTypes(op.getIteratorTypesArray(), collapsingInfo); // Get the indexing maps. auto indexingMaps = llvm::map_to_vector(op.getIndexingMapsArray(), [&](AffineMap map) { return getCollapsedOpIndexingMap(map, collapsingInfo); }); Operation *collapsedOp = rewriter.create<linalg::GenericOp>( loc, resultTypes, inputOperands, outputOperands, indexingMaps, iteratorTypes, [](OpBuilder &builder, Location loc, ValueRange args) {}); Block *origOpBlock = &op->getRegion(0).front(); Block *collapsedOpBlock = &collapsedOp->getRegion(0).front(); rewriter.mergeBlocks(origOpBlock, collapsedOpBlock, collapsedOpBlock->getArguments()); return collapsedOp; } /// Implementation of fusion with reshape operation by collapsing dimensions. FailureOr<SmallVector<Value>> mlir::linalg::collapseGenericOpIterationDims( GenericOp genericOp, ArrayRef<ReassociationIndices> foldedIterationDims, template <typename LinalgType> FailureOr<SmallVector<Value>> mlir::linalg::collapseOpIterationDims( LinalgType op, ArrayRef<ReassociationIndices> foldedIterationDims, RewriterBase &rewriter) { static_assert(llvm::is_one_of<LinalgType, GenericOp, CopyOp>::value, "unsupported linalg op type to collapse"); // Bail on trivial no-op cases. if (genericOp.getNumLoops() <= 1 || foldedIterationDims.empty() || if (op.getNumLoops() <= 1 || foldedIterationDims.empty() || llvm::all_of(foldedIterationDims, [](ReassociationIndicesRef foldedDims) { return foldedDims.size() <= 1; })) return failure(); bool hasBufferSemantics = genericOp.hasBufferSemantics(); bool hasBufferSemantics = op.hasBufferSemantics(); if (hasBufferSemantics && !llvm::all_of(genericOp->getOperands(), [&](Value operand) -> bool { !llvm::all_of(op->getOperands(), [&](Value operand) -> bool { MemRefType memRefToCollapse = dyn_cast<MemRefType>(operand.getType()); if (!memRefToCollapse) return true; Loading @@ -1460,20 +1521,19 @@ FailureOr<SmallVector<Value>> mlir::linalg::collapseGenericOpIterationDims( return memref::CollapseShapeOp::isGuaranteedCollapsible( memRefToCollapse, foldedIterationDims); })) return rewriter.notifyMatchFailure(genericOp, return rewriter.notifyMatchFailure(op, "memref is not guaranteed collapsible"); CollapsingInfo collapsingInfo; if (failed(collapsingInfo.initialize(genericOp.getNumLoops(), foldedIterationDims))) { if (failed( collapsingInfo.initialize(op.getNumLoops(), foldedIterationDims))) { return rewriter.notifyMatchFailure( genericOp, "illegal to collapse specified dimensions"); op, "illegal to collapse specified dimensions"); } // Bail on non-canonical ranges. SmallVector<Range> loopRanges = cast<LinalgOp>(genericOp.getOperation()) .createLoopRanges(rewriter, genericOp.getLoc()); cast<LinalgOp>(op.getOperation()).createLoopRanges(rewriter, op.getLoc()); auto opFoldIsConstantValue = [](OpFoldResult ofr, int64_t value) { if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr)) return cast<IntegerAttr>(attr).getInt() == value; Loading @@ -1486,78 +1546,36 @@ FailureOr<SmallVector<Value>> mlir::linalg::collapseGenericOpIterationDims( opFoldIsConstantValue(range.stride, 1); })) { return rewriter.notifyMatchFailure( genericOp, "expected all loop ranges to have zero start and unit stride"); op, "expected all loop ranges to have zero start and unit stride"); } // Get the iterator types for the operand. SmallVector<utils::IteratorType> iteratorTypes = getCollapsedOpIteratorTypes( genericOp.getIteratorTypesArray(), collapsingInfo); // Get the indexing maps. auto indexingMaps = llvm::to_vector( llvm::map_range(genericOp.getIndexingMapsArray(), [&](AffineMap map) { return getCollapsedOpIndexingMap(map, collapsingInfo); })); Location loc = genericOp->getLoc(); // Get the input operands. auto inputOperands = llvm::to_vector(llvm::map_range( genericOp.getDpsInputOperands(), [&](OpOperand *opOperand) { return getCollapsedOpOperand(loc, genericOp, opOperand, collapsingInfo, rewriter); })); LinalgType collapsedOp = cast<LinalgType>( createCollapsedOp<LinalgType>(op, collapsingInfo, rewriter)); // Get the output operands and result types. SmallVector<Type> resultTypes; SmallVector<Value> outputOperands; resultTypes.reserve(genericOp.getNumDpsInits()); outputOperands.reserve(genericOp.getNumDpsInits()); for (OpOperand &output : genericOp.getDpsInitsMutable()) { Value newOutput = getCollapsedOpOperand(loc, genericOp, &output, collapsingInfo, rewriter); outputOperands.push_back(newOutput); // If the op has "buffer semantics", then the init operands are ranked // memrefs and the op has no results. if (!hasBufferSemantics) resultTypes.push_back(newOutput.getType()); } // Create the generic op. auto collapsedGenericOp = rewriter.create<linalg::GenericOp>( loc, resultTypes, inputOperands, outputOperands, indexingMaps, iteratorTypes, [](OpBuilder &builder, Location loc, ValueRange args) {}); Block *origOpBlock = &genericOp->getRegion(0).front(); Block *collapsedOpBlock = &collapsedGenericOp->getRegion(0).front(); rewriter.mergeBlocks(origOpBlock, collapsedOpBlock, collapsedOpBlock->getArguments()); if (collapsedGenericOp.hasIndexSemantics()) { Location loc = op->getLoc(); if (collapsedOp.hasIndexSemantics()) { // Collect the loop range of the generic op. OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(collapsedGenericOp); rewriter.setInsertionPoint(collapsedOp); SmallVector<Value> loopBound = llvm::to_vector(llvm::map_range(loopRanges, [&](Range range) { llvm::map_to_vector(loopRanges, [&](Range range) { return getValueOrCreateConstantIndexOp(rewriter, loc, range.size); })); generateCollapsedIndexingRegion(loc, &collapsedGenericOp->getRegion(0).front(), }); generateCollapsedIndexingRegion(loc, &collapsedOp->getRegion(0).front(), collapsingInfo, loopBound, rewriter); } // Insert expanding reshape for the result to get back the original result // type. SmallVector<Value> results; for (const auto &originalResult : llvm::enumerate(genericOp->getResults())) { Value collapsedOpResult = collapsedGenericOp->getResult(originalResult.index()); for (const auto &originalResult : llvm::enumerate(op->getResults())) { Value collapsedOpResult = collapsedOp->getResult(originalResult.index()); auto originalResultType = cast<ShapedType>(originalResult.value().getType()); auto collapsedOpResultType = cast<ShapedType>(collapsedOpResult.getType()); if (collapsedOpResultType.getRank() != originalResultType.getRank()) { AffineMap indexingMap = genericOp.getIndexingMapMatchingResult(originalResult.value()); op.getIndexingMapMatchingResult(originalResult.value()); SmallVector<ReassociationIndices> reassociation = getOperandReassociation(indexingMap, collapsingInfo); if (isa<MemRefType>(collapsedOpResult.getType())) { Loading Loading @@ -1606,8 +1624,8 @@ public: } std::optional<SmallVector<Value>> replacements = collapseGenericOpIterationDims(genericOp, collapsableIterationDims, rewriter); collapseOpIterationDims<linalg::GenericOp>( genericOp, collapsableIterationDims, rewriter); if (!replacements) { return rewriter.notifyMatchFailure( genericOp, "failed to do the fusion by collapsing transformation"); Loading @@ -1624,36 +1642,36 @@ private: }; /// Pattern to collapse dimensions. class CollapseLinalgDimensions : public OpRewritePattern<GenericOp> { template <typename LinalgType> class CollapseLinalgDimensions : public OpRewritePattern<LinalgType> { public: CollapseLinalgDimensions(MLIRContext *context, GetCollapsableDimensionsFn collapseDimensions, PatternBenefit benefit = 1) : OpRewritePattern<GenericOp>(context, benefit), : OpRewritePattern<LinalgType>(context, benefit), controlCollapseDimension(std::move(collapseDimensions)) {} LogicalResult matchAndRewrite(GenericOp genericOp, LogicalResult matchAndRewrite(LinalgType op, PatternRewriter &rewriter) const override { SmallVector<ReassociationIndices> collapsableIterationDims = controlCollapseDimension(genericOp); controlCollapseDimension(op); if (collapsableIterationDims.empty()) return failure(); // Check if the specified list of dimensions to collapse is a valid list. if (!areDimSequencesPreserved(genericOp.getIndexingMapsArray(), if (!areDimSequencesPreserved(op.getIndexingMapsArray(), collapsableIterationDims)) { return rewriter.notifyMatchFailure( genericOp, "specified dimensions cannot be collapsed"); op, "specified dimensions cannot be collapsed"); } std::optional<SmallVector<Value>> replacements = collapseGenericOpIterationDims(genericOp, collapsableIterationDims, collapseOpIterationDims<LinalgType>(op, collapsableIterationDims, rewriter); if (!replacements) { return rewriter.notifyMatchFailure(genericOp, "failed to collapse dimensions"); return rewriter.notifyMatchFailure(op, "failed to collapse dimensions"); } rewriter.replaceOp(genericOp, *replacements); rewriter.replaceOp(op, *replacements); return success(); } Loading Loading @@ -1884,8 +1902,9 @@ void mlir::linalg::populateElementwiseOpsFusionPatterns( void mlir::linalg::populateCollapseDimensions( RewritePatternSet &patterns, const GetCollapsableDimensionsFn &controlCollapseDimensions) { patterns.add<CollapseLinalgDimensions>(patterns.getContext(), controlCollapseDimensions); patterns.add<CollapseLinalgDimensions<linalg::GenericOp>, CollapseLinalgDimensions<linalg::CopyOp>>( patterns.getContext(), controlCollapseDimensions); } //===---------------------------------------------------------------------===// Loading
mlir/test/Dialect/Linalg/collapse-dim.mlir +37 −0 Original line number Diff line number Diff line Loading @@ -116,3 +116,40 @@ func.func @uncollapsable_strided_memref(%arg0: memref<2x6x24x48xi32>, %arg1: mem } return %alloc : memref<2x6x24x48xi32> } // ----- // CHECK-LABEL: func.func @linalg_copy( // CHECK-SAME: %[[VAL_0:.*]]: tensor<1x2x3x4x5xf32, 1 : i64>, // CHECK-SAME: %[[VAL_1:.*]]: tensor<1x2x3x4x5xf32, 3 : i64>) -> tensor<1x2x3x4x5xf32, 3 : i64> { // CHECK: %[[VAL_2:.*]] = tensor.collapse_shape %[[VAL_0]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x3x4x5xf32, 1 : i64> into tensor<1x2x12x5xf32> // CHECK: %[[VAL_3:.*]] = tensor.collapse_shape %[[VAL_1]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x3x4x5xf32, 3 : i64> into tensor<1x2x12x5xf32> // CHECK: %[[VAL_4:.*]] = tensor.collapse_shape %[[VAL_2]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x12x5xf32> into tensor<1x2x60xf32> // CHECK: %[[VAL_5:.*]] = tensor.collapse_shape %[[VAL_3]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x12x5xf32> into tensor<1x2x60xf32> // CHECK: %[[VAL_6:.*]] = linalg.copy ins(%[[VAL_4]] : tensor<1x2x60xf32>) outs(%[[VAL_5]] : tensor<1x2x60xf32>) -> tensor<1x2x60xf32> // CHECK: %[[VAL_7:.*]] = tensor.expand_shape %[[VAL_6]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x60xf32> into tensor<1x2x12x5xf32> // CHECK: %[[VAL_8:.*]] = tensor.expand_shape %[[VAL_7]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x12x5xf32> into tensor<1x2x3x4x5xf32, 3 : i64> // CHECK: return %[[VAL_8]] : tensor<1x2x3x4x5xf32, 3 : i64> // CHECK: } func.func @linalg_copy( %arg0: tensor<1x2x3x4x5xf32, 1>, %arg1: tensor<1x2x3x4x5xf32, 3>) -> tensor<1x2x3x4x5xf32, 3> { %0 = linalg.copy ins(%arg0: tensor<1x2x3x4x5xf32, 1>) outs(%arg1: tensor<1x2x3x4x5xf32, 3>) -> tensor<1x2x3x4x5xf32, 3> return %0 : tensor<1x2x3x4x5xf32, 3> } // ----- // CHECK-LABEL: func.func private @memref_linalg_copy( // CHECK-SAME: %[[VAL_0:.*]]: memref<1x24x32x8xf32, 1>, // CHECK-SAME: %[[VAL_1:.*]]: memref<1x24x32x8xf32, 1>) { // CHECK: %[[VAL_2:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0], [1], [2, 3]] : memref<1x24x32x8xf32, 1> into memref<1x24x256xf32, 1> // CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[VAL_1]] {{\[\[}}0], [1], [2, 3]] : memref<1x24x32x8xf32, 1> into memref<1x24x256xf32, 1> // CHECK: linalg.copy ins(%[[VAL_2]] : memref<1x24x256xf32, 1>) outs(%[[VAL_3]] : memref<1x24x256xf32, 1>) // CHECK: return // CHECK: } func.func private @memref_linalg_copy(%arg0: memref<1x24x32x8xf32, 1>, %arg1: memref<1x24x32x8xf32, 1>) { linalg.copy ins(%arg0: memref<1x24x32x8xf32, 1>) outs(%arg1: memref<1x24x32x8xf32, 1>) return }
mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp +1 −1 Original line number Diff line number Diff line Loading @@ -258,7 +258,7 @@ struct TestLinalgElementwiseFusion SmallVector<int64_t, 2> dims(collapseDimensions.begin(), collapseDimensions.end()); linalg::GetCollapsableDimensionsFn collapseFn = [&dims](linalg::GenericOp op) { [&dims](linalg::LinalgOp op) { SmallVector<ReassociationIndices> reassociations; reassociations.emplace_back(dims); return reassociations; Loading