Commit 7d1452d8 authored by River Riddle's avatar River Riddle
Browse files

[mlir] Refactor OpInterface internals to be faster and factor out common bits.

This revision adds a new support header, InterfaceSupport, to contain various generic bits of functionality for implementing "Interfaces". Interfaces embody a mechanism for attaching concept-based polymorphism to a type system. With this refactoring a new InterfaceMap type is added to allow for efficient interface lookups without going through an indirect call. This should provide a decent performance speedup without changing the size of AbstractOperation.

In a future revision, this functionality will also be used to bring Interface like functionality to Attributes and Types.

Differential Revision: https://reviews.llvm.org/D81882
parent 93bc571d
......@@ -131,14 +131,14 @@ struct ExampleOpInterfaceTraits {
/// to be overridden.
struct Concept {
virtual ~Concept();
virtual unsigned getNumInputs(Operation *op) = 0;
virtual unsigned getNumInputs(Operation *op) const = 0;
};
/// Define a model class that specializes a concept on a given operation type.
template <typename OpT>
struct Model : public Concept {
/// Override the method to dispatch on the concrete operation.
unsigned getNumInputs(Operation *op) final {
unsigned getNumInputs(Operation *op) const final {
return llvm::cast<OpT>(op).getNumInputs();
}
};
......@@ -151,7 +151,7 @@ public:
using OpInterface<ExampleOpInterface, ExampleOpInterfaceTraits>::OpInterface;
/// The interface dispatches to 'getImpl()', an instance of the concept.
unsigned getNumInputs() {
unsigned getNumInputs() const {
return getImpl()->getNumInputs(getOperation());
}
};
......
......@@ -1348,120 +1348,39 @@ private:
traitID);
}
/// Returns an opaque pointer to a concept instance of the interface with the
/// given ID if one was registered to this operation.
static void *getRawInterface(TypeID id) {
return InterfaceLookup::template lookup<Traits<ConcreteType>...>(id);
}
struct InterfaceLookup {
/// Trait to check if T provides a static 'getInterfaceID' method.
template <typename T, typename... Args>
using has_get_interface_id = decltype(T::getInterfaceID());
/// If 'T' is the same interface as 'interfaceID' return the concept
/// instance.
template <typename T>
static typename std::enable_if<
llvm::is_detected<has_get_interface_id, T>::value, void *>::type
lookup(TypeID interfaceID) {
return (T::getInterfaceID() == interfaceID) ? &T::instance() : nullptr;
}
/// 'T' is known to not be an interface, return nullptr.
template <typename T>
static typename std::enable_if<
!llvm::is_detected<has_get_interface_id, T>::value, void *>::type
lookup(TypeID) {
return nullptr;
}
template <typename T, typename T2, typename... Ts>
static void *lookup(TypeID interfaceID) {
auto *concept = lookup<T>(interfaceID);
return concept ? concept : lookup<T2, Ts...>(interfaceID);
}
};
/// Returns an interface map for the interfaces registered to this operation.
static detail::InterfaceMap getInterfaceMap() {
return detail::InterfaceMap::template get<Traits<ConcreteType>...>();
}
/// Allow access to 'hasTrait' and 'getRawInterface'.
/// Allow access to 'hasTrait' and 'getInterfaceMap'.
friend AbstractOperation;
};
/// This class represents the base of an operation interface. Operation
/// interfaces provide access to derived *Op properties through an opaquely
/// Operation instance. Derived interfaces must also provide a 'Traits' class
/// that defines a 'Concept' and a 'Model' class. The 'Concept' class defines an
/// abstract virtual interface, where as the 'Model' class implements this
/// interface for a specific derived *Op type. Both of these classes *must* not
/// contain non-static data. A simple example is shown below:
///
/// struct ExampleOpInterfaceTraits {
/// struct Concept {
/// virtual unsigned getNumInputs(Operation *op) = 0;
/// };
/// template <typename OpT> class Model {
/// unsigned getNumInputs(Operation *op) final {
/// return cast<OpT>(op).getNumInputs();
/// }
/// };
/// };
///
/// This class represents the base of an operation interface. See the definition
/// of `detail::Interface` for requirements on the `Traits` type.
template <typename ConcreteType, typename Traits>
class OpInterface : public Op<ConcreteType> {
class OpInterface
: public detail::Interface<ConcreteType, Operation *, Traits,
Op<ConcreteType>, OpTrait::TraitBase> {
public:
using Concept = typename Traits::Concept;
template <typename T> using Model = typename Traits::template Model<T>;
using Base = OpInterface<ConcreteType, Traits>;
using InterfaceBase = detail::Interface<ConcreteType, Operation *, Traits,
Op<ConcreteType>, OpTrait::TraitBase>;
OpInterface(Operation *op = nullptr)
: Op<ConcreteType>(op), impl(op ? getInterfaceFor(op) : nullptr) {
assert((!op || impl) &&
"instantiating an interface with an unregistered operation");
}
/// Support 'classof' by checking if the given operation defines the concrete
/// interface.
static bool classof(Operation *op) { return getInterfaceFor(op); }
/// Define an accessor for the ID of this interface.
static TypeID getInterfaceID() { return TypeID::get<ConcreteType>(); }
/// This is a special trait that registers a given interface with an
/// operation.
template <typename ConcreteOp>
struct Trait : public OpTrait::TraitBase<ConcreteOp, Trait> {
/// Define an accessor for the ID of this interface.
static TypeID getInterfaceID() { return TypeID::get<ConcreteType>(); }
/// Provide an accessor to a static instance of the interface model for the
/// concrete operation type.
/// The implementation is inspired from Sean Parent's concept-based
/// polymorphism. A key difference is that the set of classes erased is
/// statically known, which alleviates the need for using dynamic memory
/// allocation.
/// We use a zero-sized templated class `Model<ConcreteOp>` to emit the
/// virtual table and generate a singleton object for each instantiation of
/// this class.
static Concept &instance() {
static Model<ConcreteOp> singleton;
return singleton;
}
};
protected:
/// Get the raw concept in the correct derived concept type.
Concept *getImpl() { return impl; }
/// Inherit the base class constructor.
using InterfaceBase::InterfaceBase;
private:
/// Returns the impl interface instance for the given operation.
static Concept *getInterfaceFor(Operation *op) {
static typename InterfaceBase::Concept *getInterfaceFor(Operation *op) {
// Access the raw interface from the abstract operation.
auto *abstractOp = op->getAbstractOperation();
return abstractOp ? abstractOp->getInterface<ConcreteType>() : nullptr;
}
/// A pointer to the impl concept object.
Concept *impl;
/// Allow access to `getInterfaceFor`.
friend InterfaceBase;
};
//===----------------------------------------------------------------------===//
......
......@@ -19,7 +19,7 @@
#include "mlir/IR/Location.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Support/InterfaceSupport.h"
#include "llvm/ADT/BitmaskEnum.h"
#include "llvm/ADT/PointerUnion.h"
#include "llvm/Support/PointerLikeTypeTraits.h"
......@@ -136,8 +136,7 @@ public:
/// was registered to this operation, null otherwise. This should not be used
/// directly.
template <typename T> typename T::Concept *getInterface() const {
return reinterpret_cast<typename T::Concept *>(
getRawInterface(T::getInterfaceID()));
return interfaceMap.lookup<T>();
}
/// Returns if the operation has a particular trait.
......@@ -157,7 +156,7 @@ public:
T::getOperationName(), dialect, T::getOperationProperties(),
TypeID::get<T>(), T::parseAssembly, T::printAssembly,
T::verifyInvariants, T::foldHook, T::getCanonicalizationPatterns,
T::getRawInterface, T::hasTrait);
T::getInterfaceMap(), T::hasTrait);
}
private:
......@@ -171,22 +170,19 @@ private:
SmallVectorImpl<OpFoldResult> &results),
void (&getCanonicalizationPatterns)(OwningRewritePatternList &results,
MLIRContext *context),
void *(&getRawInterface)(TypeID interfaceID),
bool (&hasTrait)(TypeID traitID))
detail::InterfaceMap &&interfaceMap, bool (&hasTrait)(TypeID traitID))
: name(name), dialect(dialect), typeID(typeID),
parseAssembly(parseAssembly), printAssembly(printAssembly),
verifyInvariants(verifyInvariants), foldHook(foldHook),
getCanonicalizationPatterns(getCanonicalizationPatterns),
opProperties(opProperties), getRawInterface(getRawInterface),
opProperties(opProperties), interfaceMap(std::move(interfaceMap)),
hasRawTrait(hasTrait) {}
/// The properties of the operation.
const OperationProperties opProperties;
/// Returns a raw instance of the concept for the given interface id if it is
/// registered to this operation, nullptr otherwise. This should not be used
/// directly.
void *(&getRawInterface)(TypeID interfaceID);
/// A map of interfaces that were registered to this operation.
detail::InterfaceMap interfaceMap;
/// This hook returns if the operation contains the trait corresponding
/// to the given TypeID.
......
//===- InterfaceSupport.h - MLIR Interface Support Classes ------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines several support classes for defining interfaces.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_SUPPORT_INTERFACESUPPORT_H
#define MLIR_SUPPORT_INTERFACESUPPORT_H
#include "mlir/Support/TypeID.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/Support/TypeName.h"
#include "llvm/Support/raw_ostream.h"
namespace mlir {
namespace detail {
//===----------------------------------------------------------------------===//
// Interface
//===----------------------------------------------------------------------===//
/// This class represents an abstract interface. An interface is a simplified
/// mechanism for attaching concept based polymorphism to a class hierarchy. An
/// interace is comprised of two components:
/// * The derived interface class: This is what users interact with, and invoke
/// methods on.
/// * An interface `Trait` class: This is the class that is attached to the
/// object implementing the interface. It is the mechanism with which models
/// are specialized.
///
/// Derived interfaces types must provide the following template types:
/// * ConcreteType: The CRTP derived type.
/// * ValueT: The opaque type the derived interface operates on. For example
/// `Operation*` for operation interfaces, or `Attribute` for
/// attribute interfaces.
/// * Traits: A class that contains definitions for a 'Concept' and a 'Model'
/// class. The 'Concept' class defines an abstract virtual interface,
/// where as the 'Model' class implements this interface for a
/// specific derived T type. Both of these classes *must* not contain
/// non-static data. A simple example is shown below:
///
/// ```c++
/// struct ExampleInterfaceTraits {
/// struct Concept {
/// virtual unsigned getNumInputs(T t) const = 0;
/// };
/// template <typename DerivedT> class Model {
/// unsigned getNumInputs(T t) const final {
/// return cast<DerivedT>(t).getNumInputs();
/// }
/// };
/// };
/// ```
///
/// * BaseType: A desired base type for the interface. This is a class that
/// provides that provides specific functionality for the `ValueT`
/// value. For instance the specific `Op` that will wrap the
/// `Operation*` for an `OpInterface`.
/// * BaseTrait: The base type for the interface trait. This is the base class
/// to use for the interface trait that will be attached to each
/// instance of `ValueT` that implements this interface.
///
template <typename ConcreteType, typename ValueT, typename Traits,
typename BaseType,
template <typename, template <typename> class> class BaseTrait>
class Interface : public BaseType {
public:
using Concept = typename Traits::Concept;
template <typename T> using Model = typename Traits::template Model<T>;
using InterfaceBase =
Interface<ConcreteType, ValueT, Traits, BaseType, BaseTrait>;
Interface(ValueT t = ValueT())
: BaseType(t), impl(t ? ConcreteType::getInterfaceFor(t) : nullptr) {
assert((!t || impl) &&
"instantiating an interface with an unregistered operation");
}
/// Support 'classof' by checking if the given object defines the concrete
/// interface.
static bool classof(ValueT t) { return ConcreteType::getInterfaceFor(t); }
/// Define an accessor for the ID of this interface.
static TypeID getInterfaceID() { return TypeID::get<ConcreteType>(); }
/// This is a special trait that registers a given interface with an object.
template <typename ConcreteT>
struct Trait : public BaseTrait<ConcreteT, Trait> {
/// Define an accessor for the ID of this interface.
static TypeID getInterfaceID() { return TypeID::get<ConcreteType>(); }
/// Provide an accessor to a static instance of the interface model for the
/// concrete T type.
/// The implementation is inspired from Sean Parent's concept-based
/// polymorphism. A key difference is that the set of classes erased is
/// statically known, which alleviates the need for using dynamic memory
/// allocation.
/// We use a zero-sized templated class `Model<ConcreteT>` to emit the
/// virtual table and generate a singleton object for each instantiation of
/// this class.
static Concept &instance() {
static Model<ConcreteT> singleton;
return singleton;
}
};
protected:
/// Get the raw concept in the correct derived concept type.
const Concept *getImpl() const { return impl; }
Concept *getImpl() { return impl; }
private:
/// A pointer to the impl concept object.
Concept *impl;
};
//===----------------------------------------------------------------------===//
// InterfaceMap
//===----------------------------------------------------------------------===//
/// This class provides an efficient mapping between a given `Interface` type,
/// and a particular implementation of its concept.
class InterfaceMap {
public:
/// Construct an InterfaceMap with the given set of template types. For
/// convenience given that object trait lists may contain other non-interface
/// types, not all of the types need to be interfaces. The provided types that
/// do not represent interfaces are not added to the interface map.
template <typename... Types> static InterfaceMap get() {
return InterfaceMap(MapBuilder::create<Types...>());
}
/// Returns an instance of the concept object for the given interface if it
/// was registered to this map, null otherwise.
template <typename T> typename T::Concept *lookup() const {
if (!interfaces)
return nullptr;
return reinterpret_cast<typename T::Concept *>(
interfaces->lookup(T::getInterfaceID()));
}
private:
/// This struct provides support for building a map of interfaces.
class MapBuilder {
public:
template <typename... Types>
static std::unique_ptr<llvm::SmallDenseMap<TypeID, void *>> create() {
// Filter the provided types for those that are interfaces. This reduces
// the amount of maps that are generated.
return createImpl((typename FilterTypes<detect_get_interface_id,
Types...>::type *)nullptr);
}
private:
/// Trait to check if T provides a static 'getInterfaceID' method.
template <typename T, typename... Args>
using has_get_interface_id = decltype(T::getInterfaceID());
template <typename T>
using detect_get_interface_id = llvm::is_detected<has_get_interface_id, T>;
/// Utility to filter a given sequence of types base upon a predicate.
template <bool> struct FilterTypeT {
template <class E> using type = std::tuple<E>;
};
template <> struct FilterTypeT<false> {
template <class E> using type = std::tuple<>;
};
template <template <class> class Pred, class... Es> struct FilterTypes {
using type = decltype(std::tuple_cat(
std::declval<
typename FilterTypeT<Pred<Es>::value>::template type<Es>>()...));
};
template <typename... Ts>
static std::unique_ptr<llvm::SmallDenseMap<TypeID, void *>>
createImpl(std::tuple<Ts...> *) {
// Only create an instance of the map if there are any interface types.
if (sizeof...(Ts) == 0)
return std::unique_ptr<llvm::SmallDenseMap<TypeID, void *>>();
auto map = std::make_unique<llvm::SmallDenseMap<TypeID, void *>>();
(void)std::initializer_list<int>{
0, (map->try_emplace(Ts::getInterfaceID(), &Ts::instance()), 0)...};
return map;
}
};
private:
InterfaceMap(std::unique_ptr<llvm::SmallDenseMap<TypeID, void *>> interfaces)
: interfaces(std::move(interfaces)) {}
/// The internal map of interfaces. This is constructed statically for each
/// set of interfaces.
std::unique_ptr<llvm::SmallDenseMap<TypeID, void *>> interfaces;
};
} // end namespace detail
} // end namespace mlir
#endif
......@@ -560,9 +560,10 @@ void Dialect::addOperation(AbstractOperation opInfo) {
auto &impl = context->getImpl();
// Lock access to the context registry.
StringRef opName = opInfo.name;
ScopedWriterLock registryLock(impl.contextMutex, impl.threadingIsEnabled);
if (!impl.registeredOperations.insert({opInfo.name, opInfo}).second) {
llvm::errs() << "error: operation named '" << opInfo.name
if (!impl.registeredOperations.insert({opName, std::move(opInfo)}).second) {
llvm::errs() << "error: operation named '" << opName
<< "' is already registered.\n";
abort();
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment