Commit 681f929f authored by Nicolas Vasilache's avatar Nicolas Vasilache
Browse files

[mlir][VectorOps] Introduce a `vector.fma` op that works on n-D vectors and...

[mlir][VectorOps] Introduce a `vector.fma` op that works on n-D vectors and lowers to `llvm.intrin.fmuladd`

Summary:
The `vector.fma` operation is portable enough across targets that we do not want
to keep it wrapped under `vector.outerproduct` and `llvm.intrin.fmuladd`.
This revision lifts the op into the vector dialect and implements the lowering to LLVM by using two patterns:
1. a pattern that lowers from n-D to (n-1)-D by unrolling when n > 2
2. a pattern that converts from 1-D to the proper LLVM representation

Reviewers: ftynse, stellaraccident, aartbik, dcaballe, jsetoain, tetuante

Reviewed By: aartbik

Subscribers: fhahn, dcaballe, merge_guards_bot, mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, aartbik, liufengdb, Joonsoo, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D74075
parent 26bf877e
Loading
Loading
Loading
Loading
+32 −0
Original line number Diff line number Diff line
@@ -388,6 +388,38 @@ def Vector_ExtractSlicesOp :
  }];
}

def Vector_FMAOp :
  Op<Vector_Dialect, "fma", [NoSideEffect,
                             AllTypesMatch<["lhs", "rhs", "acc", "result"]>]>,
    Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, AnyVector:$acc)>,
    Results<(outs AnyVector:$result)> {
  let summary = "vector fused multiply-add";
  let description = [{
    Multiply-add expressions operate on n-D vectors and compute a fused
    pointwise multiply-and-accumulate: `$result = `$lhs * $rhs + $acc`.
    All operands and result have the same vector type. The semantics
    of the operation correspond to those of the `llvm.fma`
    [intrinsic](https://llvm.org/docs/LangRef.html#int-fma). In the
    particular case of lowering to LLVM, this is guaranteed to lower
    to the `llvm.fma.*` intrinsic.

    Example:
    
    ```
      %3 = vector.fma %0, %1, %2: vector<8x16xf32>
    ```
  }];
  // Fully specified by traits.
  let verifier = ?;
  let assemblyFormat = "$lhs `,` $rhs `,` $acc attr-dict `:` type($lhs)";
  let builders = [OpBuilder<
    "Builder *b, OperationState &result, Value lhs, Value rhs, Value acc",
    "build(b, result, lhs.getType(), lhs, rhs, acc);">];
  let extraClassDeclaration = [{
    VectorType getVectorType() { return lhs().getType().cast<VectorType>(); }
  }];
}

def Vector_InsertElementOp :
  Vector_Op<"insertelement", [NoSideEffect,
     PredOpTrait<"source operand and result have same element type",
+89 −4
Original line number Diff line number Diff line
@@ -410,6 +410,41 @@ public:
  }
};

/// Conversion pattern that turns a vector.fma on a 1-D vector
/// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion.
/// This does not match vectors of n >= 2 rank.
///
/// Example:
/// ```
///  vector.fma %a, %a, %a : vector<8xf32>
/// ```
/// is converted to:
/// ```
///  llvm.intr.fma %va, %va, %va:
///    (!llvm<"<8 x float>">, !llvm<"<8 x float>">, !llvm<"<8 x float>">)
///    -> !llvm<"<8 x float>">
/// ```
class VectorFMAOp1DConversion : public LLVMOpLowering {
public:
  explicit VectorFMAOp1DConversion(MLIRContext *context,
                                   LLVMTypeConverter &typeConverter)
      : LLVMOpLowering(vector::FMAOp::getOperationName(), context,
                       typeConverter) {}

  PatternMatchResult
  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                  ConversionPatternRewriter &rewriter) const override {
    auto adaptor = vector::FMAOpOperandAdaptor(operands);
    vector::FMAOp fmaOp = cast<vector::FMAOp>(op);
    VectorType vType = fmaOp.getVectorType();
    if (vType.getRank() != 1)
      return matchFailure();
    rewriter.replaceOpWithNewOp<LLVM::FMAOp>(op, adaptor.lhs(), adaptor.rhs(),
                                             adaptor.acc());
    return matchSuccess();
  }
};

class VectorInsertElementOpConversion : public LLVMOpLowering {
public:
  explicit VectorInsertElementOpConversion(MLIRContext *context,
@@ -502,6 +537,54 @@ public:
  }
};

/// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1.
///
/// Example:
/// ```
///   %d = vector.fma %a, %b, %c : vector<2x4xf32>
/// ```
/// is rewritten into:
/// ```
///  %r = splat %f0: vector<2x4xf32>
///  %va = vector.extractvalue %a[0] : vector<2x4xf32>
///  %vb = vector.extractvalue %b[0] : vector<2x4xf32>
///  %vc = vector.extractvalue %c[0] : vector<2x4xf32>
///  %vd = vector.fma %va, %vb, %vc : vector<4xf32>
///  %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32>
///  %va2 = vector.extractvalue %a2[1] : vector<2x4xf32>
///  %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32>
///  %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32>
///  %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32>
///  %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32>
///  // %r3 holds the final value.
/// ```
class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
public:
  using OpRewritePattern<FMAOp>::OpRewritePattern;

  PatternMatchResult matchAndRewrite(FMAOp op,
                                     PatternRewriter &rewriter) const override {
    auto vType = op.getVectorType();
    if (vType.getRank() < 2)
      return matchFailure();

    auto loc = op.getLoc();
    auto elemType = vType.getElementType();
    Value zero = rewriter.create<ConstantOp>(loc, elemType,
                                             rewriter.getZeroAttr(elemType));
    Value desc = rewriter.create<SplatOp>(loc, vType, zero);
    for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
      Value extrLHS = rewriter.create<ExtractOp>(loc, op.lhs(), i);
      Value extrRHS = rewriter.create<ExtractOp>(loc, op.rhs(), i);
      Value extrACC = rewriter.create<ExtractOp>(loc, op.acc(), i);
      Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC);
      desc = rewriter.create<InsertOp>(loc, fma, desc, i);
    }
    rewriter.replaceOp(op, desc);
    return matchSuccess();
  }
};

