Commit a9e5407b authored by Nguyen, Thien Minh's avatar Nguyen, Thien Minh
Browse files

Extract merge pass to respect the pseudo-scoping of the start/end region ops



Also, in the instruction handler, adding the start region op **before** doing any qubit operand processing so that any extract ops will be after the start op.

Robust handling of qubit vs. qreg in broadcast syntax.

Signed-off-by: default avatarThien Nguyen <nguyentm@ornl.gov>
parent 38b09dc3
Loading
Loading
Loading
Loading
+30 −0
Original line number Diff line number Diff line
@@ -570,6 +570,36 @@ QCOR_EXPECT_TRUE(c[3] == 1);
  EXPECT_EQ(countSubstring(llvm, "__quantum__qis__cphase"), 4);
}

TEST(qasm3PassManagerTester, checkModifierRegions) {
  {
    // The modified op must disconnect the SSA chain
    // i.e., these 3 Z gates are completely disconnected, should not be merged.
    const std::string src = R"#(OPENQASM 3;
include "qelib1.inc";

qubit control;
qubit target;

z target;
ctrl @ pow(5) @ z control, target;
z target;
)#";
    auto llvm =
        qcor::mlir_compile(src, "test_kernel", qcor::OutputType::LLVMIR, false);
    std::cout << "LLVM:\n" << llvm << "\n";

    // Get the main kernel section only (there is the oracle LLVM section as
    // well)
    llvm = llvm.substr(llvm.find("@__internal_mlir_test_kernel"));
    const auto last = llvm.find_first_of("}");
    llvm = llvm.substr(0, last + 1);
    std::cout << "LLVM:\n" << llvm << "\n";
    // These should be 3 Z gates (1 inside
    // __quantum__rt__start/__quantum__rt__end)
    EXPECT_EQ(countSubstring(llvm, "@__quantum__qis__z"), 3);
  }
}

