Commit f9efce1d authored by Andy Davis's avatar Andy Davis
Browse files

[mlir][VectorOps] Support vector transfer_read/write unrolling for memrefs...

[mlir][VectorOps] Support vector transfer_read/write unrolling for memrefs with vector element type.

Summary:
[mlir][VectorOps] Support vector transfer_read/write unrolling for memrefs with vector element type.  When unrolling vector transfer read/write on memrefs with vector element type, the indices used to index the memref argument must be updated to reflect the unrolled operation.   However, in the case of memrefs with vector element type, we need to be careful to only update the relevant memref indices.

For example, a vector transfer read with the following source/result types, memref<6x2x1xvector<2x4xf32>>, vector<2x1x2x4xf32>, should only update memref indices 1 and 2 during unrolling.

Reviewers: nicolasvasilache, aartbik

Reviewed By: nicolasvasilache, aartbik

Subscribers: lebedev.ri, Joonsoo, merge_guards_bot, mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, liufengdb, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D72965
parent 643dee90
Loading
Loading
Loading
Loading
+66 −17
Original line number Diff line number Diff line
@@ -460,12 +460,13 @@ SmallVector<Value, 1> mlir::vector::unrollSingleResultOpMatchingType(
      op, iterationBounds, vectors, resultIndex, targetShape, builder)};
}

