Unverified Commit 95acb33b authored by Cullen Rhodes's avatar Cullen Rhodes Committed by GitHub
Browse files

[mlir][vector] Move transpose with unit-dim to shape_cast pattern (#72493)

Moved from lowering to canonicalization.
parent e77af7e1
Loading
Loading
Loading
Loading
+40 −1
Original line number Diff line number Diff line
@@ -5564,12 +5564,51 @@ public:
  }
};

/// Folds transpose with non-scalable unit dims into a shape_cast.
///
/// Replace:
///   vector.transpose %0, [1, 0] : vector<nx1x<eltty>> to
///                                 vector<1xnxelty>
/// with:
///   vector.shape_cast %0 : vector<nx1x<eltty>> to vector<1xnxelty>
///
/// Source with leading unit dim (inverse) is also replaced. Unit dim must
/// be fixed. Non-unit dims can be scalable.
class FoldTransposeWithNonScalableUnitDimsToShapeCast final
    : public OpRewritePattern<TransposeOp> {
public:
  using OpRewritePattern::OpRewritePattern;

  LogicalResult matchAndRewrite(TransposeOp transpOp,
                                PatternRewriter &rewriter) const override {
    Value input = transpOp.getVector();
    VectorType resType = transpOp.getResultVectorType();

    SmallVector<int64_t> permutation;
    transpOp.getTransp(permutation);

    if (resType.getRank() == 2 &&
        ((resType.getShape().front() == 1 &&
          !resType.getScalableDims().front()) ||
         (resType.getShape().back() == 1 &&
          !resType.getScalableDims().back())) &&
        permutation == ArrayRef<int64_t>({1, 0})) {
      rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(transpOp, resType,
                                                       input);
      return success();
    }

    return failure();
  }
};

} // namespace

void vector::TransposeOp::getCanonicalizationPatterns(
    RewritePatternSet &results, MLIRContext *context) {
  results.add<FoldTransposeCreateMask, FoldTransposedScalarBroadcast,
              TransposeFolder, FoldTransposeSplat>(context);
              TransposeFolder, FoldTransposeSplat,
              FoldTransposeWithNonScalableUnitDimsToShapeCast>(context);
}

void vector::TransposeOp::getTransp(SmallVectorImpl<int64_t> &results) {
+0 −18
Original line number Diff line number Diff line
@@ -336,24 +336,6 @@ public:
      return rewriter.notifyMatchFailure(
          op, "Options specifies lowering to shuffle");

    // Replace:
    //   vector.transpose %0, [1, 0] : vector<nx1x<eltty>> to
    //                                 vector<1xnxelty>
    // with:
    //   vector.shape_cast %0 : vector<nx1x<eltty>> to vector<1xnxelty>
    //
    // Source with leading unit dim (inverse) is also replaced. Unit dim must
    // be fixed. Non-unit can be scalable.
    if (resType.getRank() == 2 &&
        ((resType.getShape().front() == 1 &&
          !resType.getScalableDims().front()) ||
         (resType.getShape().back() == 1 &&
          !resType.getScalableDims().back())) &&
        transp == ArrayRef<int64_t>({1, 0})) {
      rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, input);
      return success();
    }

    if (inputType.isScalable())
      return failure();

+51 −0
Original line number Diff line number Diff line
@@ -2524,3 +2524,54 @@ func.func @load_store_forwarding_rank_mismatch(%v0: vector<4x1x1xf32>, %arg0: te
      tensor<4x4x4xf32>, vector<1x100x4x5xf32>
  return %r : vector<1x100x4x5xf32>
}

// -----

/// Transpose of rank-2 vector with leading or trailing non-scalable unit dim to shape_cast.

// CHECK-LABEL: func @fold_transpose_with_unit_dims_to_shape_cast_4x1xf32
func.func @fold_transpose_with_unit_dims_to_shape_cast_4x1xf32(%arg0: vector<4x1xf32>) -> vector<1x4xf32> {
  // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<4x1xf32> to vector<1x4xf32>
  %0 = vector.transpose %arg0, [1, 0] : vector<4x1xf32> to vector<1x4xf32>
  return %0 : vector<1x4xf32>
}

// CHECK-LABEL: func @fold_transpose_with_unit_dims_to_shape_cast_nx4x1xf32
func.func @fold_transpose_with_unit_dims_to_shape_cast_nx4x1xf32(%arg0: vector<[4]x1xf32>) -> vector<1x[4]xf32> {
  // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<[4]x1xf32> to vector<1x[4]xf32>
  %0 = vector.transpose %arg0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
  return %0 : vector<1x[4]xf32>
}

// CHECK-LABEL: func @fold_transpose_with_unit_dims_to_shape_cast_1x4xf32
func.func @fold_transpose_with_unit_dims_to_shape_cast_1x4xf32(%arg0: vector<1x4xf32>) -> vector<4x1xf32> {
  // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4x1xf32>
  %0 = vector.transpose %arg0, [1, 0] : vector<1x4xf32> to vector<4x1xf32>
  return %0 : vector<4x1xf32>
}

