Commit 49df0688 authored by Jakub Kuderski's avatar Jakub Kuderski
Browse files

[mlir][arith][NFC] Simplify narrowing patterns with a wrapper type

Add a new wraper type that represents either of `ExtSIOp` or `ExtUIOp`.
This is to simplify the code by using a single type, so that we do not
have to use templates or branching to handle both extension kinds.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D149485
parent f7627985
Loading
Loading
Loading
Loading
+97 −67
Original line number Diff line number Diff line
@@ -15,13 +15,13 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TypeSwitch.h"
#include <cassert>
#include <cstdint>

@@ -100,11 +100,63 @@ FailureOr<unsigned> calculateBitsRequired(Type type) {

enum class ExtensionKind { Sign, Zero };

ExtensionKind getExtensionKind(Operation *op) {
/// Wrapper around `arith::ExtSIOp` and `arith::ExtUIOp` ops that abstracts away
/// the exact op type. Exposes helper functions to query the types, operands,
/// and the result. This is so that we can handle both extension kinds without
/// needing to use templates or branching.
class ExtensionOp {
public:
  /// Attemps to create a new extension op from `op`. Returns an extension op
  /// wrapper when `op` is either `arith.extsi` or `arith.extui`, and failure
  /// otherwise.
  static FailureOr<ExtensionOp> from(Operation *op) {
    if (auto sext = dyn_cast_or_null<arith::ExtSIOp>(op))
      return ExtensionOp{op, ExtensionKind::Sign};
    if (auto zext = dyn_cast_or_null<arith::ExtUIOp>(op))
      return ExtensionOp{op, ExtensionKind::Zero};

    return failure();
  }

  ExtensionOp(const ExtensionOp &) = default;
  ExtensionOp &operator=(const ExtensionOp &) = default;

  /// Creates a new extension op of the same kind.
  Operation *recreate(PatternRewriter &rewriter, Location loc, Type newType,
                      Value in) {
    if (kind == ExtensionKind::Sign)
      return rewriter.create<arith::ExtSIOp>(loc, newType, in);

    return rewriter.create<arith::ExtUIOp>(loc, newType, in);
  }

  /// Replaces `toReplace` with a new extension op of the same kind.
  void recreateAndReplace(PatternRewriter &rewriter, Operation *toReplace,
                          Value in) {
    assert(toReplace->getNumResults() == 1);
    Type newType = toReplace->getResult(0).getType();
    Operation *newOp = recreate(rewriter, toReplace->getLoc(), newType, in);
    rewriter.replaceOp(toReplace, newOp->getResult(0));
  }

  ExtensionKind getKind() { return kind; }

  Value getResult() { return op->getResult(0); }
  Value getIn() { return op->getOperand(0); }

  Type getType() { return getResult().getType(); }
  Type getElementType() { return getElementTypeOrSelf(getType()); }
  Type getInType() { return getIn().getType(); }
  Type getInElementType() { return getElementTypeOrSelf(getInType()); }

private:
  ExtensionOp(Operation *op, ExtensionKind kind) : op(op), kind(kind) {
    assert(op);
    assert((isa<arith::ExtSIOp, arith::ExtUIOp>(op)) && "Not an extension op");
  return isa<arith::ExtSIOp>(op) ? ExtensionKind::Sign : ExtensionKind::Zero;
  }
  Operation *op = nullptr;
  ExtensionKind kind = {};
};

/// Returns the integer bitwidth required to represent `value`.
unsigned calculateBitsRequired(const APInt &value,
@@ -202,19 +254,15 @@ struct ExtensionOverExtract final : NarrowingPattern<vector::ExtractOp> {

  LogicalResult matchAndRewrite(vector::ExtractOp op,
                                PatternRewriter &rewriter) const override {
    Operation *def = op.getVector().getDefiningOp();
    if (!def)
    FailureOr<ExtensionOp> ext =
        ExtensionOp::from(op.getVector().getDefiningOp());
    if (failed(ext))
      return failure();

    return TypeSwitch<Operation *, LogicalResult>(def)
        .Case<arith::ExtSIOp, arith::ExtUIOp>([&](auto extOp) {
    Value newExtract = rewriter.create<vector::ExtractOp>(
              op.getLoc(), extOp.getIn(), op.getPosition());
          rewriter.replaceOpWithNewOp<decltype(extOp)>(op, op.getType(),
                                                       newExtract);
        op.getLoc(), ext->getIn(), op.getPosition());
    ext->recreateAndReplace(rewriter, op, newExtract);
    return success();
        })
        .Default(failure());
  }
};

@@ -224,19 +272,15 @@ struct ExtensionOverExtractElement final

  LogicalResult matchAndRewrite(vector::ExtractElementOp op,
                                PatternRewriter &rewriter) const override {
    Operation *def = op.getVector().getDefiningOp();
    if (!def)
    FailureOr<ExtensionOp> ext =
        ExtensionOp::from(op.getVector().getDefiningOp());
    if (failed(ext))
      return failure();

    return TypeSwitch<Operation *, LogicalResult>(def)
        .Case<arith::ExtSIOp, arith::ExtUIOp>([&](auto extOp) {
    Value newExtract = rewriter.create<vector::ExtractElementOp>(
              op.getLoc(), extOp.getIn(), op.getPosition());
          rewriter.replaceOpWithNewOp<decltype(extOp)>(op, op.getType(),
                                                       newExtract);
        op.getLoc(), ext->getIn(), op.getPosition());
    ext->recreateAndReplace(rewriter, op, newExtract);
    return success();
        })
        .Default(failure());
  }
};

