Commit e83b7b99 authored by aartbik's avatar aartbik
Browse files

[mlir] [VectorOps] Implement vector.reduce operation

Summary:
This new operation operates on 1-D vectors and
forms the bridge between vector.contract and
llvm intrinsics for vector reductions.

Reviewers: nicolasvasilache, andydavis1, ftynse

Reviewed By: nicolasvasilache

Subscribers: mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, liufengdb, Joonsoo, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D74370
parent 0cecafd6
Loading
Loading
Loading
Loading
+33 −0
Original line number Diff line number Diff line
@@ -183,6 +183,39 @@ def Vector_ContractionOp :
  }];
}

def Vector_ReductionOp :
  Vector_Op<"reduction", [NoSideEffect,
     PredOpTrait<"source operand and result have same element type",
                 TCresVTEtIsSameAsOpBase<0, 0>>]>,
    Arguments<(ins StrAttr:$kind, AnyVector:$vector)>,
    Results<(outs AnyType:$dest)> {
  let summary = "reduction operation";
  let description = [{
    Reduces an 1-D vector "horizontally" into a scalar using the given
    operation (add/mul/min/max for int/fp and and/or/xor for int only).
    Note that these operations are restricted to 1-D vectors to remain
    close to the corresponding LLVM intrinsics:

    http://llvm.org/docs/LangRef.html#experimental-vector-reduction-intrinsics

    Examples:
    ```
      %1 = vector.reduction "add", %0 : vector<16xf32> into f32

      %3 = vector.reduction "xor", %2 : vector<4xi32> into i32
    ```
  }];
  let verifier = [{ return ::verify(*this); }];
  let assemblyFormat = [{
    $kind `,` $vector attr-dict `:` type($vector) `into` type($dest)
  }];
  let extraClassDeclaration = [{
    VectorType getVectorType() {
      return vector().getType().cast<VectorType>();
    }
  }];
}

def Vector_BroadcastOp :
  Vector_Op<"broadcast", [NoSideEffect,
     PredOpTrait<"source operand and result have same element type",
+74 −6
Original line number Diff line number Diff line
@@ -124,6 +124,7 @@ static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
}

namespace {

class VectorBroadcastOpConversion : public LLVMOpLowering {
public:
  explicit VectorBroadcastOpConversion(MLIRContext *context,
@@ -272,6 +273,73 @@ private:
  }
};

class VectorReductionOpConversion : public LLVMOpLowering {
public:
  explicit VectorReductionOpConversion(MLIRContext *context,
                                       LLVMTypeConverter &typeConverter)
      : LLVMOpLowering(vector::ReductionOp::getOperationName(), context,
                       typeConverter) {}

  PatternMatchResult
  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                  ConversionPatternRewriter &rewriter) const override {
    auto reductionOp = cast<vector::ReductionOp>(op);
    auto kind = reductionOp.kind();
    Type eltType = reductionOp.dest().getType();
    Type llvmType = lowering.convertType(eltType);
    if (eltType.isInteger(32) || eltType.isInteger(64)) {
      // Integer reductions: add/mul/min/max/and/or/xor.
      if (kind == "add")
        rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_add>(
            op, llvmType, operands[0]);
      else if (kind == "mul")
        rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_mul>(
            op, llvmType, operands[0]);
      else if (kind == "min")
        rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_smin>(
            op, llvmType, operands[0]);
      else if (kind == "max")
        rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_smax>(
            op, llvmType, operands[0]);
      else if (kind == "and")
        rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_and>(
            op, llvmType, operands[0]);
      else if (kind == "or")
        rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_or>(
            op, llvmType, operands[0]);
      else if (kind == "xor")
        rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_xor>(
            op, llvmType, operands[0]);
      else
        return matchFailure();
      return matchSuccess();

    } else if (eltType.isF32() || eltType.isF64()) {
      // Floating-point reductions: add/mul/min/max
      if (kind == "add") {
        Value zero = rewriter.create<LLVM::ConstantOp>(
            op->getLoc(), llvmType, rewriter.getZeroAttr(eltType));
        rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_v2_fadd>(
            op, llvmType, zero, operands[0]);
      } else if (kind == "mul") {
        Value one = rewriter.create<LLVM::ConstantOp>(
            op->getLoc(), llvmType, rewriter.getFloatAttr(eltType, 1.0));
        rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_v2_fmul>(
            op, llvmType, one, operands[0]);
      } else if (kind == "min")
        rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_fmin>(
            op, llvmType, operands[0]);
      else if (kind == "max")
        rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_fmax>(
            op, llvmType, operands[0]);
      else
        return matchFailure();
      return matchSuccess();
    }
    return matchFailure();
  }
};

