Commit 870ff220 authored by Nguyen, Thien Minh's avatar Nguyen, Thien Minh
Browse files

robust qubit extract and ssa fixing for loop unroll



The Affine loop unroll will result in:

- Duplicate q.extract in the main loop -> disconnect SSA chain and hence no optimization possible.

- Incorrect SSA chaining if the extract is outside the loop, e.g. qubit type (not qreg)

Hence, we need to fix up all of those issues.

Add test for optimizing trotter loop

Signed-off-by: default avatarThien Nguyen <nguyentm@ornl.gov>
parent 8b4841c8
Loading
Loading
Loading
Loading
+34 −2
Original line number Diff line number Diff line
@@ -167,8 +167,8 @@ for i in [0:10] {
      qcor::mlir_compile(src, "test_kernel", qcor::OutputType::LLVMIR, false);
  std::cout << "LLVM:\n" << llvm << "\n";
  
  // Get the main kernel section only (there is the oracle LLVM section as well)
  llvm = llvm.substr(llvm.find("@test_kernel"));
  // Get the main kernel section only 
  llvm = llvm.substr(llvm.find("@__internal_mlir_test_kernel"));
  const auto last = llvm.find_first_of("}");
  llvm = llvm.substr(0, last + 1);
  std::cout << "LLVM:\n" << llvm << "\n";
@@ -177,6 +177,38 @@ for i in [0:10] {
  EXPECT_EQ(countSubstring(llvm, "__quantum__qis__rx"), 1);
}

TEST(qasm3PassManagerTester, checkLoopUnrollTrotter) {
  // Unroll the loop:
  // Trotter decompose
  const std::string src = R"#(OPENQASM 3;
include "stdgates.inc";
qubit qq[2];
for i in [0:100] {
    h qq;
    cx qq[0], qq[1];
    rx(0.0123) qq[1];
    cx qq[0], qq[1];
    h qq;
}
)#";
  auto llvm =
      qcor::mlir_compile(src, "test_kernel", qcor::OutputType::LLVMIR, false);
  std::cout << "LLVM:\n" << llvm << "\n";
  
  // Get the main kernel section only 
  llvm = llvm.substr(llvm.find("@__internal_mlir_test_kernel"));
  const auto last = llvm.find_first_of("}");
  llvm = llvm.substr(0, last + 1);
  std::cout << "LLVM:\n" << llvm << "\n";
  // Only a single Rx remains (combine all angles)
  // 2 Hadamard before + 1 CX before
  // 2 Hadamard after + 1 CX after
  EXPECT_EQ(countSubstring(llvm, "__quantum__qis"), 7);
  EXPECT_EQ(countSubstring(llvm, "__quantum__qis__rx"), 1);
  EXPECT_EQ(countSubstring(llvm, "__quantum__qis__h"), 4);
  EXPECT_EQ(countSubstring(llvm, "__quantum__qis__cnot"), 2);
}