// Generates slices of 'vectorType' according to 'sizes' and 'strides, and
// calls 'fn' with linear index and indices for each slice.
/// Generates slices of 'vectorType' according to 'sizes' and 'strides, and
/// calls 'fn' with linear index and indices for each slice.
static void
generateTransferOpSlices(VectorType vectorType, TupleType tupleType,
                         ArrayRef<int64_t> sizes, ArrayRef<int64_t> strides,
                         ArrayRef<Value> indices, PatternRewriter &rewriter,
generateTransferOpSlices(Type memrefElementType, VectorType vectorType,
                         TupleType tupleType, ArrayRef<int64_t> sizes,
                         ArrayRef<int64_t> strides, ArrayRef<Value> indices,
                         PatternRewriter &rewriter,
                         function_ref<void(unsigned, ArrayRef<Value>)> fn) {
  // Compute strides w.r.t. to slice counts in each dimension.
  auto maybeDimSliceCounts = shapeRatio(vectorType.getShape(), sizes);
@@ -475,6 +476,25 @@ generateTransferOpSlices(VectorType vectorType, TupleType tupleType,

  int64_t numSlices = tupleType.size();
  unsigned numSliceIndices = indices.size();
  // Compute 'indexOffset' at which to update 'indices', which is equal
  // to the memref rank (indices.size) minus the effective 'vectorRank'.
  // The effective 'vectorRank', is equal to the rank of the vector type
  // minus the rank of the memref vector element type (if it has one).
  //
  // For example:
  //
  //   Given memref type 'memref<6x2x1xvector<2x4xf32>>' and vector
  //   transfer_read/write ops which read/write vectors of type
  //   'vector<2x1x2x4xf32>'. The memref rank is 3, and the effective
  //   vector rank is 4 - 2 = 2, and so 'indexOffset' = 3 - 2 = 1.
  //
  unsigned vectorRank = vectorType.getRank();
  if (auto memrefVectorElementType = memrefElementType.dyn_cast<VectorType>()) {
    assert(vectorRank >= memrefVectorElementType.getRank());
    vectorRank -= memrefVectorElementType.getRank();
  }
  unsigned indexOffset = numSliceIndices - vectorRank;

  auto *ctx = rewriter.getContext();
  for (unsigned i = 0; i < numSlices; ++i) {
    auto vectorOffsets = delinearize(sliceStrides, i);
@@ -482,18 +502,41 @@ generateTransferOpSlices(VectorType vectorType, TupleType tupleType,
        computeElementOffsetsFromVectorSliceOffsets(sizes, vectorOffsets);
    // Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'.
    SmallVector<Value, 4> sliceIndices(numSliceIndices);
    for (auto it : llvm::enumerate(indices)) {
    for (unsigned j = 0; j < numSliceIndices; ++j) {
      if (j < indexOffset) {
        sliceIndices[j] = indices[j];
      } else {
        auto expr = getAffineDimExpr(0, ctx) +
                  getAffineConstantExpr(elementOffsets[it.index()], ctx);
                    getAffineConstantExpr(elementOffsets[j - indexOffset], ctx);
        auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr);
      sliceIndices[it.index()] = rewriter.create<AffineApplyOp>(
          it.value().getLoc(), map, ArrayRef<Value>(it.value()));
        sliceIndices[j] = rewriter.create<AffineApplyOp>(
            indices[j].getLoc(), map, ArrayRef<Value>(indices[j]));
      }
    }
    // Call 'fn' to generate slice 'i' at 'sliceIndices'.
    fn(i, sliceIndices);
  }
}

/// Returns true if 'map' is a suffix of an identity affine map, false
/// otherwise. Example: affine_map<(d0, d1, d2, d3) -> (d2, d3)>
static bool isIdentitySuffix(AffineMap map) {
  if (map.getNumDims() < map.getNumResults())
    return false;
  ArrayRef<AffineExpr> results = map.getResults();
  Optional<int> lastPos;
  for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
    auto expr = results[i].dyn_cast<AffineDimExpr>();
    if (!expr)
      return false;
    int currPos = static_cast<int>(expr.getPosition());
    if (lastPos.hasValue() && currPos != lastPos.getValue() + 1)
      return false;
    lastPos = currPos;
  }
  return true;
}

namespace {
// Splits vector TransferReadOp into smaller TransferReadOps based on slicing
// scheme of its unique ExtractSlicesOp user.
@@ -504,7 +547,7 @@ struct SplitTransferReadOp : public OpRewritePattern<vector::TransferReadOp> {
                                     PatternRewriter &rewriter) const override {
    // TODO(andydavis, ntv) Support splitting TransferReadOp with non-identity
    // permutation maps. Repurpose code from MaterializeVectors transformation.
    if (!xferReadOp.permutation_map().isIdentity())
    if (!isIdentitySuffix(xferReadOp.permutation_map()))
      return matchFailure();
    // Return unless the unique 'xferReadOp' user is an ExtractSlicesOp.
    Value xferReadResult = xferReadOp.getResult();
@@ -523,6 +566,8 @@ struct SplitTransferReadOp : public OpRewritePattern<vector::TransferReadOp> {
    assert(llvm::all_of(strides, [](int64_t s) { return s == 1; }));

    Location loc = xferReadOp.getLoc();
    auto memrefElementType =
        xferReadOp.memref().getType().cast<MemRefType>().getElementType();
    int64_t numSlices = resultTupleType.size();
    SmallVector<Value, 4> vectorTupleValues(numSlices);
    SmallVector<Value, 4> indices(xferReadOp.indices().begin(),
@@ -535,8 +580,9 @@ struct SplitTransferReadOp : public OpRewritePattern<vector::TransferReadOp> {
          loc, sliceVectorType, xferReadOp.memref(), sliceIndices,
          xferReadOp.permutation_map(), xferReadOp.padding());
    };
    generateTransferOpSlices(sourceVectorType, resultTupleType, sizes, strides,
                             indices, rewriter, createSlice);
    generateTransferOpSlices(memrefElementType, sourceVectorType,
                             resultTupleType, sizes, strides, indices, rewriter,
                             createSlice);

    // Create tuple of splice xfer read operations.
    Value tupleOp = rewriter.create<vector::TupleOp>(loc, resultTupleType,
@@ -557,7 +603,7 @@ struct SplitTransferWriteOp : public OpRewritePattern<vector::TransferWriteOp> {
                                     PatternRewriter &rewriter) const override {
    // TODO(andydavis, ntv) Support splitting TransferWriteOp with non-identity
    // permutation maps. Repurpose code from MaterializeVectors transformation.
    if (!xferWriteOp.permutation_map().isIdentity())
    if (!isIdentitySuffix(xferWriteOp.permutation_map()))
      return matchFailure();
    // Return unless the 'xferWriteOp' 'vector' operand is an 'InsertSlicesOp'.
    auto *vectorDefOp = xferWriteOp.vector().getDefiningOp();
@@ -580,6 +626,8 @@ struct SplitTransferWriteOp : public OpRewritePattern<vector::TransferWriteOp> {
    insertSlicesOp.getStrides(strides);

    Location loc = xferWriteOp.getLoc();
    auto memrefElementType =
        xferWriteOp.memref().getType().cast<MemRefType>().getElementType();
    SmallVector<Value, 4> indices(xferWriteOp.indices().begin(),
                                  xferWriteOp.indices().end());
    auto createSlice = [&](unsigned index, ArrayRef<Value> sliceIndices) {
@@ -588,8 +636,9 @@ struct SplitTransferWriteOp : public OpRewritePattern<vector::TransferWriteOp> {
          loc, tupleOp.getOperand(index), xferWriteOp.memref(), sliceIndices,
          xferWriteOp.permutation_map());
    };
    generateTransferOpSlices(resultVectorType, sourceTupleType, sizes, strides,
                             indices, rewriter, createSlice);
    generateTransferOpSlices(memrefElementType, resultVectorType,
                             sourceTupleType, sizes, strides, indices, rewriter,
                             createSlice);

    // Erase old 'xferWriteOp'.
    rewriter.eraseOp(xferWriteOp);
+35 −0
Original line number Diff line number Diff line
// RUN: mlir-opt %s -test-vector-to-vector-conversion | FileCheck %s

// CHECK-DAG: #[[MAP0:map[0-9]+]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-DAG: #[[MAP1:map[0-9]+]] = affine_map<(d0, d1, d2) -> (d1, d2)>

// CHECK-LABEL: func @add4x2
//      CHECK: %[[ES1:.*]] = vector.extract_slices %{{.*}}, [2, 2], [1, 1] : vector<4x2xf32> into tuple<vector<2x2xf32>, vector<2x2xf32>>
@@ -311,3 +312,37 @@ func @tuple_get(%arg0: vector<4xf32>, %arg1: vector<8xf32>) -> vector<8xf32> {
  %1 = vector.tuple_get %0, 1 : tuple<vector<4xf32>, vector<8xf32>>
  return %1 : vector<8xf32>
}

// CHECK-LABEL: func @vector_transfers_vector_element_type
//      CHECK: %[[C0:.*]] = constant 0 : index
//      CHECK: %[[C1:.*]] = constant 1 : index
//      CHECK: %[[VTR0:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]]], %{{.*}} {permutation_map = #[[MAP1]]} : memref<6x2x1xvector<2x4xf32>>, vector<1x1x2x4xf32>
// CHECK-NEXT: %[[VTR1:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C1]], %[[C0]]], %{{.*}} {permutation_map = #[[MAP1]]} : memref<6x2x1xvector<2x4xf32>>, vector<1x1x2x4xf32>
// CHECK-NEXT: vector.transfer_write %[[VTR0]], %{{.*}}[%[[C0]], %[[C0]], %[[C0]]] {permutation_map = #[[MAP1]]} : vector<1x1x2x4xf32>, memref<6x2x1xvector<2x4xf32>>
// CHECK-NEXT: vector.transfer_write %[[VTR1]], %{{.*}}[%[[C0]], %[[C1]], %[[C0]]] {permutation_map = #[[MAP1]]} : vector<1x1x2x4xf32>, memref<6x2x1xvector<2x4xf32>>

func @vector_transfers_vector_element_type() {
  %c0 = constant 0 : index
  %cf0 = constant 0.000000e+00 : f32
  %vf0 = splat %cf0 : vector<2x4xf32>

  %0 = alloc() : memref<6x2x1xvector<2x4xf32>>

  %1 = vector.transfer_read %0[%c0, %c0, %c0], %vf0
      {permutation_map = affine_map<(d0, d1, d2) -> (d1, d2)>}
        : memref<6x2x1xvector<2x4xf32>>, vector<2x1x2x4xf32>

  %2 = vector.extract_slices %1, [1, 1, 2, 4], [1, 1, 1, 1]
    : vector<2x1x2x4xf32> into tuple<vector<1x1x2x4xf32>, vector<1x1x2x4xf32>>
  %3 = vector.tuple_get %2, 0 : tuple<vector<1x1x2x4xf32>, vector<1x1x2x4xf32>>
  %4 = vector.tuple_get %2, 1 : tuple<vector<1x1x2x4xf32>, vector<1x1x2x4xf32>>
  %5 = vector.tuple %3, %4 : vector<1x1x2x4xf32>, vector<1x1x2x4xf32>
  %6 = vector.insert_slices %5, [1, 1, 2, 4], [1, 1, 1, 1]
    : tuple<vector<1x1x2x4xf32>, vector<1x1x2x4xf32>> into vector<2x1x2x4xf32>

  vector.transfer_write %6, %0[%c0, %c0, %c0]
    {permutation_map = affine_map<(d0, d1, d2) -> (d1, d2)>}
      : vector<2x1x2x4xf32>, memref<6x2x1xvector<2x4xf32>>

  return
}