Commit 1b00b94f authored by Rob Suderman's avatar Rob Suderman
Browse files

[mlir][tosa] Tosa shape propagation for tosa.cond_if

We can propagate the shape from tosa.cond_if operands into the true/false
regions then through the connected blocks. Then, using the tosa.yield ops
we can determine what all possible return types are.

Reviewed By: jpienaar

Differential Revision: https://reviews.llvm.org/D105940
parent b4121b33
Loading
Loading
Loading
Loading
+2 −0
Original line number Diff line number Diff line
@@ -1789,6 +1789,8 @@ def Tosa_CustomOp : Tosa_Op<"custom"> {
// Further described in docs/Rationale/RationaleTOSADialect.md .
//===----------------------------------------------------------------------===//
def Tosa_IfOp : Tosa_Op<"cond_if", [
      DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
                                ["inferReturnTypeComponents"]>,
       SingleBlockImplicitTerminator<"YieldOp">,
       RecursiveSideEffects]> {
  let summary = "Conditional if operator";
+178 −0
Original line number Diff line number Diff line
//===-- ShapeUtils.h - TOSA shape support declarations ----------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Class declarations for shape utilities meant to assist shape propagation.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_TOSA_UTILS_SHAPEUTILS_H
#define MLIR_DIALECT_TOSA_UTILS_SHAPEUTILS_H

#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Types.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SmallVector.h"

namespace mlir {
namespace tosa {
/// Statically known information for a particular Value.
///
/// This struct currently tracks only information relevant for tensor/array-like
/// shaped types. It is fine to associate a `ValueKnowledge` with a non-shaped
/// type as long as it is in the default "no knowledge" state returned by
/// `getPessimisticValueState`. The important invariant is that we cannot
/// claim to know something about a value which is false.
///
/// This class could also be called "dataflow facts", "lattice value", etc.
struct ValueKnowledge {
  ValueKnowledge() = delete;
  ValueKnowledge(bool hasRank, llvm::ArrayRef<int64_t> newSizes, Type dtype)
      : hasError(false), hasRank(hasRank), dtype(dtype) {
    sizes.reserve(newSizes.size());
    for (auto size : newSizes)
      sizes.push_back(size);
  }

  operator bool() const { return !hasError; }

  // Get the static knowledge intrinsic to `type`.
  static ValueKnowledge getKnowledgeFromType(Type type) {
    ValueKnowledge result = getPessimisticValueState();
    if (auto shapedType = type.dyn_cast<ShapedType>()) {
      if (shapedType.hasRank()) {
        result.hasRank = true;
        result.sizes.reserve(shapedType.getRank());
        for (auto dim : shapedType.getShape())
          result.sizes.push_back(dim);
      }
      result.dtype = shapedType.getElementType();
    }
    return result;
  }

  // Return a pessimistic/conservative value state without assuming any knowlege
  // about the IR.
  static ValueKnowledge getPessimisticValueState() {
    return ValueKnowledge(false, {}, Type());
  }

  Type getType() const {
    if (hasRank)
      return RankedTensorType::get(llvm::makeArrayRef(sizes), dtype);
    return UnrankedTensorType::get(dtype);
  }

  bool operator==(const ValueKnowledge &rhs) const {
    return hasRank == rhs.hasRank && sizes == rhs.sizes && dtype == rhs.dtype;
  }

  // Given two pieces of static knowledge, calculate conservatively the
  // information we can be sure about.
  static ValueKnowledge join(const ValueKnowledge &lhs,
                             const ValueKnowledge &rhs) {
    // Mental model: All conditions are checking how to change from the safe "no
    // knowledge" default-initialized state to a state with more knowledge
    // consistent with lhs and rhs.
    ValueKnowledge result = getPessimisticValueState();
    result.hasError = true;

    if (!lhs || !rhs || lhs.dtype != rhs.dtype)
      return result;

    result.hasError = false;
    result.dtype = lhs.dtype;

    if (!lhs.hasRank && !rhs.hasRank)
      return result;

    if (!rhs.hasRank) {
      result.hasRank = true;
      result.sizes = lhs.sizes;
      return result;
    }

    if (!lhs.hasRank) {
      result.hasRank = true;
      result.sizes = rhs.sizes;
      return result;
    }

    if (lhs.sizes.size() != rhs.sizes.size())
      return result;

    result.hasRank = true;
    result.sizes.resize(lhs.sizes.size(), ShapedType::kDynamicSize);
    for (auto i : llvm::seq<unsigned>(0, result.sizes.size())) {
      int64_t lhsSize = lhs.sizes[i];
      int64_t rhsSize = rhs.sizes[i];
      int64_t &resultSize = result.sizes[i];
      if (lhsSize == ShapedType::kDynamicSize) {
        resultSize = rhsSize;
      } else if (rhsSize == ShapedType::kDynamicSize) {
        resultSize = lhsSize;
      } else if (lhsSize == rhsSize) {
        resultSize = lhsSize;
      } else {
        result.hasError = true;
      }
    }

    return result;
  }

  // Given to types, generate a new ValueKnowledge that meets to cover both
  // cases. E.g. if the rank of the LHS and RHS differ, the resulting tensor
  // has unknown rank.
  static ValueKnowledge meet(const ValueKnowledge &lhs,
                             const ValueKnowledge &rhs) {
    ValueKnowledge result = getPessimisticValueState();
    result.hasError = true;

    if (!rhs || !rhs || lhs.dtype != rhs.dtype)
      return result;

    result.hasError = false;
    result.dtype = lhs.dtype;

    if (!lhs.hasRank || !rhs.hasRank) {
      result.hasRank = false;
      return result;
    }

    if (lhs.sizes.size() != rhs.sizes.size()) {
      result.hasRank = false;
      return result;
    }

    result.hasRank = true;
    result.sizes.resize(lhs.sizes.size(), ShapedType::kDynamicSize);
    for (int i = 0, e = lhs.sizes.size(); i < e; i++) {
      if (lhs.sizes[i] == rhs.sizes[i]) {
        result.sizes[i] = lhs.sizes[i];
      }
    }

    return result;
  }

  // Whether the value information has an error.
  bool hasError;
  // Whether the value has known rank.
  bool hasRank;
  // If `hasRank`, the sizes along each rank. Unknown sizes are represented as
  // `ShapedType::kDynamicSize`.
  llvm::SmallVector<int64_t> sizes;
  // The dtype of a tensor.
  // This is equal to nullptr if we don't know that it is a specific concrete
  // type.
  Type dtype;
};
} // namespace tosa
} // namespace mlir

#endif // MLIR_DIALECT_TOSA_UTILS_SHAPEUTILS_H
+49 −0
Original line number Diff line number Diff line
@@ -15,6 +15,7 @@
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
@@ -1301,6 +1302,54 @@ LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
  return success();
}

LogicalResult IfOp::inferReturnTypeComponents(
    MLIRContext *context, ::llvm::Optional<Location> location,
    ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
    SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
  llvm::SmallVector<tosa::YieldOp> yieldOps;
  for (Region *region : regions) {
    for (auto &block : *region)
      if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
        yieldOps.push_back(returnOp);
  }

  if (yieldOps.empty())
    return failure();

  // Get the initial type information for the yield op.
  llvm::SmallVector<ValueKnowledge> resultKnowledge;
  resultKnowledge.reserve(yieldOps.front().getNumOperands());
  for (auto operand : yieldOps.front().getOperands()) {
    resultKnowledge.push_back(
        ValueKnowledge::getKnowledgeFromType(operand.getType()));
  }

  for (auto yieldOp : yieldOps) {
    if (resultKnowledge.size() != yieldOp.getNumOperands())
      return failure();

    for (auto it : llvm::enumerate(yieldOp.getOperands())) {
      int32_t index = it.index();
      auto meet = ValueKnowledge::meet(
          resultKnowledge[index],
          ValueKnowledge::getKnowledgeFromType(it.value().getType()));
      if (!meet)
        continue;
      resultKnowledge[index] = meet;
    }
  }

  for (auto result : resultKnowledge) {
    if (result.hasRank) {
      inferredReturnShapes.push_back(ShapedTypeComponents(result.sizes));
    } else {
      inferredReturnShapes.push_back(ShapedTypeComponents());
    }
  }

  return success();
}

//===----------------------------------------------------------------------===//
// TOSA Operator Definitions.
//===----------------------------------------------------------------------===//
+55 −120
Original line number Diff line number Diff line
@@ -17,6 +17,7 @@
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Transforms/PassDetail.h"
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
@@ -30,137 +31,57 @@ using namespace mlir::tosa;

namespace {

// -----------------------------------------------------------------------------
// Analysis.
// -----------------------------------------------------------------------------
void propagateShapesInRegion(Region &region);

static Type joinElementTypes(Type lhs, Type rhs) {
  return lhs == rhs ? lhs : Type();
}
void propagateShapesToTosaIf(Operation &op) {
  tosa::IfOp ifOp = dyn_cast<tosa::IfOp>(op);
  if (!ifOp)
    return;

namespace {
// Statically known information for a particular Value.
//
// This struct currently tracks only information relevant for tensor/array-like
// shaped types. It is fine to associate a `ValueKnowledge` with a non-shaped
// type as long as it is in the default "no knowledge" state returned by
// `getPessimisticValueState`. The important invariant is that we cannot
// claim to know something about a value which is false.
//
// This class could also be called "dataflow facts", "lattice value", etc.
struct ValueKnowledge {
  ValueKnowledge() = delete;
  ValueKnowledge(bool hasSizes, std::vector<int64_t> sizes, Type dtype)
      : hasSizes(hasSizes), sizes(sizes), dtype(dtype) {
    assert(sizes.size() == 0 || hasSizes);
  }

  // Get the static knowledge intrinsic to `type`.
  static ValueKnowledge getKnowledgeFromType(Type type) {
    ValueKnowledge result = getPessimisticValueState(type.getContext());
    if (auto shapedType = type.dyn_cast<ShapedType>()) {
      if (shapedType.hasRank()) {
        result.hasSizes = true;
        result.sizes = shapedType.getShape();
      }
      result.dtype = shapedType.getElementType();
    }
    return result;
  }

  // Return a pessimistic/conservative value state without assuming any knowlege
  // about the IR.
  static ValueKnowledge getPessimisticValueState(MLIRContext *context) {
    return ValueKnowledge(false, {}, Type());
  }

  Type getType() const {
    if (hasSizes) {
      return RankedTensorType::get(llvm::makeArrayRef(sizes), dtype);
    }
    return UnrankedTensorType::get(dtype);
  }

  bool operator==(const ValueKnowledge &rhs) const {
    return std::make_tuple(hasSizes, sizes, dtype) ==
           std::make_tuple(rhs.hasSizes, rhs.sizes, rhs.dtype);
  }

  // Given two pieces of static knowledge, calculate conservatively the
  // information we can be sure about.
  static ValueKnowledge join(const ValueKnowledge &lhs,
                             const ValueKnowledge &rhs) {
    // Mental model: All conditions are checking how to change from the safe "no
    // knowledge" default-initialized state to a state with more knowledge
    // consistent with lhs and rhs.
    ValueKnowledge result = getPessimisticValueState(nullptr);

    if (lhs.hasSizes && !rhs.hasSizes) {
      result.hasSizes = true;
      result.sizes = lhs.sizes;
    } else if (!lhs.hasSizes && rhs.hasSizes) {
      result.hasSizes = true;
      result.sizes = rhs.sizes;
    } else if (lhs.hasSizes && rhs.hasSizes &&
               lhs.sizes.size() == rhs.sizes.size()) {
      result.hasSizes = true;
      result.sizes.resize(lhs.sizes.size(), ShapedType::kDynamicSize);
      for (int i = 0, e = result.sizes.size(); i != e; i++) {
        int64_t lhsSize = lhs.sizes[i];
        int64_t rhsSize = rhs.sizes[i];
        int64_t &resultSize = result.sizes[i];
        if (lhsSize == ShapedType::kDynamicSize) {
          resultSize = rhsSize;
        } else if (rhsSize == ShapedType::kDynamicSize) {
          resultSize = lhsSize;
        } else if (lhsSize == rhsSize) {
          resultSize = lhsSize;
        }
      }
    }

    result.dtype = joinElementTypes(lhs.dtype, rhs.dtype);
    return result;
  }

  // Whether the Value is known to have a list of sizes.
  bool hasSizes;
  // If `hasSizes`, the sizes along each rank. Unknown sizes are represented as
  // `ShapedType::kDynamicSize`.
  std::vector<int64_t> sizes;
  // The dtype of a tensor.
  // This is equal to nullptr if we don't know that it is a specific concrete
  // type.
  Type dtype;
};
  for (auto &region : op.getRegions()) {
    Block &frontBlock = region.front();
    if (frontBlock.getNumArguments() + 1 != ifOp.getNumOperands())
      return;

} // namespace
    for (int i = 0, e = frontBlock.getNumArguments(); i < e; i++) {
      ValueKnowledge operandKnowledge = ValueKnowledge::getKnowledgeFromType(
          ifOp.getOperand(i + 1).getType());
      ValueKnowledge blockKnowledge = ValueKnowledge::getKnowledgeFromType(
          frontBlock.getArgument(i).getType());
      ValueKnowledge joinedKnowledge =
          ValueKnowledge::join(operandKnowledge, blockKnowledge);
      if (!joinedKnowledge)
        continue;
      frontBlock.getArgument(i).setType(joinedKnowledge.getType());
    }

/// Pass that enables broadcast by making all input arrays have the same
/// number of dimensions. Insert RESHAPE operations to lower rank operand
struct TosaInferShapes : public TosaInferShapesBase<TosaInferShapes> {
public:
  void runOnFunction() override {
    FuncOp func = getOperation();
    propagateShapesInRegion(region);
  }

    IRRewriter rewriter(func.getContext());
  return;
}

    func.walk([&](Operation *op) {
      if (op->getDialect()->getNamespace() !=
void propagateShapesInRegion(Region &region) {
  for (auto &block : region) {
    for (Operation &op : block) {
      if (op.getDialect()->getNamespace() !=
          tosa::TosaDialect::getDialectNamespace())
        return;
        continue;

      propagateShapesToTosaIf(op);

      InferShapedTypeOpInterface shapeInterface =
          dyn_cast<InferShapedTypeOpInterface>(op);
      if (!shapeInterface)
        return;
        continue;

      SmallVector<ShapedTypeComponents> returnedShapes;
      if (shapeInterface
              .inferReturnTypeComponents(
                  op->getContext(), op->getLoc(), op->getOperands(),
                  op->getAttrDictionary(), op->getRegions(), returnedShapes)
                  op.getContext(), op.getLoc(), op.getOperands(),
                  op.getAttrDictionary(), op.getRegions(), returnedShapes)
              .succeeded()) {
        for (auto it : llvm::zip(op->getResults(), returnedShapes)) {
        for (auto it : llvm::zip(op.getResults(), returnedShapes)) {
          Value result = std::get<0>(it);
          ShapedTypeComponents predictedShape = std::get<1>(it);

@@ -183,11 +104,10 @@ public:
              ValueKnowledge::getKnowledgeFromType(resultTy);

          // Compute the knowledge based on the inferred type.
          auto inferredKnowledge =
              ValueKnowledge::getPessimisticValueState(op->getContext());
          auto inferredKnowledge = ValueKnowledge::getPessimisticValueState();
          inferredKnowledge.dtype =
              resultTy.cast<ShapedType>().getElementType();
          inferredKnowledge.hasSizes = predictedShape.hasRank();
          inferredKnowledge.hasRank = predictedShape.hasRank();
          if (predictedShape.hasRank()) {
            for (auto dim : predictedShape.getDims()) {
              inferredKnowledge.sizes.push_back(dim);
@@ -200,10 +120,25 @@ public:
          // Compute the new type based on the joined version.
          auto newKnowledge =
              ValueKnowledge::join(currentKnowledge, inferredKnowledge);
          if (!newKnowledge)
            continue;
          result.setType(newKnowledge.getType());
        }
      }
    });
    }
  }
}

/// Pass that performs shape propagation across TOSA operations. This includes
/// migrating to within the regions of if/while operations.
struct TosaInferShapes : public TosaInferShapesBase<TosaInferShapes> {
public:
  void runOnFunction() override {
    FuncOp func = getOperation();

    IRRewriter rewriter(func.getContext());

    propagateShapesInRegion(func.body());

    // Insert UnrealizedConversionCasts to guarantee ReturnOp agress with
    // the FuncOp type.
+62 −4
Original line number Diff line number Diff line
@@ -774,7 +774,6 @@ func @conv2d_dilated(%input: tensor<2x12x14x3xf32>, %weights: tensor<5x3x6x3xf32

// -----


// CHECK-LABEL: @conv2d_strided
func @conv2d_strided(%input: tensor<1x13x14x1xf32>, %weights: tensor<1x1x1x1xf32>, %bias: tensor<1xf32>) -> () {
  // CHECK: -> tensor<1x5x7x1xf32>
@@ -1033,12 +1032,71 @@ func @resize_fp_vertical(%arg0: tensor<1x2x4x1xi32>) {
  %0 = "tosa.resize"(%arg0) {mode = "NEAREST_NEIGHBOR", offset = [0, 0], offset_fp = [0.000000e+00 : f32, 0.000000e+00 : f32], output_size = [-1, -1], shift = 0 : i32, stride = [0, 0], stride_fp = [5.000000e-01 : f32, 1.000000e+00 : f32]} : (tensor<1x2x4x1xi32>) -> tensor<?x?x?x?xi32>
  return
}

// -----

// CHECK-LABEL: @resize_fp_offsetted
func @resize_fp_offsetted(%arg0: tensor<1x2x4x1xi32>) {
  // CHECK: -> tensor<1x4x6x1xi32>
  %0 = "tosa.resize"(%arg0) {mode = "NEAREST_NEIGHBOR", offset = [0, 0], offset_fp = [2.500000e-01 : f32, 2.500000e-01 : f32], output_size = [-1, -1], shift = 0 : i32, stride = [0, 0], stride_fp = [2.500000e-01 : f32, 5.000000e-01 : f32]} : (tensor<1x2x4x1xi32>) -> tensor<?x?x?x?xi32>
  return
}

// -----

// CHECK-LABEL: @if_test_simple
func @if_test_simple(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<i1>) -> () {
  // CHECK: (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
  %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
  ^bb1(%arg3 : tensor<f32>, %arg4 : tensor<f32>):
    "tosa.yield"(%arg3) : (tensor<f32>) -> ()
  }, {
  ^bb1(%arg5 : tensor<f32>, %arg6 : tensor<f32>):
    "tosa.yield"(%arg6) : (tensor<f32>) -> ()
  }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> (tensor<*xf32>)
  return
}

// -----

// CHECK-LABEL: @if_test_dynamic
func @if_test_dynamic(%arg0 : tensor<2xf32>, %arg1 : tensor<3xf32>, %arg2 : tensor<i1>) -> () {
  // CHECK: (tensor<i1>, tensor<2xf32>, tensor<3xf32>) -> tensor<?xf32>
  %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
  ^bb1(%arg3 : tensor<2xf32>, %arg4 : tensor<3xf32>):
    "tosa.yield"(%arg3) : (tensor<2xf32>) -> ()
  }, {
  ^bb1(%arg5 : tensor<2xf32>, %arg6 : tensor<3xf32>):
    "tosa.yield"(%arg6) : (tensor<3xf32>) -> ()
  }) : (tensor<i1>, tensor<2xf32>, tensor<3xf32>) -> (tensor<*xf32>)
  return
}

// -----

// CHECK-LABEL: @if_test_unranked
func @if_test_unranked(%arg0 : tensor<f32>, %arg1 : tensor<3xf32>, %arg2 : tensor<i1>) -> () {
  // CHECK: (tensor<i1>, tensor<f32>, tensor<3xf32>) -> tensor<*xf32>
  %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
  ^bb1(%arg3 : tensor<f32>, %arg4 : tensor<3xf32>):
    "tosa.yield"(%arg3) : (tensor<f32>) -> ()
  }, {
  ^bb1(%arg5 : tensor<f32>, %arg6 : tensor<3xf32>):
    "tosa.yield"(%arg6) : (tensor<3xf32>) -> ()
  }) : (tensor<i1>, tensor<f32>, tensor<3xf32>) -> (tensor<*xf32>)
  return
}

// -----

// CHECK-LABEL: @if_test_propagate
func @if_test_propagate(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<i1>) -> () {
  // CHECK: (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
  %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
  ^bb1(%arg3 : tensor<*xf32>, %arg4 : tensor<*xf32>):
    %1 = "tosa.add"(%arg3, %arg4) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
    "tosa.yield"(%1) : (tensor<*xf32>) -> ()
  }, {
  ^bb1(%arg5 : tensor<*xf32>, %arg6 : tensor<*xf32>):
    %1 = "tosa.sub"(%arg5, %arg6) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
    "tosa.yield"(%1) : (tensor<*xf32>) -> ()
  }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> (tensor<*xf32>)
  return
}