Commit 7f3b0e58 authored by Jakub Kuderski's avatar Jakub Kuderski
Browse files

[mlir][arith] Add narrowing patterns to commute more vector ops

This commutes the extension (`arith.extsi`, `arith.extui`) over the
following vector ops: `vector.broadcast`, `vector.shape_cast`,
`vector.transpose`, `vector.flat_transpose`.

I focused on these as I saw them getting created by vector unroll
patterns. Maybe except `vector.flat_transpose`.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D149534
parent 3ff87088
Loading
Loading
Loading
Loading
+87 −3
Original line number Diff line number Diff line
@@ -249,6 +249,26 @@ using UIToFPPattern = IToFPPattern<arith::UIToFPOp, ExtensionKind::Zero>;
// Patterns to Commute Extension Ops
//===----------------------------------------------------------------------===//

struct ExtensionOverBroadcast final : NarrowingPattern<vector::BroadcastOp> {
  using NarrowingPattern::NarrowingPattern;

  LogicalResult matchAndRewrite(vector::BroadcastOp op,
                                PatternRewriter &rewriter) const override {
    FailureOr<ExtensionOp> ext =
        ExtensionOp::from(op.getSource().getDefiningOp());
    if (failed(ext))
      return failure();

    VectorType origTy = op.getResultVectorType();
    VectorType newTy =
        origTy.cloneWith(origTy.getShape(), ext->getInElementType());
    Value newBroadcast =
        rewriter.create<vector::BroadcastOp>(op.getLoc(), newTy, ext->getIn());
    ext->recreateAndReplace(rewriter, op, newBroadcast);
    return success();
  }
};

struct ExtensionOverExtract final : NarrowingPattern<vector::ExtractOp> {
  using NarrowingPattern::NarrowingPattern;

@@ -421,6 +441,68 @@ struct ExtensionOverInsertStridedSlice final
  }
};

struct ExtensionOverShapeCast final : NarrowingPattern<vector::ShapeCastOp> {
  using NarrowingPattern::NarrowingPattern;

  LogicalResult matchAndRewrite(vector::ShapeCastOp op,
                                PatternRewriter &rewriter) const override {
    FailureOr<ExtensionOp> ext =
        ExtensionOp::from(op.getSource().getDefiningOp());
    if (failed(ext))
      return failure();

    VectorType origTy = op.getResultVectorType();
    VectorType newTy =
        origTy.cloneWith(origTy.getShape(), ext->getInElementType());
    Value newCast =
        rewriter.create<vector::ShapeCastOp>(op.getLoc(), newTy, ext->getIn());
    ext->recreateAndReplace(rewriter, op, newCast);
    return success();
  }
};

struct ExtensionOverTranspose final : NarrowingPattern<vector::TransposeOp> {
  using NarrowingPattern::NarrowingPattern;

  LogicalResult matchAndRewrite(vector::TransposeOp op,
                                PatternRewriter &rewriter) const override {
    FailureOr<ExtensionOp> ext =
        ExtensionOp::from(op.getVector().getDefiningOp());
    if (failed(ext))
      return failure();

    VectorType origTy = op.getResultVectorType();
    VectorType newTy =
        origTy.cloneWith(origTy.getShape(), ext->getInElementType());
    Value newTranspose = rewriter.create<vector::TransposeOp>(
        op.getLoc(), newTy, ext->getIn(), op.getTransp());
    ext->recreateAndReplace(rewriter, op, newTranspose);
    return success();
  }
};

struct ExtensionOverFlatTranspose final
    : NarrowingPattern<vector::FlatTransposeOp> {
  using NarrowingPattern::NarrowingPattern;

  LogicalResult matchAndRewrite(vector::FlatTransposeOp op,
                                PatternRewriter &rewriter) const override {
    FailureOr<ExtensionOp> ext =
        ExtensionOp::from(op.getMatrix().getDefiningOp());
    if (failed(ext))
      return failure();

    VectorType origTy = op.getType();
    VectorType newTy =
        origTy.cloneWith(origTy.getShape(), ext->getInElementType());
    Value newTranspose = rewriter.create<vector::FlatTransposeOp>(
        op.getLoc(), newTy, ext->getIn(), op.getRowsAttr(),
        op.getColumnsAttr());
    ext->recreateAndReplace(rewriter, op, newTranspose);
    return success();
  }
};

