Commit 166f83f4 authored by Feng Liu's avatar Feng Liu
Browse files

[QuantOps] Add the quant region definition

Summary:
This regional op in the QuantOps dialect will be used to wrap
high-precision ops into atomic units for quantization. All the values
used by the internal ops are captured explicitly by the op inputs. The
quantization parameters of the inputs and outputs are stored in the
attributes.

Subscribers: jfb, mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, aartbik, Joonsoo, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D75972
parent 378b1e60
Loading
Loading
Loading
Loading
+30 −0
Original line number Diff line number Diff line
@@ -83,6 +83,36 @@ def quant_StorageCastOp : quant_Op<"scast", [NoSideEffect]> {
  let hasFolder = 1;
}

// A QuantizeRegion (region) represents a quantization unit which wraps
// high-precision ops with quantization specifications for all the inputs
// and outputs. Some quantization specifications can be undetermined and
// derived from other ports by the target specification of the kernel.
def quant_QuantizeRegionOp : quant_Op<"region", [
    NoSideEffect,
    IsolatedFromAbove,
    SingleBlockImplicitTerminator<"ReturnOp">]> {
  let summary = [{
    The `region operation wraps high-precision ops as a logical low-precision
    quantized kernel.
  }];

  let arguments = (ins Variadic<AnyType>:$inputs,
                    TypeArrayAttr:$input_specs,
                    TypeArrayAttr:$output_specs,
                    StrAttr:$logical_kernel);
  let results = (outs Variadic<AnyType>:$outputs);
  let regions = (region SizedRegion<1>:$body);
  let verifier = [{ return verifyRegionOp(*this); }];
}

def quant_ReturnOp : quant_Op<"return", [Terminator]> {
  let summary = [{
    The `return` operation terminates a quantize region and returns values.
  }];

  let arguments = (ins Variadic<AnyTensor>:$results);
}

//===----------------------------------------------------------------------===//
// Training integration and instrumentation ops
//===----------------------------------------------------------------------===//
+52 −2
Original line number Diff line number Diff line
@@ -34,13 +34,63 @@ QuantizationDialect::QuantizationDialect(MLIRContext *context)
}

OpFoldResult StorageCastOp::fold(ArrayRef<Attribute> operands) {
  /// Matches x -> [scast -> scast] -> y, replacing the second scast with the
  /// value of x if the casts invert each other.
  // Matches x -> [scast -> scast] -> y, replacing the second scast with the
  // value of x if the casts invert each other.
  auto srcScastOp = dyn_cast_or_null<StorageCastOp>(arg().getDefiningOp());
  if (!srcScastOp || srcScastOp.arg().getType() != getType())
    return OpFoldResult();
  return srcScastOp.arg();
}

/// The quantization specification should match the expressed type.
static bool isValidQuantizationSpec(Attribute quantSpec, Type expressed) {
  if (auto typeAttr = quantSpec.dyn_cast<TypeAttr>()) {
    Type spec = typeAttr.getValue();
    if (spec.isa<TensorType>() || spec.isa<VectorType>())
      return false;

    // The spec should be either a quantized type which is compatible to the
    // expressed type, or a primitive type which is as same as the
    // (element type of) the expressed type.
    if (auto quantizedType = spec.dyn_cast<QuantizedType>())
      return quantizedType.isCompatibleExpressedType(expressed);

    if (auto tensorType = expressed.dyn_cast<TensorType>())
      return spec == tensorType.getElementType();

    if (auto vectorType = expressed.dyn_cast<VectorType>())
      return spec == vectorType.getElementType();
  }
  return false;
}

static LogicalResult verifyRegionOp(QuantizeRegionOp op) {
  // There are specifications for both inputs and outputs.
  if (op.getNumOperands() != op.input_specs().size() ||
      op.getNumResults() != op.output_specs().size())
    return op.emitOpError(
        "has unmatched operands/results number and spec attributes number");

  // Verify that quantization specifications are valid.
  for (auto input : llvm::zip(op.getOperandTypes(), op.input_specs())) {
    Type inputType = std::get<0>(input);
    Attribute inputSpec = std::get<1>(input);
    if (!isValidQuantizationSpec(inputSpec, inputType)) {
      return op.emitOpError() << "has incompatible specification " << inputSpec
                              << " and input type " << inputType;
    }
  }

  for (auto result : llvm::zip(op.getResultTypes(), op.output_specs())) {
    Type outputType = std::get<0>(result);
    Attribute outputSpec = std::get<1>(result);
    if (!isValidQuantizationSpec(outputSpec, outputType)) {
      return op.emitOpError() << "has incompatible specification " << outputSpec
                              << " and output type " << outputType;
    }
  }
  return success();
}

#define GET_OP_CLASSES
#include "mlir/Dialect/QuantOps/QuantOps.cpp.inc"
+1 −1
Original line number Diff line number Diff line
@@ -60,7 +60,7 @@ struct FxpMathTargetConfigImpl : public FxpMathTargetConfig {
    // Op handlers.
    addOpHandler<ConstantOp>(
        std::bind(&FxpMathTargetConfigImpl::handleConstant, this, _1, _2));
    addOpHandler<ReturnOp>(
    addOpHandler<mlir::ReturnOp>(
        std::bind(&FxpMathTargetConfigImpl::handleTerminal, this, _1, _2));
    addOpHandler<quant::StatisticsOp>(
        std::bind(&FxpMathTargetConfigImpl::handleStats, this, _1, _2));
+101 −0
Original line number Diff line number Diff line
// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s

// CHECK-LABEL: @source
func @source(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) {
  %0 = "quant.region"(%arg0, %arg1, %arg2) ({
    ^bb0(%10: tensor<4xf32>, %11: tensor<4xf32>, %12: tensor<4xf32>):
      %13 = "foo"(%10, %11) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
      %14 = "bar"(%13, %12) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
      "quant.return"(%14) : (tensor<4xf32>) -> ()
  }) {input_specs = [f32, f32, f32], output_specs = [f32], logical_kernel = "xyz"}
    : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> (tensor<4xf32>)
  return %0 : tensor<4xf32>
}

// CHECK-LABEL: @annotated
func @annotated(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) {
  %0 = "quant.region"(%arg0, %arg1, %arg2) ({
    ^bb0(%10: tensor<4xf32>, %11: tensor<4xf32>, %12: tensor<4xf32>):
      %13 = "foo"(%10, %11) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
      %14 = "bar"(%13, %12) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
      "quant.return"(%14) : (tensor<4xf32>) -> ()
  }) {input_specs = [!quant.uniform<i8:f32, 1.0>, !quant.uniform<i8:f32, 2.0>, f32],
      output_specs = [!quant.uniform<i8:f32, 4.0>], logical_kernel = "xyz"}
    : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> (tensor<4xf32>)
  return %0 : tensor<4xf32>
}

// CHECK-LABEL: @quantized
func @quantized(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) {
  %0 = "quant.region"(%arg0, %arg1, %arg2) ({
    ^bb0(%10: tensor<4xf32>, %11: tensor<4xf32>, %12: tensor<4xf32>):
      %13 = "foo"(%10, %11) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
      %14 = "bar"(%13, %12) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
      "quant.return"(%14) : (tensor<4xf32>) -> ()
  }) {input_specs = [!quant.uniform<i8:f32, 1.0>, !quant.uniform<i8:f32, 2.0>, !quant.uniform<i32:f32, 2.0>],
      output_specs = [!quant.uniform<i8:f32, 4.0>], logical_kernel = "xyz"}
    : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> (tensor<4xf32>)
  return %0 : tensor<4xf32>
}

// -----

func @unmatched_quantize(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) {
  // @expected-error @+1 {{'quant.region' op has incompatible specification !quant.uniform<i32:f16, 3.000000e+00> and input type 'tensor<4xf32>'}}
  %0 = "quant.region"(%arg0, %arg1, %arg2) ({
    ^bb0(%10: tensor<4xf32>, %11: tensor<4xf32>, %12: tensor<4xf32>):
      %13 = "foo"(%10, %11) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
      %14 = "bar"(%13, %12) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
      "quant.return"(%14) : (tensor<4xf32>) -> ()
  }) {input_specs = [!quant.uniform<i8:f32, 1.0>, !quant.uniform<i8:f32, 2.0>, !quant.uniform<i32:f16, 3.0>],
      output_specs = [!quant.uniform<i8:f32, 4.0>], logical_kernel = "xyz"}
    : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> (tensor<4xf32>)
  return %0 : tensor<4xf32>
}

// -----

func @unmatched_primitive(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) {
  // @expected-error @+1 {{'quant.region' op has incompatible specification i32 and input type 'tensor<4xf32>'}}
  %0 = "quant.region"(%arg0, %arg1, %arg2) ({
    ^bb0(%10: tensor<4xf32>, %11: tensor<4xf32>, %12: tensor<4xf32>):
      %13 = "foo"(%10, %11) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
      %14 = "bar"(%13, %12) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
      "quant.return"(%14) : (tensor<4xf32>) -> ()
  }) {input_specs = [!quant.uniform<i8:f32, 1.0>, !quant.uniform<i8:f32, 2.0>, i32],
      output_specs = [!quant.uniform<i8:f32, 4.0>], logical_kernel = "xyz"}
    : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> (tensor<4xf32>)
  return %0 : tensor<4xf32>
}

// -----

func @unmatched_number(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) {
  // @expected-error @+1 {{'quant.region' op has unmatched operands/results number and spec attributes number}}
  %0 = "quant.region"(%arg0, %arg1, %arg2) ({
    ^bb0(%10: tensor<4xf32>, %11: tensor<4xf32>, %12: tensor<4xf32>):
      %13 = "foo"(%10, %11) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
      %14 = "bar"(%13, %12) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
      "quant.return"(%14) : (tensor<4xf32>) -> ()
  }) {input_specs = [!quant.uniform<i8:f32, 1.0>, !quant.uniform<i8:f32, 2.0>],
      output_specs = [!quant.uniform<i8:f32, 4.0>], logical_kernel = "xyz"}
    : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> (tensor<4xf32>)
  return %0 : tensor<4xf32>
}

// -----

func @isolated(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) {
  // @expected-note @+1 {{required by region isolation constraints}}
  %0 = "quant.region"(%arg0, %arg1) ({
    ^bb0(%10: tensor<4xf32>, %11: tensor<4xf32>):
      %13 = "foo"(%10, %11) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
      // @expected-error @+1 {{'bar' op using value defined outside the region}}
      %14 = "bar"(%13, %arg2) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
      "quant.return"(%14) : (tensor<4xf32>) -> ()
  }) {input_specs = [!quant.uniform<i8:f32, 1.0>, !quant.uniform<i8:f32, 2.0>],
      output_specs = [!quant.uniform<i8:f32, 4.0>], logical_kernel = "xyz"}
    : (tensor<4xf32>, tensor<4xf32>) -> (tensor<4xf32>)
  return %0 : tensor<4xf32>
}