Commit 3ce8095c authored by Andy Davis's avatar Andy Davis
Browse files

[mlir][VectorOps] Add ShapeCastOp to the vector ops dialect.

Summary:
Add ShapeCastOp to the vector ops dialect.

The shape_cast operation casts between an n-D source vector shape and a k-D result vector shape (the element type remains the same).

Reviewers: nicolasvasilache, aartbik

Reviewed By: nicolasvasilache

Subscribers: Joonsoo, merge_guards_bot, mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, liufengdb, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D73635
parent 801857c5
Loading
Loading
Loading
Loading
+43 −0
Original line number Diff line number Diff line
@@ -963,6 +963,49 @@ def Vector_TransferWriteOp :
  }];
}

def Vector_ShapeCastOp :
  Vector_Op<"shape_cast", [NoSideEffect]>,
    Arguments<(ins AnyTypeOf<[AnyVector, TupleOf<[AnyVector]>]>:$source)>,
    Results<(outs AnyTypeOf<[AnyVector, TupleOf<[AnyVector]>]>:$result)> {
  let summary = "shape_cast casts between vector shapes";
  let description = [{
    The shape_cast operation casts between an n-D source vector shape and
    a k-D result vector shape (the element type remains the same).

    If reducing rank (n > k), result dimension sizes must be a product
    of contiguous source dimension sizes.
    If expanding rank (n < k), source dimensions must factor into a
    contiguous sequence of destination dimension sizes.
    Each source dim is expanded (or contiguous sequence of source dims combined)
    in source dimension list order (i.e. 0 <= i < n), to produce a contiguous
    sequence of result dims (or a single result dim), in result dimension list
    order (i.e. 0 <= j < k). The product of all source dimension sizes and all
    result dimension sizes must match.

    If the source/result types are a tuple of vectors, the casting operation
    described above is applied to each source/result tuple element pair.

    It is currently assumed that this operation does not require moving data,
    and that it will be canonicalized away before lowering vector operations.

    Examples:

    ```mlir
    // Example casting to a lower vector rank.
    %1 = vector.shape_cast %0 : vector<5x1x4x3xf32> to vector<20x3xf32>

    // Example casting to a higher vector rank.
    %3 = vector.shape_cast %2 : vector<10x12x8xf32> to vector<5x2x3x4x8xf32>

    // Example casting a tuple of vectors of same rank, where tuple elements
    // may have different shapes.
    %5 = vector.shape_cast %4 : tuple<vector<3x4x2xf32>, vector<3x3x2xf32>> to
                                tuple<vector<12x2xf32>, vector<9x2xf32>>
    ```
  }];
  let assemblyFormat = "$source attr-dict `:` type($source) `to` type($result)";
}

def Vector_TypeCastOp :
  Vector_Op<"type_cast", [NoSideEffect]>,
    Arguments<(ins StaticShapeMemRefOf<[AnyType]>:$memref)>,
+85 −0
Original line number Diff line number Diff line
@@ -26,6 +26,7 @@
#include "mlir/Support/MathExtras.h"
#include "mlir/Support/STLExtras.h"
#include "llvm/ADT/StringSet.h"
#include <numeric>

using namespace mlir;
using namespace mlir::vector;
@@ -1389,6 +1390,90 @@ static LogicalResult verify(TransferWriteOp op) {
                              [&op](Twine t) { return op.emitOpError(t); });
}

//===----------------------------------------------------------------------===//
// ShapeCastOp
//===----------------------------------------------------------------------===//

/// Returns true if each element of 'a' is equal to the product of a contiguous
/// sequence of the elements of 'b'. Returns false otherwise.
static bool isValidShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
  unsigned rankA = a.size();
  unsigned rankB = b.size();
  assert(rankA < rankB);

  unsigned i = 0;
  unsigned j = 0;
  while (i < rankA && j < rankB) {
    int64_t dimA = a[i];
    int64_t dimB = 1;
    while (dimB < dimA && j < rankB)
      dimB *= b[j++];
    if (dimA != dimB)
      break;
    ++i;
  }

  return i == rankA && j == rankB;
}