class VectorShuffleOpConversion : public LLVMOpLowering {
public:
  explicit VectorShuffleOpConversion(MLIRContext *context,
@@ -1056,12 +1124,12 @@ void mlir::populateVectorToLLVMConversionPatterns(
                  VectorInsertStridedSliceOpDifferentRankRewritePattern,
                  VectorInsertStridedSliceOpSameRankRewritePattern,
                  VectorStridedSliceOpConversion>(ctx);
  patterns.insert<VectorBroadcastOpConversion, VectorShuffleOpConversion,
                  VectorExtractElementOpConversion, VectorExtractOpConversion,
                  VectorFMAOp1DConversion, VectorInsertElementOpConversion,
                  VectorInsertOpConversion, VectorOuterProductOpConversion,
                  VectorTypeCastOpConversion, VectorPrintOpConversion>(
      ctx, converter);
  patterns.insert<VectorBroadcastOpConversion, VectorReductionOpConversion,
                  VectorShuffleOpConversion, VectorExtractElementOpConversion,
                  VectorExtractOpConversion, VectorFMAOp1DConversion,
                  VectorInsertElementOpConversion, VectorInsertOpConversion,
                  VectorOuterProductOpConversion, VectorTypeCastOpConversion,
                  VectorPrintOpConversion>(ctx, converter);
}

namespace {
+27 −0
Original line number Diff line number Diff line
@@ -60,6 +60,33 @@ ArrayAttr vector::getVectorSubscriptAttr(Builder &builder,
  return builder.getI64ArrayAttr(values);
}

//===----------------------------------------------------------------------===//
// ReductionOp
//===----------------------------------------------------------------------===//

static LogicalResult verify(ReductionOp op) {
  // Verify for 1-D vector.
  int64_t rank = op.getVectorType().getRank();
  if (rank != 1)
    return op.emitOpError("unsupported reduction rank: ") << rank;

  // Verify supported reduction kind.
  auto kind = op.kind();
  Type eltType = op.dest().getType();
  if (kind == "add" || kind == "mul" || kind == "min" || kind == "max") {
    if (eltType.isF32() || eltType.isF64() || eltType.isInteger(32) ||
        eltType.isInteger(64))
      return success();
    return op.emitOpError("unsupported reduction type");
  }
  if (kind == "and" || kind == "or" || kind == "xor") {
    if (eltType.isInteger(32) || eltType.isInteger(64))
      return success();
    return op.emitOpError("unsupported reduction type");
  }
  return op.emitOpError("unknown reduction kind: ") << kind;
}

//===----------------------------------------------------------------------===//
// ContractionOp
//===----------------------------------------------------------------------===//
+41 −3
Original line number Diff line number Diff line
@@ -663,3 +663,41 @@ func @vector_fma(%a: vector<8xf32>, %b: vector<2x4xf32>) -> (vector<8xf32>, vect
  return %0, %1: vector<8xf32>, vector<2x4xf32>
}

func @reduce_f32(%arg0: vector<16xf32>) -> f32 {
  %0 = vector.reduction "add", %arg0 : vector<16xf32> into f32
  return %0 : f32
}
// CHECK-LABEL: llvm.func @reduce_f32
// CHECK-SAME: %[[A:.*]]: !llvm<"<16 x float>">
//      CHECK: %[[C:.*]] = llvm.mlir.constant(0.000000e+00 : f32) : !llvm.float
//      CHECK: %[[V:.*]] = "llvm.intr.experimental.vector.reduce.v2.fadd"(%[[C]], %[[A]])
//      CHECK: llvm.return %[[V]] : !llvm.float

func @reduce_f64(%arg0: vector<16xf64>) -> f64 {
  %0 = vector.reduction "add", %arg0 : vector<16xf64> into f64
  return %0 : f64
}
// CHECK-LABEL: llvm.func @reduce_f64
// CHECK-SAME: %[[A:.*]]: !llvm<"<16 x double>">
//      CHECK: %[[C:.*]] = llvm.mlir.constant(0.000000e+00 : f64) : !llvm.double
//      CHECK: %[[V:.*]] = "llvm.intr.experimental.vector.reduce.v2.fadd"(%[[C]], %[[A]])
//      CHECK: llvm.return %[[V]] : !llvm.double

func @reduce_i32(%arg0: vector<16xi32>) -> i32 {
  %0 = vector.reduction "add", %arg0 : vector<16xi32> into i32
  return %0 : i32
}
// CHECK-LABEL: llvm.func @reduce_i32
// CHECK-SAME: %[[A:.*]]: !llvm<"<16 x i32>">
//      CHECK: %[[V:.*]] = "llvm.intr.experimental.vector.reduce.add"(%[[A]])
//      CHECK: llvm.return %[[V]] : !llvm.i32

func @reduce_i64(%arg0: vector<16xi64>) -> i64 {
  %0 = vector.reduction "add", %arg0 : vector<16xi64> into i64
  return %0 : i64
}
// CHECK-LABEL: llvm.func @reduce_i64
// CHECK-SAME: %[[A:.*]]: !llvm<"<16 x i64>">
//      CHECK: %[[V:.*]] = "llvm.intr.experimental.vector.reduce.add"(%[[A]])
//      CHECK: llvm.return %[[V]] : !llvm.i64
+28 −0
Original line number Diff line number Diff line
@@ -990,3 +990,31 @@ func @shape_cast_different_tuple_sizes(
  %1 = vector.shape_cast %arg1 : tuple<vector<5x4x2xf32>, vector<3x4x2xf32>> to
                                 tuple<vector<20x2xf32>>
}

// -----

func @reduce_unknown_kind(%arg0: vector<16xf32>) -> f32 {
  // expected-error@+1 {{'vector.reduction' op unknown reduction kind: joho}}
  %0 = vector.reduction "joho", %arg0 : vector<16xf32> into f32
}

// -----

func @reduce_elt_type_mismatch(%arg0: vector<16xf32>) -> i32 {
  // expected-error@+1 {{'vector.reduction' op failed to verify that source operand and result have same element type}}
  %0 = vector.reduction "add", %arg0 : vector<16xf32> into i32
}

// -----

func @reduce_unsupported_type(%arg0: vector<16xf32>) -> f32 {
  // expected-error@+1 {{'vector.reduction' op unsupported reduction type}}
  %0 = vector.reduction "xor", %arg0 : vector<16xf32> into f32
}

// -----

func @reduce_unsupported_rank(%arg0: vector<4x16xf32>) -> f32 {
  // expected-error@+1 {{'vector.reduction' op unsupported reduction rank: 2}}
  %0 = vector.reduction "add", %arg0 : vector<4x16xf32> into f32
}
Loading