Unverified Commit aeb89698 authored by Erick Ochoa Lopez's avatar Erick Ochoa Lopez Committed by GitHub
Browse files

[mlir][vector] Skip redundant affine.apply when unrolling transfer ops. (#192700)



Unrolling transfer ops generates affine.apply ops for every
non-broadcasted dimension, even when the offset being added is zero.
Skip these.

Also export sliceTransferIndices to allow downstream projects to call
this method.

Co-authored-by: default avatarClaude Opus 4.6 (1M context) <noreply@anthropic.com>
parent 8b8c271f
Loading
Loading
Loading
Loading
+6 −0
Original line number Diff line number Diff line
@@ -322,6 +322,12 @@ void populateVectorUnrollPatterns(RewritePatternSet &patterns,
                                  const UnrollVectorOptions &options,
                                  PatternBenefit benefit = 1);

/// Compute indices for a transfer op slice.
SmallVector<Value> sliceTransferIndices(ArrayRef<int64_t> elementOffsets,
                                        ArrayRef<Value> indices,
                                        AffineMap permutationMap, Location loc,
                                        OpBuilder &builder);

/// Unrolls 2 or more dimensional `vector.to_elements` ops by unrolling the
/// outermost dimension of the operand.
void populateVectorToElementsUnrollPatterns(RewritePatternSet &patterns,
+6 −8
Original line number Diff line number Diff line
@@ -26,12 +26,9 @@
using namespace mlir;
using namespace mlir::vector;

/// Compute the indices of the slice `index` for a transfer op.
static SmallVector<Value> sliceTransferIndices(ArrayRef<int64_t> elementOffsets,
                                               ArrayRef<Value> indices,
                                               AffineMap permutationMap,
                                               Location loc,
                                               OpBuilder &builder) {
SmallVector<Value> mlir::vector::sliceTransferIndices(
    ArrayRef<int64_t> elementOffsets, ArrayRef<Value> indices,
    AffineMap permutationMap, Location loc, OpBuilder &builder) {
  MLIRContext *ctx = builder.getContext();
  auto isBroadcast = [](AffineExpr expr) {
    if (auto constExpr = dyn_cast<AffineConstantExpr>(expr))
@@ -41,11 +38,12 @@ static SmallVector<Value> sliceTransferIndices(ArrayRef<int64_t> elementOffsets,
  // Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'.
  SmallVector<Value> slicedIndices(indices);
  for (const auto &dim : llvm::enumerate(permutationMap.getResults())) {
    if (isBroadcast(dim.value()))
    int64_t elementOffset = elementOffsets[dim.index()];
    if (isBroadcast(dim.value()) || elementOffset == 0)
      continue;
    unsigned pos = cast<AffineDimExpr>(dim.value()).getPosition();
    auto expr = getAffineDimExpr(0, builder.getContext()) +
                getAffineConstantExpr(elementOffsets[dim.index()], ctx);
                getAffineConstantExpr(elementOffset, ctx);
    auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr);
    slicedIndices[pos] =
        affine::AffineApplyOp::create(builder, loc, map, indices[pos]);
+33 −0
Original line number Diff line number Diff line
@@ -383,3 +383,36 @@ func.func @vector_gather_unroll(%mem : memref<?x?x?xf32>,
  %res = vector.gather %mem[%c0, %c0, %c0] [%indices], %mask, %pass_thru : memref<?x?x?xf32>, vector<6x4xindex>, vector<6x4xi1>, vector<6x4xf32> into vector<6x4xf32>
  return %res : vector<6x4xf32>
}

// -----

// Verify that no redundant affine.apply ops are generated for zero offsets
// when the base indices are dynamic.

// ALL-LABEL: func @transfer_read_unroll_dynamic_index(
//  ALL-SAME:   %[[MEM:.*]]: memref<4x4xf32>,
//  ALL-SAME:   %[[IDX:.*]]: index
//       ALL:   vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32>
//   ALL-NOT:   affine.apply
//       ALL:   %[[MAP:.*]] = affine.apply {{.*}}(%[[IDX]])
//       ALL:   vector.transfer_read %[[MEM]][%[[MAP]], %[[IDX]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32>
func.func @transfer_read_unroll_dynamic_index(%mem : memref<4x4xf32>, %idx : index) -> vector<4x2xf32> {
  %cf0 = arith.constant 0.0 : f32
  %res = vector.transfer_read %mem[%idx, %idx], %cf0 : memref<4x4xf32>, vector<4x2xf32>
  return %res : vector<4x2xf32>
}

// -----

// ALL-LABEL: func @transfer_write_unroll_dynamic_index(
//  ALL-SAME:   %[[MEM:.*]]: memref<4x4xf32>,
//  ALL-SAME:   %[[VEC:.*]]: vector<4x2xf32>,
//  ALL-SAME:   %[[IDX:.*]]: index
//   ALL-NOT:   affine.apply
//       ALL:   vector.transfer_write %{{.*}}, %[[MEM]][%[[IDX]], %[[IDX]]] : vector<2x2xf32>, memref<4x4xf32>
//       ALL:   %[[MAP:.*]] = affine.apply {{.*}}(%[[IDX]])
//       ALL:   vector.transfer_write %{{.*}}, %[[MEM]][%[[MAP]], %[[IDX]]] : vector<2x2xf32>, memref<4x4xf32>
func.func @transfer_write_unroll_dynamic_index(%mem : memref<4x4xf32>, %vec : vector<4x2xf32>, %idx : index) {
  vector.transfer_write %vec, %mem[%idx, %idx] : vector<4x2xf32>, memref<4x4xf32>
  return
}