int main(int argc, char **argv) {
  ::testing::InitGoogleTest(&argc, argv);
  auto ret = RUN_ALL_TESTS();
+4 −2
Original line number Diff line number Diff line
@@ -404,8 +404,10 @@ antlrcpp::Any qasm3_expression_generator::visitAdditiveExpression(
        }

        createOp<mlir::AddFOp>(location, lhs, rhs);
      } else if (lhs.getType().isa<mlir::IntegerType>() &&
                 rhs.getType().isa<mlir::IntegerType>()) {
      } else if ((lhs.getType().isa<mlir::IntegerType>() ||
                  lhs.getType().isa<mlir::IndexType>()) &&
                 (rhs.getType().isa<mlir::IntegerType>() ||
                  rhs.getType().isa<mlir::IndexType>())) {
        if (lhs.getType().getIntOrFloatBitWidth() <
            rhs.getType().getIntOrFloatBitWidth()) {
          lhs =
+54 −41
Original line number Diff line number Diff line
@@ -22,6 +22,7 @@ void SimplifyQubitExtractPass::runOnOperation() {
                     std::unordered_map<int64_t, mlir::Value>>
      extract_qubit_map;

  // Map const qubit extract to its first extract
  getOperation().walk([&](mlir::quantum::ExtractQubitOp op) {
    mlir::Value idx_val = op.idx();
    mlir::Value qreg = op.qreg();
@@ -46,50 +47,62 @@ void SimplifyQubitExtractPass::runOnOperation() {
            previous_qreg_extract[index_const] = op.qbit();
          } else {
            mlir::Value previous_extract = previous_qreg_extract[index_const];
            previous_extract.dump();
            const std::function<mlir::Value(mlir::Value)> get_last_use =
                [&get_last_use](mlir::Value var) -> mlir::Value {
              if (var.hasOneUse()) {
                auto use = *var.user_begin();
                auto next_inst =
                    dyn_cast_or_null<mlir::quantum::ValueSemanticsInstOp>(use);
                if (next_inst) {
                  if (next_inst.qubits().size() == 1) {
                    return get_last_use(*next_inst.result_begin());
                  } else {
                    assert(next_inst.qubits().size() == 2);
                    // std::cout << "Two qubit gate use\n";
                    // Need to determine which operand this value is used
                    // i.e. map to the corresponding output
                    for (size_t i = 0; i < next_inst.qubits().size(); ++i) {
                      mlir::Value operand = next_inst.qubits()[i];
                      if (operand == var) {
                        // std::cout << "Find operand: " << i << "\n";
                        return get_last_use(next_inst.result()[i]);
            op.qbit().replaceAllUsesWith(previous_extract);
          }
        }
                    // Something wrong, cannot match the operand of 2-q
                    // ValueSemanticsInstOp
                    __builtin_unreachable();
                    assert(false);
                    return var;
      }
                } else {
                  return var;
    }
  });

  // Fix up the SSA chain
  // Mini symbol table to track all SSA values.
  std::unordered_map<void *, void *> ssa_var_to_root;
  std::unordered_map<void *, mlir::Value> root_ssa_var_to_last_use;
  getOperation().walk([&](mlir::quantum::ValueSemanticsInstOp op) {
    if (op.qubits().size() == 1) {
      mlir::Value operand = op.qubits()[0];
      void *operand_ptr = operand.getAsOpaquePointer();
      if (ssa_var_to_root.find(operand_ptr) == ssa_var_to_root.end()) {
        ssa_var_to_root[operand_ptr] = operand_ptr;
        assert(root_ssa_var_to_last_use.find(operand_ptr) ==
               root_ssa_var_to_last_use.end());
        root_ssa_var_to_last_use[operand_ptr] = op.result()[0];
        ssa_var_to_root[op.result()[0].getAsOpaquePointer()] = operand_ptr;
      } else {
                // No other use (last)
                // std::cout << "Last use\n";
                // var.dump();
                return var;
              }
            };
        // Match SSA operand:
        void *root_value_ptr = ssa_var_to_root[operand_ptr];
        assert(root_ssa_var_to_last_use.find(root_value_ptr) !=
               root_ssa_var_to_last_use.end());

            mlir::Value last_use = get_last_use(previous_extract);
            op.qbit().replaceAllUsesWith(last_use);
        // Fix up the input operand and update the last output
        op.qubitsMutable().assign(root_ssa_var_to_last_use[root_value_ptr]);
        ssa_var_to_root[op.result()[0].getAsOpaquePointer()] = root_value_ptr;
        root_ssa_var_to_last_use[root_value_ptr] = op.result()[0];
      }
    } else {
      assert(op.qubits().size() == 2);
      std::vector<mlir::Value> new_operands{op.qubits()[0], op.qubits()[1]};
      for (int i = 0; i < 2; ++i) {
        mlir::Value operand = op.qubits()[i];
        void *operand_ptr = operand.getAsOpaquePointer();
        if (ssa_var_to_root.find(operand_ptr) == ssa_var_to_root.end()) {
          ssa_var_to_root[operand_ptr] = operand_ptr;
          assert(root_ssa_var_to_last_use.find(operand_ptr) ==
                 root_ssa_var_to_last_use.end());
          root_ssa_var_to_last_use[operand_ptr] = op.result()[i];
          ssa_var_to_root[op.result()[i].getAsOpaquePointer()] = operand_ptr;
        } else {
          // Match SSA operand:
          // Fix up the input operand and update the last output
          void *root_value_ptr = ssa_var_to_root[operand_ptr];
          assert(root_ssa_var_to_last_use.find(root_value_ptr) !=
                 root_ssa_var_to_last_use.end());
          new_operands[i] = root_ssa_var_to_last_use[root_value_ptr];
          ssa_var_to_root[op.result()[i].getAsOpaquePointer()] = root_value_ptr;
          root_ssa_var_to_last_use[root_value_ptr] = op.result()[i];
        }
      }
      op.qubitsMutable().assign(llvm::makeArrayRef(new_operands));
    }
  });
}