Commit 730ba97c authored by Mccaskey, Alex's avatar Mccaskey, Alex
Browse files

setup to expand controlled gates from stdlib

parent bb1c6542
Loading
Loading
Loading
Loading
+130 −37
Original line number Diff line number Diff line
@@ -6,13 +6,50 @@

namespace qcor {

std::set<std::string_view> default_inline_overrides{
    "x",  "y",  "z",  "h",  "s",  "sdg",  "t", "tdg",
    "rx", "ry", "rz", "cz", "cy", "swap", "cx"};

static std::vector<std::string> builtins{
    "u3", "u2",   "u1",  "cx",  "id",  "u0",  "x",   "y",  "z",
    "h",  "s",    "sdg", "t",   "tdg", "rx",  "ry",  "rz", "cz",
    "cy", "swap", "ch",  "ccx", "crz", "cu1", "cu2", "cu3"};

static std::vector<std::string> search_for_inliner{
    "u3", "u2",   "u1",  "cx",  "id",  "u0",  "x",   "y",  "z",
    "h",  "s",    "sdg", "t",   "tdg", "rx",  "ry",  "rz", "cz",
    "cy", "swap"};

void CountGateDecls::visit(GateDecl &g) { 
  auto name = g.id();
  if (std::find(builtins.begin(), builtins.end(), name) == builtins.end()) {
    gates_to_inline.push_back(name);
  }
  count++; 
}

void OpenQasmMLIRGenerator::visit(Program &prog) {
  prog.foreach_stmt([this](auto &stmt) { stmt.accept(*this); });
  // How many statements are there (starts with 25)
  auto n_stmts = prog.body().size();
  // How many gatedecls are there?
  std::size_t n_gate_decls = 0;
  CountGateDecls count_gate_decls(n_gate_decls);
  prog.foreach_stmt([&](auto &stmt) { stmt.accept(count_gate_decls); });

  for(auto& g : count_gate_decls.gates_to_inline) {
    default_inline_overrides.insert(g);
  }

void OpenQasmMLIRGenerator::initialize_mlirgen() {
  // INLINE any complex controlled gates from stdlib
  staq::transformations::Inliner::config config;
  config.overrides = default_inline_overrides;
  staq::transformations::inline_ast(prog);

  // If n_stmts > n_gate_decls, then we need a main function
  add_main = (n_stmts > n_gate_decls);
  m_module = mlir::ModuleOp::create(builder.getUnknownLoc());

  // Useful opaque type defs
  llvm::StringRef qubit_type_name("Qubit"), array_type_name("Array"),
      result_type_name("Result");
  mlir::Identifier dialect = mlir::Identifier::get("quantum", &context);
@@ -23,11 +60,12 @@ void OpenQasmMLIRGenerator::initialize_mlirgen() {
  auto argv_type =
      mlir::OpaqueType::get(dialect, llvm::StringRef("ArgvType"), &context);

  if (add_main) {
    std::vector<mlir::Type> arg_types_vec{int_type, argv_type};
  // llvm::SmallVector<mlir::Type, 4> arg_types;
    auto func_type =
        builder.getFunctionType(llvm::makeArrayRef(arg_types_vec), int_type);
  auto proto = mlir::FuncOp::create(builder.getUnknownLoc(), "main", func_type);
    auto proto =
        mlir::FuncOp::create(builder.getUnknownLoc(), "main", func_type);
    mlir::FuncOp function(proto);
    main_entry_block = function.addEntryBlock();
    auto &entryBlock = *main_entry_block;
@@ -35,6 +73,10 @@ void OpenQasmMLIRGenerator::initialize_mlirgen() {
    m_module.push_back(function);
    function_names.push_back("main");
  }
  prog.foreach_stmt([this](auto &stmt) { stmt.accept(*this); });
}

void OpenQasmMLIRGenerator::initialize_mlirgen() {}

void OpenQasmMLIRGenerator::mlirgen(const std::string &src) {
  using namespace staq;
@@ -48,7 +90,59 @@ void OpenQasmMLIRGenerator::mlirgen(const std::string &src) {
    std::cout << e.what() << "\n";
  }

  // Replace standard controlled gates with expanded versions
  // First get mapping of gate name to composite gates
  // auto tmp_prog = parser::parse_string(R"(OPENQASM 2.0;
  // include "qelib1.inc";
  // )");

  // class CollectGateDecomps : public Traverse {
  //  public:
  //   std::map<std::string, std::list<std::unique_ptr<Gate>>> gate_decomps;

  //       void
  //       visit(GateDecl &gate) override {
  //     if (gate.id() == "ccx") {
  //       std::cout << "Found CCX\n";
  //       gate.body();
  //     }
  //     gate_decomps.insert({gate.id(), gate.body()});
  //   }
  // };
  // CollectGateDecomps collect;
  // tmp_prog->foreach_stmt([&](auto &stmt) { stmt.accept(collect); });

  // class BuildReplacerMap : public Traverse {
  //  protected:
  //   std::map<std::string, std::list<std::unique_ptr<Gate>>> &gate_decomps;

  //  public:
  //   std::unordered_map<int, std::list<std::unique_ptr<Gate>>> replacer_map;
  //   BuildReplacerMap(
  //       std::map<std::string, std::list<std::unique_ptr<Gate>>> &gd)
  //       : gate_decomps(gd) {}

  //   void visit(DeclaredGate &gate) {
  //     auto name = gate.name();

  //     if (name == "ccx") {
  //       std::cout << "adding to replacer map for ccx\n";
  //       auto uid = gate.uid();
  //       // auto gates = ;
  //       if (!replacer_map.count(uid)) {
  //       replacer_map.insert({uid, std::move(gate_decomps[name])});
  //       }
  //     }
  //   }
  // };

  // BuildReplacerMap replacer_builder(collect.gate_decomps);
  // prog->foreach_stmt([&](auto &stmt) { stmt.accept(replacer_builder); });

  // First, get uid of declared gate to replace

  visit(*prog);

  return;
}

@@ -59,12 +153,12 @@ void OpenQasmMLIRGenerator::finalize_mlirgen() {
                                             qalloc_op);
  }

  if (add_main) {
    builder.create<mlir::quantum::QRTFinalizeOp>(builder.getUnknownLoc());

    auto integer_attr = mlir::IntegerAttr::get(builder.getI32Type(), 0);
    mlir::Value ret_zero =
      builder.create<mlir::ConstantOp>(builder.getUnknownLoc(),
      integer_attr);
        builder.create<mlir::ConstantOp>(builder.getUnknownLoc(), integer_attr);

    builder.create<mlir::ReturnOp>(builder.getUnknownLoc(), ret_zero);

@@ -72,24 +166,23 @@ void OpenQasmMLIRGenerator::finalize_mlirgen() {
                                     function_names.end());

    auto function_names_datatype = mlir::VectorType::get(
      {static_cast<std::int64_t>(function_names.size())}, builder.getI64Type());
        {static_cast<std::int64_t>(function_names.size())},
        builder.getI64Type());
    auto function_names_ref = llvm::makeArrayRef(tmp);
    auto attrs = mlir::DenseStringElementsAttr::get(function_names_datatype,
                                                    function_names_ref);

  mlir::Identifier id =
      mlir::Identifier::get("quantum.internal_functions", builder.getContext());
    mlir::Identifier id = mlir::Identifier::get("quantum.internal_functions",
                                                builder.getContext());

    m_module.setAttrs(
        llvm::makeArrayRef({mlir::NamedAttribute(std::make_pair(id, attrs))}));
  }
}

void OpenQasmMLIRGenerator::visit(GateDecl &gate_function) {
  auto name = gate_function.id();
  static std::vector<std::string> builtins{
      "u3", "u2",   "u1",  "cx",  "id",  "u0",  "x",   "y",  "z",
      "h",  "s",    "sdg", "t",   "tdg", "rx",  "ry",  "rz", "cz",
      "cy", "swap", "ch",  "ccx", "crz", "cu1", "cu2", "cu3"};

  if (std::find(builtins.begin(), builtins.end(), name) == builtins.end()) {
    std::vector<mlir::Type> arg_types;

@@ -122,7 +215,7 @@ void OpenQasmMLIRGenerator::visit(GateDecl &gate_function) {
    builder.create<mlir::ReturnOp>(builder.getUnknownLoc());
    m_module.push_back(function);

    builder.setInsertionPointToStart(main_entry_block);
    if (add_main) builder.setInsertionPointToStart(main_entry_block);
  }
}

+29 −0
Original line number Diff line number Diff line
@@ -27,6 +27,7 @@ class OpenQasmMLIRGenerator : public qcor::QuantumMLIRGenerator,
  std::map<std::string, mlir::Value> temporary_sub_kernel_args;
  std::vector<std::string> function_names;
  bool is_first_inst = true;
  bool add_main = true;

  mlir::Type qubit_type;
  mlir::Type array_type;
@@ -61,4 +62,32 @@ class OpenQasmMLIRGenerator : public qcor::QuantumMLIRGenerator,
  void visit(DeclaredGate &g) override;
  void addReturn();
};
class CountGateDecls : public staq::ast::Visitor {
private: 
  std::size_t& count;
public:
  std::vector<std::string> gates_to_inline;

