Commit ed971c76 authored by Mccaskey, Alex's avatar Mccaskey, Alex
Browse files

updated mlir work with latest from mlir master

parent ef1ce545
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -8,6 +8,7 @@ mkdir llvm_mlir/build
cd llvm_mlir/build
cmake -G Ninja ../llvm \
   -DLLVM_ENABLE_PROJECTS=mlir \
   -DBUILD_SHARED_LIBS=TRUE
   -DLLVM_BUILD_EXAMPLES=ON \
   -DLLVM_TARGETS_TO_BUILD="X86" \
   -DCMAKE_BUILD_TYPE=Release \
+5 −5
Original line number Diff line number Diff line
@@ -52,14 +52,14 @@ void OpenQasmMLIRGenerator::visit(Program &prog) {
  llvm::StringRef qubit_type_name("Qubit"), array_type_name("Array"),
      result_type_name("Result");
  mlir::Identifier dialect = mlir::Identifier::get("quantum", &context);
  qubit_type = mlir::OpaqueType::get(dialect, qubit_type_name, &context);
  array_type = mlir::OpaqueType::get(dialect, array_type_name, &context);
  result_type = mlir::OpaqueType::get(dialect, result_type_name, &context);
  qubit_type = mlir::OpaqueType::get(&context, dialect, qubit_type_name);
  array_type = mlir::OpaqueType::get(&context, dialect, array_type_name);
  result_type = mlir::OpaqueType::get(&context, dialect, result_type_name);
  auto int_type = builder.getI32Type();
  auto argv_type =
      mlir::OpaqueType::get(dialect, llvm::StringRef("ArgvType"), &context);
      mlir::OpaqueType::get(&context, dialect, llvm::StringRef("ArgvType"));
  auto qreg_type =
      mlir::OpaqueType::get(dialect, llvm::StringRef("qreg"), &context);
      mlir::OpaqueType::get(&context, dialect, llvm::StringRef("qreg"));

  if (add_main) {
    std::vector<mlir::Type> arg_types_vec2{};
+77 −68
Original line number Diff line number Diff line
@@ -27,7 +27,7 @@ namespace {
using namespace mlir;
std::map<std::string, std::string> inst_map{{"cx", "cnot"}, {"measure", "mz"}};

mlir::LLVM::LLVMType get_quantum_type(std::string type,
mlir::Type get_quantum_type(std::string type,
                                      mlir::MLIRContext *context) {
  return LLVM::LLVMStructType::getOpaque(type, context);
}
@@ -71,10 +71,11 @@ class QallocOpLowering : public ConversionPattern {
      symbol_ref = SymbolRefAttr::get(qir_qubit_array_allocate, context);
    } else {
      // prototype is (size : int64) -> Array* : qubit_array_ptr
      auto qubit_type = LLVM::LLVMType::getInt64Ty(context);
      auto array_qbit_type = get_quantum_type("Array", context).getPointerTo();
      auto qubit_type = IntegerType::get(context, 64);
      auto array_qbit_type =
          LLVM::LLVMPointerType::get(get_quantum_type("Array", context));
      auto qalloc_ftype =
          LLVM::LLVMType::getFunctionTy(array_qbit_type, qubit_type, false);
          LLVM::LLVMFunctionType::get(array_qbit_type, qubit_type, false);

      // Insert the function declaration
      PatternRewriter::InsertionGuard insertGuard(rewriter);
@@ -94,9 +95,10 @@ class QallocOpLowering : public ConversionPattern {
    // size_value = constantop (size)
    // qubit_array_ptr = callop ( size_value )
    Value create_size_int = rewriter.create<LLVM::ConstantOp>(
        loc, LLVM::LLVMType::getInt64Ty(rewriter.getContext()),
        loc, IntegerType::get(rewriter.getContext(), 64),
        rewriter.getIntegerAttr(rewriter.getI64Type(), size));
    auto array_qbit_type = get_quantum_type("Array", context).getPointerTo();
    auto array_qbit_type =
        LLVM::LLVMPointerType::get(get_quantum_type("Array", context));
    auto qalloc_qir_call = rewriter.create<mlir::CallOp>(
        loc, symbol_ref, array_qbit_type, ArrayRef<Value>({create_size_int}));

@@ -150,10 +152,11 @@ class DeallocOpLowering : public ConversionPattern {
      symbol_ref = SymbolRefAttr::get(qir_qubit_array_deallocate, context);
    } else {
      // prototype is (Array*) -> void
      auto void_type = LLVM::LLVMType::getVoidTy(context);
      auto array_qbit_type = get_quantum_type("Array", context).getPointerTo();
      auto void_type = LLVM::LLVMVoidType::get(context);
      auto array_qbit_type =
          LLVM::LLVMPointerType::get(get_quantum_type("Array", context));
      auto dealloc_ftype =
          LLVM::LLVMType::getFunctionTy(void_type, array_qbit_type, false);
          LLVM::LLVMFunctionType::get(void_type, array_qbit_type, false);

      // Insert the function declaration
      PatternRewriter::InsertionGuard insertGuard(rewriter);
@@ -174,7 +177,7 @@ class DeallocOpLowering : public ConversionPattern {
    // create a CallOp for the new quantum runtime de-allocation
    // function.
    rewriter.create<mlir::CallOp>(loc, symbol_ref,
                                  LLVM::LLVMType::getVoidTy(context),
                                  LLVM::LLVMVoidType::get(context),
                                  ArrayRef<Value>({qubits}));

    // Remove the old QuantumDialect QallocOp
@@ -219,11 +222,12 @@ class QRTInitOpLowering : public ConversionPattern {
      symbol_ref = SymbolRefAttr::get(qir_qrt_initialize, context);
    } else {
      // prototype is (Array*) -> void
      auto int_type = LLVM::LLVMType::getInt32Ty(context);
      std::vector<LLVM::LLVMType> arg_types{
          LLVM::LLVMType::getInt32Ty(context),
          LLVM::LLVMType::getInt8PtrTy(context).getPointerTo()};
      auto init_ftype = LLVM::LLVMType::getFunctionTy(
      auto int_type = IntegerType::get(context, 32);
      std::vector<mlir::Type> arg_types{
          IntegerType::get(context, 32),
          LLVM::LLVMPointerType::get(
              LLVM::LLVMPointerType::get(IntegerType::get(context, 8)))};
      auto init_ftype = LLVM::LLVMFunctionType::get(
          int_type, llvm::makeArrayRef(arg_types), false);

      // Insert the function declaration
@@ -244,7 +248,7 @@ class QRTInitOpLowering : public ConversionPattern {
    // create a CallOp for the new quantum runtime initialize
    // function.
    rewriter.create<mlir::CallOp>(
        loc, symbol_ref, LLVM::LLVMType::getInt32Ty(context),
        loc, symbol_ref, IntegerType::get(context, 32),
        ArrayRef<Value>({variables["main_argc"], variables["main_argv"]}));

    // Remove the old QuantumDialect QallocOp
@@ -288,9 +292,9 @@ class QRTFinalizeOpLowering : public ConversionPattern {
      symbol_ref = SymbolRefAttr::get(qir_qrt_finalize, context);
    } else {
      // prototype is () -> void
      auto void_type = LLVM::LLVMType::getVoidTy(context);
      std::vector<LLVM::LLVMType> arg_types;
      auto init_ftype = LLVM::LLVMType::getFunctionTy(
      auto void_type = LLVM::LLVMVoidType::get(context);
      std::vector<mlir::Type> arg_types;
      auto init_ftype = LLVM::LLVMFunctionType::get(
          void_type, llvm::makeArrayRef(arg_types), false);

      // Insert the function declaration
@@ -310,9 +314,8 @@ class QRTFinalizeOpLowering : public ConversionPattern {
    // args[1].dump();
    // create a CallOp for the new quantum runtime initialize
    // function.
    rewriter.create<mlir::CallOp>(loc, symbol_ref,
                                  LLVM::LLVMType::getVoidTy(context),
                                  ArrayRef<Value>({}));
    rewriter.create<mlir::CallOp>(
        loc, symbol_ref, LLVM::LLVMVoidType::get(context), ArrayRef<Value>({}));

    // Remove the old QuantumDialect QallocOp
    rewriter.eraseOp(op);
@@ -356,10 +359,10 @@ class SetQregOpLowering : public ConversionPattern {
      symbol_ref = SymbolRefAttr::get(qir_qrt_finalize, context);
    } else {
      // prototype is () -> void
      auto void_type = LLVM::LLVMType::getVoidTy(context);
      std::vector<LLVM::LLVMType> arg_types{
          get_quantum_type("qreg", context).getPointerTo()};
      auto init_ftype = LLVM::LLVMType::getFunctionTy(
      auto void_type = LLVM::LLVMVoidType::get(context);
      std::vector<mlir::Type> arg_types{
          LLVM::LLVMPointerType::get(get_quantum_type("qreg", context))};
      auto init_ftype = LLVM::LLVMFunctionType::get(
          void_type, llvm::makeArrayRef(arg_types), false);

      // Insert the function declaration
@@ -380,7 +383,7 @@ class SetQregOpLowering : public ConversionPattern {
    // create a CallOp for the new quantum runtime initialize
    // function.
    rewriter.create<mlir::CallOp>(
        loc, symbol_ref, LLVM::LLVMType::getVoidTy(context),
        loc, symbol_ref, LLVM::LLVMVoidType::get(context),
        ArrayRef<Value>({variables["_incoming_qreg_variable"]}));

    // Remove the old QuantumDialect QallocOp
@@ -407,14 +410,17 @@ class QuantumFuncArgConverter : public ConversionPattern {
      if (type.isa<mlir::OpaqueType>()) {
        auto casted = type.cast<mlir::OpaqueType>();
        if (casted.getTypeData() == "Qubit") {
          return get_quantum_type("Qubit", this->context).getPointerTo();
          return LLVM::LLVMPointerType::get(
              get_quantum_type("Qubit", this->context));
        } else if (casted.getTypeData() == "ArgvType") {
          return LLVM::LLVMType::getInt8PtrTy(context).getPointerTo();
          return LLVM::LLVMPointerType::get(
              LLVM::LLVMPointerType::get(IntegerType::get(context, 8)));
        } else if (casted.getTypeData() == "qreg") {
          return get_quantum_type("qreg", this->context).getPointerTo();
          return LLVM::LLVMPointerType::get(
              get_quantum_type("qreg", this->context));
        }
      } else if (type.isa<mlir::IntegerType>()) {
        return LLVM::LLVMType::getInt32Ty(this->context);
        return IntegerType::get(this->context, 32);
      }
      return llvm::None;
    });
@@ -432,12 +438,13 @@ class QuantumFuncArgConverter : public ConversionPattern {
    auto func_name = funcOp.getName().str();

    if (func_name == "main") {
      auto charstarstar = LLVM::LLVMType::getInt8PtrTy(context).getPointerTo();
      std::vector<LLVM::LLVMType> tmp_arg_types{
          LLVM::LLVMType::getInt32Ty(context), charstarstar};
      auto charstarstar = LLVM::LLVMPointerType::get(
          LLVM::LLVMPointerType::get(IntegerType::get(context, 8)));
      std::vector<Type> tmp_arg_types{IntegerType::get(context, 32),
                                      charstarstar};

      auto new_main_signature = LLVM::LLVMType::getFunctionTy(
          LLVM::LLVMType::getInt32Ty(context),
      auto new_main_signature =
          LLVM::LLVMFunctionType::get(IntegerType::get(context, 32),
                                      llvm::makeArrayRef(tmp_arg_types), false);

      auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(loc, funcOp.sym_name(),
@@ -461,12 +468,12 @@ class QuantumFuncArgConverter : public ConversionPattern {
    if (ftype.getNumInputs() == 1 &&
        ftype.getInput(0).isa<mlir::OpaqueType>() &&
        ftype.getInput(0).cast<mlir::OpaqueType>().getTypeData() == "qreg") {
      std::vector<LLVM::LLVMType> tmp_arg_types{
          get_quantum_type("qreg", context).getPointerTo()};
      std::vector<mlir::Type> tmp_arg_types{
          LLVM::LLVMPointerType::get(get_quantum_type("qreg", context))};

      auto new_func_signature = LLVM::LLVMType::getFunctionTy(
          LLVM::LLVMType::getVoidTy(context), llvm::makeArrayRef(tmp_arg_types),
          false);
      auto new_func_signature =
          LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(context),
                                 llvm::makeArrayRef(tmp_arg_types), false);

      auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(loc, funcOp.sym_name(),
                                                         new_func_signature);
@@ -488,15 +495,14 @@ class QuantumFuncArgConverter : public ConversionPattern {

    // Not main, sub quantum kernel, convert Qubit to Qubit*
    if (ftype.getNumInputs() > 0) {
      std::vector<LLVM::LLVMType> tmp_arg_types;
      std::vector<mlir::Type> tmp_arg_types;
      for (unsigned i = 0; i < ftype.getNumInputs(); i++) {
        tmp_arg_types.push_back(
            get_quantum_type("Qubit", context).getPointerTo());
            LLVM::LLVMPointerType::get(get_quantum_type("Qubit", context)));
      }

      auto new_func_signature = LLVM::LLVMType::getFunctionTy(
          LLVM::LLVMType::getVoidTy(context), llvm::makeArrayRef(tmp_arg_types),
          false);
      auto new_func_signature = LLVM::LLVMFunctionType::get(
          LLVM::LLVMVoidType::get(context), llvm::makeArrayRef(tmp_arg_types), false);

      auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(loc, funcOp.sym_name(),
                                                         new_func_signature);
@@ -606,19 +612,19 @@ class InstOpLowering : public ConversionPattern {
      q_symbol_ref = SymbolRefAttr::get(q_function_name, context);
    } else {
      // Return type should be void except for mz, which should be int64
      LLVM::LLVMType ret_type = LLVM::LLVMType::getVoidTy(context);
      mlir::Type ret_type = LLVM::LLVMVoidType::get(context);
      if (inst_name == "mz") {
        ret_type = get_quantum_type("Result", context).getPointerTo();
        // LLVM::LLVMType::getInt64Ty(context);
        ret_type =
            LLVM::LLVMPointerType::get(get_quantum_type("Result", context));
      }

      // Create Types for all function arguments, start with
      // double parameters (if instOp has them)
      std::vector<LLVM::LLVMType> tmp_arg_types;
      std::vector<Type> tmp_arg_types;
      if (instOp.params()) {
        auto params = instOp.params().getValue();
        for (int i = 0; i < params.size(); i++) {
          auto param_type = LLVM::LLVMType::getDoubleTy(context);
          auto param_type = FloatType::getF64(context);
          tmp_arg_types.push_back(param_type);
        }
      }
@@ -626,13 +632,13 @@ class InstOpLowering : public ConversionPattern {
      // Now, we need a Int64Type for each qubit argument
      for (std::size_t i = 0; i < operands.size(); i++) {
        auto qubit_index_type =
            get_quantum_type("Qubit", context).getPointerTo();
        // LLVM::LLVMType::getInt64Ty(context).getPointerTo();
            LLVM::LLVMPointerType::get(get_quantum_type("Qubit", context));
        // IntegerType::get(context, 64).getPointerTo();
        tmp_arg_types.push_back(qubit_index_type);
      }

      // Create the LLVM FunctionType
      auto get_ptr_qbit_ftype = LLVM::LLVMType::getFunctionTy(
      auto get_ptr_qbit_ftype = LLVM::LLVMFunctionType::get(
          ret_type, llvm::makeArrayRef(tmp_arg_types), false);

      // Insert the function since it hasn't been seen yet
@@ -655,8 +661,7 @@ class InstOpLowering : public ConversionPattern {
            mlir::FloatAttr::get(rewriter.getF64Type(), param_double);

        Value const_double_op = rewriter.create<LLVM::ConstantOp>(
            loc, LLVM::LLVMType::getDoubleTy(rewriter.getContext()),
            double_attr);
            loc, FloatType::getF64(rewriter.getContext()), double_attr);

        func_args.push_back(const_double_op);
      }
@@ -668,9 +673,10 @@ class InstOpLowering : public ConversionPattern {
    }

    // once again, return type should be void unless its a measure
    LLVM::LLVMType ret_type = LLVM::LLVMType::getVoidTy(context);
    mlir::Type ret_type = LLVM::LLVMVoidType::get(context);
    if (inst_name == "mz") {
      ret_type = get_quantum_type("Result", context).getPointerTo();
      ret_type =
          LLVM::LLVMPointerType::get(get_quantum_type("Result", context));
    }

    // Create the CallOp for this quantum instruction
@@ -750,16 +756,16 @@ class ExtractQubitOpConversion : public ConversionPattern {
      symbol_ref = SymbolRefAttr::get(qir_get_qubit_from_array, context);
    } else {
      // prototype should be (int64* : qreg, int64 : element) -> int64* : qubit
      auto qubit_array_type = get_quantum_type("Array", context).getPointerTo();
      auto qubit_index_type = LLVM::LLVMType::getInt64Ty(context);
      auto qubit_array_type =
          LLVM::LLVMPointerType::get(get_quantum_type("Array", context));
      auto qubit_index_type = IntegerType::get(context, 64);
      // ret is i8*
      auto qbit_element_ptr_type =
          LLVM::LLVMType::getInt8Ty(context).getPointerTo();
          LLVM::LLVMPointerType::get(IntegerType::get(context, 8));

      auto get_ptr_qbit_ftype = LLVM::LLVMType::getFunctionTy(
      auto get_ptr_qbit_ftype = LLVM::LLVMFunctionType::get(
          qbit_element_ptr_type,
          llvm::ArrayRef<LLVM::LLVMType>{qubit_array_type, qubit_index_type},
          false);
          llvm::ArrayRef<Type>{qubit_array_type, qubit_index_type}, false);

      PatternRewriter::InsertionGuard insertGuard(rewriter);
      rewriter.setInsertionPointToStart(parentModule.getBody());
@@ -770,17 +776,20 @@ class ExtractQubitOpConversion : public ConversionPattern {
    }

    // Create the CallOp for the get element ptr 1d function
    auto array_qbit_type = LLVM::LLVMType::getInt8Ty(context).getPointerTo();
    auto array_qbit_type =
        LLVM::LLVMPointerType::get(IntegerType::get(context, 8));
    auto get_qbit_qir_call = rewriter.create<mlir::CallOp>(
        location, symbol_ref, array_qbit_type,
        ArrayRef<Value>({vars[qreg_name], adaptor.idx()}));

    auto bitcast = rewriter.create<LLVM::BitcastOp>(
        location,
        get_quantum_type("Qubit", context).getPointerTo().getPointerTo(),
        LLVM::LLVMPointerType::get(
            LLVM::LLVMPointerType::get(get_quantum_type("Qubit", context))),
        get_qbit_qir_call.getResult(0));
    auto real_casted_qubit = rewriter.create<LLVM::LoadOp>(
        location, get_quantum_type("Qubit", context).getPointerTo(),
        location,
        LLVM::LLVMPointerType::get(get_quantum_type("Qubit", context)),
        bitcast.res());

    // Remember the variable name for this qubit
+13 −8
Original line number Diff line number Diff line
@@ -255,11 +255,12 @@ class LLVMJIT {
          std::unique_ptr<LLVMContext> ctx = std::make_unique<LLVMContext>())
      : ObjectLayer(ES,
                    []() { return std::make_unique<SectionMemoryManager>(); }),
        CompileLayer(ES, ObjectLayer, ConcurrentIRCompiler(std::move(JTMB))),
        CompileLayer(ES, ObjectLayer,
                     std::make_unique<ConcurrentIRCompiler>(std::move(JTMB))),
        DL(std::move(DL)),
        Mangle(ES, this->DL),
        Ctx(std::move(ctx)),
        MainJD(ES.createJITDylib("<main>")) {
        MainJD(ES.createBareJITDylib("<main>")) {
    MainJD.addGenerator(
        cantFail(DynamicLibrarySearchGenerator::GetForCurrentProcess(
            DL.getGlobalPrefix())));
@@ -294,17 +295,21 @@ class LLVMJIT {
  LLVMContext &getContext() { return *Ctx.getContext(); }

  Error addModule(std::unique_ptr<llvm::Module> M) {
    // FIXME hook up to cmake
    MainJD.addGenerator(cantFail(DynamicLibrarySearchGenerator::Load(
        "@XACC_ROOT@/lib/libxacc@CMAKE_SHARED_LIBRARY_SUFFIX@", DL.getGlobalPrefix())));
        "@XACC_ROOT@/lib/libxacc@CMAKE_SHARED_LIBRARY_SUFFIX@",
        DL.getGlobalPrefix())));
    MainJD.addGenerator(cantFail(DynamicLibrarySearchGenerator::Load(
        "@CMAKE_INSTALL_PREFIX@/lib/libqrt@CMAKE_SHARED_LIBRARY_SUFFIX@", DL.getGlobalPrefix())));
        "@CMAKE_INSTALL_PREFIX@/lib/libqrt@CMAKE_SHARED_LIBRARY_SUFFIX@",
        DL.getGlobalPrefix())));
    MainJD.addGenerator(cantFail(DynamicLibrarySearchGenerator::Load(
        "@CMAKE_INSTALL_PREFIX@/lib/libqcor@CMAKE_SHARED_LIBRARY_SUFFIX@", DL.getGlobalPrefix())));
        "@CMAKE_INSTALL_PREFIX@/lib/libqcor@CMAKE_SHARED_LIBRARY_SUFFIX@",
        DL.getGlobalPrefix())));
    MainJD.addGenerator(cantFail(DynamicLibrarySearchGenerator::Load(
        "@XACC_ROOT@/lib/libCppMicroServices@CMAKE_SHARED_LIBRARY_SUFFIX@", DL.getGlobalPrefix())));
        "@XACC_ROOT@/lib/libCppMicroServices@CMAKE_SHARED_LIBRARY_SUFFIX@",
        DL.getGlobalPrefix())));

    return CompileLayer.add(MainJD, ThreadSafeModule(std::move(M), Ctx));
    auto rt = MainJD.getDefaultResourceTracker();
    return CompileLayer.add(rt, ThreadSafeModule(std::move(M), Ctx));
  }

  Expected<JITEvaluatedSymbol> lookup(StringRef Name) {
+1 −1
Original line number Diff line number Diff line
add_subdirectory(clang-wrapper)
add_subdirectory(qopt)
#add_subdirectory(qopt)
add_subdirectory(driver)
 No newline at end of file