Commit 303fddee authored by aartbik's avatar aartbik
Browse files

[mlir] [VectorOps] Rewriting of vector.extract/insert_slices to other vector ops

Summary:
Rewrites the extract/insert_slices operation in terms of
strided_slice/insert_strided_slice ops with intermediate
tuple uses (that should get optimimized away with typical
usage). This is done in a separate "pass" to enable testing
this particular rewriting in isolation.

Reviewers: nicolasvasilache, andydavis1, ftynse

Reviewed By: nicolasvasilache

Subscribers: merge_guards_bot, mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, liufengdb, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D73295
parent e3a7c771
Loading
Loading
Loading
Loading
+11 −0
Original line number Diff line number Diff line
@@ -43,6 +43,17 @@ void populateVectorToVectorCanonicalizationPatterns(
void populateVectorToVectorTransformationPatterns(
    OwningRewritePatternList &patterns, MLIRContext *context);

/// Collect a set of vector slices transformation patterns:
///    ExtractSlicesOpLowering, InsertSlicesOpLowering
/// Useful for clients that want to express all vector "slices"
/// ops in terms of more elementary vector "slice" ops. If all
/// "produced" tuple values are "consumed" (the most common
/// use for "slices" ops), this lowering removes all tuple related
/// operations as well (through DCE and folding). If tuple values
/// "leak" coming in, however, some tuple related ops will remain.
void populateVectorSlicesLoweringPatterns(OwningRewritePatternList &patterns,
                                          MLIRContext *context);

/// Returns the integer type required for subscripts in the vector dialect.
IntegerType getVectorSubscriptType(Builder &builder);

+132 −0
Original line number Diff line number Diff line
@@ -13,6 +13,7 @@
#include <type_traits>

#include "mlir/Dialect/AffineOps/AffineOps.h"
#include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/Dialect/VectorOps/VectorOps.h"
#include "mlir/Dialect/VectorOps/VectorTransforms.h"
#include "mlir/Dialect/VectorOps/VectorUtils.h"
@@ -28,6 +29,7 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Types.h"
#include "mlir/Support/Functional.h"
#include "mlir/Support/MathExtras.h"
#include "mlir/Support/STLExtras.h"

#include "llvm/Support/CommandLine.h"
@@ -657,6 +659,131 @@ struct TupleGetFolderOp : public OpRewritePattern<vector::TupleGetOp> {
  }
};

/// Progressive lowering of ExtractSlicesOp to tuple of StridedSliceOp.
/// One:
///   %x = vector.extract_slices %0
/// is replaced by:
///   %a = vector.strided_slice %0
///   %b = vector.strided_slice %0
///   ..
///   %x = vector.tuple %a, %b, ..
class ExtractSlicesOpLowering
    : public OpRewritePattern<vector::ExtractSlicesOp> {
public:
  using OpRewritePattern<vector::ExtractSlicesOp>::OpRewritePattern;

  // TODO(ajcbik): refactor slice utilities out into VectorUtils.h
  PatternMatchResult matchAndRewrite(vector::ExtractSlicesOp op,
                                     PatternRewriter &rewriter) const override {
    auto loc = op.getLoc();

    VectorType vectorType = op.getSourceVectorType();
    int64_t rank = vectorType.getRank();
    auto shape = vectorType.getShape();

    SmallVector<int64_t, 4> sizes;
    op.getSizes(sizes);
    SmallVector<int64_t, 4> strides;
    op.getStrides(strides); // all-ones at the moment

    // Compute the number of slices in each dimension.
    SmallVector<int64_t, 4> sliceDimCounts(rank);
    for (int64_t r = 0; r < rank; ++r)
      sliceDimCounts[r] = ceilDiv(shape[r], sizes[r]);

    // For each element in the tuple, generate the proper strided slice.
    auto basis = computeStrides(sliceDimCounts);
    TupleType tupleType = op.getResultTupleType();
    int64_t tupleSize = tupleType.size();
    SmallVector<Value, 4> tupleValues(tupleSize);
    for (int64_t i = 0; i < tupleSize; ++i) {
      // De-linearize w.r.t. 'basis'.
      auto vectorOffsets = delinearize(i, basis);
      // Convert from unrolled vector-space offsets to element-space offsets.
      auto elementOffsets = mlir::functional::zipMap(
          [](int64_t v1, int64_t v2) { return v1 * v2; }, vectorOffsets, sizes);
      // Compute the size of each slice.
      SmallVector<int64_t, 4> sliceSizes(rank);
      for (int64_t r = 0; r < rank; ++r)
        sliceSizes[r] = std::min(sizes[r], shape[r] - elementOffsets[r]);
      // Insert in tuple.
      tupleValues[i] = rewriter.create<vector::StridedSliceOp>(
          loc, op.vector(), elementOffsets, sliceSizes, strides);
    }

    rewriter.replaceOpWithNewOp<vector::TupleOp>(op, tupleType, tupleValues);
    return matchSuccess();
  }
};

