Commit e0f4e9e4 authored by Mccaskey, Alex's avatar Mccaskey, Alex
Browse files

update the mlir api to accept an llvm opt level

parent b44adf06
Loading
Loading
Loading
Loading
+12 −0
Original line number Diff line number Diff line
@@ -41,3 +41,15 @@ add_executable(qasm3CompilerTester_Subroutine test_subroutine.cpp)
add_test(NAME qcor_qasm3_test_subroutine COMMAND qasm3CompilerTester_Subroutine)
target_include_directories(qasm3CompilerTester_Subroutine PRIVATE . ../../ ${XACC_ROOT}/include/gtest)
target_link_libraries(qasm3CompilerTester_Subroutine qcor-mlir-api gtest gtest_main)


#add_executable(qasm3CompilerTester_GlobalConstInSubroutine test_use_global_const_in_subroutine.cpp)
#add_test(NAME qcor_qasm3_test_global_const_subroutine COMMAND qasm3CompilerTester_GlobalConstInSubroutine)
#target_include_directories(qasm3CompilerTester_GlobalConstInSubroutine PRIVATE . ../../ ${XACC_ROOT}/include/gtest)
#target_link_libraries(qasm3CompilerTester_GlobalConstInSubroutine qcor-mlir-api gtest gtest_main)


add_executable(qasm3CompilerTester_Arithmetic test_complex_arithmetic.cpp)
add_test(NAME qcor_qasm3_test_arithmetic COMMAND qasm3CompilerTester_Arithmetic)
target_include_directories(qasm3CompilerTester_Arithmetic PRIVATE . ../../ ${XACC_ROOT}/include/gtest)
target_link_libraries(qasm3CompilerTester_Arithmetic qcor-mlir-api gtest gtest_main)
+25 −0
Original line number Diff line number Diff line
#include "gtest/gtest.h"
#include "qcor_mlir_api.hpp"

TEST(qasm3VisitorTester, checkArithmetic) {
  const std::string global_const = R"#(OPENQASM 3;
include "qelib1.inc";
const shots = 1024.0;
float[64] num_parity_ones = 508.0;
float[64] result, test;

result = (shots - num_parity_ones) / shots - num_parity_ones / shots;
test = result - .007812;
QCOR_EXPECT_TRUE(test < .01);
)#";
  auto mlir = qcor::mlir_compile("qasm3", global_const, "global_const",
                                 qcor::OutputType::MLIR, false);
  std::cout << mlir << "\n";
  EXPECT_FALSE(qcor::execute("qasm3", global_const, "global_const"));
}

int main(int argc, char **argv) {
  ::testing::InitGoogleTest(&argc, argv);
  auto ret = RUN_ALL_TESTS();
  return ret;
}
+28 −7
Original line number Diff line number Diff line
@@ -24,7 +24,7 @@ const std::string mlir_compile(const std::string &src_language_type,
                               const std::string &src,
                               const std::string &kernel_name,
                               const OutputType &output_type,
                               bool add_entry_point) {
                               bool add_entry_point, int opt_level) {
  mlir::registerAsmPrinterCLOptions();
  mlir::registerMLIRContextCLOptions();

@@ -84,7 +84,7 @@ const std::string mlir_compile(const std::string &src_language_type,
  // Optimize the LLVM IR
  llvm::InitializeNativeTarget();
  llvm::InitializeNativeTargetAsmPrinter();
  auto optPipeline = mlir::makeOptimizingTransformer(3, 0, nullptr);
  auto optPipeline = mlir::makeOptimizingTransformer(opt_level, 0, nullptr);
  if (auto err = optPipeline(llvmModule.get())) {
    llvm::errs() << "Failed to optimize LLVM IR " << err << "\n";
    return "";
@@ -103,7 +103,7 @@ const std::string mlir_compile(const std::string &src_language_type,
}

int execute(const std::string &src_language_type, const std::string &src,
             const std::string &kernel_name) {
            const std::string &kernel_name, int opt_level) {
  mlir::registerAsmPrinterCLOptions();
  mlir::registerMLIRContextCLOptions();

@@ -129,6 +129,27 @@ int execute(const std::string &src_language_type, const std::string &src,
  unique_function_names = mlir_generator->seen_function_names();
  auto module = mlir_generator->get_module();

DiagnosticEngine& engine = context.getDiagEngine();

/// Handle the reported diagnostic.
  // Return success to signal that the diagnostic has either been fully
  // processed, or failure if the diagnostic should be propagated to the
  // previous handlers.
  DiagnosticEngine::HandlerID id =
      engine.registerHandler([&](Diagnostic &diag) -> LogicalResult {
        std::cout << "Dumping Module after error.\n";
        module->dump();
        for (auto &n : diag.getNotes()) {
          std::string s;
          llvm::raw_string_ostream os(s);
          n.print(os);
          os.flush();
          std::cout << "DiagnosticEngine Note: " << s << "\n";
        }
        bool should_propagate_diagnostic = true;
        return failure(should_propagate_diagnostic);
      });

  // Create the PassManager for lowering to LLVM MLIR and run it
  mlir::PassManager pm(&context);
  pm.addPass(
@@ -146,7 +167,7 @@ int execute(const std::string &src_language_type, const std::string &src,
  // Optimize the LLVM IR
  llvm::InitializeNativeTarget();
  llvm::InitializeNativeTargetAsmPrinter();
  auto optPipeline = mlir::makeOptimizingTransformer(3, 0, nullptr);
  auto optPipeline = mlir::makeOptimizingTransformer(opt_level, 0, nullptr);
  if (auto err = optPipeline(llvmModule.get())) {
    llvm::errs() << "Failed to optimize LLVM IR " << err << "\n";
    return 1;
+3 −4
Original line number Diff line number Diff line
@@ -10,9 +10,8 @@ const std::string mlir_compile(const std::string& src_language_type,
                               const std::string& src,
                               const std::string& kernel_name,
                               const OutputType& output_type,
                               bool add_entry_point);
                               bool add_entry_point, int opt_level = 3);

int execute(const std::string& src_language_type,
                               const std::string& src,
                               const std::string& kernel_name);
int execute(const std::string& src_language_type, const std::string& src,
            const std::string& kernel_name, int opt_level = 3);
}  // namespace qcor
 No newline at end of file
+21 −0
Original line number Diff line number Diff line
@@ -104,6 +104,27 @@ int main(int argc, char **argv) {
    return 0;
  }

  DiagnosticEngine& engine = context.getDiagEngine();

  /// Handle the reported diagnostic.
  // Return success to signal that the diagnostic has either been fully
  // processed, or failure if the diagnostic should be propagated to the
  // previous handlers.
  DiagnosticEngine::HandlerID id =
      engine.registerHandler([&](Diagnostic &diag) -> LogicalResult {
        std::cout << "Dumping Module after error.\n";
        module->dump();
        for (auto &n : diag.getNotes()) {
          std::string s;
          llvm::raw_string_ostream os(s);
          n.print(os);
          os.flush();
          std::cout << "DiagnosticEngine Note: " << s << "\n";
        }
        bool should_propagate_diagnostic = true;
        return failure(should_propagate_diagnostic);
      });

  // Create the PassManager for lowering to LLVM MLIR and run it
  mlir::PassManager pm(&context);
  applyPassManagerCLOptions(pm);