// When ranks are different, InsertStridedSlice needs to extract a properly
// ranked vector from the destination vector into which to insert. This pattern
// only takes care of this part and forwards the rest of the conversion to
@@ -969,14 +1052,16 @@ public:
void mlir::populateVectorToLLVMConversionPatterns(
    LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
  MLIRContext *ctx = converter.getDialect()->getContext();
  patterns.insert<VectorInsertStridedSliceOpDifferentRankRewritePattern,
  patterns.insert<VectorFMAOpNDRewritePattern,
                  VectorInsertStridedSliceOpDifferentRankRewritePattern,
                  VectorInsertStridedSliceOpSameRankRewritePattern,
                  VectorStridedSliceOpConversion>(ctx);
  patterns.insert<VectorBroadcastOpConversion, VectorShuffleOpConversion,
                  VectorExtractElementOpConversion, VectorExtractOpConversion,
                  VectorInsertElementOpConversion, VectorInsertOpConversion,
                  VectorOuterProductOpConversion, VectorTypeCastOpConversion,
                  VectorPrintOpConversion>(ctx, converter);
                  VectorFMAOp1DConversion, VectorInsertElementOpConversion,
                  VectorInsertOpConversion, VectorOuterProductOpConversion,
                  VectorTypeCastOpConversion, VectorPrintOpConversion>(
      ctx, converter);
}