static LogicalResult verifyVectorShapeCast(Operation *op,
                                           VectorType sourceVectorType,
                                           VectorType resultVectorType) {
  // Check that element type is the same.
  if (sourceVectorType.getElementType() != resultVectorType.getElementType())
    return op->emitOpError("source/result vectors must have same element type");
  auto sourceShape = sourceVectorType.getShape();
  auto resultShape = resultVectorType.getShape();

  // Check that product of source dim sizes matches product of result dim sizes.
  int64_t sourceDimProduct = std::accumulate(
      sourceShape.begin(), sourceShape.end(), 1LL, std::multiplies<int64_t>{});
  int64_t resultDimProduct = std::accumulate(
      resultShape.begin(), resultShape.end(), 1LL, std::multiplies<int64_t>{});
  if (sourceDimProduct != resultDimProduct)
    return op->emitOpError("source/result number of elements must match");

  // Check that expanding/contracting rank cases.
  unsigned sourceRank = sourceVectorType.getRank();
  unsigned resultRank = resultVectorType.getRank();
  if (sourceRank < resultRank) {
    if (!isValidShapeCast(sourceShape, resultShape))
      return op->emitOpError("invalid shape cast");
  } else if (sourceRank > resultRank) {
    if (!isValidShapeCast(resultShape, sourceShape))
      return op->emitOpError("invalid shape cast");
  }
  return success();
}

static LogicalResult verify(ShapeCastOp op) {
  auto sourceVectorType = op.source().getType().dyn_cast_or_null<VectorType>();
  auto resultVectorType = op.result().getType().dyn_cast_or_null<VectorType>();

  // Check if source/result are of vector type.
  if (sourceVectorType && resultVectorType)
    return verifyVectorShapeCast(op, sourceVectorType, resultVectorType);

  // Check if source/result are "tuple of vectors" type.
  auto sourceTupleType = op.source().getType().dyn_cast_or_null<TupleType>();
  auto resultTupleType = op.result().getType().dyn_cast_or_null<TupleType>();
  if (!sourceTupleType || !resultTupleType)
    return op.emitOpError("source/result must be of same type");

  // Check that source/result tuple sizes are the same.
  if (sourceTupleType.size() != resultTupleType.size())
    return op.emitOpError("source/result tuples must be the same size");

  // Check each source/result tuple element pair.
  for (unsigned i = 0, e = sourceTupleType.size(); i < e; ++i)
    if (failed(verifyVectorShapeCast(
            op, sourceTupleType.getType(i).cast<VectorType>(),
            resultTupleType.getType(i).cast<VectorType>())))
      return failure();

  return success();
}

//===----------------------------------------------------------------------===//
// TypeCastOp
//===----------------------------------------------------------------------===//
+82 −0
Original line number Diff line number Diff line
@@ -889,3 +889,85 @@ func @reshape_bad_output_fixed_size(%arg0 : vector<3x2x4xf32>) {
  %1 = vector.reshape %arg0, [%c3, %c6], [%c2, %c9], [4]
    : vector<3x2x4xf32> to vector<2x3x5xf32>
}

// -----

func @shape_cast_wrong_element_type(%arg0 : vector<5x1x3x2xf32>) {
  // expected-error@+1 {{op source/result vectors must have same element type}}
  %0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<15x2xi32>
}

// -----

func @shape_cast_wrong_element_type_tuple(%arg0 : tuple<vector<5x4x2xf32>,
                                                        vector<3x4x2xf32>>) {
  // expected-error@+1 {{op source/result vectors must have same element type}}
  %0 = vector.shape_cast %arg0 : tuple<vector<5x4x2xf32>, vector<3x4x2xf32>> to
                                 tuple<vector<20x2xi32>, vector<12x2xi32>>
}

// -----

