Commit d62b8e9b authored by Nguyen, Thien Minh's avatar Nguyen, Thien Minh
Browse files

Control if-else codegen based on target information



Only enable when targeting nisq runtime for a set of QPU

Signed-off-by: default avatarThien Nguyen <nguyentm@ornl.gov>
parent b6c4994f
Loading
Loading
Loading
Loading
+24 −1
Original line number Diff line number Diff line
@@ -51,6 +51,28 @@ void OpenQasmV3MLIRGenerator::initialize_mlirgen(
  file_name = function;
  add_entry_point = _add_entry_point;

  // Only enable the rewrite to NISQ If-statements when the compilation 
  // targets NISQ qrt for some specific QPUs:
  static const std::vector<std::string> IF_STMT_CAPABLE_QPUS{"qpp", "aer",
                                                             "honeywell"};
  if (extra_quantum_args.find("qrt") != extra_quantum_args.end() &&
      extra_quantum_args["qrt"] == "nisq") {
    // Default is qpp (i.e., not provided)
    if (extra_quantum_args.find("qpu") == extra_quantum_args.end()) {
      enable_qir_apply_ifelse = true;
    } else {
      for (const auto &name_to_check : IF_STMT_CAPABLE_QPUS) {
        const auto qpu_name = extra_quantum_args["qpu"];
        if (name_to_check.rfind(qpu_name, 0) == 0) {
          // QPU start with aer, honeywell, etc.
          // (it could have backend name customization after ':')
          enable_qir_apply_ifelse = true;
          break;
        }
      }
    }
  }

  // Useful opaque type defs
  llvm::StringRef qubit_type_name("Qubit"), array_type_name("Array"),
      result_type_name("Result");
@@ -146,7 +168,8 @@ void OpenQasmV3MLIRGenerator::mlirgen(const std::string &src) {
  using namespace qasm3;

  if (!visitor) {
    visitor = std::make_shared<qasm3_visitor>(builder, m_module, file_name);
    visitor = std::make_shared<qasm3_visitor>(builder, m_module, file_name,
                                              enable_qir_apply_ifelse);
  }

  ANTLRInputStream input(src);
+3 −1
Original line number Diff line number Diff line
@@ -11,7 +11,9 @@ class OpenQasmV3MLIRGenerator : public qcor::QuantumMLIRGenerator {
  std::string file_name = "main";
  bool add_entry_point = true;
  bool add_custom_return = false;

  // Enable special code-gen mode for specific targets that
  // support NISQ-like conditional statements.
  bool enable_qir_apply_ifelse = false;
  mlir::Type return_type;

  mlir::Type qubit_type;
+5 −3
Original line number Diff line number Diff line
@@ -30,8 +30,10 @@ class qasm3_visitor : public qasm3::qasm3BaseVisitor {
  ScopedSymbolTable* getScopedSymbolTable() { return &symbol_table; }

  // The constructor, instantiates commonly used opaque types
  qasm3_visitor(mlir::OpBuilder b, mlir::ModuleOp m, std::string& fname)
      : builder(b), file_name(fname), m_module(m) {
  qasm3_visitor(mlir::OpBuilder b, mlir::ModuleOp m, std::string &fname,
                bool enable_nisq_conditional = false)
      : builder(b), file_name(fname), m_module(m),
        enable_nisq_ifelse(enable_nisq_conditional) {
    auto context = b.getContext();
    llvm::StringRef qubit_type_name("Qubit"), array_type_name("Array"),
        result_type_name("Result");
@@ -189,7 +191,7 @@ class qasm3_visitor : public qasm3::qasm3BaseVisitor {
  mlir::OpBuilder builder;
  mlir::ModuleOp m_module;
  std::string file_name = "";

  bool enable_nisq_ifelse = false;  
  // We keep reference to these blocks so that
  // we can handle break/continue correctly
  mlir::Block* current_loop_exit_block;
+69 −66
Original line number Diff line number Diff line
@@ -126,7 +126,9 @@ antlrcpp::Any qasm3_visitor::visitBranchingStatement(

  // Get the conditional expression
  auto conditional_expr = context->booleanExpression();

  // Only consider this codegen strategy if requested (for specific qrt/qpu
  // target)
  if (enable_nisq_ifelse) {
    auto bit_check_conditional =
        tryParseSimpleBooleanExpression(*conditional_expr);
    // Currently, we're only support If (not else yet)
@@ -167,8 +169,8 @@ antlrcpp::Any qasm3_visitor::visitBranchingStatement(
      const std::string tmp_func_name =
          "if_body_" + toString(context->programBlock(0));
      auto func_type = builder.getFunctionType(argument_types, llvm::None);
    auto proto =
        mlir::FuncOp::create(builder.getUnknownLoc(), tmp_func_name, func_type);
      auto proto = mlir::FuncOp::create(builder.getUnknownLoc(), tmp_func_name,
                                        func_type);
      mlir::FuncOp function(proto);
      function.setVisibility(mlir::SymbolTable::Visibility::Private);
      auto &entryBlock = *function.addEntryBlock();
@@ -197,6 +199,7 @@ antlrcpp::Any qasm3_visitor::visitBranchingStatement(
      // Done
      return 0;
    }
  }

  // TODO: The below code could be rewritten to an AffineIfOp/SCF::IfOp:
  // Map it to a Value
+11 −7
Original line number Diff line number Diff line
@@ -232,8 +232,10 @@ antlrcpp::Any qasm3_visitor::visitQuantumMeasurementAssignment(

      if (bit_value.getType().isa<mlir::MemRefType>() &&
          bit_value.getType().cast<mlir::MemRefType>().getShape().empty()) {
        if (enable_nisq_ifelse) {
          // Track the Result* associated with the bit in the Symbol Table
          symbol_table.add_measure_bit_assignment(bit_value, instop.bit());
        }
        // If the array is a **zero-dimemsion** Memref *without* shape
        // we don't send on the index (probably v = 0).
        // This will fail to validate at the MLIR level (Memref dimension mismatches)
@@ -241,11 +243,13 @@ antlrcpp::Any qasm3_visitor::visitQuantumMeasurementAssignment(
        builder.create<mlir::StoreOp>(location, cast_bit_op.bit_result(),
                                      bit_value);
      } else {
        if (enable_nisq_ifelse) {
          if (!symbol_table.has_symbol(indexIdentifierList->getText())) {
            // Added a measure Result* tracking to the bit array element:
            // e.g. track var name 'c[1]' -> Result*
            symbol_table.add_symbol(indexIdentifierList->getText(),
                                    cast_bit_op.bit_result());
          }
          symbol_table.add_measure_bit_assignment(cast_bit_op.bit_result(),
                                                  instop.bit());
        }