// CHECK-LABEL: func @fold_transpose_with_unit_dims_to_shape_cast_1xnx4xf32
func.func @fold_transpose_with_unit_dims_to_shape_cast_1xnx4xf32(%arg0: vector<1x[4]xf32>) -> vector<[4]x1xf32> {
  // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<1x[4]xf32> to vector<[4]x1xf32>
  %0 = vector.transpose %arg0, [1, 0] : vector<1x[4]xf32> to vector<[4]x1xf32>
  return %0 : vector<[4]x1xf32>
}

/// Scalable unit dim should not be lowered to shape_cast.

// CHECK-LABEL: func @fold_transpose_with_unit_dims_to_shape_cast_4xnx1xf32
func.func @fold_transpose_with_unit_dims_to_shape_cast_4xnx1xf32(%arg0: vector<4x[1]xf32>) -> vector<[1]x4xf32> {
  // CHECK-NOT: vector.shape_cast
  // CHECK: vector.transpose %{{.*}} : vector<4x[1]xf32> to vector<[1]x4xf32>
  %0 = vector.transpose %arg0, [1, 0] : vector<4x[1]xf32> to vector<[1]x4xf32>
  return %0 : vector<[1]x4xf32>
}

// CHECK-LABEL: func @fold_transpose_with_unit_dims_to_shape_cast_nx4xnx1xf32
func.func @fold_transpose_with_unit_dims_to_shape_cast_nx4xnx1xf32(%arg0: vector<4x[1]xf32>) -> vector<[1]x4xf32> {
  // CHECK-NOT: vector.shape_cast
  // CHECK: vector.transpose %{{.*}} : vector<4x[1]xf32> to vector<[1]x4xf32>
  %0 = vector.transpose %arg0, [1, 0] : vector<4x[1]xf32> to vector<[1]x4xf32>

  return %0 : vector<[1]x4xf32>
}
+0 −51
Original line number Diff line number Diff line
@@ -790,57 +790,6 @@ module attributes {transform.with_named_sequence} {
  }
}

// -----

/// Transpose of rank-2 vector with leading or trailing unit dim to shape_cast.

// CHECK-LABEL: func @transpose10_4x1xf32
func.func @transpose10_4x1xf32(%arg0: vector<4x1xf32>) -> vector<1x4xf32> {
  // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<4x1xf32> to vector<1x4xf32>
  %0 = vector.transpose %arg0, [1, 0] : vector<4x1xf32> to vector<1x4xf32>
  return %0 : vector<1x4xf32>
}

// CHECK-LABEL: func @transpose10_nx4x1xf32
func.func @transpose10_nx4x1xf32(%arg0: vector<[4]x1xf32>) -> vector<1x[4]xf32> {
  // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<[4]x1xf32> to vector<1x[4]xf32>
  %0 = vector.transpose %arg0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
  return %0 : vector<1x[4]xf32>
}

// CHECK-LABEL: func @transpose10_1x4xf32
func.func @transpose10_1x4xf32(%arg0: vector<1x4xf32>) -> vector<4x1xf32> {
  // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4x1xf32>
  %0 = vector.transpose %arg0, [1, 0] : vector<1x4xf32> to vector<4x1xf32>
  return %0 : vector<4x1xf32>
}

// CHECK-LABEL: func @transpose10_1xnx4xf32
func.func @transpose10_1xnx4xf32(%arg0: vector<1x[4]xf32>) -> vector<[4]x1xf32> {
  // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<1x[4]xf32> to vector<[4]x1xf32>
  %0 = vector.transpose %arg0, [1, 0] : vector<1x[4]xf32> to vector<[4]x1xf32>
  return %0 : vector<[4]x1xf32>
}

/// Scalable unit dim should not be lowered to shape_cast.

// CHECK-LABEL: func @transpose10_4xnx1xf32
func.func @transpose10_4xnx1xf32(%arg0: vector<4x[1]xf32>) -> vector<[1]x4xf32> {
  // CHECK-NOT: vector.shape_cast
  // CHECK: vector.transpose %{{.*}} : vector<4x[1]xf32> to vector<[1]x4xf32>
  %0 = vector.transpose %arg0, [1, 0] : vector<4x[1]xf32> to vector<[1]x4xf32>
  return %0 : vector<[1]x4xf32>
}

// CHECK-LABEL: func @transpose10_nx4xnx1xf32
func.func @transpose10_nx4xnx1xf32(%arg0: vector<4x[1]xf32>) -> vector<[1]x4xf32> {
  // CHECK-NOT: vector.shape_cast
  // CHECK: vector.transpose %{{.*}} : vector<4x[1]xf32> to vector<[1]x4xf32>
  %0 = vector.transpose %arg0, [1, 0] : vector<4x[1]xf32> to vector<[1]x4xf32>

  return %0 : vector<[1]x4xf32>
}

module attributes {transform.with_named_sequence} {
  transform.named_sequence @__transform_main(%func_op: !transform.op<"func.func"> {transform.readonly}) {
    transform.apply_patterns to %func_op {