func @shape_cast_wrong_num_elements(%arg0 : vector<5x1x3x2xf32>) {
  // expected-error@+1 {{op source/result number of elements must match}}
  %0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<10x2xf32>
}

// -----

func @shape_cast_wrong_num_elements_tuple(%arg0 : tuple<vector<5x4x2xf32>,
                                                        vector<3x4x2xf32>>) {
  // expected-error@+1 {{op source/result number of elements must match}}
  %0 = vector.shape_cast %arg0 : tuple<vector<5x4x2xf32>, vector<3x4x2xf32>> to
                                 tuple<vector<21x2xf32>, vector<13x2xf32>>
}

// -----

func @shape_cast_invalid_rank_reduction(%arg0 : vector<5x1x3x2xf32>) {
  // expected-error@+1 {{invalid shape cast}}
  %0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<2x15xf32>
}

// -----

func @shape_cast_invalid_rank_reduction_tuple(%arg0
  : tuple<vector<5x4x2xf32>, vector<3x4x2xf32>>) {
  // expected-error@+1 {{invalid shape cast}}
  %0 = vector.shape_cast %arg0: tuple<vector<5x4x2xf32>, vector<3x4x2xf32>> to
                                tuple<vector<10x4xf32>, vector<6x4xf32>>
}

// -----

func @shape_cast_invalid_rank_expansion(%arg0 : vector<15x2xf32>) {
  // expected-error@+1 {{invalid shape cast}}
  %0 = vector.shape_cast %arg0 : vector<15x2xf32> to vector<5x2x3x1xf32>
}

// -----

func @shape_cast_invalid_rank_expansion_tuple(%arg0 : tuple<vector<20x2xf32>,
                                                            vector<12x2xf32>>) {
  // expected-error@+1 {{invalid shape cast}}
  %0 = vector.shape_cast %arg0 : tuple<vector<20x2xf32>, vector<12x2xf32>> to
                                 tuple<vector<5x2x4xf32>, vector<4x3x2xf32>>
}

// -----

func @shape_cast_source_result_different_types(
  %arg1 : tuple<vector<20x2xf32>, vector<12x2xf32>>) {
  // expected-error@+1 {{source/result must be of same type}}
  %1 = vector.shape_cast %arg1 : tuple<vector<20x2xf32>, vector<12x2xf32>> to
                                 vector<5x2x4xf32>
}

// -----

func @shape_cast_different_tuple_sizes(
  %arg1 : tuple<vector<5x4x2xf32>, vector<3x4x2xf32>>) {
  // expected-error@+1 {{op source/result tuples must be the same size}}
  %1 = vector.shape_cast %arg1 : tuple<vector<5x4x2xf32>, vector<3x4x2xf32>> to
                                 tuple<vector<20x2xf32>>
}
+15 −0
Original line number Diff line number Diff line
@@ -233,3 +233,18 @@ func @reshape(%arg0 : vector<3x2x4xf32>) -> (vector<2x3x4xf32>) {

  return %1 : vector<2x3x4xf32>
}

// CHECK-LABEL: shape_cast
func @shape_cast(%arg0 : vector<5x1x3x2xf32>,
                 %arg1 : tuple<vector<5x4x2xf32>, vector<3x4x2xf32>>)
  -> (vector<15x2xf32>, tuple<vector<20x2xf32>, vector<12x2xf32>>) {

  // CHECK: vector.shape_cast %{{.*}} : vector<5x1x3x2xf32> to vector<15x2xf32>
  %0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<15x2xf32>

  // CHECK-NEXT: vector.shape_cast %{{.*}} : tuple<vector<5x4x2xf32>, vector<3x4x2xf32>> to tuple<vector<20x2xf32>, vector<12x2xf32>>
  %1 = vector.shape_cast %arg1 : tuple<vector<5x4x2xf32>, vector<3x4x2xf32>> to
                                 tuple<vector<20x2xf32>, vector<12x2xf32>>

  return %0, %1 : vector<15x2xf32>, tuple<vector<20x2xf32>, vector<12x2xf32>>
}