Commit 137415ad authored by Nicolas Vasilache's avatar Nicolas Vasilache
Browse files

[mlir][EDSC][Linalg] Compose linalg_matmul and vector.contract

Summary:
This revision allows model builder to create a linalg_matmul whose body
is a vector.contract. This shows the abstractions compose nicely.

Differential Revision: https://reviews.llvm.org/D74457
parent 5ed15ff6
Loading
Loading
Loading
Loading
+21 −7
Original line number Diff line number Diff line
@@ -18,6 +18,9 @@
#include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/EDSC/Builders.h"
#include "mlir/EDSC/Intrinsics.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Builders.h"

namespace mlir {
class AffineForOp;
@@ -127,8 +130,12 @@ using edsc::ValueHandle;
// EDSC builders for linalg generic operations.
//===----------------------------------------------------------------------===//

/// Build the body of a region to compute a multiply-accumulate, under the
/// current ScopedContext, at the current insert point.
/// Build the body of a region to compute a scalar multiply, under the current
/// ScopedContext, at the current insert point.
void mulRegionBuilder(ArrayRef<BlockArgument> args);

/// Build the body of a region to compute a scalar multiply-accumulate, under
/// the current ScopedContext, at the current insert point.
void macRegionBuilder(ArrayRef<BlockArgument> args);

/// TODO(ntv): In the future we should tie these implementations to something in
@@ -182,6 +189,8 @@ Operation *linalg_pointwise_max(StructuredIndexed I1, StructuredIndexed I2,

// TODO(ntv): Implement more useful pointwise operations on a per-need basis.

using MatmulRegionBuilder = function_ref<void(ArrayRef<BlockArgument> args)>;

/// Build a linalg.generic, under the current ScopedContext, at the current
/// insert point, that computes:
/// ```
@@ -189,7 +198,8 @@ Operation *linalg_pointwise_max(StructuredIndexed I1, StructuredIndexed I2,
///    |
///    |  C(m, n) += A(m, k) * B(k, n)
/// ```
Operation *linalg_matmul(ValueHandle vA, ValueHandle vB, ValueHandle vC);
Operation *linalg_matmul(ValueHandle vA, ValueHandle vB, ValueHandle vC,
                         MatmulRegionBuilder regionBuilder = macRegionBuilder);

/// Build a linalg.generic, under the current ScopedContext, at the current
/// insert point, that computes:
@@ -199,7 +209,8 @@ Operation *linalg_matmul(ValueHandle vA, ValueHandle vB, ValueHandle vC);
///    |  C(m, n) = sum_k(A(m, k) * B(k, n))
/// ```
/// and returns the tensor `C`.
Operation *linalg_matmul(ValueHandle vA, ValueHandle vB, RankedTensorType tC);
Operation *linalg_matmul(ValueHandle vA, ValueHandle vB, RankedTensorType tC,
                         MatmulRegionBuilder regionBuilder = mulRegionBuilder);

/// Build a linalg.generic, under the current ScopedContext, at the current
/// insert point, that computes:
@@ -210,11 +221,14 @@ Operation *linalg_matmul(ValueHandle vA, ValueHandle vB, RankedTensorType tC);
/// ```
/// and returns the tensor `D`.
Operation *linalg_matmul(ValueHandle vA, ValueHandle vB, ValueHandle vC,
                         RankedTensorType tD);
                         RankedTensorType tD,
                         MatmulRegionBuilder regionBuilder = macRegionBuilder);

template <typename Container> Operation *linalg_matmul(Container values) {
template <typename Container>
Operation *linalg_matmul(Container values,
                         MatmulRegionBuilder regionBuilder = macRegionBuilder) {
  assert(values.size() == 3 && "Expected exactly 3 values");
  return linalg_matmul(values[0], values[1], values[2]);
  return linalg_matmul(values[0], values[1], values[2], regionBuilder);
}

/// Build a linalg.generic, under the current ScopedContext, at the current
+17 −6
Original line number Diff line number Diff line
@@ -212,6 +212,14 @@ static void mulRegionBuilder(ArrayRef<BlockArgument> args) {
  linalg_yield((a * b).getValue());
}

void mlir::edsc::ops::mulRegionBuilder(ArrayRef<BlockArgument> args) {
  using edsc::op::operator+;
  using edsc::op::operator*;
  assert(args.size() == 2 && "expected 2 block arguments");
  ValueHandle a(args[0]), b(args[1]);
  linalg_yield((a * b).getValue());
}

void mlir::edsc::ops::macRegionBuilder(ArrayRef<BlockArgument> args) {
  using edsc::op::operator+;
  using edsc::op::operator*;
@@ -291,7 +299,8 @@ Operation *mlir::edsc::ops::linalg_pointwise_max(StructuredIndexed I1,
}

Operation *mlir::edsc::ops::linalg_matmul(ValueHandle vA, ValueHandle vB,
                                          ValueHandle vC) {
                                          ValueHandle vC,
                                          MatmulRegionBuilder regionBuilder) {
  // clang-format off
  AffineExpr m, n, k;
  bindDims(ScopedContext::getContext(), m, n, k);
@@ -300,12 +309,13 @@ Operation *mlir::edsc::ops::linalg_matmul(ValueHandle vA, ValueHandle vB,
    {IteratorType::Parallel, IteratorType::Parallel, IteratorType::Reduction},
    {A({m, k}), B({k, n})},
    {C({m, n})},
    macRegionBuilder);
    regionBuilder);
  // clang-format on
}

Operation *mlir::edsc::ops::linalg_matmul(ValueHandle vA, ValueHandle vB,
                                          RankedTensorType tC) {
                                          RankedTensorType tC,
                                          MatmulRegionBuilder regionBuilder) {
  // clang-format off
  AffineExpr m, n, k;
  bindDims(ScopedContext::getContext(), m, n, k);
@@ -314,12 +324,13 @@ Operation *mlir::edsc::ops::linalg_matmul(ValueHandle vA, ValueHandle vB,
    {IteratorType::Parallel, IteratorType::Parallel, IteratorType::Reduction},
    {A({m, k}), B({k, n})},
    {C({m, n})},
    mulRegionBuilder);
    regionBuilder);
  // clang-format on
}

Operation *mlir::edsc::ops::linalg_matmul(ValueHandle vA, ValueHandle vB,
                                          ValueHandle vC, RankedTensorType tD) {
                                          ValueHandle vC, RankedTensorType tD,
                                          MatmulRegionBuilder regionBuilder) {
  // clang-format off
  AffineExpr m, n, k;
  bindDims(ScopedContext::getContext(), m, n, k);
@@ -328,7 +339,7 @@ Operation *mlir::edsc::ops::linalg_matmul(ValueHandle vA, ValueHandle vB,
    {IteratorType::Parallel, IteratorType::Parallel, IteratorType::Reduction},
    {A({m, k}), B({k, n}), C({m, n})},
    {D({m, n})},
    macRegionBuilder);
    regionBuilder);
  // clang-format on
}

+30 −14
Original line number Diff line number Diff line
@@ -990,17 +990,20 @@ TEST_FUNC(linalg_tensors_test) {
  f.erase();
}

// CHECK-LABEL: func @vector_matmul_test(
//  CHECK-SAME:   %[[A:.*]]: vector<4x16xf32>,
//  CHECK-SAME:   %[[B:.*]]: vector<16x8xf32>,
//  CHECK-SAME:   %[[C:.*]]: vector<4x8xf32>)
//  CHECK:   vector.contract {{.*}}[affine_map<(d0, d1, d2) -> (d0, d2)>,
// CHECK-LABEL: func @memref_vector_matmul_test(
//  CHECK-SAME:   %[[A:.*]]: memref<?x?xvector<4x16xf32>>,
//  CHECK-SAME:   %[[B:.*]]: memref<?x?xvector<16x8xf32>>,
//  CHECK-SAME:   %[[C:.*]]: memref<?x?xvector<4x8xf32>>)
//       CHECK:   linalg.generic {{.*}} %[[A]], %[[B]], %[[C]]
//       CHECK:     vector.contract{{.*}}[affine_map<(d0, d1, d2) -> (d0,
//  d2)>,
//  CHECK-SAME:                       affine_map<(d0, d1, d2) -> (d2, d1)>,
//  CHECK-SAME:                       affine_map<(d0, d1, d2) -> (d0, d1)>],
//  CHECK-SAME:                {{.*}}["parallel", "parallel", "reduction"]
//  CHECK-SAME: %[[A]], %[[B]], %[[C]]
//  CHECK-SAME:     vector<4x16xf32>, vector<16x8xf32> into vector<4x8xf32>
TEST_FUNC(vector_matmul_test) {
//       CHECK:   memref<?x?xvector<4x16xf32>>, memref<?x?xvector<16x8xf32>>,
//  CHECK-SAME:   memref<?x?xvector<4x8xf32>>
TEST_FUNC(memref_vector_matmul_test) {
  using namespace edsc;
  using namespace edsc::ops;

@@ -1009,13 +1012,26 @@ TEST_FUNC(vector_matmul_test) {
  auto mkVectorType = VectorType::get({M, K}, f32Type);
  auto knVectorType = VectorType::get({K, N}, f32Type);
  auto mnVectorType = VectorType::get({M, N}, f32Type);
  auto f = makeFunction("vector_matmul_test", {},
                        {mkVectorType, knVectorType, mnVectorType});
  auto typeA =
      MemRefType::get({ShapedType::kDynamicSize, ShapedType::kDynamicSize},
                      mkVectorType, {}, 0);
  auto typeB =
      MemRefType::get({ShapedType::kDynamicSize, ShapedType::kDynamicSize},
                      knVectorType, {}, 0);
  auto typeC =
      MemRefType::get({ShapedType::kDynamicSize, ShapedType::kDynamicSize},
                      mnVectorType, {}, 0);
  auto f = makeFunction("memref_vector_matmul_test", {}, {typeA, typeB, typeC});

  OpBuilder builder(f.getBody());
  ScopedContext scope(builder, f.getLoc());
  ValueHandle A(f.getArgument(0)), B(f.getArgument(1)), C(f.getArgument(2));
  vector_matmul(A, B, C);
  auto contractionBuilder = [](ArrayRef<BlockArgument> args) {
    assert(args.size() == 3 && "expected 3 block arguments");
    (linalg_yield(vector_matmul(args[0], args[1], args[2])));
  };
  linalg_matmul(A, B, C, contractionBuilder);

  f.print(llvm::outs());
  f.erase();
}