Commit 8513ff05 authored by Nicolas Vasilache's avatar Nicolas Vasilache
Browse files

[mlir][VectorOps][EDSC] Add EDSC for VectorOps

Summary:
This revision adds EDSC support for VectorOps to enable the creation of a `vector_matmul` declaratively. The `vector_matmul` is a simple configuration
 of the `vector.contract` op that follows the StructuredOps abstraction.

Differential Revision: https://reviews.llvm.org/D74284
parent 62ce7e65
Loading
Loading
Loading
Loading
+53 −0
Original line number Diff line number Diff line
//===- Builders.h - MLIR Declarative Vector Builders ------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Provides intuitive composable interfaces for building structured MLIR
// snippets in a declarative fashion.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_VECTOR_EDSC_BUILDERS_H_
#define MLIR_DIALECT_VECTOR_EDSC_BUILDERS_H_

#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/VectorOps/VectorOps.h"
#include "mlir/EDSC/Builders.h"
#include "mlir/EDSC/Intrinsics.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Builders.h"

namespace mlir {
namespace edsc {
namespace ops {

/// Build a generic vector contraction, that is a `vector.contract` op with
/// specified `iteratorTypes`. The client is responsible for specifying proper
/// indexings when creating the StructuredIndexed.
/// The computation represents a notional (A * B + C) where indexings specify
/// which dimensions are reduced and reordered.
/// Return the result of the `vector.contract` op
///
/// Prerequisites:
/// A, B and C capture values of proper vector types, and indexing expressions
/// that match semantics of the `vector.contract` op.
Value vector_contraction(StructuredIndexed A, StructuredIndexed B,
                         StructuredIndexed C,
                         ArrayRef<IteratorType> iteratorTypes);

/// Build a generic vector contraction that computes a matmul on vectors.
/// Return the result of C(i, j) + sum_k {A(i, k) * B(k, j)} on vectors.
///
/// Prerequisites:
/// A, B and C capture values of proper vector types. For instance
/// `A: vector<4x8xf32>`, `B: vector<8x16f32>` and `C: vector<4x16xf32>`.
Value vector_matmul(Value A, Value B, Value C);

} // namespace ops
} // namespace edsc
} // namespace mlir

#endif // MLIR_DIALECT_VECTOR_EDSC_BUILDERS_H_
+23 −0
Original line number Diff line number Diff line
//===- Intrinsics.h - MLIR EDSC Intrinsics for VectorOps --------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_VECTOROPS_EDSC_INTRINSICS_H_
#define MLIR_DIALECT_VECTOROPS_EDSC_INTRINSICS_H_

#include "mlir/Dialect/VectorOps/EDSC/Builders.h"

namespace mlir {
namespace edsc {
namespace intrinsics {

using vector_contract = ValueBuilder<vector::ContractionOp>;

} // namespace intrinsics
} // namespace edsc
} // namespace mlir

#endif // MLIR_DIALECT_VECTOROPS_EDSC_INTRINSICS_H_
+5 −1
Original line number Diff line number Diff line
@@ -141,7 +141,11 @@ def Vector_ContractionOp :
  }];
  let builders = [OpBuilder<
    "Builder *builder, OperationState &result, Value lhs, Value rhs, "
    "Value acc, ArrayAttr indexingMaps, ArrayAttr iteratorTypes">];
    "Value acc, ArrayAttr indexingMaps, ArrayAttr iteratorTypes">,
    OpBuilder<
      "Builder *builder, OperationState &result, Value lhs, Value rhs, "
      "Value acc, ArrayRef<ArrayRef<AffineExpr>> indexingExprs, "
      "ArrayRef<StringRef> iteratorTypes">];
  let extraClassDeclaration = [{
    VectorType getLhsType() {
      return lhs().getType().cast<VectorType>();
+3 −2
Original line number Diff line number Diff line
@@ -436,8 +436,9 @@ struct StructuredIndexed : public ValueHandle {
  StructuredIndexed(Value v, ArrayRef<AffineExpr> indexings)
      : ValueHandle(v), exprs(indexings.begin(), indexings.end()) {
    assert((v.getType().isa<MemRefType>() ||
            v.getType().isa<RankedTensorType>()) &&
           "MemRef or RankedTensor expected");
            v.getType().isa<RankedTensorType>() ||
            v.getType().isa<VectorType>()) &&
           "MemRef, RankedTensor or Vector expected");
  }
  StructuredIndexed(ValueHandle vh, ArrayRef<AffineExpr> indexings)
      : ValueHandle(vh), exprs(indexings.begin(), indexings.end()) {}
+8 −0
Original line number Diff line number Diff line
@@ -63,6 +63,14 @@ public:
  static AffineMap getPermutationMap(ArrayRef<unsigned> permutation,
                                     MLIRContext *context);

  /// Returns a vector of AffineMaps; each with as many results as
  /// `exprs.size()`, as many dims as the largest dim in `exprs` and as many
  /// symbols as the largest symbol in `exprs`.
  static SmallVector<AffineMap, 4>
  inferFromExprList(ArrayRef<ArrayRef<AffineExpr>> exprsList);
  static SmallVector<AffineMap, 4>
  inferFromExprList(ArrayRef<SmallVector<AffineExpr, 4>> exprsList);

  MLIRContext *getContext() const;

  explicit operator bool() { return map != nullptr; }
Loading