Commit 63b683a8 authored by Nicolas Vasilache's avatar Nicolas Vasilache
Browse files

[mlir][Vector] Add a vector.matrix_multiply op on 1-D vectors

Summary: This op mirrors the llvm.intr counterpart and allows lowering + type conversions in a progressive fashion.

Differential Revision: https://reviews.llvm.org/D75775
parent 47caa691
Loading
Loading
Loading
Loading
+6 −0
Original line number Diff line number Diff line
@@ -15,6 +15,12 @@ class LLVMTypeConverter;
class ModuleOp;
template <typename T> class OpPassBase;

/// Collect a set of patterns to convert from Vector contractions to LLVM Matrix
/// Intrinsics. To lower to assembly, the LLVM flag -lower-matrix-intrinsics
/// will be needed when invoking LLVM.
void populateVectorToLLVMMatrixConversionPatterns(
    LLVMTypeConverter &converter, OwningRewritePatternList &patterns);

/// Collect a set of patterns to convert from the Vector dialect to LLVM.
void populateVectorToLLVMConversionPatterns(LLVMTypeConverter &converter,
                                            OwningRewritePatternList &patterns);
+2 −2
Original line number Diff line number Diff line
@@ -836,12 +836,12 @@ def LLVM_MatrixMultiplyOp
    : LLVM_OneResultOp<"intr.matrix.multiply">,
      Arguments<(
        ins LLVM_Type:$lhs, LLVM_Type:$rhs,
            I32Attr:$lhs_rows, I32Attr:$lhs_columns, I32Attr:$rhs_rows)> {
            I32Attr:$lhs_rows, I32Attr:$lhs_columns, I32Attr:$rhs_columns)> {
  string llvmBuilder = [{
    llvm::MatrixBuilder<decltype(builder)> mb(builder);
    $res = mb.CreateMatrixMultiply(
      $lhs, $rhs, $lhs_rows.getZExtValue(), $lhs_columns.getZExtValue(),
      $rhs_rows.getZExtValue());
      $rhs_columns.getZExtValue());
  }];
  let assemblyFormat = "$lhs `,` $rhs attr-dict "
    "`:` `(` type($lhs) `,` type($rhs) `)` `->` type($res)";
+61 −0
Original line number Diff line number Diff line
@@ -1336,4 +1336,65 @@ def Vector_PrintOp :
  let assemblyFormat = "$source attr-dict `:` type($source)";
}

//===----------------------------------------------------------------------===//
// Ops used for supporting progressive lowering and conversion type changes.
//===----------------------------------------------------------------------===//

/// Vector dialect matrix multiplication op that operates on flattened 1-D
/// MLIR vectors. This is the counterpart of llvm.matrix.multiply in MLIR.
/// This may seem redundant with vector.contract but it serves the purposes of
/// more progressive lowering and localized type conversion on the path:
///   `vector<...x...xf32> -> vector<...xf32> -> !llvm<... x float>`.
def Vector_MatmulOp : Vector_Op<"matrix_multiply", [NoSideEffect,
        PredOpTrait<"lhs operand and result have same element type",
                    TCresVTEtIsSameAsOpBase<0, 0>>,
        PredOpTrait<"rhs operand and result have same element type",
                    TCresVTEtIsSameAsOpBase<0, 1>>]>,
      Arguments<(
        // TODO(ntv, fhahn): tighten vector element types that make sense.
        ins VectorOfRankAndType<[1],
              [AnySignlessInteger, AnySignedInteger, AnyFloat]>:$lhs,
            VectorOfRankAndType<[1],
              [AnySignlessInteger, AnySignedInteger, AnyFloat]>:$rhs,
            I32Attr:$lhs_rows, I32Attr:$lhs_columns, I32Attr:$rhs_columns)>,
      Results<(
        outs VectorOfRankAndType<[1],
               [AnySignlessInteger, AnySignedInteger, AnyFloat]>:$res)>
{
  let summary = "Vector matrix multiplication op that operates on flattened 1-D"
    " MLIR vectors";
  let description = [{
    This is the counterpart of llvm.matrix.multiply in MLIR. It serves the
    purposes of more progressive lowering and localized type conversion.

    The ‘vector.matrix_multiply’ op treats `lhs` as matrix with <lhs_rows> rows
    and <lhs_columns> columns, `rhs` as matrix with <lhs_columns> rows and
    <rhs_columns> and multiplies them. The result matrix is returned embedded in
    the result vector.

    Example:

    ```
      %C = vector.matrix_multiply %A, %B
        { lhs_rows = 4: i32, lhs_columns = 16: i32 , rhs_columns = 3: i32 } :
        (vector<64xf64>, vector<48xf64>) -> vector<12xf64>
    ```
  }];
  let builders = [
   OpBuilder<"Builder *builder, OperationState &result, Value lhs, Value rhs, "
             "unsigned lhsRows, unsigned lhsColumns, unsigned rhsColumns",
   [{
     result.addOperands({lhs, rhs});
     result.addAttribute("lhs_rows", builder->getI32IntegerAttr(lhsRows));
     result.addAttribute("lhs_columns", builder->getI32IntegerAttr(lhsColumns));
     result.addAttribute("rhs_columns", builder->getI32IntegerAttr(rhsColumns));
     result.addTypes(VectorType::get(lhsRows * lhsColumns,
       lhs.getType().cast<VectorType>().getElementType()));
   }]>,
  ];
  let verifier = ?;
  let assemblyFormat = "$lhs `,` $rhs attr-dict "
    "`:` `(` type($lhs) `,` type($rhs) `)` `->` type($res)";
}