/// Progressive lowering of InsertSlicesOp to series of InsertStridedSliceOp.
/// One:
///   %x = vector.insert_slices %0
/// is replaced by:
///   %r0 = vector.splat 0
//    %t1 = vector.tuple_get %0, 0
///   %r1 = vector.insert_strided_slice %r0, %t1
//    %t2 = vector.tuple_get %0, 1
///   %r2 = vector.insert_strided_slice %r1, %t2
///   ..
///   %x  = ..
class InsertSlicesOpLowering : public OpRewritePattern<vector::InsertSlicesOp> {
public:
  using OpRewritePattern<vector::InsertSlicesOp>::OpRewritePattern;

  // TODO(ajcbik): refactor slice utilities out into VectorUtils.h
  PatternMatchResult matchAndRewrite(vector::InsertSlicesOp op,
                                     PatternRewriter &rewriter) const override {
    auto loc = op.getLoc();

    VectorType vectorType = op.getResultVectorType();
    int64_t rank = vectorType.getRank();
    auto shape = vectorType.getShape();

    SmallVector<int64_t, 4> sizes;
    op.getSizes(sizes);
    SmallVector<int64_t, 4> strides;
    op.getStrides(strides); // all-ones at the moment

    // Compute the number of slices in each dimension.
    SmallVector<int64_t, 4> sliceDimCounts(rank);
    for (int64_t r = 0; r < rank; ++r)
      sliceDimCounts[r] = ceilDiv(shape[r], sizes[r]);

    // Prepare result.
    auto elemType = vectorType.getElementType();
    Value zero = rewriter.create<ConstantOp>(loc, elemType,
                                             rewriter.getZeroAttr(elemType));
    Value result = rewriter.create<SplatOp>(loc, vectorType, zero);

    // For each element in the tuple, extract the proper strided slice.
    auto basis = computeStrides(sliceDimCounts);
    TupleType tupleType = op.getSourceTupleType();
    int64_t tupleSize = tupleType.size();
    SmallVector<Value, 4> tupleValues(tupleSize);
    for (int64_t i = 0; i < tupleSize; ++i) {
      // De-linearize w.r.t. 'basis'.
      auto vectorOffsets = delinearize(i, basis);
      // Convert from unrolled vector-space offsets to element-space offsets.
      auto elementOffsets = mlir::functional::zipMap(
          [](int64_t v1, int64_t v2) { return v1 * v2; }, vectorOffsets, sizes);
      // Compute the size of each slice.
      SmallVector<int64_t, 4> sliceSizes(rank);
      for (int64_t r = 0; r < rank; ++r)
        sliceSizes[r] = std::min(sizes[r], shape[r] - elementOffsets[r]);
      // Extract from tuple into the result.
      auto index = rewriter.getI64IntegerAttr(i);
      auto tupleGet = rewriter.create<vector::TupleGetOp>(
          loc, tupleType.getType(i), op.getOperand(), index);
      result = rewriter.create<vector::InsertStridedSliceOp>(
          loc, tupleGet, result, elementOffsets, strides);
    }

    rewriter.replaceOp(op, result);
    return matchSuccess();
  }
};

} // namespace

// TODO(andydavis) Add pattern to rewrite ExtractSlices(ConstantMaskOp).
@@ -666,3 +793,8 @@ void mlir::vector::populateVectorToVectorTransformationPatterns(
  patterns.insert<SplitTransferReadOp, SplitTransferWriteOp, TupleGetFolderOp>(
      context);
}

void mlir::vector::populateVectorSlicesLoweringPatterns(
    OwningRewritePatternList &patterns, MLIRContext *context) {
  patterns.insert<ExtractSlicesOpLowering, InsertSlicesOpLowering>(context);
}
+63 −0
Original line number Diff line number Diff line
// RUN: mlir-opt %s -test-vector-slices-conversion | FileCheck %s

// CHECK-LABEL: func @extract_slices(%arg0: vector<3x3xf32>)
//       CHECK: %[[SS:.*]] = vector.strided_slice %arg0 {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]}
//       CHECK: return %[[SS]]

func @extract_slices(%arg0: vector<3x3xf32>) -> vector<2x2xf32> {
  %0 = vector.extract_slices %arg0, [2, 2], [1, 1]
    : vector<3x3xf32> into tuple<vector<2x2xf32>, vector<2x1xf32>, vector<1x2xf32>, vector<1x1xf32>>
  %1 = vector.tuple_get %0, 0 : tuple<vector<2x2xf32>, vector<2x1xf32>, vector<1x2xf32>, vector<1x1xf32>>
  return %1 : vector<2x2xf32>
}

// CHECK-LABEL: func @insert_slices(%arg0: vector<2x2xf32>, %arg1: vector<2x1xf32>, %arg2: vector<1x2xf32>, %arg3: vector<1x1xf32>)
//       CHECK: %[[C0:.*]] = constant dense<0.000000e+00> : vector<3x3xf32>
//       CHECK: %[[I0:.*]] = vector.insert_strided_slice %arg0, %[[C0]] {offsets = [0, 0], strides = [1, 1]}
//       CHECK: %[[I1:.*]] = vector.insert_strided_slice %arg1, %[[I0]] {offsets = [0, 2], strides = [1, 1]}
//       CHECK: %[[I2:.*]] = vector.insert_strided_slice %arg2, %[[I1]] {offsets = [2, 0], strides = [1, 1]}
//       CHECK: %[[I3:.*]] = vector.insert_strided_slice %arg3, %[[I2]] {offsets = [2, 2], strides = [1, 1]}
//       CHECK: return %[[I3]]