  CountGateDecls(std::size_t& c) :count(c){}
  void visit(VarAccess &) override {}
  void visit(BExpr &) override {}
  void visit(UExpr &) override {}
  void visit(PiExpr &) override {}
  void visit(IntExpr &) override {}
  void visit(RealExpr &r) override {}
  void visit(VarExpr &v) override {}
  void visit(ResetStmt &) override {}
  void visit(IfStmt &) override {}
  void visit(BarrierGate &) override {}
  void visit(GateDecl &g) override;
  void visit(OracleDecl &) override {}
  void visit(RegisterDecl &) override{}
  void visit(AncillaDecl &) override {}
  void visit(Program &prog) override{}
  void visit(MeasureStmt &m) override{}
  void visit(UGate &u) override{}
  void visit(CNOTGate &cx) override{}
  void visit(DeclaredGate &g) override{}
  
};
}  // namespace qcor
 No newline at end of file
+103 −49
Original line number Diff line number Diff line
@@ -13,18 +13,31 @@ unsigned long allocated_qbits = 0;
std::shared_ptr<xacc::AcceleratorBuffer> qbits;
std::shared_ptr<xacc::Accelerator> qpu;
std::string qpu_name = "qpp";

enum QRT_MODE { FTQC, NISQ };
QRT_MODE mode;
std::vector<std::unique_ptr<Array>> allocated_arrays;
int shots = 1024;
bool verbose = false;

