Commit 16e82d85 authored by Ahmed Taei's avatar Ahmed Taei
Browse files

[mlir] Add primitive transform pattern to rewrite linalg.fill into vector.broadcast form.

Summary:
This diff adds a transformation patter to rewrite linalg.fill as broadcasting a scaler into a vector.
It uses the same preconditioning as matmul (memory is contiguous).

Reviewers: nicolasvasilache

Subscribers: mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, aartbik, liufengdb, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D73391
parent 60b88420
Loading
Loading
Loading
Loading
+33 −18
Original line number Diff line number Diff line
@@ -16,10 +16,12 @@
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/VectorOps/VectorOps.h"
#include "mlir/EDSC/Helpers.h"
#include "mlir/EDSC/Intrinsics.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include <type_traits>
@@ -156,8 +158,8 @@ static bool isMatmul(linalg::GenericOp genericOp) {
         genericOp.indexing_maps() == maps && hasMultiplyAddBody(genericOp);
}

// TODO(ntv): This is in fact much more general than just vectorization for
// matmul ops.
// TODO(ntv, ataei): This is in fact much more general than just vectorization
// for matmul and fill ops.
LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
  auto linalgOp = cast<linalg::LinalgOp>(op);
  // All types must be static shape to go to vector.
@@ -167,7 +169,7 @@ LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
  for (Type outputTensorType : linalgOp.getOutputTensorTypes())
    if (!outputTensorType.cast<ShapedType>().hasStaticShape())
      return failure();
  if (isa<linalg::MatmulOp>(op))
  if (isa<linalg::MatmulOp>(op) || isa<linalg::FillOp>(op))
    return success();

  auto genericOp = dyn_cast<linalg::GenericOp>(op);
@@ -189,21 +191,33 @@ LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {

SmallVector<Value, 0> mlir::linalg::vectorizeLinalgOp(PatternRewriter &rewriter,
                                                      Operation *op) {
  LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE
                       "]: Rewrite linalg op as vector.contract: "
                    << *op << ":\n");
  using edsc::intrinsics::std_load;
  using edsc::intrinsics::std_store;
  using vector_contract = edsc::intrinsics::ValueBuilder<vector::ContractionOp>;
  using vector_broadcast = edsc::intrinsics::ValueBuilder<vector::BroadcastOp>;
  using vector_type_cast = edsc::intrinsics::ValueBuilder<vector::TypeCastOp>;

  assert(succeeded(vectorizeLinalgOpPrecondition(op)) &&
         "DRR failure case must be a precondition");

  auto linalgOp = cast<linalg::LinalgOp>(op);
  assert(linalgOp.hasBufferSemantics() &&
         "expected linalg op with buffer semantics");
  edsc::ScopedContext scope(rewriter, op->getLoc());
  using edsc::intrinsics::std_load;
  using edsc::intrinsics::std_store;
  using vector_contract = edsc::intrinsics::ValueBuilder<vector::ContractionOp>;
  using vector_type_cast = edsc::intrinsics::ValueBuilder<vector::TypeCastOp>;

  if (auto fillOp = dyn_cast<linalg::FillOp>(op)) {
    // Vectorize fill as a vector.broadcast.
    LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE
                         "]: Rewrite linalg.fill as vector.broadcast: "
                      << *op << ":\n");
    auto dstMemrefVec = vector_type_cast(fillOp.getOutputBuffer(0));
    auto dstVec = std_load(dstMemrefVec);
    auto resVec = vector_broadcast(dstVec, fillOp.value());
    std_store(resVec, dstMemrefVec);
  } else {
    // Vectorize other ops as vector contraction (currently only matmul).
    LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE
                         "]: Rewrite linalg op as vector.contract: "
                      << *op << ":\n");
    auto vA = std_load(vector_type_cast(linalgOp.getInput(0)));
    auto vB = std_load(vector_type_cast(linalgOp.getInput(1)));
    auto vectorMemRefC = vector_type_cast(linalgOp.getOutputBuffer(0));
@@ -211,6 +225,7 @@ SmallVector<Value, 0> mlir::linalg::vectorizeLinalgOp(PatternRewriter &rewriter,
    auto vRes = vector_contract(vA, vB, vC, linalgOp.indexing_maps(),
                                linalgOp.iterator_types());
    std_store(vRes, vectorMemRefC);
  }
  return {};
}

+7 −0
Original line number Diff line number Diff line
@@ -205,6 +205,13 @@ func @vectorization_test_2(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
//       CHECK: vector.contract {{.*}} :
//                vector<8x16xf32>, vector<16x32xf32> into vector<8x32xf32>

func @test_vectorize_fill(%A : memref<8x16xf32>, %arg0 : f32) {
  linalg.fill(%A, %arg0) { __internal_linalg_transform__ = "VECTORIZE"} :  memref<8x16xf32>, f32
  return
}
// CHECK-LABEL: func @test_vectorize_fill
//       CHECK: vector.broadcast {{.*}} : f32 to vector<8x16xf32>

func @fma(%a: f32, %b: f32, %c: f32) -> f32 {
          %d = mulf %a, %b: f32
          %e = addf %c, %d: f32
+7 −0
Original line number Diff line number Diff line
@@ -105,6 +105,12 @@ def : Pattern<(MatmulOp:$op $_, $_, $_),
                HasLinalgTransformMarker<"VECTORIZE">,
                PreconditionVectorizeLinalgOp
               ]>>)]>;
def : Pattern<(FillOp:$op $_, $_),
              [(VectorizeLinalgOp)],
              [(Constraint<And<[
                HasLinalgTransformMarker<"VECTORIZE">,
                PreconditionVectorizeLinalgOp
               ]>>)]>;
def : Pattern<(GenericOp:$op $_, $_, $_, $_, $_, $_, $_, $_),
              [(VectorizeLinalgOp)],
              [(Constraint<And<[
@@ -112,6 +118,7 @@ def : Pattern<(GenericOp:$op $_, $_, $_, $_, $_, $_, $_, $_),
                PreconditionVectorizeLinalgOp
               ]>>)]>;


//===----------------------------------------------------------------------===//
// Linalg generic permutation patterns.
//===----------------------------------------------------------------------===//