Unverified Commit 07707707 authored by Krzysztof Drewniak's avatar Krzysztof Drewniak Committed by GitHub
Browse files

[mlir][Vector] Add load, store, etc. to dropleadunitdim (#195686)

Discussions on improvements to fold-memref-alias-ops changes revealed
that the patterns meant to drop leading unit dimensions from vector
operations weren't handling load, store, and other "terminal" vector
dialect operations. This PR adds the patterns to fix that.

Assisted-by: Claude 4.7
parent a6470d6d
Loading
Loading
Loading
Loading
+105 −1
Original line number Diff line number Diff line
@@ -537,6 +537,101 @@ public:
    return success();
  }
};
} // namespace

// Drops `dropDim` leading dimensions from `operand` using vector.extract when
// those dims are all non-scalable units (the cheap, structural rewrite); falls
// back to vector.shape_cast otherwise.
static Value dropLeadingOneDimsFromOperand(OpBuilder &b, Location loc,
                                           Value operand, int64_t nDropped) {
  auto oldType = cast<VectorType>(operand.getType());
  ArrayRef<int64_t> leadingShape = oldType.getShape().take_front(nDropped);
  ArrayRef<bool> leadingScalable =
      oldType.getScalableDims().take_front(nDropped);
  bool extractable =
      llvm::all_of(leadingShape, [](int64_t d) { return d == 1; }) &&
      llvm::none_of(leadingScalable, [](bool s) { return s; });
  if (extractable)
    return vector::ExtractOp::create(b, loc, operand, splatZero(nDropped));
  VectorType newType = VectorType::get(
      oldType.getShape().drop_front(nDropped), oldType.getElementType(),
      oldType.getScalableDims().drop_front(nDropped));
  return vector::ShapeCastOp::create(b, loc, newType, operand);
}

namespace {

// Drops leading 1 dimensions from load-like memory operaitons. REmoves leading
// unit dimensions from the result types and then broadcasts back in those 1s,
// while also extracting (or shape_cast-ing) any leading unit dimensions on
// the input operands.
template <typename OpTy>
struct CastAwayLoadLikeLeadingOneDim : public OpRewritePattern<OpTy> {
  using OpRewritePattern<OpTy>::OpRewritePattern;

  LogicalResult matchAndRewrite(OpTy op,
                                PatternRewriter &rewriter) const override {
    VectorType oldResultType = op.getVectorType();
    VectorType newResultType = trimLeadingOneDims(oldResultType);
    if (newResultType == oldResultType)
      return failure();
    int64_t nDropped = oldResultType.getRank() - newResultType.getRank();

    Location loc = op.getLoc();
    SmallVector<Value> newOperands;
    newOperands.reserve(op->getNumOperands());
    for (Value operand : op->getOperands()) {
      if (isa<VectorType>(operand.getType())) {
        newOperands.push_back(
            dropLeadingOneDimsFromOperand(rewriter, loc, operand, nDropped));
      } else {
        newOperands.push_back(operand);
      }
    }

    Operation *newOp =
        rewriter.create(loc, op->getName().getIdentifier(), newOperands,
                        TypeRange{newResultType}, op->getAttrs());
    rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, oldResultType,
                                                     newOp->getResult(0));
    return success();
  }
};

// Drops leading 1 dimensions from store-like memory ops. Extracts or
// `shape_cast`s away those leading unit dimensions and leaves any scalar
// operands alone.
template <typename OpTy>
struct CastAwayStoreLikeLeadingOneDim : public OpRewritePattern<OpTy> {
  using OpRewritePattern<OpTy>::OpRewritePattern;

  LogicalResult matchAndRewrite(OpTy op,
                                PatternRewriter &rewriter) const override {
    VectorType oldVecType = op.getVectorType();
    VectorType newVecType = trimLeadingOneDims(oldVecType);
    if (newVecType == oldVecType)
      return failure();
    int64_t nDropped = oldVecType.getRank() - newVecType.getRank();

    Location loc = op.getLoc();
    SmallVector<Value> newOperands;
    newOperands.reserve(op->getNumOperands());
    for (Value operand : op->getOperands()) {
      if (isa<VectorType>(operand.getType())) {
        newOperands.push_back(
            dropLeadingOneDimsFromOperand(rewriter, loc, operand, nDropped));
      } else {
        newOperands.push_back(operand);
      }
    }

    Operation *newOp =
        rewriter.create(loc, op->getName().getIdentifier(), newOperands,
                        op->getResultTypes(), op->getAttrs());
    rewriter.replaceOp(op, newOp->getResults());
    return success();
  }
};

