Commit b7d50ba1 authored by Stephen Neuendorffer's avatar Stephen Neuendorffer
Browse files

[MLIR] Refactor library initialization of JitRunner.

Previously, lib/Support/JitRunner.cpp was essentially a complete application,
performing all library initialization, along with dealing with command line
arguments and actually running passes.  This differs significantly from
mlir-opt and required a dependency on InitAllDialects.h.  This dependency
is significant, since it requires a dependency on all of the resulting
libraries.

This patch refactors the code so that tools are responsible for library
initialization, including registering all dialects, prior to calling
JitRunnerMain.  This places the concern about what dialect to support
with the end application, enabling more extensibility at the cost of
a small amount of code duplication between tools.  It also fixes
BUILD_SHARED_LIBS=on.

Differential Revision: https://reviews.llvm.org/D75272
parent c07fb9e0
Loading
Loading
Loading
Loading
+0 −15
Original line number Diff line number Diff line
@@ -22,7 +22,6 @@
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/InitAllDialects.h"
#include "mlir/Parser.h"
#include "mlir/Support/FileUtilities.h"

@@ -34,10 +33,8 @@
#include "llvm/IR/Module.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/FileUtilities.h"
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/StringSaver.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Support/ToolOutputFile.h"
#include <numeric>

@@ -108,12 +105,6 @@ static OwningModuleRef parseMLIRInput(StringRef inputFilename,
  return OwningModuleRef(parseSourceFile(sourceMgr, context));
}

// Initialize the relevant subsystems of LLVM.
static void initializeLLVM() {
  llvm::InitializeNativeTarget();
  llvm::InitializeNativeTargetAsmPrinter();
}

static inline Error make_string_error(const Twine &message) {
  return llvm::make_error<llvm::StringError>(message.str(),
                                             llvm::inconvertibleErrorCode());
@@ -210,12 +201,6 @@ static Error compileAndExecuteSingleFloatReturnFunction(
int mlir::JitRunnerMain(
    int argc, char **argv,
    function_ref<LogicalResult(mlir::ModuleOp)> mlirTransformer) {
  registerAllDialects();
  llvm::InitLLVM y(argc, argv);

  initializeLLVM();
  mlir::initializeLLVMPasses();

  llvm::cl::ParseCommandLineOptions(argc, argv, "MLIR CPU execution driver\n");

  Optional<unsigned> optLevel = getCommandLineOptLevel();
+10 −0
Original line number Diff line number Diff line
@@ -12,8 +12,18 @@
//
//===----------------------------------------------------------------------===//

#include "mlir/InitAllDialects.h"
#include "mlir/Support/JitRunner.h"
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/TargetSelect.h"
#include "mlir/ExecutionEngine/OptUtils.h"

int main(int argc, char **argv) {
  mlir::registerAllDialects();
  llvm::InitLLVM y(argc, argv);
  llvm::InitializeNativeTarget();
  llvm::InitializeNativeTargetAsmPrinter();
  mlir::initializeLLVMPasses();

  return mlir::JitRunnerMain(argc, argv, nullptr);
}
+9 −0
Original line number Diff line number Diff line
@@ -22,13 +22,17 @@
#include "mlir/Dialect/GPU/Passes.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/ExecutionEngine/OptUtils.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Module.h"
#include "mlir/InitAllDialects.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/JitRunner.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/TargetSelect.h"

#include "cuda.h"

@@ -118,5 +122,10 @@ static LogicalResult runMLIRPasses(ModuleOp m) {

int main(int argc, char **argv) {
  registerPassManagerCLOptions();
  mlir::registerAllDialects();
  llvm::InitLLVM y(argc, argv);
  llvm::InitializeNativeTarget();
  llvm::InitializeNativeTargetAsmPrinter();
  mlir::initializeLLVMPasses();
  return mlir::JitRunnerMain(argc, argv, &runMLIRPasses);
}
+11 −0
Original line number Diff line number Diff line
@@ -19,9 +19,13 @@
#include "mlir/Dialect/GPU/Passes.h"
#include "mlir/Dialect/SPIRV/Passes.h"
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
#include "mlir/ExecutionEngine/OptUtils.h"
#include "mlir/InitAllDialects.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/JitRunner.h"
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/TargetSelect.h"

using namespace mlir;

@@ -42,5 +46,12 @@ static LogicalResult runMLIRPasses(ModuleOp module) {
int main(int argc, char **argv) {
  llvm::llvm_shutdown_obj x;
  registerPassManagerCLOptions();

  mlir::registerAllDialects();
  llvm::InitLLVM y(argc, argv);
  llvm::InitializeNativeTarget();
  llvm::InitializeNativeTargetAsmPrinter();
  mlir::initializeLLVMPasses();

  return mlir::JitRunnerMain(argc, argv, &runMLIRPasses);
}