int main(int argc, char **argv) {
  ::testing::InitGoogleTest(&argc, argv);
  auto ret = RUN_ALL_TESTS();
+73 −42
Original line number Diff line number Diff line
@@ -290,6 +290,43 @@ antlrcpp::Any qasm3_visitor::visitQuantumGateCall(
    }
  }

  // Handle modifier first so that qubit extraction (if any)
  // is within the scope of the modifier.
  bool has_ctrl = false;
  enum EndAction { EndCtrlU, EndAdjU, EndPowU };
  std::stack<std::pair<EndAction, mlir::Value>> action_and_extrainfo;
  for (auto m : modifiers) {
    if (m->getText().find("pow") != std::string::npos) {
      builder.create<mlir::quantum::StartPowURegion>(location);
      mlir::Value power;
      if (symbol_table.has_symbol(m->expression()->getText())) {
        power = symbol_table.get_symbol(m->expression()->getText());
        if (power.getType().isa<mlir::MemRefType>()) {
          power = builder.create<mlir::LoadOp>(location, power);
          if (power.getType().getIntOrFloatBitWidth() < 64) {
            power = builder.create<mlir::ZeroExtendIOp>(location, power,
                                                        builder.getI64Type());
          }
        }
      } else {
        auto p = symbol_table.evaluate_constant_integer_expression(
            m->expression()->getText());
        auto pow_attr = mlir::IntegerAttr::get(builder.getI64Type(), p);
        power = builder.create<mlir::ConstantOp>(location, pow_attr);
      }
      action_and_extrainfo.emplace(std::make_pair(EndAction::EndPowU, power));
    } else if (m->getText().find("inv") != std::string::npos) {
      builder.create<mlir::quantum::StartAdjointURegion>(location);
      action_and_extrainfo.emplace(
          std::make_pair(EndAction::EndAdjU, mlir::Value()));
    } else if (m->getText().find("ctrl") != std::string::npos) {
      has_ctrl = true;
      builder.create<mlir::quantum::StartCtrlURegion>(location);
      action_and_extrainfo.emplace(
          std::make_pair(EndAction::EndCtrlU, mlir::Value()));
    }
  }

  std::vector<std::string> qreg_names, qubit_symbol_table_keys;
  for (auto idx_identifier :
       context->indexIdentifierList()->indexIdentifier()) {
@@ -365,49 +402,29 @@ antlrcpp::Any qasm3_visitor::visitQuantumGateCall(
      qubit_symbol_table_keys.push_back(qubit_symbol_name);

    } else {
      // this is a qubit
      // This is a qubit or whole qubit array (broadcast)
      mlir::Value qbit = [&]() {
        mlir::Value v = symbol_table.get_symbol(qbit_var_name);
        // If this instruction is modified,
        // invalidate previous extracts (all qubits if qreg broadcast or single
        // extract from the internal register for single qubit vars)
        if (!modifiers.empty()) {
          if (v.getType() == array_type) {
            // This is a qreg
            symbol_table.invalidate_qubit_extracts(qbit_var_name);
            return symbol_table.get_symbol(qbit_var_name);
          } else {
            // This is a qubit
            symbol_table.erase_symbol(qbit_var_name);
      }
      auto qbit =
          get_or_extract_qubit(qbit_var_name, location, symbol_table, builder);
      qbit_values.push_back(qbit);
      qubit_symbol_table_keys.push_back(qbit_var_name);
            return get_or_extract_qubit(qbit_var_name, location, symbol_table,
                                        builder);
          }
        }
        return v;
      }();

  bool has_ctrl = false;
  enum EndAction { EndCtrlU, EndAdjU, EndPowU };
  std::stack<std::pair<EndAction, mlir::Value>> action_and_extrainfo;
  for (auto m : modifiers) {
    if (m->getText().find("pow") != std::string::npos) {
      builder.create<mlir::quantum::StartPowURegion>(location);
      mlir::Value power;
      if (symbol_table.has_symbol(m->expression()->getText())) {
        power = symbol_table.get_symbol(m->expression()->getText());
        if (power.getType().isa<mlir::MemRefType>()) {
          power = builder.create<mlir::LoadOp>(location, power);
          if (power.getType().getIntOrFloatBitWidth() < 64) {
            power = builder.create<mlir::ZeroExtendIOp>(location, power,
                                                        builder.getI64Type());
          }
        }
      } else {
        auto p = symbol_table.evaluate_constant_integer_expression(
            m->expression()->getText());
        auto pow_attr = mlir::IntegerAttr::get(builder.getI64Type(), p);
        power = builder.create<mlir::ConstantOp>(location, pow_attr);
      }
      action_and_extrainfo.emplace(std::make_pair(EndAction::EndPowU, power));
    } else if (m->getText().find("inv") != std::string::npos) {
      builder.create<mlir::quantum::StartAdjointURegion>(location);
      action_and_extrainfo.emplace(
          std::make_pair(EndAction::EndAdjU, mlir::Value()));
    } else if (m->getText().find("ctrl") != std::string::npos) {
      has_ctrl = true;
      builder.create<mlir::quantum::StartCtrlURegion>(location);
      action_and_extrainfo.emplace(
          std::make_pair(EndAction::EndCtrlU, mlir::Value()));
      qbit_values.push_back(qbit);
      qubit_symbol_table_keys.push_back(qbit_var_name);
    }
  }

@@ -529,7 +546,21 @@ antlrcpp::Any qasm3_visitor::visitQuantumGateCall(
  // block.
  if (!modifiers.empty()) {
    for (const auto &key : cached_qubit_symbol_table_keys) {
      mlir::Value qubit_or_qreg = symbol_table.get_symbol(key);
      if (qubit_or_qreg.getType() == array_type) {
        // For a broadcast inside a modified region,
        // we disconnect **ALL** QVS use-def chain involved any of its element
        // qubits.
        symbol_table.invalidate_qubit_extracts(key);
      } else {
        symbol_table.erase_symbol(key);
        if (key.find("%") == std::string::npos) {
          // For single qubit reg, graciously add the reextract from the
          // internal size-1 register back to the symbol table for later use
          // after this modified region.
          get_or_extract_qubit(key, location, symbol_table, builder);
        }
      }
    }
  }
  return 0;
+47 −2
Original line number Diff line number Diff line
@@ -56,9 +56,54 @@ void SimplifyQubitExtractPass::runOnOperation() {
            // std::cout << "First use\n";
            previous_qreg_extract[index_const] = op.qbit();
          } else {
            // Check if the Start/End modified regions are balanced:
            // Notes: currently, these pseudo regions are inside
            // the same block as other ops.
            // i.e., essentially just a linear sequence of Ops.
            // We count the open/close ops to determine if 
            // some extract ops are **within** these start/end regions.
            // These extract ops are **NOT** mergeable since 
            // the users of them are modified QVS ops (handled by runtime).
            auto &ops_in_blocks = op->getBlock()->getOperations();
            // We assume these braces are balanced, just to a simple open/close
            // count.
            int modifier_scope_braces_count = 0;
            for (auto iter = ops_in_blocks.rbegin();
                 iter != ops_in_blocks.rend(); ++iter) {
              auto &iter_op = *iter;
              if (mlir::dyn_cast_or_null<mlir::quantum::StartCtrlURegion>(
                      &iter_op) ||
                  mlir::dyn_cast_or_null<mlir::quantum::StartAdjointURegion>(
                      &iter_op) ||
                  mlir::dyn_cast_or_null<mlir::quantum::StartPowURegion>(
                      &iter_op)) {
                if (iter_op.isBeforeInBlock(op)) {
                  modifier_scope_braces_count++;
                }
              }
              if (mlir::dyn_cast_or_null<mlir::quantum::EndCtrlURegion>(
                      &iter_op) ||
                  mlir::dyn_cast_or_null<mlir::quantum::EndAdjointURegion>(
                      &iter_op) ||
                  mlir::dyn_cast_or_null<mlir::quantum::EndPowURegion>(
                      &iter_op)) {
                if (iter_op.isBeforeInBlock(op)) {
                  modifier_scope_braces_count--;
                }
              }
            }
            if (modifier_scope_braces_count > 0) {
              // This extract is inside a modified region.
              // Don't merge but remove the tracking for these qubits in this region
              // following extract ops will start new use-def chains:
              previous_qreg_extract.erase(index_const);
            } else {
              // Not inside any modified regions
              // Merge to the prior extract to extend the use-def chain.
              mlir::Value previous_extract = previous_qreg_extract[index_const];
              op.qbit().replaceAllUsesWith(previous_extract);
            }
          }

          // Erase the extract cache in the parent scope as well:
          // i.e., when the child scope (e.g., if block) is accessing this