namespace {
+26 −0
Original line number Diff line number Diff line
@@ -637,3 +637,29 @@ func @extract_strides(%arg0: vector<3x3xf32>) -> vector<1x1xf32> {
//      CHECK: %[[s7:.*]] = llvm.insertelement %[[s5]], %[[s3]][%[[s6]] : !llvm.i64] : !llvm<"<1 x float>">
//      CHECK: %[[s8:.*]] = llvm.insertvalue %[[s7]], %[[s0]][0] : !llvm<"[1 x <1 x float>]">
//      CHECK: llvm.return %[[s8]] : !llvm<"[1 x <1 x float>]">

// CHECK-LABEL: llvm.func @vector_fma(
//  CHECK-SAME: %[[A:.*]]: !llvm<"<8 x float>">, %[[B:.*]]: !llvm<"[2 x <4 x float>]">)
//  CHECK-SAME: -> !llvm<"{ <8 x float>, [2 x <4 x float>] }"> {
func @vector_fma(%a: vector<8xf32>, %b: vector<2x4xf32>) -> (vector<8xf32>, vector<2x4xf32>) {
  //         CHECK: "llvm.intr.fma"(%[[A]], %[[A]], %[[A]]) :
  //    CHECK-SAME:   (!llvm<"<8 x float>">, !llvm<"<8 x float>">, !llvm<"<8 x float>">) -> !llvm<"<8 x float>">
  %0 = vector.fma %a, %a, %a : vector<8xf32>
  
  //       CHECK: %[[b00:.*]] = llvm.extractvalue %[[B]][0] : !llvm<"[2 x <4 x float>]">
  //       CHECK: %[[b01:.*]] = llvm.extractvalue %[[B]][0] : !llvm<"[2 x <4 x float>]">
  //       CHECK: %[[b02:.*]] = llvm.extractvalue %[[B]][0] : !llvm<"[2 x <4 x float>]">
  //       CHECK: %[[B0:.*]] = "llvm.intr.fma"(%[[b00]], %[[b01]], %[[b02]]) :
  //  CHECK-SAME: (!llvm<"<4 x float>">, !llvm<"<4 x float>">, !llvm<"<4 x float>">) -> !llvm<"<4 x float>">
  //       CHECK: llvm.insertvalue %[[B0]], {{.*}}[0] : !llvm<"[2 x <4 x float>]">
  //       CHECK: %[[b10:.*]] = llvm.extractvalue %[[B]][1] : !llvm<"[2 x <4 x float>]">
  //       CHECK: %[[b11:.*]] = llvm.extractvalue %[[B]][1] : !llvm<"[2 x <4 x float>]">
  //       CHECK: %[[b12:.*]] = llvm.extractvalue %[[B]][1] : !llvm<"[2 x <4 x float>]">
  //       CHECK: %[[B1:.*]] = "llvm.intr.fma"(%[[b10]], %[[b11]], %[[b12]]) :
  //  CHECK-SAME: (!llvm<"<4 x float>">, !llvm<"<4 x float>">, !llvm<"<4 x float>">) -> !llvm<"<4 x float>">
  //       CHECK: llvm.insertvalue %[[B1]], {{.*}}[1] : !llvm<"[2 x <4 x float>]">
  %1 = vector.fma %b, %b, %b : vector<2x4xf32>
  
  return %0, %1: vector<8xf32>, vector<2x4xf32>
}
        
+9 −0
Original line number Diff line number Diff line
@@ -268,3 +268,12 @@ func @shape_cast(%arg0 : vector<5x1x3x2xf32>,

  return %0, %1 : vector<15x2xf32>, tuple<vector<20x2xf32>, vector<12x2xf32>>
}

// CHECK-LABEL: @vector_fma
func @vector_fma(%a: vector<8xf32>, %b: vector<8x4xf32>) {
  // CHECK: vector.fma %{{.*}} : vector<8xf32>
  vector.fma %a, %a, %a : vector<8xf32>
  // CHECK: vector.fma %{{.*}} : vector<8x4xf32>
  vector.fma %b, %b, %b : vector<8x4xf32>
  return
}