Unverified Commit c7be2dee authored by Benjamin Maxwell's avatar Benjamin Maxwell Committed by GitHub
Browse files

[mlir][VectorOps] Add fold `ExtractOp(CreateMask) -> CreateMask` (#69456)

This allows folding extracts from `vector.create_mask` ops that have a
known value. Currently, there's no fold for this, but you get the same
effect from the unrolling in LowerVectorMask (part of
-convert-vector-to-llvm), then folds after that. However, for a future
patch, this simplification needs to be done before lowering to LLVM,
hence the need for this fold.

E.g.:

```
%0 = vector.create_mask %c1, %dimA, %dimB : vector<1x[4]x[4]xi1>
%1 = vector.extract %mask[0] : vector<[4]x[4]xi1>
```
->
```
%0 = vector.create_mask %dimA, %dimB : vector<[4]x[4]xi1>
```
parent febf5c97
Loading
Loading
Loading
Loading
+80 −1
Original line number Diff line number Diff line
@@ -100,6 +100,20 @@ static MaskFormat getMaskFormat(Value mask) {
      return MaskFormat::AllTrue;
    if (allFalse)
      return MaskFormat::AllFalse;
  } else if (auto m = mask.getDefiningOp<CreateMaskOp>()) {
    // Finds all-false create_masks. An all-true create_mask requires all
    // dims to be constants, so that'll be folded to a constant_mask, then
    // detected in the constant_mask case.
    auto maskOperands = m.getOperands();
    for (Value operand : maskOperands) {
      if (auto constantOp = operand.getDefiningOp<arith::ConstantOp>()) {
        int64_t dimSize =
            llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
        if (dimSize <= 0)
          return MaskFormat::AllFalse;
      }
    }
    return MaskFormat::Unknown;
  }
  return MaskFormat::Unknown;
}
@@ -1942,6 +1956,71 @@ public:
  }
};

// Pattern to rewrite a ExtractOp(CreateMask) -> CreateMask.
class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> {
public:
  using OpRewritePattern::OpRewritePattern;

  LogicalResult matchAndRewrite(ExtractOp extractOp,
                                PatternRewriter &rewriter) const override {
    auto createMaskOp =
        extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
    if (!createMaskOp)
      return failure();

    VectorType extractedMaskType =
        llvm::dyn_cast<VectorType>(extractOp.getResult().getType());

    if (!extractedMaskType)
      return failure();

    auto maskOperands = createMaskOp.getOperands();
    ArrayRef<int64_t> extractOpPos = extractOp.getStaticPosition();
    VectorType maskType = createMaskOp.getVectorType();

    bool containsUnknownDims = false;
    bool allFalse = getMaskFormat(createMaskOp) == MaskFormat::AllFalse;

    for (size_t dimIdx = 0; !allFalse && dimIdx < extractOpPos.size();
         dimIdx++) {
      int64_t pos = extractOpPos[dimIdx];
      Value operand = maskOperands[dimIdx];
      auto constantOp = operand.getDefiningOp<arith::ConstantOp>();
      if (!constantOp) {
        // Bounds of this dim unknown.
        containsUnknownDims = true;
        continue;
      }

      int64_t createMaskBound =
          llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();

      if (pos != ShapedType::kDynamic) {
        // If any position is outside the range from the `create_mask`, then the
        // extracted mask will be all-false.
        allFalse |= pos >= createMaskBound;
      } else if (createMaskBound < maskType.getDimSize(dimIdx)) {
        // This dim is not all-true and since this is a dynamic index we don't
        // know if the extraction is within the true or false region.
        // Note: Zero dims have already handled via getMaskFormat().
        containsUnknownDims = true;
      }
    }

    if (allFalse) {
      rewriter.replaceOpWithNewOp<arith::ConstantOp>(
          extractOp, DenseElementsAttr::get(extractedMaskType, false));
    } else if (!containsUnknownDims) {
      rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
          extractOp, extractedMaskType,
          maskOperands.drop_front(extractOpPos.size()));
    } else {
      return failure();
    }
    return success();
  }
};

// Folds extract(shape_cast(..)) into shape_cast when the total element count
// does not change.
LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
@@ -1968,7 +2047,7 @@ LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                            MLIRContext *context) {
  results.add<ExtractOpSplatConstantFolder, ExtractOpNonSplatConstantFolder,
              ExtractOpFromBroadcast>(context);
              ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
  results.add(foldExtractFromShapeCastToShapeCast);
}