//===----------------------------------------------------------------------===//
// Pass Definitions
//===----------------------------------------------------------------------===//
@@ -449,9 +531,11 @@ void populateArithIntNarrowingPatterns(
    RewritePatternSet &patterns, const ArithIntNarrowingOptions &options) {
  // Add commute patterns with a higher benefit. This is to expose more
  // optimization opportunities to narrowing patterns.
  patterns.add<ExtensionOverExtract, ExtensionOverExtractElement,
               ExtensionOverExtractStridedSlice, ExtensionOverInsert,
               ExtensionOverInsertElement, ExtensionOverInsertStridedSlice>(
  patterns.add<ExtensionOverBroadcast, ExtensionOverExtract,
               ExtensionOverExtractElement, ExtensionOverExtractStridedSlice,
               ExtensionOverInsert, ExtensionOverInsertElement,
               ExtensionOverInsertStridedSlice, ExtensionOverShapeCast,
               ExtensionOverTranspose, ExtensionOverFlatTranspose>(
      patterns.getContext(), options, PatternBenefit(2));

  patterns.add<SIToFPPattern, UIToFPPattern>(patterns.getContext(), options);
+88 −0
Original line number Diff line number Diff line
@@ -442,3 +442,91 @@ func.func @extui_over_insert_strided_slice_cst_2d(%a: vector<1x2xi8>) -> vector<
  %e = vector.insert_strided_slice %d, %cst {offsets = [0, 1], strides = [1, 1]} : vector<1x2xi32> into vector<2x3xi32>
  return %e : vector<2x3xi32>
}

// CHECK-LABEL: func.func @extsi_over_broadcast_3xi16
// CHECK-SAME:    (%[[ARG:.+]]: i16)
// CHECK-NEXT:    %[[BCST:.+]] = vector.broadcast %[[ARG]] : i16 to vector<3xi16>
// CHECK-NEXT:    %[[RET:.+]]  = arith.extsi %[[BCST]] : vector<3xi16> to vector<3xi32>
// CHECK-NEXT:    return %[[RET]] : vector<3xi32>
func.func @extsi_over_broadcast_3xi16(%a: i16) -> vector<3xi32> {
  %b = arith.extsi %a : i16 to i32
  %r = vector.broadcast %b : i32 to vector<3xi32>
  return %r : vector<3xi32>
}

// CHECK-LABEL: func.func @extui_over_broadcast_2x3xi16
// CHECK-SAME:    (%[[ARG:.+]]: vector<3xi16>)
// CHECK-NEXT:    %[[BCST:.+]] = vector.broadcast %[[ARG]] : vector<3xi16> to vector<2x3xi16>
// CHECK-NEXT:    %[[RET:.+]]  = arith.extui %[[BCST]] : vector<2x3xi16> to vector<2x3xi32>
// CHECK-NEXT:    return %[[RET]] : vector<2x3xi32>
func.func @extui_over_broadcast_2x3xi16(%a: vector<3xi16>) -> vector<2x3xi32> {
  %b = arith.extui %a : vector<3xi16> to vector<3xi32>
  %r = vector.broadcast %b : vector<3xi32> to vector<2x3xi32>
  return %r : vector<2x3xi32>
}

// CHECK-LABEL: func.func @extsi_over_shape_cast_2x3xi16
// CHECK-SAME:    (%[[ARG:.+]]: vector<2x3xi16>)
// CHECK-NEXT:    %[[CAST:.+]] = vector.shape_cast %[[ARG]] : vector<2x3xi16> to vector<3x2xi16>
// CHECK-NEXT:    %[[RET:.+]]  = arith.extsi %[[CAST]] : vector<3x2xi16> to vector<3x2xi32>
// CHECK-NEXT:    return %[[RET]] : vector<3x2xi32>
func.func @extsi_over_shape_cast_2x3xi16(%a: vector<2x3xi16>) -> vector<3x2xi32> {
  %b = arith.extsi %a : vector<2x3xi16> to vector<2x3xi32>
  %r = vector.shape_cast %b : vector<2x3xi32> to vector<3x2xi32>
  return %r : vector<3x2xi32>
}

// CHECK-LABEL: func.func @extui_over_shape_cast_5x2x3xi16
// CHECK-SAME:    (%[[ARG:.+]]: vector<5x2x3xi16>)
// CHECK-NEXT:    %[[CAST:.+]] = vector.shape_cast %[[ARG]] : vector<5x2x3xi16> to vector<2x3x5xi16>
// CHECK-NEXT:    %[[RET:.+]]  = arith.extui %[[CAST]] : vector<2x3x5xi16> to vector<2x3x5xi32>
// CHECK-NEXT:    return %[[RET]] : vector<2x3x5xi32>
func.func @extui_over_shape_cast_5x2x3xi16(%a: vector<5x2x3xi16>) -> vector<2x3x5xi32> {
  %b = arith.extui %a : vector<5x2x3xi16> to vector<5x2x3xi32>
  %r = vector.shape_cast %b : vector<5x2x3xi32> to vector<2x3x5xi32>
  return %r : vector<2x3x5xi32>
}

// CHECK-LABEL: func.func @extsi_over_transpose_2x3xi16
// CHECK-SAME:    (%[[ARG:.+]]: vector<2x3xi16>)
// CHECK-NEXT:    %[[TRAN:.+]] = vector.transpose %[[ARG]], [1, 0] : vector<2x3xi16> to vector<3x2xi16>
// CHECK-NEXT:    %[[RET:.+]]  = arith.extsi %[[TRAN]] : vector<3x2xi16> to vector<3x2xi32>
// CHECK-NEXT:    return %[[RET]] : vector<3x2xi32>
func.func @extsi_over_transpose_2x3xi16(%a: vector<2x3xi16>) -> vector<3x2xi32> {
  %b = arith.extsi %a : vector<2x3xi16> to vector<2x3xi32>
  %r = vector.transpose %b, [1, 0] : vector<2x3xi32> to vector<3x2xi32>
  return %r : vector<3x2xi32>
}

// CHECK-LABEL: func.func @extui_over_transpose_5x2x3xi16
// CHECK-SAME:    (%[[ARG:.+]]: vector<5x2x3xi16>)
// CHECK-NEXT:    %[[TRAN:.+]] = vector.transpose %[[ARG]], [1, 2, 0] : vector<5x2x3xi16> to vector<2x3x5xi16>
// CHECK-NEXT:    %[[RET:.+]]  = arith.extui %[[TRAN]] : vector<2x3x5xi16> to vector<2x3x5xi32>
// CHECK-NEXT:    return %[[RET]] : vector<2x3x5xi32>
func.func @extui_over_transpose_5x2x3xi16(%a: vector<5x2x3xi16>) -> vector<2x3x5xi32> {
  %b = arith.extui %a : vector<5x2x3xi16> to vector<5x2x3xi32>
  %r = vector.transpose %b, [1, 2, 0] : vector<5x2x3xi32> to vector<2x3x5xi32>
  return %r : vector<2x3x5xi32>
}

// CHECK-LABEL: func.func @extsi_over_flat_transpose_16xi16
// CHECK-SAME:    (%[[ARG:.+]]: vector<16xi16>)
// CHECK-NEXT:    %[[TRAN:.+]] = vector.flat_transpose %[[ARG]] {columns = 4 : i32, rows = 4 : i32} : vector<16xi16> -> vector<16xi16>
// CHECK-NEXT:    %[[RET:.+]]  = arith.extsi %[[TRAN]] : vector<16xi16> to vector<16xi32>
// CHECK-NEXT:    return %[[RET]] : vector<16xi32>
func.func @extsi_over_flat_transpose_16xi16(%a: vector<16xi16>) -> vector<16xi32> {
  %b = arith.extsi %a : vector<16xi16> to vector<16xi32>
  %r = vector.flat_transpose %b {columns = 4 : i32, rows = 4 : i32} : vector<16xi32> -> vector<16xi32>
  return %r : vector<16xi32>
}

// CHECK-LABEL: func.func @extui_over_flat_transpose_16xi16
// CHECK-SAME:    (%[[ARG:.+]]: vector<16xi16>)
// CHECK-NEXT:    %[[TRAN:.+]] = vector.flat_transpose %[[ARG]] {columns = 8 : i32, rows = 2 : i32} : vector<16xi16> -> vector<16xi16>
// CHECK-NEXT:    %[[RET:.+]]  = arith.extui %[[TRAN]] : vector<16xi16> to vector<16xi32>
// CHECK-NEXT:    return %[[RET]] : vector<16xi32>
func.func @extui_over_flat_transpose_16xi16(%a: vector<16xi16>) -> vector<16xi32> {
  %b = arith.extui %a : vector<16xi16> to vector<16xi32>
  %r = vector.flat_transpose %b {columns = 8 : i32, rows = 2 : i32} : vector<16xi32> -> vector<16xi32>
  return %r : vector<16xi32>
}