@@ -246,24 +290,19 @@ struct ExtensionOverExtractStridedSlice final

  LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp op,
                                PatternRewriter &rewriter) const override {
    Operation *def = op.getVector().getDefiningOp();
    if (!def)
    FailureOr<ExtensionOp> ext =
        ExtensionOp::from(op.getVector().getDefiningOp());
    if (failed(ext))
      return failure();

    return TypeSwitch<Operation *, LogicalResult>(def)
        .Case<arith::ExtSIOp, arith::ExtUIOp>([&](auto extOp) {
    VectorType origTy = op.getType();
          Type inElemTy =
              cast<VectorType>(extOp.getIn().getType()).getElementType();
          VectorType extractTy = origTy.cloneWith(origTy.getShape(), inElemTy);
    VectorType extractTy =
        origTy.cloneWith(origTy.getShape(), ext->getInElementType());
    Value newExtract = rewriter.create<vector::ExtractStridedSliceOp>(
              op.getLoc(), extractTy, extOp.getIn(), op.getOffsets(),
              op.getSizes(), op.getStrides());
          rewriter.replaceOpWithNewOp<decltype(extOp)>(op, op.getType(),
                                                       newExtract);
        op.getLoc(), extractTy, ext->getIn(), op.getOffsets(), op.getSizes(),
        op.getStrides());
    ext->recreateAndReplace(rewriter, op, newExtract);
    return success();
        })
        .Default(failure());
  }
};

@@ -272,30 +311,22 @@ struct ExtensionOverInsert final : NarrowingPattern<vector::InsertOp> {

  LogicalResult matchAndRewrite(vector::InsertOp op,
                                PatternRewriter &rewriter) const override {
    Operation *def = op.getSource().getDefiningOp();
    if (!def)
    FailureOr<ExtensionOp> ext =
        ExtensionOp::from(op.getSource().getDefiningOp());
    if (failed(ext))
      return failure();

    return TypeSwitch<Operation *, LogicalResult>(def)
        .Case<arith::ExtSIOp, arith::ExtUIOp>([&](auto extOp) {
          // Rewrite the insertion in terms of narrower operands
          // and later extend the result to the original bitwidth.
    FailureOr<vector::InsertOp> newInsert =
              createNarrowInsert(op, rewriter, extOp);
        createNarrowInsert(op, rewriter, *ext);
    if (failed(newInsert))
      return failure();
          rewriter.replaceOpWithNewOp<decltype(extOp)>(op, op.getType(),
                                                       *newInsert);
    ext->recreateAndReplace(rewriter, op, *newInsert);
    return success();
        })
        .Default(failure());
  }

  FailureOr<vector::InsertOp> createNarrowInsert(vector::InsertOp op,
                                                 PatternRewriter &rewriter,
                                                 Operation *insValue) const {
    assert((isa<arith::ExtSIOp, arith::ExtUIOp>(insValue)));

                                                 ExtensionOp insValue) const {
    // Calculate the operand and result bitwidths. We can only apply narrowing
    // when the inserted source value and destination vector require fewer bits
    // than the result. Because the source and destination may have different
@@ -306,14 +337,13 @@ struct ExtensionOverInsert final : NarrowingPattern<vector::InsertOp> {
    if (failed(origBitsRequired))
      return failure();

    ExtensionKind kind = getExtensionKind(insValue);
    FailureOr<unsigned> destBitsRequired =
        calculateBitsRequired(op.getDest(), kind);
        calculateBitsRequired(op.getDest(), insValue.getKind());
    if (failed(destBitsRequired) || *destBitsRequired >= *origBitsRequired)
      return failure();

    FailureOr<unsigned> insertedBitsRequired =
        calculateBitsRequired(insValue->getOperands().front(), kind);
        calculateBitsRequired(insValue.getIn(), insValue.getKind());
    if (failed(insertedBitsRequired) ||
        *insertedBitsRequired >= *origBitsRequired)
      return failure();
@@ -327,13 +357,13 @@ struct ExtensionOverInsert final : NarrowingPattern<vector::InsertOp> {
      return failure();

    FailureOr<Type> newInsertedValueTy =
        getNarrowType(newInsertionBits, insValue->getResultTypes().front());
        getNarrowType(newInsertionBits, insValue.getType());
    if (failed(newInsertedValueTy))
      return failure();

    Location loc = op.getLoc();
    Value narrowValue = rewriter.createOrFold<arith::TruncIOp>(
        loc, *newInsertedValueTy, insValue->getResult(0));
        loc, *newInsertedValueTy, insValue.getResult());
    Value narrowDest =
        rewriter.createOrFold<arith::TruncIOp>(loc, *newVecTy, op.getDest());
    return rewriter.create<vector::InsertOp>(loc, narrowValue, narrowDest,