+106 −0
Original line number Diff line number Diff line
@@ -67,6 +67,112 @@ func.func @create_mask_transpose_to_transposed_create_mask(

// -----

// CHECK-LABEL: extract_from_create_mask
//  CHECK-SAME: %[[DIM0:.*]]: index, %[[DIM1:.*]]: index
func.func @extract_from_create_mask(%dim0: index, %dim1: index) -> vector<[4]x[4]xi1> {
  %c2 = arith.constant 2 : index
  %mask = vector.create_mask %c2, %dim0, %dim1 : vector<4x[4]x[4]xi1>
  // CHECK: vector.create_mask %[[DIM0]], %[[DIM1]] : vector<[4]x[4]xi1>
  // CHECK-NOT: vector.extract
  %extract = vector.extract %mask[1] : vector<[4]x[4]xi1> from vector<4x[4]x[4]xi1>
  return %extract : vector<[4]x[4]xi1>
}

// -----

// CHECK-LABEL: extract_from_create_mask_all_false
func.func @extract_from_create_mask_all_false(%dim0: index, %dim1: index) -> vector<[4]x[4]xi1> {
  %c2 = arith.constant 2 : index
  %mask = vector.create_mask %c2, %dim0, %dim1 : vector<4x[4]x[4]xi1>
  // CHECK: arith.constant dense<false> : vector<[4]x[4]xi1>
  // CHECK-NOT: vector.extract
  %extract = vector.extract %mask[2] : vector<[4]x[4]xi1> from vector<4x[4]x[4]xi1>
  return %extract : vector<[4]x[4]xi1>
}

// -----

// CHECK-LABEL: extract_from_create_mask_leading_scalable
//  CHECK-SAME: %[[DIM0:.*]]: index
func.func @extract_from_create_mask_leading_scalable(%dim0: index) -> vector<8xi1> {
  %c3 = arith.constant 3 : index
  %mask = vector.create_mask %c3, %dim0 : vector<[4]x8xi1>
  // CHECK: vector.create_mask %[[DIM0]] : vector<8xi1>
  // CHECK-NOT: vector.extract
  %extract = vector.extract %mask[1] : vector<8xi1> from vector<[4]x8xi1>
  return %extract : vector<8xi1>
}

// -----

// CHECK-LABEL: extract_from_create_mask_dynamic_position
//  CHECK-SAME: %[[DIM0:.*]]: index, %[[INDEX:.*]]: index
func.func @extract_from_create_mask_dynamic_position(%dim0: index, %index: index) -> vector<6xi1> {
  %c4 = arith.constant 4 : index
  %c3 = arith.constant 3 : index
  %mask = vector.create_mask %c3, %c4, %dim0 : vector<4x4x6xi1>
  // CHECK: vector.create_mask %[[DIM0]] : vector<6xi1>
  // CHECK-NOT: vector.extract
  %extract = vector.extract %mask[2, %index] : vector<6xi1> from vector<4x4x6xi1>
  return %extract : vector<6xi1>
}

// -----

// CHECK-LABEL: extract_from_create_mask_dynamic_position_all_false
//  CHECK-SAME: %[[DIM0:.*]]: index, %[[INDEX:.*]]: index
func.func @extract_from_create_mask_dynamic_position_all_false(%dim0: index, %index: index) -> vector<6xi1> {
  %c0 = arith.constant 0 : index
  %c1 = arith.constant 1 : index
  %mask = vector.create_mask %c1, %c0, %dim0 : vector<1x4x6xi1>
  // CHECK: arith.constant dense<false> : vector<6xi1>
  // CHECK-NOT: vector.extract
  %extract = vector.extract %mask[0, %index] : vector<6xi1> from vector<1x4x6xi1>
  return %extract : vector<6xi1>
}

// -----

// CHECK-LABEL: extract_from_create_mask_dynamic_position_unknown
//  CHECK-SAME: %[[DIM0:.*]]: index, %[[INDEX:.*]]: index
func.func @extract_from_create_mask_dynamic_position_unknown(%dim0: index, %index: index) -> vector<6xi1> {
  %c2 = arith.constant 2 : index
  %mask = vector.create_mask %c2, %dim0 : vector<4x6xi1>
  // CHECK: %[[C2:.*]] = arith.constant 2 : index
  // CHECK-NEXT: %[[MASK:.*]] = vector.create_mask %[[C2]], %[[DIM0]] : vector<4x6xi1>
  // CHECK-NEXT: vector.extract %[[MASK]][%[[INDEX]]] : vector<6xi1> from vector<4x6xi1>
  %extract = vector.extract %mask[%index] : vector<6xi1> from vector<4x6xi1>
  return %extract : vector<6xi1>
}

// -----

// CHECK-LABEL: extract_from_create_mask_mixed_position_unknown
//  CHECK-SAME: %[[DIM0:.*]]: index, %[[INDEX:.*]]: index
func.func @extract_from_create_mask_mixed_position_unknown(%dim0: index, %index0: index) -> vector<4xi1> {
  %c2 = arith.constant 2 : index
  %mask = vector.create_mask %c2, %c2, %dim0 : vector<2x4x4xi1>
  // CHECK: %[[C2:.*]] = arith.constant 2 : index
  // CHECK-NEXT: %[[MASK:.*]] = vector.create_mask %[[C2]], %[[C2]], %[[DIM0]] : vector<2x4x4xi1>
  // CHECK-NEXT: vector.extract %[[MASK]][1, %[[INDEX]]] : vector<4xi1> from vector<2x4x4xi1>
  %extract = vector.extract %mask[1, %index0] : vector<4xi1> from vector<2x4x4xi1>
  return %extract : vector<4xi1>
}

// -----

// CHECK-LABEL: extract_from_non_constant_create_mask
//  CHECK-SAME: %[[DIM0:.*]]: index
func.func @extract_from_non_constant_create_mask(%dim0: index) -> vector<[2]xi1> {
  %mask = vector.create_mask %dim0, %dim0 : vector<[2]x[2]xi1>
  // CHECK: %[[MASK:.*]] = vector.create_mask %[[DIM0]], %[[DIM0]] : vector<[2]x[2]xi1>
  // CHECK-NEXT: vector.extract %[[MASK]][0] : vector<[2]xi1> from vector<[2]x[2]xi1>
  %extract = vector.extract %mask[0] : vector<[2]xi1> from vector<[2]x[2]xi1>
  return %extract : vector<[2]xi1>
}

// -----

// CHECK-LABEL: constant_mask_transpose_to_transposed_constant_mask
func.func @constant_mask_transpose_to_transposed_constant_mask() -> (vector<2x3x4xi1>, vector<4x2x3xi1>) {
  //     CHECK: vector.constant_mask [1, 2, 3] : vector<2x3x4xi1>