// Drops leading 1 dimensions from vector.constant_mask and inserts a
// vector.broadcast back to the original shape.
@@ -578,5 +673,14 @@ void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(
           CastAwayInsertStridedSliceLeadingOneDim, CastAwayInsertLeadingOneDim,
           CastAwayConstantMaskLeadingOneDim, CastAwayTransferReadLeadingOneDim,
           CastAwayTransferWriteLeadingOneDim, CastAwayElementwiseLeadingOneDim,
           CastAwayContractionLeadingOneDim>(patterns.getContext(), benefit);
           CastAwayContractionLeadingOneDim,
           CastAwayLoadLikeLeadingOneDim<vector::LoadOp>,
           CastAwayLoadLikeLeadingOneDim<vector::MaskedLoadOp>,
           CastAwayLoadLikeLeadingOneDim<vector::ExpandLoadOp>,
           CastAwayLoadLikeLeadingOneDim<vector::GatherOp>,
           CastAwayStoreLikeLeadingOneDim<vector::StoreOp>,
           CastAwayStoreLikeLeadingOneDim<vector::MaskedStoreOp>,
           CastAwayStoreLikeLeadingOneDim<vector::CompressStoreOp>,
           CastAwayStoreLikeLeadingOneDim<vector::ScatterOp>>(
          patterns.getContext(), benefit);
}
+95 −0
Original line number Diff line number Diff line
@@ -693,3 +693,98 @@ func.func @drop_unit_dims_scalar_cond_select(%cond: i1, %arg0: vector<1x16xi1>,
  %sel = arith.select %cond, %arg0, %arg1 : vector<1x16xi1>
  return %sel : vector<1x16xi1>
}

// -----

// CHECK-LABEL: func.func @cast_away_load_leading_one_dims
// CHECK:         %[[L:.+]] = vector.load %{{.*}}[%{{.*}}, %{{.*}}] : memref<8x16xf32>, vector<4xf32>
// CHECK:         %[[B:.+]] = vector.broadcast %[[L]] : vector<4xf32> to vector<1x4xf32>
// CHECK:         return %[[B]] : vector<1x4xf32>
func.func @cast_away_load_leading_one_dims(%base: memref<8x16xf32>, %i: index, %j: index) -> vector<1x4xf32> {
  %0 = vector.load %base[%i, %j] : memref<8x16xf32>, vector<1x4xf32>
  return %0 : vector<1x4xf32>
}

// -----

// CHECK-LABEL: func.func @cast_away_maskedload_leading_one_dims
// CHECK:         %[[M:.+]] = vector.extract %{{.*}}[0] : vector<4xi1> from vector<1x4xi1>
// CHECK:         %[[P:.+]] = vector.extract %{{.*}}[0] : vector<4xf32> from vector<1x4xf32>
// CHECK:         %[[L:.+]] = vector.maskedload %{{.*}}[%{{.*}}], %[[M]], %[[P]] : memref<16xf32>, vector<4xi1>, vector<4xf32> into vector<4xf32>
// CHECK:         %[[B:.+]] = vector.broadcast %[[L]] : vector<4xf32> to vector<1x4xf32>
// CHECK:         return %[[B]] : vector<1x4xf32>
func.func @cast_away_maskedload_leading_one_dims(%base: memref<16xf32>, %i: index, %mask: vector<1x4xi1>, %pass: vector<1x4xf32>) -> vector<1x4xf32> {
  %0 = vector.maskedload %base[%i], %mask, %pass : memref<16xf32>, vector<1x4xi1>, vector<1x4xf32> into vector<1x4xf32>
  return %0 : vector<1x4xf32>
}

// -----

// CHECK-LABEL: func.func @cast_away_expandload_leading_one_dims
// CHECK:         %[[M:.+]] = vector.extract %{{.*}}[0] : vector<4xi1> from vector<1x4xi1>
// CHECK:         %[[P:.+]] = vector.extract %{{.*}}[0] : vector<4xf32> from vector<1x4xf32>
// CHECK:         %[[L:.+]] = vector.expandload %{{.*}}[%{{.*}}], %[[M]], %[[P]] : memref<16xf32>, vector<4xi1>, vector<4xf32> into vector<4xf32>
// CHECK:         %[[B:.+]] = vector.broadcast %[[L]] : vector<4xf32> to vector<1x4xf32>
// CHECK:         return %[[B]] : vector<1x4xf32>
func.func @cast_away_expandload_leading_one_dims(%base: memref<16xf32>, %i: index, %mask: vector<1x4xi1>, %pass: vector<1x4xf32>) -> vector<1x4xf32> {
  %0 = vector.expandload %base[%i], %mask, %pass : memref<16xf32>, vector<1x4xi1>, vector<1x4xf32> into vector<1x4xf32>
  return %0 : vector<1x4xf32>
}

// -----