func @insert_slices(%arg0: vector<2x2xf32>,
                    %arg1: vector<2x1xf32>,
                    %arg2: vector<1x2xf32>,
                    %arg3: vector<1x1xf32>) -> vector<3x3xf32> {
  %0 = vector.tuple %arg0, %arg1, %arg2, %arg3
    : vector<2x2xf32>, vector<2x1xf32>, vector<1x2xf32>, vector<1x1xf32>
  %1 = vector.insert_slices %0, [2, 2], [1, 1]
    : tuple<vector<2x2xf32>, vector<2x1xf32>, vector<1x2xf32>, vector<1x1xf32>> into vector<3x3xf32>
  return %1 : vector<3x3xf32>
}

// CHECK-LABEL: func @extract_insert_slices(%arg0: vector<3x3xf32>)
//       CHECK: %[[C:.*]] = constant dense<0.000000e+00> : vector<3x3xf32>
//       CHECK: %[[X0:.*]] = vector.strided_slice %arg0 {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]}
//       CHECK: %[[X1:.*]] = vector.strided_slice %arg0 {offsets = [0, 2], sizes = [2, 1], strides = [1, 1]}
//       CHECK: %[[X2:.*]] = vector.strided_slice %arg0 {offsets = [2, 0], sizes = [1, 2], strides = [1, 1]}
//       CHECK: %[[X3:.*]] = vector.strided_slice %arg0 {offsets = [2, 2], sizes = [1, 1], strides = [1, 1]}
//       CHECK: %[[X4:.*]] = vector.insert_strided_slice %[[X0]], %[[C0]] {offsets = [0, 0], strides = [1, 1]}
//       CHECK: %[[X5:.*]] = vector.insert_strided_slice %[[X1]], %[[X4]] {offsets = [0, 2], strides = [1, 1]}
//       CHECK: %[[X6:.*]] = vector.insert_strided_slice %[[X2]], %[[X5]] {offsets = [2, 0], strides = [1, 1]}
//       CHECK: %[[X7:.*]] = vector.insert_strided_slice %[[X3]], %[[X6]] {offsets = [2, 2], strides = [1, 1]}
//       CHECK:return %[[X7]]

func @extract_insert_slices(%arg0: vector<3x3xf32>) -> vector<3x3xf32> {
  %0 = vector.extract_slices %arg0, [2, 2], [1, 1]
    : vector<3x3xf32> into tuple<vector<2x2xf32>, vector<2x1xf32>, vector<1x2xf32>, vector<1x1xf32>>
  %1 = vector.insert_slices %0, [2, 2], [1, 1]
    : tuple<vector<2x2xf32>, vector<2x1xf32>, vector<1x2xf32>, vector<1x1xf32>> into vector<3x3xf32>
  return %1 : vector<3x3xf32>
}

// CHECK-LABEL: func @extract_slices_tuple_leaks(%arg0: vector<4xf32>)
//       CHECK: %[[X0:.*]] = vector.strided_slice %arg0 {offsets = [0], sizes = [2], strides = [1]}
//       CHECK: %[[X1:.*]] = vector.strided_slice %arg0 {offsets = [2], sizes = [2], strides = [1]}
//       CHECK: %[[X2:.*]] = vector.tuple %[[X0]], %[[X1]]
//       CHECK: return %[[X2]]

func @extract_slices_tuple_leaks(%arg0: vector<4xf32>) -> tuple<vector<2xf32>, vector<2xf32>> {
  %0 = vector.extract_slices %arg0, [2], [1] : vector<4xf32> into tuple<vector<2xf32>, vector<2xf32>>
  return %0 : tuple<vector<2xf32>, vector<2xf32>>
}
+15 −0
Original line number Diff line number Diff line
@@ -18,6 +18,7 @@ using namespace mlir;
using namespace mlir::vector;

namespace {

#include "TestVectorTransformPatterns.h.inc"

struct TestVectorToVectorConversion
@@ -31,8 +32,22 @@ struct TestVectorToVectorConversion
    applyPatternsGreedily(getFunction(), patterns);
  }
};

struct TestVectorSlicesConversion
    : public FunctionPass<TestVectorSlicesConversion> {
  void runOnFunction() override {
    OwningRewritePatternList patterns;
    populateVectorSlicesLoweringPatterns(patterns, &getContext());
    applyPatternsGreedily(getFunction(), patterns);
  }
};

} // end anonymous namespace

static PassRegistration<TestVectorToVectorConversion>
    pass("test-vector-to-vector-conversion",
         "Test conversion patterns between ops in the vector dialect");

static PassRegistration<TestVectorSlicesConversion> slices_pass(
    "test-vector-slices-conversion",
    "Test conversion patterns that lower slices ops in the vector dialect");