Commit 499ad458 authored by Nicolas Vasilache's avatar Nicolas Vasilache
Browse files

[mlir][VectorOps] Expose and use llvm.intrin.fma*

Summary:
This revision exposes the portable `llvm.fma` intrinsic in LLVMOps and uses it
in lieu of `llvm.fmuladd` when lowering the `vector.outerproduct` op to LLVM.
This guarantees proper `fma` instructions will be emitted if the target ISA
supports it.

`llvm.fmuladd` does not have this guarantee in its semantics, despite evidence
that the proper x86 instructions are emitted.

For more details, see https://llvm.org/docs/LangRef.html#llvm-fmuladd-intrinsic.

Reviewers: ftynse, aartbik, dcaballe, fhahn

Reviewed By: aartbik

Subscribers: mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, liufengdb, Joonsoo, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D74219
parent e8e05de0
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -732,6 +732,7 @@ def LLVM_FAbsOp : LLVM_UnaryIntrinsicOp<"fabs">;
def LLVM_FCeilOp : LLVM_UnaryIntrinsicOp<"ceil">;
def LLVM_CosOp : LLVM_UnaryIntrinsicOp<"cos">;
def LLVM_CopySignOp : LLVM_BinarySameArgsIntrinsicOp<"copysign">;
def LLVM_FMAOp : LLVM_TernarySameArgsIntrinsicOp<"fma">;
def LLVM_FMulAddOp : LLVM_TernarySameArgsIntrinsicOp<"fmuladd">;
def LLVM_SqrtOp : LLVM_UnaryIntrinsicOp<"sqrt">;

+9 −5
Original line number Diff line number Diff line
@@ -567,21 +567,25 @@ def Vector_OuterProductOp :
    Results<(outs AnyVector)> {
  let summary = "vector outerproduct with optional fused add";
  let description = [{
    Takes 2 1-D vectors and returns the 2-D vector containing the outer product.
    Takes 2 1-D vectors and returns the 2-D vector containing the outer-product.

    An optional extra 2-D vector argument may be specified in which case the
    operation returns the sum of the outer product and the extra vector. When
    lowered to the LLVMIR dialect, this form emits `llvm.intr.fmuladd`, which
    can lower to actual `fma` instructions in LLVM.
    operation returns the sum of the outer-product and the extra vector. In this
    multiply-accumulate scenario, the rounding mode is that obtained by
    guaranteeing that a fused-multiply add operation is emitted. When lowered to
    the LLVMIR dialect, this form emits `llvm.intr.fma`, which is guaranteed to
    lower to actual `fma` instructions on x86.

    Examples
    Examples:

    ```
      %2 = vector.outerproduct %0, %1: vector<4xf32>, vector<8xf32>
      return %2: vector<4x8xf32>

      %3 = vector.outerproduct %0, %1, %2:
        vector<4xf32>, vector<8xf32>, vector<4x8xf32>
      return %3: vector<4x8xf32>
    ```
  }];
  let extraClassDeclaration = [{
    VectorType getOperandVectorTypeLHS() {
+3 −3
Original line number Diff line number Diff line
@@ -674,8 +674,8 @@ public:
            loc, vRHS, acc, rewriter.getI64ArrayAttr(d));
      // 3. Compute aD outer b (plus accD, if relevant).
      Value aOuterbD =
          accD ? rewriter.create<LLVM::FMulAddOp>(loc, vRHS, aD, b, accD)
                     .getResult()
          accD
              ? rewriter.create<LLVM::FMAOp>(loc, vRHS, aD, b, accD).getResult()
              : rewriter.create<LLVM::FMulOp>(loc, aD, b).getResult();
      // 4. Insert as value `d` in the descriptor.
      desc = rewriter.create<LLVM::InsertValueOp>(loc, llvmArrayOfVectType,
+2 −2
Original line number Diff line number Diff line
@@ -222,11 +222,11 @@ func @outerproduct_add(%arg0: vector<2xf32>, %arg1: vector<3xf32>, %arg2: vector
//       CHECK:   llvm.mlir.undef : !llvm<"[2 x <3 x float>]">
//       CHECK:   llvm.shufflevector {{.*}} [0 : i32, 0 : i32, 0 : i32] : !llvm<"<2 x float>">, !llvm<"<2 x float>">
//       CHECK:   llvm.extractvalue {{.*}}[0] : !llvm<"[2 x <3 x float>]">
//       CHECK:   "llvm.intr.fmuladd"({{.*}}) : (!llvm<"<3 x float>">, !llvm<"<3 x float>">, !llvm<"<3 x float>">) -> !llvm<"<3 x float>">
//       CHECK:   "llvm.intr.fma"({{.*}}) : (!llvm<"<3 x float>">, !llvm<"<3 x float>">, !llvm<"<3 x float>">) -> !llvm<"<3 x float>">
//       CHECK:   llvm.insertvalue {{.*}}[0] : !llvm<"[2 x <3 x float>]">
//       CHECK:   llvm.shufflevector {{.*}} [1 : i32, 1 : i32, 1 : i32] : !llvm<"<2 x float>">, !llvm<"<2 x float>">
//       CHECK:   llvm.extractvalue {{.*}}[1] : !llvm<"[2 x <3 x float>]">
//       CHECK:   "llvm.intr.fmuladd"({{.*}}) : (!llvm<"<3 x float>">, !llvm<"<3 x float>">, !llvm<"<3 x float>">) -> !llvm<"<3 x float>">
//       CHECK:   "llvm.intr.fma"({{.*}}) : (!llvm<"<3 x float>">, !llvm<"<3 x float>">, !llvm<"<3 x float>">) -> !llvm<"<3 x float>">
//       CHECK:   llvm.insertvalue {{.*}}[1] : !llvm<"[2 x <3 x float>]">
//       CHECK:   llvm.return {{.*}} : !llvm<"[2 x <3 x float>]">

+1 −1
Original line number Diff line number Diff line
@@ -171,7 +171,7 @@ func @matmul_vec_impl(%A: !matrix_type_A, %B: !matrix_type_B, %C: !matrix_type_C
//   LLVM-LOOPS: llvm.shufflevector {{.*}} [2 : i32, 2 : i32, 2 : i32, 2 : i32] : !llvm<"<4 x float>">, !llvm<"<4 x float>">
//   LLVM-LOOPS: llvm.shufflevector {{.*}} [3 : i32, 3 : i32, 3 : i32, 3 : i32] : !llvm<"<4 x float>">, !llvm<"<4 x float>">
//   LLVM-LOOPS-NEXT: llvm.extractvalue {{.*}}[3] : !llvm<"[4 x <4 x float>]">
//   LLVM-LOOPS-NEXT: "llvm.intr.fmuladd"({{.*}}) : (!llvm<"<4 x float>">, !llvm<"<4 x float>">, !llvm<"<4 x float>">) -> !llvm<"<4 x float>">
//   LLVM-LOOPS-NEXT: "llvm.intr.fma"({{.*}}) : (!llvm<"<4 x float>">, !llvm<"<4 x float>">, !llvm<"<4 x float>">) -> !llvm<"<4 x float>">
//   LLVM-LOOPS-NEXT: llvm.insertvalue {{.*}}, {{.*}}[3] : !llvm<"[4 x <4 x float>]">


Loading