bool initialized = false;
void __quantum__rt__initialize(int argc, int8_t** argv) {
  
  char** casted = reinterpret_cast<char**>(argv);
  std::vector<std::string> args(casted, casted + argc);

  mode = QRT_MODE::FTQC;
  for (auto [i, arg] : qcor::enumerate(args)) {
    if (arg == "-qpu") {
      qpu_name = args[i + 1];
    } else if (arg == "-qrt") {
      mode = args[i + 1] == "nisq" ? QRT_MODE::NISQ : QRT_MODE::FTQC;
    } else if (arg == "-shots") {
      shots = std::stoi(args[i+1]);
    } else if (arg == "-v") {
      verbose = true;
    } else if (arg == "-verbose") {
      verbose = true;
    } else if (arg == "--verbose") {
      verbose = true;
    }
  }

@@ -33,12 +46,20 @@ void __quantum__rt__initialize(int argc, int8_t** argv) {

void initialize() {
  if (!initialized) {
    printf("[qir-qrt] Initializing FTQC runtime...\n");
    if(verbose) printf("[qir-qrt] Initializing FTQC runtime...\n");
    // qcor::set_verbose(true);
    xacc::internal_compiler::__qrt_env = "ftqc";
    xacc::Initialize();
    std::cout << "[qir-qrt] Running on " << qpu_name << " backend.\n";
    auto qpu = xacc::getAccelerator(qpu_name);
    if (verbose) std::cout << "[qir-qrt] Running on " << qpu_name << " backend.\n";
    std::shared_ptr<xacc::Accelerator> qpu;

    if (mode == QRT_MODE::NISQ) {
      xacc::internal_compiler::__qrt_env = "nisq";
      qpu = xacc::getAccelerator(qpu_name, {{"shots", shots}});
    } else {
      qpu = xacc::getAccelerator(qpu_name);
    }

    xacc::internal_compiler::qpu = qpu;
    ::quantum::qrt_impl = xacc::getService<::quantum::QuantumRuntime>(
        xacc::internal_compiler::__qrt_env);
@@ -50,48 +71,74 @@ void initialize() {
void __quantum__qis__cnot(Qubit* src, Qubit* tgt) {
  std::size_t src_copy = reinterpret_cast<std::size_t>(src);
  std::size_t tgt_copy = reinterpret_cast<std::size_t>(tgt);
  printf("[qir-qrt] Applying CX %lu, %lu\n", src_copy, tgt_copy);
  if(verbose) printf("[qir-qrt] Applying CX %lu, %lu\n", src_copy, tgt_copy);
  ::quantum::cnot({"q", src_copy}, {"q", tgt_copy});
}

void __quantum__qis__h(Qubit* q) {
  std::size_t qcopy = reinterpret_cast<std::size_t>(q);
  printf("[qir-qrt] Applying H %lu\n", qcopy);
  if(verbose) printf("[qir-qrt] Applying H %lu\n", qcopy);
  ::quantum::h({"q", qcopy});
}

// void __quantum__qis__s(Qubit* q) {
//   initialize();
//   printf("[qir-qrt] Applying S %lu\n", q);
//   ::quantum::s({"q", q});
// }

// void __quantum__qis__x(Qubit* q) {
//   initialize();
//   printf("[qir-qrt] Applying X %lu\n", q);
//   ::quantum::x({"q", q});
// }
// void __quantum__qis__z(Qubit* q) {
//   initialize();
//   printf("[qir-qrt] Applying Z %lu\n", q);
//   ::quantum::z({"q", q});
// }

// void __quantum__qis__rx(double x, Qubit* q) {
//   initialize();
//   printf("[qir-qrt] Applying Rx(%f) %lu\n", x, q);
//   ::quantum::rx({"q", q}, x);
// }

// void __quantum__qis__rz(double x, Qubit* q) {
//   initialize();
//   printf("[qir-qrt] Applying Rz(%f) %lu\n", x, q);
//   ::quantum::rz({"q", q}, x);
// }
void __quantum__qis__s(Qubit* q) {
  std::size_t qcopy = reinterpret_cast<std::size_t>(q);
  if(verbose) printf("[qir-qrt] Applying S %lu\n", qcopy);
  ::quantum::s({"q", qcopy});
}

void __quantum__qis__sdg(Qubit* q) {
  std::size_t qcopy = reinterpret_cast<std::size_t>(q);
  if(verbose) printf("[qir-qrt] Applying Sdg %lu\n", qcopy);
  ::quantum::sdg({"q", qcopy});
}
void __quantum__qis__t(Qubit* q) {
  std::size_t qcopy = reinterpret_cast<std::size_t>(q);
  if(verbose) printf("[qir-qrt] Applying T %lu\n", qcopy);
  ::quantum::t({"q", qcopy});
}
void __quantum__qis__tdg(Qubit* q) {
  std::size_t qcopy = reinterpret_cast<std::size_t>(q);
  if(verbose) printf("[qir-qrt] Applying Tdg %lu\n", qcopy);
  ::quantum::tdg({"q", qcopy});
}

void __quantum__qis__x(Qubit* q) {
  std::size_t qcopy = reinterpret_cast<std::size_t>(q);
  if(verbose) printf("[qir-qrt] Applying X %lu\n", qcopy);
  ::quantum::x({"q", qcopy});
}
void __quantum__qis__y(Qubit* q) {
  std::size_t qcopy = reinterpret_cast<std::size_t>(q);
  if(verbose) printf("[qir-qrt] Applying Y %lu\n", qcopy);
  ::quantum::y({"q", qcopy});
}
void __quantum__qis__z(Qubit* q) {
  std::size_t qcopy = reinterpret_cast<std::size_t>(q);
  if(verbose) printf("[qir-qrt] Applying Z %lu\n", qcopy);
  ::quantum::z({"q", qcopy});
}

void __quantum__qis__rx(double x, Qubit* q) {
  std::size_t qcopy = reinterpret_cast<std::size_t>(q);
  if(verbose) printf("[qir-qrt] Applying Rx(%f) %lu\n", x, qcopy);
  ::quantum::rx({"q", qcopy}, x);
}

void __quantum__qis__ry(double x, Qubit* q) {
  std::size_t qcopy = reinterpret_cast<std::size_t>(q);
  if(verbose) printf("[qir-qrt] Applying Ry(%f) %lu\n", x, qcopy);
  ::quantum::ry({"q", qcopy}, x);
}

void __quantum__qis__rz(double x, Qubit* q) {
  std::size_t qcopy = reinterpret_cast<std::size_t>(q);
  if(verbose) printf("[qir-qrt] Applying Rz(%f) %lu\n", x, qcopy);
  ::quantum::rz({"q", qcopy}, x);
}

Result* __quantum__qis__mz(Qubit* q) {
  initialize();
  printf("[qir-qrt] Measuring qubit %lu\n", reinterpret_cast<std::size_t>(q));
  if(verbose) printf("[qir-qrt] Measuring qubit %lu\n", reinterpret_cast<std::size_t>(q));
  std::size_t qcopy = reinterpret_cast<std::size_t>(q);

  if (!qbits) {
@@ -100,13 +147,12 @@ Result* __quantum__qis__mz(Qubit* q) {

  ::quantum::set_current_buffer(qbits.get());
  auto bit = ::quantum::mz({"q", qcopy});
  printf("[qir-qrt] Result was %d.\n", bit);
  if (mode == QRT_MODE::FTQC) if(verbose) printf("[qir-qrt] Result was %d.\n", bit);
  return bit ? &ResultOne : &ResultZero;
}

Array* __quantum__rt__qubit_allocate_array(uint64_t size) {
  initialize();
  printf("[qir-qrt] Allocating qubit array of size %lu.\n", size);
  if(verbose) printf("[qir-qrt] Allocating qubit array of size %lu.\n", size);

  auto new_array = std::make_unique<Array>(size);
  for (uint64_t i = 0; i < size; i++) {
@@ -117,7 +163,7 @@ Array* __quantum__rt__qubit_allocate_array(uint64_t size) {

  allocated_qbits = size;
  if (!qbits) {
    qbits = std::make_shared<xacc::AcceleratorBuffer>(allocated_qbits);
    qbits = std::make_shared<xacc::AcceleratorBuffer>(size);
    ::quantum::set_current_buffer(qbits.get());
  }

@@ -131,8 +177,7 @@ int8_t* __quantum__rt__array_get_element_ptr_1d(Array* q, uint64_t idx) {
  int8_t* ptr = arr[idx];
  Qubit* qq = reinterpret_cast<Qubit*>(ptr);

  printf("[qir-qrt] Returning qubit array element %lu, idx=%lu.\n",
         *qq, idx);
  if(verbose) printf("[qir-qrt] Returning qubit array element %lu, idx=%lu.\n", *qq, idx);
  return ptr;
}

@@ -141,7 +186,8 @@ void __quantum__rt__qubit_release_array(Array* q) {
    if (allocated_arrays[i].get() == q) {
      auto& array_ptr = allocated_arrays[i];
      auto array_size = array_ptr->size();
      printf("[qir-qrt] deallocating the qubit array of size %lu\n", array_size);
      if(verbose) printf("[qir-qrt] deallocating the qubit array of size %lu\n",
             array_size);
      for (int k = 0; k < array_size; k++) {
        delete (*array_ptr)[k];
      }
@@ -151,5 +197,13 @@ void __quantum__rt__qubit_release_array(Array* q) {
}

void __quantum__rt__finalize() {
  std::cout << "[qir-qrt] Running finalization routine.\n";
  if (verbose) std::cout << "[qir-qrt] Running finalization routine.\n";
  if (mode == QRT_MODE::NISQ) {
    ::quantum::submit(qbits.get());
    auto counts = qbits->getMeasurementCounts();
    std::cout << "Observed Counts:\n";
    for (auto [bits, count] : counts) {
      qcor::print(bits, ":", count);
    }
  }
}
 No newline at end of file
+12 −6
Original line number Diff line number Diff line
@@ -30,12 +30,18 @@ void __quantum__rt__finalize();

void __quantum__qis__cnot(Qubit* src, Qubit* tgt);
void __quantum__qis__h(Qubit* q);
// void __quantum__qis__s(Qubit* q);
// void __quantum__qis__x(Qubit* q);
// void __quantum__qis__z(Qubit* q);

// void __quantum__qis__rx(double x, Qubit* q);
// void __quantum__qis__rz(double x, Qubit* q);
void __quantum__qis__s(Qubit* q);
void __quantum__qis__sdg(Qubit* q);
void __quantum__qis__t(Qubit* q);
void __quantum__qis__tdg(Qubit* q);

void __quantum__qis__x(Qubit* q);
void __quantum__qis__y(Qubit* q);
void __quantum__qis__z(Qubit* q);

void __quantum__qis__rx(double x, Qubit* q);
void __quantum__qis__ry(double x, Qubit* q);
void __quantum__qis__rz(double x, Qubit* q);

Result* __quantum__qis__mz(Qubit* q);

+7 −7
Original line number Diff line number Diff line
add_llvm_executable(qcor-qasm qcor-qasm.cpp)
add_llvm_executable(qcor-mlir-tool qcor-mlir-tool.cpp)

target_compile_options(qcor-qasm PUBLIC "-fexceptions")
target_compile_options(qcor-mlir-tool PUBLIC "-fexceptions")

target_compile_features(qcor-qasm
target_compile_features(qcor-mlir-tool
                        PUBLIC
                        cxx_std_17)
llvm_update_compile_flags(qcor-qasm)
target_link_libraries(qcor-qasm PUBLIC quantum-to-llvm-lowering openqasm-mlir-generator )
llvm_update_compile_flags(qcor-mlir-tool)
target_link_libraries(qcor-mlir-tool PUBLIC quantum-to-llvm-lowering openqasm-mlir-generator )

set_target_properties(qcor-qasm
set_target_properties(qcor-mlir-tool
                        PROPERTIES INSTALL_RPATH "${MLIR_INSTALL_DIR}/lib:${CMAKE_BINARY_DIR}/mlir/parsers/openqasm:${CMAKE_BINARY_DIR}/lib")
install(PROGRAMS ${CMAKE_BINARY_DIR}/bin/qcor-qasm DESTINATION bin)
install(PROGRAMS ${CMAKE_BINARY_DIR}/bin/qcor-mlir-tool DESTINATION bin)
Loading