Unverified Commit 8d060c02 authored by Davide Grohmann's avatar Davide Grohmann Committed by GitHub
Browse files

[mlir][spirv] Tighten SPIR-V TOSA pool constraints (#193515)



Tighten AvgPool2D and MaxPool2D verification by constraining kernel,
stride, and pad attributes and by checking the input/output NHWC
relationship.

Add verification tests for batch/channel mismatches, non-divisible
pooled shapes, pad-vs-kernel failures, and incorrect output shapes.

Signed-off-by: default avatarDavide Grohmann <davide.grohmann@arm.com>
parent 6d89cd85
Loading
Loading
Loading
Loading
+16 −8
Original line number Diff line number Diff line
@@ -285,7 +285,9 @@ def SPIRV_TosaAvgPool2DOp : SPIRV_TosaOpWithResult<"AvgPool2D", 1, [NoMemoryEffe
  TypeImpliesAccType<"input", F32, ["FP32"]>,
  TypeImpliesAccType<"input", F8E4M3FN, ["FP16"]>,
  TypeImpliesAccType<"input", F8E5M2, ["FP16"]>,
  AllElementTypesMatch<["input", "input_zp", "output", "output_zp"]>]> {
  AllElementTypesMatch<["input", "input_zp", "output", "output_zp"]>,
  NHWCInputOutputShapeMatch<"input", "output">,
  Pool2DPadValuesLessThanKernel<"pad", "kernel">]> {
  let summary = "Performs average pooling on the input.";

  let description = [{
@@ -308,9 +310,9 @@ def SPIRV_TosaAvgPool2DOp : SPIRV_TosaOpWithResult<"AvgPool2D", 1, [NoMemoryEffe
  }];

  let arguments = (ins
    SPIRV_I32_1DTensorArmOfLength2Attr: $kernel,
    SPIRV_I32_1DTensorArmOfLength2Attr: $stride,
    SPIRV_I32_1DTensorArmOfLength4Attr: $pad,
    SPIRV_PositiveInt32_1DTensorArmOfLength2Attr: $kernel,
    SPIRV_PositiveInt32_1DTensorArmOfLength2Attr: $stride,
    SPIRV_NonNegativeInt32_1DTensorArmOfLength4Attr: $pad,
    SPIRV_TosaExtAccTypeAttr: $acc_type,
    SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_TensorArm4D: $input,
    SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_1DTensorArmOfLength1: $input_zp,
@@ -337,6 +339,8 @@ def SPIRV_TosaAvgPool2DOp : SPIRV_TosaOpWithResult<"AvgPool2D", 1, [NoMemoryEffe
      return cast<::mlir::spirv::TensorArmType>(getInput().getType());
    }
  }];

  let hasVerifier = 1;
}


@@ -619,7 +623,9 @@ def SPIRV_TosaMatMulOp : SPIRV_TosaOpWithResult<"MatMul", 6, [NoMemoryEffect,


def SPIRV_TosaMaxPool2DOp : SPIRV_TosaOpWithResult<"MaxPool2D", 7, [Pure,
  AllElementTypesMatch<["input", "output"]>]> {
  AllElementTypesMatch<["input", "output"]>,
  NHWCInputOutputShapeMatch<"input", "output">,
  Pool2DPadValuesLessThanKernel<"pad", "kernel">]> {
  let summary = "Performs max pooling on the input.";

  let description = [{
@@ -640,9 +646,9 @@ def SPIRV_TosaMaxPool2DOp : SPIRV_TosaOpWithResult<"MaxPool2D", 7, [Pure,
  }];

  let arguments = (ins
    SPIRV_I32_1DTensorArmOfLength2Attr: $kernel,
    SPIRV_I32_1DTensorArmOfLength2Attr: $stride,
    SPIRV_I32_1DTensorArmOfLength4Attr: $pad,
    SPIRV_PositiveInt32_1DTensorArmOfLength2Attr: $kernel,
    SPIRV_PositiveInt32_1DTensorArmOfLength2Attr: $stride,
    SPIRV_NonNegativeInt32_1DTensorArmOfLength4Attr: $pad,
    SPIRV_TosaExtNaNPropagationModeAttr: $nan_mode,
    SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_TensorArm4D: $input
  );
@@ -665,6 +671,8 @@ def SPIRV_TosaMaxPool2DOp : SPIRV_TosaOpWithResult<"MaxPool2D", 7, [Pure,
      return cast<::mlir::spirv::TensorArmType>(getInput().getType());
    }
  }];

  let hasVerifier = 1;
}


+31 −0
Original line number Diff line number Diff line
@@ -115,6 +115,14 @@ def SPIRV_I32_1DTensorArmOfLength3Attr : ConfinedAttr<RankedI32ElementsAttr<[3]>
def SPIRV_I32_1DTensorArmOfLength4Attr : ConfinedAttr<RankedI32ElementsAttr<[4]>, [SPIRV_DenseElementAttrsWithTensorArmType]>;
def SPIRV_I32_1DTensorArmOfLength5Attr : ConfinedAttr<RankedI32ElementsAttr<[5]>, [SPIRV_DenseElementAttrsWithTensorArmType]>;
def SPIRV_I32_1DTensorArmOfLength6Attr : ConfinedAttr<RankedI32ElementsAttr<[6]>, [SPIRV_DenseElementAttrsWithTensorArmType]>;
class IntElementsAttrAllValuesAtLeast<int minValue> : AttrConstraint<
  CPred<"::llvm::all_of(::llvm::cast<::mlir::DenseElementsAttr>($_self).getValues<::llvm::APInt>(), "
        "[](const ::llvm::APInt &value) { return value.getSExtValue() >= " #
        minValue # "; })">,
  "all values must be >= " # minValue>;

def SPIRV_PositiveInt32_1DTensorArmOfLength2Attr : ConfinedAttr<RankedI32ElementsAttr<[2]>, [SPIRV_DenseElementAttrsWithTensorArmType, IntElementsAttrAllValuesAtLeast<1>]>;
def SPIRV_NonNegativeInt32_1DTensorArmOfLength4Attr : ConfinedAttr<RankedI32ElementsAttr<[4]>, [SPIRV_DenseElementAttrsWithTensorArmType, IntElementsAttrAllValuesAtLeast<0>]>;

class Is1DTensorArmAttrOfLength<list<int> allowedLengths> :
  AttrConstraint<And<[CPred<[{::llvm::cast<::mlir::spirv::TensorArmType>(::llvm::cast<::mlir::DenseElementsAttr>($_self).getType()).getShape().size() == 1 }]>,
@@ -217,6 +225,29 @@ class ValuesIndicesShapesMatch<string values, string indices, string tensor>:
      SameDimsOrDynamicPred<values, 2, tensor, 2>
    ]>>;

// The tensor shapes are [N,H,W,C] where N,H,W,C are the dimension values.
class NHWCInputOutputShapeMatch<string input, string output>:
  PredOpTrait<"shapes of " # input # " and " # output #
                  " must satisfy [N,*,*,C] and [N,*,*,C]",
    And<[
      SameDimsOrDynamicPred<input, 0, output, 0>,
      SameDimsOrDynamicPred<input, 3, output, 3>
    ]>>;

class FetchNthIntElementsAttr<string attrName, int idx> :
  StrFunc<"get" # snakeCaseToCamelCase<attrName>.ret # "().getValues<APInt>()[" # idx # "].getSExtValue()">;

class ElementsAttrValueLessThan<string leftAttrName, int leftIdx,string rightAttrName, int rightIdx> :
   CPred<FetchNthIntElementsAttr<leftAttrName, leftIdx>.result # " < " # FetchNthIntElementsAttr<rightAttrName, rightIdx>.result>;

class Pool2DPadValuesLessThanKernel<string padAttr, string kernelAttr> :
    PredOpTrait<"op pad values must satisfy pad_top/pad_bottom < kernel_y and pad_left/pad_right < kernel_x",
    And<[ElementsAttrValueLessThan<padAttr, 0, kernelAttr, 0>,
         ElementsAttrValueLessThan<padAttr, 1, kernelAttr, 0>,
         ElementsAttrValueLessThan<padAttr, 2, kernelAttr, 1>,
         ElementsAttrValueLessThan<padAttr, 3, kernelAttr, 1>]
    >>;

class TableSizeConstraint<string input, Type type, int size>:
  PredOpTrait<"table must have size " # size # " if " # input # " has element type " # type.summary,
      Implies<ElementTypeIsPred<input, type>, [CPred<"::llvm::cast<::mlir::ShapedType>(getTable().getType()).getShape()[0] == " # size>]>
+69 −0
Original line number Diff line number Diff line
@@ -54,6 +54,75 @@ void printSPIRV_I32_1DArmTensor(OpAsmPrinter &printer, Operation *,
// SPIRV Tosa Custom verifiers
//===----------------------------------------------------------------------===//

namespace {

int64_t getIntValue(DenseIntElementsAttr attr, size_t idx) {
  return attr.getValues<APInt>()[idx].getSExtValue();
}

LogicalResult verifyPool2DOutputDim(Operation *op, int64_t inputSize,
                                    int64_t outputSize, int64_t kernelSize,
                                    int64_t strideSize, int64_t padBefore,
                                    int64_t padAfter, StringRef dimName,
                                    StringRef dimAxis, StringRef padBeforeName,
                                    StringRef padAfterName) {
  if (ShapedType::isDynamic(inputSize))
    return success();

  const int64_t numerator = inputSize + padBefore + padAfter - kernelSize;
  if (numerator % strideSize != 0)
    return op->emitOpError("expected input_")
           << dimName << " + pad_" << padBeforeName << " + pad_" << padAfterName
           << " - kernel_" << dimAxis << " to be wholly divisible by stride_"
           << dimAxis << ", got (" << inputSize << " + " << padBefore << " + "
           << padAfter << " - " << kernelSize << ") / " << strideSize;

  const int64_t calculatedOutput = numerator / strideSize + 1;
  if (!ShapedType::isDynamic(outputSize) && outputSize != calculatedOutput)
    return op->emitOpError("failed to verify that shapes of input and output "
                           "must satisfy [N,IH,IW,C] and [N,OH,OW,C], with "
                           "OH = ((IH + pad_top + pad_bottom - kernel_y) / "
                           "stride_y) + 1 and OW = ((IW + pad_left + "
                           "pad_right - kernel_x) / stride_x) + 1");

  return success();
}

LogicalResult verifyPool2DOp(Operation *op, DenseIntElementsAttr kernel,
                             DenseIntElementsAttr stride,
                             DenseIntElementsAttr pad, TensorArmType inputType,
                             TensorArmType outputType) {

  if (!inputType.hasRank() || !outputType.hasRank())
    return success();

  if (failed(verifyPool2DOutputDim(
          op, inputType.getDimSize(1), outputType.getDimSize(1),
          getIntValue(kernel, 0), getIntValue(stride, 0), getIntValue(pad, 0),
          getIntValue(pad, 1), "height", "y", "top", "bottom")))
    return failure();

  if (failed(verifyPool2DOutputDim(
          op, inputType.getDimSize(2), outputType.getDimSize(2),
          getIntValue(kernel, 1), getIntValue(stride, 1), getIntValue(pad, 2),
          getIntValue(pad, 3), "width", "x", "left", "right")))
    return failure();

  return success();
}

} // namespace

LogicalResult TosaAvgPool2DOp::verify() {
  return verifyPool2DOp(getOperation(), getKernel(), getStride(), getPad(),
                        getInputType(), getResultType());
}

LogicalResult TosaMaxPool2DOp::verify() {
  return verifyPool2DOp(getOperation(), getKernel(), getStride(), getPad(),
                        getInputType(), getResultType());
}

LogicalResult TosaSelectOp::verify() {
  TensorArmType condType = getConditionType();
  TensorArmType trueValType = getTrueValueType();
+56 −0
Original line number Diff line number Diff line
@@ -91,6 +91,38 @@ spirv.ARM.Graph @avgpool2d_accumulator_must_be_FP16_for_f8e5m2_element_type(%arg
  spirv.ARM.GraphOutputs %6 : !spirv.arm.tensor<1x2x2x2xf8E5M2>
}

spirv.ARM.Graph @avgpool2d_input_output_batch_or_channel_mismatch(%arg0: !spirv.arm.tensor<1x3x65537x2xi8>) -> (!spirv.arm.tensor<2x2x32768x1xi8>) {
  %4 = spirv.Constant dense<125> : !spirv.arm.tensor<1xi8>
  %5 = spirv.Constant dense<-90> : !spirv.arm.tensor<1xi8>
  // expected-error @+1 {{op failed to verify that shapes of input and output must satisfy [N,*,*,C] and [N,*,*,C]}}
  %6 = spirv.Tosa.AvgPool2D kernel = [3, 3], stride = [1, 2], pad = [0, 1, 0, 0], acc_type = <INT32>, %arg0, %4, %5 : !spirv.arm.tensor<1x3x65537x2xi8>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<2x2x32768x1xi8>
  spirv.ARM.GraphOutputs %6 : !spirv.arm.tensor<2x2x32768x1xi8>
}

spirv.ARM.Graph @avgpool2d_input_shape_not_wholly_divisible_by_stride(%arg0: !spirv.arm.tensor<1x4x4x1xi8>) -> (!spirv.arm.tensor<1x1x1x1xi8>) {
  %4 = spirv.Constant dense<125> : !spirv.arm.tensor<1xi8>
  %5 = spirv.Constant dense<-90> : !spirv.arm.tensor<1xi8>
  // expected-error @+1 {{op expected input_height + pad_top + pad_bottom - kernel_y to be wholly divisible by stride_y}}
  %6 = spirv.Tosa.AvgPool2D kernel = [3, 3], stride = [2, 2], pad = [0, 0, 0, 0], acc_type = <INT32>, %arg0, %4, %5 : !spirv.arm.tensor<1x4x4x1xi8>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x1x1x1xi8>
  spirv.ARM.GraphOutputs %6 : !spirv.arm.tensor<1x1x1x1xi8>
}

spirv.ARM.Graph @avgpool2d_pad_values_must_be_less_than_kernel(%arg0: !spirv.arm.tensor<1x4x4x1xi8>) -> (!spirv.arm.tensor<1x2x1x1xi8>) {
  %4 = spirv.Constant dense<125> : !spirv.arm.tensor<1xi8>
  %5 = spirv.Constant dense<-90> : !spirv.arm.tensor<1xi8>
  // expected-error @+1 {{op pad values must satisfy pad_top/pad_bottom < kernel_y and pad_left/pad_right < kernel_x}}
  %6 = spirv.Tosa.AvgPool2D kernel = [2, 3], stride = [1, 2], pad = [2, 0, 0, 0], acc_type = <INT32>, %arg0, %4, %5 : !spirv.arm.tensor<1x4x4x1xi8>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x2x1x1xi8>
  spirv.ARM.GraphOutputs %6 : !spirv.arm.tensor<1x2x1x1xi8>
}

spirv.ARM.Graph @avgpool2d_input_output_height_width_mismatch(%arg0: !spirv.arm.tensor<1x3x65537x1xi8>) -> (!spirv.arm.tensor<1x2x32769x1xi8>) {
  %4 = spirv.Constant dense<125> : !spirv.arm.tensor<1xi8>
  %5 = spirv.Constant dense<-90> : !spirv.arm.tensor<1xi8>
  // expected-error @+1 {{op failed to verify that shapes of input and output must satisfy [N,IH,IW,C] and [N,OH,OW,C], with OH = ((IH + pad_top + pad_bottom - kernel_y) / stride_y) + 1 and OW = ((IW + pad_left + pad_right - kernel_x) / stride_x) + 1}}
  %6 = spirv.Tosa.AvgPool2D kernel = [3, 3], stride = [1, 2], pad = [0, 1, 0, 0], acc_type = <INT32>, %arg0, %4, %5 : !spirv.arm.tensor<1x3x65537x1xi8>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x2x32769x1xi8>
  spirv.ARM.GraphOutputs %6 : !spirv.arm.tensor<1x2x32769x1xi8>
}

//===----------------------------------------------------------------------===//
// spirv.TOSA.Conv2D
//===----------------------------------------------------------------------===//
@@ -537,6 +569,30 @@ spirv.ARM.Graph @maxpool2d_input_output_different_element_types(%arg0: !spirv.ar
  spirv.ARM.GraphOutputs %4 : !spirv.arm.tensor<1x2x32769x1xi16>
}

spirv.ARM.Graph @maxpool2d_input_output_batch_or_channel_mismatch(%arg0: !spirv.arm.tensor<1x3x65537x2xi8>) -> (!spirv.arm.tensor<2x2x32769x1xi8>) {
  // expected-error @+1 {{op failed to verify that shapes of input and output must satisfy [N,*,*,C] and [N,*,*,C]}}
  %4 = spirv.Tosa.MaxPool2D kernel = [3, 2], stride = [1, 2], pad = [1, 0, 0, 1], nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<1x3x65537x2xi8> -> !spirv.arm.tensor<2x2x32769x1xi8>
  spirv.ARM.GraphOutputs %4 : !spirv.arm.tensor<2x2x32769x1xi8>
}

spirv.ARM.Graph @maxpool2d_input_shape_not_wholly_divisible_by_stride(%arg0: !spirv.arm.tensor<1x4x4x1xi8>) -> (!spirv.arm.tensor<1x1x1x1xi8>) {
  // expected-error @+1 {{op expected input_height + pad_top + pad_bottom - kernel_y to be wholly divisible by stride_y}}
  %4 = spirv.Tosa.MaxPool2D kernel = [3, 3], stride = [2, 2], pad = [0, 0, 0, 0], nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<1x4x4x1xi8> -> !spirv.arm.tensor<1x1x1x1xi8>
  spirv.ARM.GraphOutputs %4 : !spirv.arm.tensor<1x1x1x1xi8>
}

spirv.ARM.Graph @maxpool2d_pad_values_must_be_less_than_kernel(%arg0: !spirv.arm.tensor<1x4x4x1xi8>) -> (!spirv.arm.tensor<1x2x1x1xi8>) {
  // expected-error @+1 {{op pad values must satisfy pad_top/pad_bottom < kernel_y and pad_left/pad_right < kernel_x}}
  %4 = spirv.Tosa.MaxPool2D kernel = [2, 3], stride = [1, 2], pad = [2, 0, 0, 0], nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<1x4x4x1xi8> -> !spirv.arm.tensor<1x2x1x1xi8>
  spirv.ARM.GraphOutputs %4 : !spirv.arm.tensor<1x2x1x1xi8>
}

spirv.ARM.Graph @maxpool2d_input_output_height_width_mismatch(%arg0: !spirv.arm.tensor<1x3x65537x1xi8>) -> (!spirv.arm.tensor<1x2x32768x1xi8>) {
  // expected-error @+1 {{op failed to verify that shapes of input and output must satisfy [N,IH,IW,C] and [N,OH,OW,C], with OH = ((IH + pad_top + pad_bottom - kernel_y) / stride_y) + 1 and OW = ((IW + pad_left + pad_right - kernel_x) / stride_x) + 1}}
  %4 = spirv.Tosa.MaxPool2D kernel = [3, 2], stride = [1, 2], pad = [1, 0, 0, 1], nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<1x3x65537x1xi8> -> !spirv.arm.tensor<1x2x32768x1xi8>
  spirv.ARM.GraphOutputs %4 : !spirv.arm.tensor<1x2x32768x1xi8>
}

//===----------------------------------------------------------------------===//
// spirv.TOSA.TransposeConv2D
//===----------------------------------------------------------------------===//