#endif // VECTOR_OPS
+29 −0
Original line number Diff line number Diff line
@@ -275,6 +275,28 @@ private:
  }
};

/// Conversion pattern for a vector.matrix_multiply.
/// This is lowered directly to the proper llvm.intr.matrix.multiply.
class VectorMatmulOpConversion : public ConvertToLLVMPattern {
public:
  explicit VectorMatmulOpConversion(MLIRContext *context,
                                    LLVMTypeConverter &typeConverter)
      : ConvertToLLVMPattern(vector::MatmulOp::getOperationName(), context,
                             typeConverter) {}

  PatternMatchResult
  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                  ConversionPatternRewriter &rewriter) const override {
    auto matmulOp = cast<vector::MatmulOp>(op);
    auto adaptor = vector::MatmulOpOperandAdaptor(operands);
    rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>(
        op, typeConverter.convertType(matmulOp.res().getType()), adaptor.lhs(),
        adaptor.rhs(), matmulOp.lhs_rows(), matmulOp.lhs_columns(),
        matmulOp.rhs_columns());
    return matchSuccess();
  }
};

class VectorReductionOpConversion : public ConvertToLLVMPattern {
public:
  explicit VectorReductionOpConversion(MLIRContext *context,
@@ -1141,6 +1163,12 @@ void mlir::populateVectorToLLVMConversionPatterns(
                  VectorPrintOpConversion>(ctx, converter);
}

void mlir::populateVectorToLLVMMatrixConversionPatterns(
    LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
  MLIRContext *ctx = converter.getDialect()->getContext();
  patterns.insert<VectorMatmulOpConversion>(ctx, converter);
}

namespace {
struct LowerVectorToLLVMPass : public ModulePass<LowerVectorToLLVMPass> {
  void runOnModule() override;
@@ -1160,6 +1188,7 @@ void LowerVectorToLLVMPass::runOnModule() {
  // Convert to the LLVM IR dialect.
  LLVMTypeConverter converter(&getContext());
  OwningRewritePatternList patterns;
  populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
  populateVectorToLLVMConversionPatterns(converter, patterns);
  populateStdToLLVMConversionPatterns(converter, patterns);

+12 −0
Original line number Diff line number Diff line
@@ -701,3 +701,15 @@ func @reduce_i64(%arg0: vector<16xi64>) -> i64 {
//      CHECK: %[[V:.*]] = "llvm.intr.experimental.vector.reduce.add"(%[[A]])
//      CHECK: llvm.return %[[V]] : !llvm.i64


//                          4x16                16x3               4x3
func @matrix_ops(%A: vector<64xf64>, %B: vector<48xf64>) -> vector<12xf64> {
  %C = vector.matrix_multiply %A, %B
    { lhs_rows = 4: i32, lhs_columns = 16: i32 , rhs_columns = 3: i32 } :
    (vector<64xf64>, vector<48xf64>) -> vector<12xf64>
  return %C: vector<12xf64>
}
// CHECK-LABEL: llvm.func @matrix_ops
//       CHECK:   llvm.intr.matrix.multiply %{{.*}}, %{{.*}} {
//  CHECK-SAME: lhs_columns = 16 : i32, lhs_rows = 4 : i32, rhs_columns = 3 : i32
//  CHECK-SAME: } : (!llvm<"<64 x double>">, !llvm<"<48 x double>">) -> !llvm<"<12 x double>">
Loading