Commit 196bb3ed authored by Mccaskey, Alex's avatar Mccaskey, Alex
Browse files

adding first pass at kernel support for qasm3 compiler

parent a361900f
Loading
Loading
Loading
Loading
+36 −25
Original line number Diff line number Diff line
@@ -13,6 +13,9 @@
using namespace clang;

namespace qcor {
namespace __internal__developer__flags__ {
bool add_predefines = true;
}

bool qrt = false;
std::string qpu_name = "qpp";
@@ -331,9 +334,10 @@ void QCORSyntaxHandler::GetReplacement(
  OS << "}\n";

  if (add_het_map_ctor) {
    // Remove "&" from type string before getting the Python variables in the HetMap.
    // Note: HetMap can't store references.
    const auto remove_ref_arg_type = [](const std::string &org_arg_type) -> std::string {
    // Remove "&" from type string before getting the Python variables in the
    // HetMap. Note: HetMap can't store references.
    const auto remove_ref_arg_type =
        [](const std::string &org_arg_type) -> std::string {
      // We intentially only support a very limited set of pass-by-ref types
      // from the HetMap.
      // Only do: double& and int&
@@ -369,19 +373,23 @@ void QCORSyntaxHandler::GetReplacement(
      // If this is a *supported* ref types: double&, int&, etc.
      if (remove_ref_arg_type(program_arg_types[i]) != program_arg_types[i]) {
        // Generate a temp var
        const std::string new_var_name = "__temp_var__" + std::to_string(var_counter++);
        const std::string new_var_name =
            "__temp_var__" + std::to_string(var_counter++);
        // Copy the var from HetMap to the temp var
        ref_type_copy_decl_ss << remove_ref_arg_type(program_arg_types[i]) << " "<< new_var_name << " = " << "args.get<" << remove_ref_arg_type(program_arg_types[i]) << ">(\""
         << program_parameters[i] << "\");\n";
        ref_type_copy_decl_ss << remove_ref_arg_type(program_arg_types[i])
                              << " " << new_var_name << " = "
                              << "args.get<"
                              << remove_ref_arg_type(program_arg_types[i])
                              << ">(\"" << program_parameters[i] << "\");\n";

        // We just pass this copied var to the ctor
        // where it expects a reference type.
        arg_ctor_list.emplace_back(new_var_name);
      }
      else {
      } else {
        // Otherwise, just unpack the arg inline in the ctor call.
        std::stringstream ss;
        ss << "args.get<" << program_arg_types[i] << ">(\""<< program_parameters[i] << "\")";
        ss << "args.get<" << program_arg_types[i] << ">(\""
           << program_parameters[i] << "\")";
        arg_ctor_list.emplace_back(ss.str());
      }
    }
@@ -394,8 +402,8 @@ void QCORSyntaxHandler::GetReplacement(
    // CTor call
    OS << "class " << kernel_name << " __ker__temp__(";
    // First arg: qreg
    OS << "args.get<" << program_arg_types[0] << ">(\""
       << program_parameters[0] << "\")";
    OS << "args.get<" << program_arg_types[0] << ">(\"" << program_parameters[0]
       << "\")";
    // The rest: either inline unpacking or temp var names (ref type)
    for (const auto &arg_str : arg_ctor_list) {
      OS << ", " << arg_str;
@@ -404,14 +412,15 @@ void QCORSyntaxHandler::GetReplacement(
    OS << "}\n";

    OS << "void " << kernel_name
       << "__with_parent_and_hetmap_args(std::shared_ptr<CompositeInstruction> parent, "
       << "__with_parent_and_hetmap_args(std::shared_ptr<CompositeInstruction> "
          "parent, "
          "HeterogeneousMap& args) {\n";
    OS << ref_type_copy_decl_ss.str();
    // CTor call with parent kernel
    OS << "class " << kernel_name << " __ker__temp__(parent, ";
    // Second arg: qreg
    OS << "args.get<" << program_arg_types[0] << ">(\""
       << program_parameters[0] << "\")";
    OS << "args.get<" << program_arg_types[0] << ">(\"" << program_parameters[0]
       << "\")";
    // The rest: either inline unpacking or temp var names (ref type)
    for (const auto &arg_str : arg_ctor_list) {
      OS << ", " << arg_str;
@@ -425,10 +434,12 @@ void QCORSyntaxHandler::GetReplacement(
}

void QCORSyntaxHandler::AddToPredefines(llvm::raw_string_ostream &OS) {
  if (__internal__developer__flags__::add_predefines) {
    OS << "#include \"qcor.hpp\"\n";
    OS << "using namespace qcor;\n";
    OS << "using namespace xacc::internal_compiler;\n";
  }
}

class DoNothingConsumer : public ASTConsumer {
 public:
+6 −0
Original line number Diff line number Diff line
@@ -8,6 +8,12 @@ namespace qcor {
extern std::string qpu_name;
extern int shots;

// Add this for internal development, specifically JIT tests
// where I don't want AddPredefines to add qcor.hpp. For example
// where I want to compile a simple c++ code with no dependencies, 
// I don't want to include qcor.hpp bc it makes it much slower.
namespace __internal__developer__flags__ { extern bool add_predefines;}

class QCORSyntaxHandler : public SyntaxHandler {
public:
  QCORSyntaxHandler() : SyntaxHandler("qcor") {}
+1 −0
Original line number Diff line number Diff line
@@ -47,6 +47,7 @@ set(LIBS
        MLIRExecutionEngine
        MLIRStandard
        MLIRAffine
        LLVMLinker
        openqasm-mlir-generator
        openqasmv3-mlir-generator
        quantum-to-llvm-lowering
+0 −1
Original line number Diff line number Diff line
@@ -45,7 +45,6 @@ h counting;
// Loop over and create ctrl-U**2k
int repetitions = 1;
for i in [0:n_counting] {
    print("i is ", i, repetitions);
    ctrl @ pow(repetitions) @ oracle counting[i], state;
    repetitions *= 2;
}
+15 −11
Original line number Diff line number Diff line
@@ -29,8 +29,8 @@ void OpenQasmV3MLIRGenerator::initialize_mlirgen(bool _add_entry_point,

  if (add_main) {
    std::vector<mlir::Type> arg_types_vec2{};
    auto func_type2 =
        builder.getFunctionType(llvm::makeArrayRef(arg_types_vec2), builder.getI32Type());
    auto func_type2 = builder.getFunctionType(
        llvm::makeArrayRef(arg_types_vec2), builder.getI32Type());
    auto proto2 = mlir::FuncOp::create(
        builder.getUnknownLoc(), "__internal_mlir_" + file_name, func_type2);
    mlir::FuncOp function2(proto2);
@@ -52,20 +52,22 @@ void OpenQasmV3MLIRGenerator::initialize_mlirgen(bool _add_entry_point,
                                               main_args[0], main_args[1]);

      // call the function from main, run finalize, and return 0
      auto call_internal = builder.create<mlir::CallOp>(builder.getUnknownLoc(), function2);
      auto call_internal =
          builder.create<mlir::CallOp>(builder.getUnknownLoc(), function2);
      builder.create<mlir::quantum::QRTFinalizeOp>(builder.getUnknownLoc());

      // auto integer_attr = mlir::IntegerAttr::get(builder.getI32Type(), 0);
      // mlir::Value ret_zero = builder.create<mlir::ConstantOp>(
      //     builder.getUnknownLoc(), integer_attr);
      builder.create<mlir::ReturnOp>(builder.getUnknownLoc(), call_internal.getResult(0));
      builder.create<mlir::ReturnOp>(builder.getUnknownLoc(),
                                     call_internal.getResult(0));
      m_module.push_back(function);
      function_names.push_back("main");
    }

    std::vector<mlir::Type> arg_types_vec3{qreg_type};
    auto func_type3 =
        builder.getFunctionType(llvm::makeArrayRef(arg_types_vec3), builder.getI32Type());
    auto func_type3 = builder.getFunctionType(
        llvm::makeArrayRef(arg_types_vec3), builder.getI32Type());
    auto proto3 =
        mlir::FuncOp::create(builder.getUnknownLoc(), file_name, func_type3);
    mlir::FuncOp function3(proto3);
@@ -74,9 +76,11 @@ void OpenQasmV3MLIRGenerator::initialize_mlirgen(bool _add_entry_point,
    builder.setInsertionPointToStart(tmp);
    builder.create<mlir::quantum::SetQregOp>(builder.getUnknownLoc(),
                                             tmp->getArguments()[0]);
    auto call_internal = builder.create<mlir::CallOp>(builder.getUnknownLoc(), function2);
    auto call_internal =
        builder.create<mlir::CallOp>(builder.getUnknownLoc(), function2);
    builder.create<mlir::quantum::QRTFinalizeOp>(builder.getUnknownLoc());
    builder.create<mlir::ReturnOp>(builder.getUnknownLoc(),
    builder.create<mlir::ReturnOp>(
        builder.getUnknownLoc(),
        llvm::ArrayRef<mlir::Value>(call_internal.getResult(0)));
    builder.setInsertionPointToStart(save_main_entry_block);

@@ -152,14 +156,14 @@ void OpenQasmV3MLIRGenerator::finalize_mlirgen() {

  if (add_main) {
    if (auto b = scoped_symbol_table.get_last_created_block()) {

      builder.setInsertionPointToEnd(b);
    } else {
      builder.setInsertionPointToEnd(main_entry_block);
    }

    auto integer_attr = mlir::IntegerAttr::get(builder.getI32Type(), 0);
    auto ret = builder.create<mlir::ConstantOp>(builder.getUnknownLoc(), integer_attr);
    auto ret =
        builder.create<mlir::ConstantOp>(builder.getUnknownLoc(), integer_attr);
    builder.create<mlir::ReturnOp>(builder.getUnknownLoc(),
                                   llvm::ArrayRef<mlir::Value>(ret));
  }
Loading