Unverified Commit 03d1c99d authored by Benjamin Maxwell's avatar Benjamin Maxwell Committed by GitHub
Browse files

[mlir][ODS] Add `OptionalTypesMatchWith` and remove a custom assemblyFormat (#68876)

This is just a slight specialization of `TypesMatchWith` that returns
success if an optional parameter is missing.

There may be other places this could help e.g.:

https://github.com/llvm/llvm-project/blob/eb21049b4b904b072679ece60e73c6b0dc0d1ebf/mlir/include/mlir/Dialect/X86Vector/X86Vector.td#L58-L59
...but I'm leaving those to avoid some churn.

This constraint will be handy for us in some later patches, it's a
formalization of a short circuiting trick with the `comparator` of the
`TypesMatchWith` constraint (devised for #69195).

```
TypesMatchWith<
  "padding type matches element type of result (if present)",
  "result", "padding",
  "::llvm::cast<VectorType>($_self).getElementType()",
  // This returns true if no padding is present, or it's present with a type that matches the element type of `result`.
  "!getPadding() || std::equal_to<>()">
```

This is a little non-obvious, so after this patch you can instead do:
```
OptionalTypesMatchWith<
  "padding type matches element type of result (if present)",
  "result", "padding",
  "::llvm::cast<VectorType>($_self).getElementType()">
```
parent e880e8ae
Loading
Loading
Loading
Loading
+4 −3
Original line number Diff line number Diff line
@@ -215,6 +215,8 @@ def Vector_ReductionOp :
  Vector_Op<"reduction", [Pure,
     PredOpTrait<"source operand and result have same element type",
                 TCresVTEtIsSameAsOpBase<0, 0>>,
     OptionalTypesMatchWith<"dest and acc have the same type",
                            "dest", "acc", "::llvm::cast<Type>($_self)">,
     DeclareOpInterfaceMethods<ArithFastMathInterface>,
     DeclareOpInterfaceMethods<MaskableOpInterface>,
     DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
@@ -263,9 +265,8 @@ def Vector_ReductionOp :
                         "::mlir::arith::FastMathFlags::none">:$fastMathFlags)>
  ];

  // TODO: Migrate to assemblyFormat once `AllTypesMatch` supports optional
  // operands.
  let hasCustomAssemblyFormat = 1;
  let assemblyFormat = "$kind `,` $vector (`,` $acc^)? (`fastmath` `` $fastmath^)?"
                       " attr-dict `:` type($vector) `into` type($dest)";
  let hasCanonicalizer = 1;
  let hasVerifier = 1;
}
+8 −0
Original line number Diff line number Diff line
@@ -568,6 +568,14 @@ class TypesMatchWith<string summary, string lhsArg, string rhsArg,
  string transformer = transform;
}

// The same as TypesMatchWith but if either `lhsArg` or `rhsArg` are optional
// and not present returns success.
class OptionalTypesMatchWith<string summary, string lhsArg, string rhsArg,
                     string transform, string comparator = "std::equal_to<>()">
  : TypesMatchWith<summary, lhsArg, rhsArg, transform,
     "!get" # snakeCaseToCamelCase<lhsArg>.ret # "()"
     # " || !get" # snakeCaseToCamelCase<rhsArg>.ret # "() || " # comparator>;

// Special variant of `TypesMatchWith` that provides a comparator suitable for
// ranged arguments.
class RangedTypesMatchWith<string summary, string lhsArg, string rhsArg,
+24 −0
Original line number Diff line number Diff line
@@ -66,4 +66,28 @@ class CArg<string ty, string value = ""> {
  string defaultValue = value;
}

// Helper which makes the first letter of a string uppercase.
// e.g. cat -> Cat
class firstCharToUpper<string str>
{
  string ret = !if(!gt(!size(str), 0),
    !toupper(!substr(str, 0, 1)) # !substr(str, 1),
    "");
}

class _snakeCaseHelper<string str> {
  int idx = !find(str, "_");
  string ret = !if(!ge(idx, 0),
    !substr(str, 0, idx) # firstCharToUpper<!substr(str, !add(idx, 1))>.ret,
    str);
}

// Converts a snake_case string to CamelCase.
// TODO: Replace with a !tocamelcase bang operator.
class snakeCaseToCamelCase<string str>
{
  string ret = !foldl(firstCharToUpper<str>.ret,
    !range(0, !size(str)), acc, idx, _snakeCaseHelper<acc>.ret);
}

#endif // UTILS_TD
+0 −41
Original line number Diff line number Diff line
@@ -524,47 +524,6 @@ LogicalResult ReductionOp::verify() {
  return success();
}

ParseResult ReductionOp::parse(OpAsmParser &parser, OperationState &result) {
  SmallVector<OpAsmParser::UnresolvedOperand, 2> operandsInfo;
  Type redType;
  Type resType;
  CombiningKindAttr kindAttr;
  arith::FastMathFlagsAttr fastMathAttr;
  if (parser.parseCustomAttributeWithFallback(kindAttr, Type{}, "kind",
                                              result.attributes) ||
      parser.parseComma() || parser.parseOperandList(operandsInfo) ||
      (succeeded(parser.parseOptionalKeyword("fastmath")) &&
       parser.parseCustomAttributeWithFallback(fastMathAttr, Type{}, "fastmath",
                                               result.attributes)) ||
      parser.parseColonType(redType) ||
      parser.parseKeywordType("into", resType) ||
      (!operandsInfo.empty() &&
       parser.resolveOperand(operandsInfo[0], redType, result.operands)) ||
      (operandsInfo.size() > 1 &&
       parser.resolveOperand(operandsInfo[1], resType, result.operands)) ||
      parser.addTypeToList(resType, result.types))
    return failure();
  if (operandsInfo.empty() || operandsInfo.size() > 2)
    return parser.emitError(parser.getNameLoc(),
                            "unsupported number of operands");
  return success();
}

void ReductionOp::print(OpAsmPrinter &p) {
  p << " ";
  getKindAttr().print(p);
  p << ", " << getVector();
  if (getAcc())
    p << ", " << getAcc();

  if (getFastmathAttr() &&
      getFastmathAttr().getValue() != arith::FastMathFlags::none) {
    p << ' ' << getFastmathAttrName().getValue();
    p.printStrippedAttrOrType(getFastmathAttr());
  }
  p << " : " << getVector().getType() << " into " << getDest().getType();
}

// MaskableOpInterface methods.

/// Returns the mask type expected by this operation.
+1 −1
Original line number Diff line number Diff line
@@ -1169,7 +1169,7 @@ func.func @reduce_unsupported_attr(%arg0: vector<16xf32>) -> i32 {
// -----

func.func @reduce_unsupported_third_argument(%arg0: vector<16xf32>, %arg1: f32) -> f32 {
  // expected-error@+1 {{'vector.reduction' unsupported number of operands}}
  // expected-error@+1 {{expected ':'}}
  %0 = vector.reduction <add>, %arg0, %arg1, %arg1 : vector<16xf32> into f32
}

Loading