Commit 3ff87088 authored by Jakub Kuderski's avatar Jakub Kuderski
Browse files

[mlir][arith] Add narrowing patterns for other insertion ops

Allow to commute extension ops over `vector.insertelement` and
`vector.insert_strided_slice`.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D149509
parent 0f1a8b45
Loading
Loading
Loading
Loading
+66 −15
Original line number Diff line number Diff line
@@ -306,25 +306,33 @@ struct ExtensionOverExtractStridedSlice final
  }
};

struct ExtensionOverInsert final : NarrowingPattern<vector::InsertOp> {
  using NarrowingPattern::NarrowingPattern;

  LogicalResult matchAndRewrite(vector::InsertOp op,
                                PatternRewriter &rewriter) const override {
/// Base pattern for `vector.insert` narrowing patterns.
template <typename InsertionOp>
struct ExtensionOverInsertionPattern : NarrowingPattern<InsertionOp> {
  using NarrowingPattern<InsertionOp>::NarrowingPattern;

  /// Derived classes must provide a function to create the matching insertion
  /// op based on the original op and new arguments.
  virtual InsertionOp createInsertionOp(PatternRewriter &rewriter,
                                        InsertionOp origInsert,
                                        Value narrowValue,
                                        Value narrowDest) const = 0;

  LogicalResult matchAndRewrite(InsertionOp op,
                                PatternRewriter &rewriter) const final {
    FailureOr<ExtensionOp> ext =
        ExtensionOp::from(op.getSource().getDefiningOp());
    if (failed(ext))
      return failure();

    FailureOr<vector::InsertOp> newInsert =
        createNarrowInsert(op, rewriter, *ext);
    FailureOr<InsertionOp> newInsert = createNarrowInsert(op, rewriter, *ext);
    if (failed(newInsert))
      return failure();
    ext->recreateAndReplace(rewriter, op, *newInsert);
    return success();
  }

  FailureOr<vector::InsertOp> createNarrowInsert(vector::InsertOp op,
  FailureOr<InsertionOp> createNarrowInsert(InsertionOp op,
                                            PatternRewriter &rewriter,
                                            ExtensionOp insValue) const {
    // Calculate the operand and result bitwidths. We can only apply narrowing
@@ -337,6 +345,8 @@ struct ExtensionOverInsert final : NarrowingPattern<vector::InsertOp> {
    if (failed(origBitsRequired))
      return failure();

    // TODO: We could relax this check by disregarding bitwidth requirements of
    // elements that we know will be replaced by the insertion.
    FailureOr<unsigned> destBitsRequired =
        calculateBitsRequired(op.getDest(), insValue.getKind());
    if (failed(destBitsRequired) || *destBitsRequired >= *origBitsRequired)
@@ -352,12 +362,13 @@ struct ExtensionOverInsert final : NarrowingPattern<vector::InsertOp> {
    // both the source and the destination values.
    unsigned newInsertionBits =
        std::max(*destBitsRequired, *insertedBitsRequired);
    FailureOr<Type> newVecTy = getNarrowType(newInsertionBits, op.getType());
    FailureOr<Type> newVecTy =
        this->getNarrowType(newInsertionBits, op.getType());
    if (failed(newVecTy) || *newVecTy == op.getType())
      return failure();

    FailureOr<Type> newInsertedValueTy =
        getNarrowType(newInsertionBits, insValue.getType());
        this->getNarrowType(newInsertionBits, insValue.getType());
    if (failed(newInsertedValueTy))
      return failure();

@@ -366,8 +377,47 @@ struct ExtensionOverInsert final : NarrowingPattern<vector::InsertOp> {
        loc, *newInsertedValueTy, insValue.getResult());
    Value narrowDest =
        rewriter.createOrFold<arith::TruncIOp>(loc, *newVecTy, op.getDest());
    return rewriter.create<vector::InsertOp>(loc, narrowValue, narrowDest,
                                             op.getPosition());
    return createInsertionOp(rewriter, op, narrowValue, narrowDest);
  }
};

struct ExtensionOverInsert final
    : ExtensionOverInsertionPattern<vector::InsertOp> {
  using ExtensionOverInsertionPattern::ExtensionOverInsertionPattern;

  vector::InsertOp createInsertionOp(PatternRewriter &rewriter,
                                     vector::InsertOp origInsert,
                                     Value narrowValue,
                                     Value narrowDest) const override {
    return rewriter.create<vector::InsertOp>(
        origInsert.getLoc(), narrowValue, narrowDest, origInsert.getPosition());
  }
};

struct ExtensionOverInsertElement final
    : ExtensionOverInsertionPattern<vector::InsertElementOp> {
  using ExtensionOverInsertionPattern::ExtensionOverInsertionPattern;

  vector::InsertElementOp createInsertionOp(PatternRewriter &rewriter,
                                            vector::InsertElementOp origInsert,
                                            Value narrowValue,
                                            Value narrowDest) const override {
    return rewriter.create<vector::InsertElementOp>(
        origInsert.getLoc(), narrowValue, narrowDest, origInsert.getPosition());
  }
};

struct ExtensionOverInsertStridedSlice final
    : ExtensionOverInsertionPattern<vector::InsertStridedSliceOp> {
  using ExtensionOverInsertionPattern::ExtensionOverInsertionPattern;

  vector::InsertStridedSliceOp
  createInsertionOp(PatternRewriter &rewriter,
                    vector::InsertStridedSliceOp origInsert, Value narrowValue,
                    Value narrowDest) const override {
    return rewriter.create<vector::InsertStridedSliceOp>(
        origInsert.getLoc(), narrowValue, narrowDest, origInsert.getOffsets(),
        origInsert.getStrides());
  }
};

@@ -400,7 +450,8 @@ void populateArithIntNarrowingPatterns(
  // Add commute patterns with a higher benefit. This is to expose more
  // optimization opportunities to narrowing patterns.
  patterns.add<ExtensionOverExtract, ExtensionOverExtractElement,
               ExtensionOverExtractStridedSlice, ExtensionOverInsert>(
               ExtensionOverExtractStridedSlice, ExtensionOverInsert,
               ExtensionOverInsertElement, ExtensionOverInsertStridedSlice>(
      patterns.getContext(), options, PatternBenefit(2));

  patterns.add<SIToFPPattern, UIToFPPattern>(patterns.getContext(), options);
+114 −0
Original line number Diff line number Diff line
@@ -328,3 +328,117 @@ func.func @extui_over_insert_3xi16_cst_i16(%a: i8) -> vector<3xi32> {
  %e = vector.insert %d, %cst [1] : i32 into vector<3xi32>
  return %e : vector<3xi32>
}

// CHECK-LABEL: func.func @extsi_over_insertelement_3xi16
// CHECK-SAME:    (%[[ARG0:.+]]: vector<3xi16>, %[[ARG1:.+]]: i16, %[[POS:.+]]: i32)
// CHECK-NEXT:    %[[INS:.+]] = vector.insertelement %[[ARG1]], %[[ARG0]][%[[POS]] : i32] : vector<3xi16>
// CHECK-NEXT:    %[[RET:.+]] = arith.extsi %[[INS]] : vector<3xi16> to vector<3xi32>
// CHECK-NEXT:    return %[[RET]] : vector<3xi32>
func.func @extsi_over_insertelement_3xi16(%a: vector<3xi16>, %b: i16, %pos: i32) -> vector<3xi32> {
  %c = arith.extsi %a : vector<3xi16> to vector<3xi32>
  %d = arith.extsi %b : i16 to i32
  %e = vector.insertelement %d, %c[%pos : i32] : vector<3xi32>
  return %e : vector<3xi32>
}

// CHECK-LABEL: func.func @extui_over_insertelement_3xi16
// CHECK-SAME:    (%[[ARG0:.+]]: vector<3xi16>, %[[ARG1:.+]]: i16, %[[POS:.+]]: i32)
// CHECK-NEXT:    %[[INS:.+]] = vector.insertelement %[[ARG1]], %[[ARG0]][%[[POS]] : i32] : vector<3xi16>
// CHECK-NEXT:    %[[RET:.+]] = arith.extui %[[INS]] : vector<3xi16> to vector<3xi32>
// CHECK-NEXT:    return %[[RET]] : vector<3xi32>
func.func @extui_over_insertelement_3xi16(%a: vector<3xi16>, %b: i16, %pos: i32) -> vector<3xi32> {
  %c = arith.extui %a : vector<3xi16> to vector<3xi32>
  %d = arith.extui %b : i16 to i32
  %e = vector.insertelement %d, %c[%pos : i32] : vector<3xi32>
  return %e : vector<3xi32>
}

// CHECK-LABEL: func.func @extsi_over_insertelement_3xi16_cst_i16
// CHECK-SAME:    (%[[ARG:.+]]: i8, %[[POS:.+]]: i32)
// CHECK-NEXT:    %[[CST:.+]]  = arith.constant dense<[-1, 128, 0]> : vector<3xi16>
// CHECK-NEXT:    %[[SRCE:.+]] = arith.extsi %[[ARG]] : i8 to i32
// CHECK-NEXT:    %[[SRCT:.+]] = arith.trunci %[[SRCE]] : i32 to i16
// CHECK-NEXT:    %[[INS:.+]] = vector.insertelement %[[SRCT]], %[[CST]][%[[POS]] : i32] : vector<3xi16>
// CHECK-NEXT:    %[[RET:.+]]  = arith.extsi %[[INS]] : vector<3xi16> to vector<3xi32>
// CHECK-NEXT:    return %[[RET]] : vector<3xi32>
func.func @extsi_over_insertelement_3xi16_cst_i16(%a: i8, %pos: i32) -> vector<3xi32> {
  %cst = arith.constant dense<[-1, 128, 0]> : vector<3xi32>
  %d = arith.extsi %a : i8 to i32
  %e = vector.insertelement %d, %cst[%pos : i32] : vector<3xi32>
  return %e : vector<3xi32>
}

// CHECK-LABEL: func.func @extui_over_insertelement_3xi16_cst_i16
// CHECK-SAME:    (%[[ARG:.+]]: i8, %[[POS:.+]]: i32)
// CHECK-NEXT:    %[[CST:.+]]  = arith.constant dense<[1, 256, 0]> : vector<3xi16>
// CHECK-NEXT:    %[[SRCE:.+]] = arith.extui %[[ARG]] : i8 to i32
// CHECK-NEXT:    %[[SRCT:.+]] = arith.trunci %[[SRCE]] : i32 to i16
// CHECK-NEXT:    %[[INS:.+]] = vector.insertelement %[[SRCT]], %[[CST]][%[[POS]] : i32] : vector<3xi16>
// CHECK-NEXT:    %[[RET:.+]]  = arith.extui %[[INS]] : vector<3xi16> to vector<3xi32>
// CHECK-NEXT:    return %[[RET]] : vector<3xi32>
func.func @extui_over_insertelement_3xi16_cst_i16(%a: i8, %pos: i32) -> vector<3xi32> {
  %cst = arith.constant dense<[1, 256, 0]> : vector<3xi32>
  %d = arith.extui %a : i8 to i32
  %e = vector.insertelement %d, %cst[%pos : i32] : vector<3xi32>
  return %e : vector<3xi32>
}

// CHECK-LABEL: func.func @extsi_over_insert_strided_slice_1d
// CHECK-SAME:    (%[[ARG0:.+]]: vector<3xi16>, %[[ARG1:.+]]: vector<2xi16>)
// CHECK-NEXT:    %[[INS:.+]] = vector.insert_strided_slice %[[ARG1]], %[[ARG0]]
// CHECK-SAME:                    {offsets = [1], strides = [1]} : vector<2xi16> into vector<3xi16>
// CHECK-NEXT:    %[[RET:.+]] = arith.extsi %[[INS]] : vector<3xi16> to vector<3xi32>
// CHECK-NEXT:    return %[[RET]] : vector<3xi32>
func.func @extsi_over_insert_strided_slice_1d(%a: vector<3xi16>, %b: vector<2xi16>) -> vector<3xi32> {
  %c = arith.extsi %a : vector<3xi16> to vector<3xi32>
  %d = arith.extsi %b : vector<2xi16> to vector<2xi32>
  %e = vector.insert_strided_slice %d, %c {offsets = [1], strides = [1]} : vector<2xi32> into vector<3xi32>
  return %e : vector<3xi32>
}

// CHECK-LABEL: func.func @extui_over_insert_strided_slice_1d
// CHECK-SAME:    (%[[ARG0:.+]]: vector<3xi16>, %[[ARG1:.+]]: vector<2xi16>)
// CHECK-NEXT:    %[[INS:.+]] = vector.insert_strided_slice %[[ARG1]], %[[ARG0]]
// CHECK-SAME:                    {offsets = [1], strides = [1]} : vector<2xi16> into vector<3xi16>
// CHECK-NEXT:    %[[RET:.+]] = arith.extui %[[INS]] : vector<3xi16> to vector<3xi32>
// CHECK-NEXT:    return %[[RET]] : vector<3xi32>
func.func @extui_over_insert_strided_slice_1d(%a: vector<3xi16>, %b: vector<2xi16>) -> vector<3xi32> {
  %c = arith.extui %a : vector<3xi16> to vector<3xi32>
  %d = arith.extui %b : vector<2xi16> to vector<2xi32>
  %e = vector.insert_strided_slice %d, %c {offsets = [1], strides = [1]} : vector<2xi32> into vector<3xi32>
  return %e : vector<3xi32>
}

// CHECK-LABEL: func.func @extsi_over_insert_strided_slice_cst_2d
// CHECK-SAME:    (%[[ARG:.+]]: vector<1x2xi8>)
// CHECK-NEXT:    %[[CST:.+]]  = arith.constant
// CHECK-SAME{LITERAL}:            dense<[[-1, 128, 0], [-129, 42, 1337]]> : vector<2x3xi16>
// CHECK-NEXT:    %[[SRCE:.+]] = arith.extsi %[[ARG]] : vector<1x2xi8> to vector<1x2xi32>
// CHECK-NEXT:    %[[SRCT:.+]] = arith.trunci %[[SRCE]] : vector<1x2xi32> to vector<1x2xi16>
// CHECK-NEXT:    %[[INS:.+]] = vector.insert_strided_slice %[[SRCT]], %[[CST]]
// CHECK-SAME:                    {offsets = [0, 1], strides = [1, 1]} : vector<1x2xi16> into vector<2x3xi16>
// CHECK-NEXT:    %[[RET:.+]]  = arith.extsi %[[INS]] : vector<2x3xi16> to vector<2x3xi32>
// CHECK-NEXT:    return %[[RET]] : vector<2x3xi32>
func.func @extsi_over_insert_strided_slice_cst_2d(%a: vector<1x2xi8>) -> vector<2x3xi32> {
  %cst = arith.constant dense<[[-1, 128, 0], [-129, 42, 1337]]> : vector<2x3xi32>
  %d = arith.extsi %a : vector<1x2xi8> to vector<1x2xi32>
  %e = vector.insert_strided_slice %d, %cst {offsets = [0, 1], strides = [1, 1]} : vector<1x2xi32> into vector<2x3xi32>
  return %e : vector<2x3xi32>
}

// CHECK-LABEL: func.func @extui_over_insert_strided_slice_cst_2d
// CHECK-SAME:    (%[[ARG:.+]]: vector<1x2xi8>)
// CHECK-NEXT:    %[[CST:.+]]  = arith.constant
// CHECK-SAME{LITERAL}:            dense<[[1, 128, 0], [256, 42, 1337]]> : vector<2x3xi16>
// CHECK-NEXT:    %[[SRCE:.+]] = arith.extui %[[ARG]] : vector<1x2xi8> to vector<1x2xi32>
// CHECK-NEXT:    %[[SRCT:.+]] = arith.trunci %[[SRCE]] : vector<1x2xi32> to vector<1x2xi16>
// CHECK-NEXT:    %[[INS:.+]] = vector.insert_strided_slice %[[SRCT]], %[[CST]]
// CHECK-SAME:                    {offsets = [0, 1], strides = [1, 1]} : vector<1x2xi16> into vector<2x3xi16>
// CHECK-NEXT:    %[[RET:.+]]  = arith.extui %[[INS]] : vector<2x3xi16> to vector<2x3xi32>
// CHECK-NEXT:    return %[[RET]] : vector<2x3xi32>
func.func @extui_over_insert_strided_slice_cst_2d(%a: vector<1x2xi8>) -> vector<2x3xi32> {
  %cst = arith.constant dense<[[1, 128, 0], [256, 42, 1337]]> : vector<2x3xi32>
  %d = arith.extui %a : vector<1x2xi8> to vector<1x2xi32>
  %e = vector.insert_strided_slice %d, %cst {offsets = [0, 1], strides = [1, 1]} : vector<1x2xi32> into vector<2x3xi32>
  return %e : vector<2x3xi32>
}