// CHECK-LABEL: func.func @cast_away_gather_leading_one_dims
// CHECK:         %[[I:.+]] = vector.extract %{{.*}}[0] : vector<4xi32> from vector<1x4xi32>
// CHECK:         %[[M:.+]] = vector.extract %{{.*}}[0] : vector<4xi1> from vector<1x4xi1>
// CHECK:         %[[P:.+]] = vector.extract %{{.*}}[0] : vector<4xf32> from vector<1x4xf32>
// CHECK:         %[[G:.+]] = vector.gather %{{.*}}[%{{.*}}] [%[[I]]], %[[M]], %[[P]] : memref<16xf32>, vector<4xi32>, vector<4xi1>, vector<4xf32> into vector<4xf32>
// CHECK:         %[[B:.+]] = vector.broadcast %[[G]] : vector<4xf32> to vector<1x4xf32>
// CHECK:         return %[[B]] : vector<1x4xf32>
func.func @cast_away_gather_leading_one_dims(%base: memref<16xf32>, %i: index, %idx: vector<1x4xi32>, %mask: vector<1x4xi1>, %pass: vector<1x4xf32>) -> vector<1x4xf32> {
  %0 = vector.gather %base[%i] [%idx], %mask, %pass : memref<16xf32>, vector<1x4xi32>, vector<1x4xi1>, vector<1x4xf32> into vector<1x4xf32>
  return %0 : vector<1x4xf32>
}

// -----

// CHECK-LABEL: func.func @cast_away_store_leading_one_dims
// CHECK:         %[[V:.+]] = vector.extract %{{.*}}[0] : vector<4xf32> from vector<1x4xf32>
// CHECK:         vector.store %[[V]], %{{.*}}[%{{.*}}, %{{.*}}] : memref<8x16xf32>, vector<4xf32>
func.func @cast_away_store_leading_one_dims(%val: vector<1x4xf32>, %base: memref<8x16xf32>, %i: index, %j: index) {
  vector.store %val, %base[%i, %j] : memref<8x16xf32>, vector<1x4xf32>
  return
}

// -----

// CHECK-LABEL: func.func @cast_away_maskedstore_leading_one_dims
// CHECK:         %[[M:.+]] = vector.extract %{{.*}}[0] : vector<4xi1> from vector<1x4xi1>
// CHECK:         %[[V:.+]] = vector.extract %{{.*}}[0] : vector<4xf32> from vector<1x4xf32>
// CHECK:         vector.maskedstore %{{.*}}[%{{.*}}], %[[M]], %[[V]] : memref<16xf32>, vector<4xi1>, vector<4xf32>
func.func @cast_away_maskedstore_leading_one_dims(%base: memref<16xf32>, %i: index, %mask: vector<1x4xi1>, %val: vector<1x4xf32>) {
  vector.maskedstore %base[%i], %mask, %val : memref<16xf32>, vector<1x4xi1>, vector<1x4xf32>
  return
}

// -----

// CHECK-LABEL: func.func @cast_away_compressstore_leading_one_dims
// CHECK:         %[[M:.+]] = vector.extract %{{.*}}[0] : vector<4xi1> from vector<1x4xi1>
// CHECK:         %[[V:.+]] = vector.extract %{{.*}}[0] : vector<4xf32> from vector<1x4xf32>
// CHECK:         vector.compressstore %{{.*}}[%{{.*}}], %[[M]], %[[V]] : memref<16xf32>, vector<4xi1>, vector<4xf32>
func.func @cast_away_compressstore_leading_one_dims(%base: memref<16xf32>, %i: index, %mask: vector<1x4xi1>, %val: vector<1x4xf32>) {
  vector.compressstore %base[%i], %mask, %val : memref<16xf32>, vector<1x4xi1>, vector<1x4xf32>
  return
}

// -----

// CHECK-LABEL: func.func @cast_away_scatter_leading_one_dims
// CHECK:         %[[I:.+]] = vector.extract %{{.*}}[0] : vector<4xi32> from vector<1x4xi32>
// CHECK:         %[[M:.+]] = vector.extract %{{.*}}[0] : vector<4xi1> from vector<1x4xi1>
// CHECK:         %[[V:.+]] = vector.extract %{{.*}}[0] : vector<4xf32> from vector<1x4xf32>
// CHECK:         vector.scatter %{{.*}}[%{{.*}}] [%[[I]]], %[[M]], %[[V]] : memref<16xf32>, vector<4xi32>, vector<4xi1>, vector<4xf32>
func.func @cast_away_scatter_leading_one_dims(%base: memref<16xf32>, %i: index, %idx: vector<1x4xi32>, %mask: vector<1x4xi1>, %val: vector<1x4xf32>) {
  vector.scatter %base[%i] [%idx], %mask, %val : memref<16xf32>, vector<1x4xi32>, vector<1x4xi1>, vector<1x4xf32>
  return
}