Commit 888041d5 authored by Nguyen, Thien Minh's avatar Nguyen, Thien Minh
Browse files

PyXASM to handle qreg slicing



Also, support rewriting Python range-based slicing => extract_range()

Signed-off-by: default avatarThien Nguyen <nguyentm@ornl.gov>
parent b3fbaf96
Loading
Loading
Loading
Loading
+52 −20
Original line number Diff line number Diff line
@@ -97,6 +97,58 @@ class pyxasm_visitor : public pyxasmBaseVisitor {
      return 0;
    }

    const auto isSliceOp =
        [](pyxasmParser::Atom_exprContext *atom_expr_context) -> bool {
      if (atom_expr_context->trailer().size() == 1) {
        auto subscriptlist = atom_expr_context->trailer(0)->subscriptlist();
        if (subscriptlist && subscriptlist->subscript().size() == 1) {
          auto subscript = subscriptlist->subscript(0);
          const auto nbTestTerms = subscript->test().size();
          // Multiple test terms (separated by ':')
          return (nbTestTerms > 1);
        }
      }

      return false;
    };

    // Handle slicing operations (multiple array subscriptions separated by ':')
    // on a qreg.
    if (context->atom() &&
        xacc::container::contains(bufferNames, context->atom()->getText()) &&
        isSliceOp(context)) {
      std::cout << "Slice op: " << context->getText() << "\n";
      sub_node_translation << context->atom()->getText() << ".extract_range({";
      auto subscripts =
          context->trailer(0)->subscriptlist()->subscript(0)->test();
      assert(subscripts.size() > 1);
      std::vector<std::string> subscriptTerms;
      for (auto &test : subscripts) {
        subscriptTerms.emplace_back(test->getText());
      }

      auto sliceOp =
          context->trailer(0)->subscriptlist()->subscript(0)->sliceop();
      if (sliceOp && sliceOp->test()) {
        subscriptTerms.emplace_back(sliceOp->test()->getText());
      }
      assert(subscriptTerms.size() == 2 || subscriptTerms.size() == 3);

      for (int i = 0; i < subscriptTerms.size(); ++i) {
        sub_node_translation << subscriptTerms[i];
        if (i != subscriptTerms.size() - 1) {
          sub_node_translation << ", ";
        }
      }

      sub_node_translation << "})";

      // convert the slice op to initializer list:
      std::cout << "Slice Convert: " << context->getText() << " --> "
                << sub_node_translation.str() << "\n";
      return 0;
    }

    // Handle kernel::ctrl(...), kernel::adjoint(...)
    if (!context->trailer().empty() &&
        (context->trailer()[0]->getText() == ".ctrl" ||
@@ -398,26 +450,6 @@ class pyxasm_visitor : public pyxasmBaseVisitor {
              replaceMeasureAssignment(sub_node_translation.str()));
        }

        // Handle Python list assignment:
        // i.e. lhs = [a, b, c] => { a, b, c }
        // NOTE: we only support simple lists, i.e. no nested.
        // const auto transformListAssignmentIfAny =
        //     [](const std::string &in_expr) -> std::string {
        //   const auto whitespace = " ";
        //   // Trim leading and trailing spaces:
        //   const auto strBegin = in_expr.find_first_not_of(whitespace);
        //   const auto strEnd = in_expr.find_last_not_of(whitespace);
        //   const auto strRange = strEnd - strBegin + 1;
        //   const auto trim_expr = in_expr.substr(strBegin, strRange);

        //   if (trim_expr.front() == '[' && trim_expr.back() == ']') {
        //     return "{" + trim_expr.substr(1, trim_expr.size() - 2) + "}";
        //   }

        //   // Returns the original expression:
        //   return in_expr;
        // };

        if (xacc::container::contains(declared_var_names, lhs)) {
          ss << lhs << " = " << rhs << "; \n";
        } else {
+43 −2
Original line number Diff line number Diff line
@@ -92,7 +92,7 @@ TEST(PyXASMTokenCollectorTester, checkPythonList) {

  std::stringstream ss;
  auto xasm_tc = xacc::getService<qcor::TokenCollector>("pyxasm");
  xasm_tc->collect(*PP.get(), cached, {"qb"}, ss);
  xasm_tc->collect(*PP.get(), cached, {"q"}, ss);
  std::cout << "heres the test\n";
  std::cout << ss.str() << "\n";
  const std::string expectedCodeGen =
@@ -119,7 +119,7 @@ TEST(PyXASMTokenCollectorTester, checkStringLiteral) {

  std::stringstream ss;
  auto xasm_tc = xacc::getService<qcor::TokenCollector>("pyxasm");
  xasm_tc->collect(*PP.get(), cached, {"qb"}, ss);
  xasm_tc->collect(*PP.get(), cached, {}, ss);
  std::cout << "heres the test\n";
  std::cout << ss.str() << "\n";
  const std::string expectedCodeGen =
@@ -129,6 +129,47 @@ print("howdy", 1, "abc");
  EXPECT_EQ(expectedCodeGen, ss.str());
}

TEST(PyXASMTokenCollectorTester, checkQregMethods) {
  LexerHelper helper;
  auto [tokens, PP] = helper.Lex(R"(
    ctrl_qubits = q.head(q.size()-1)
    last_qubit = q.tail()
    Z.ctrl(ctrl_qubits, last_qubit)
    
    # inline
    X.ctrl(q.head(q.size()-1), q.tail())

    # range:
    # API
    r = q.extract_range(0, bitPrecision)
    # Python style
    slice1 = q[0:3]
    # step size
    slice2 = q[0:5:2]
)");

  clang::CachedTokens cached;
  for (auto &t : tokens) {
    cached.push_back(t);
  }

  std::stringstream ss;
  auto xasm_tc = xacc::getService<qcor::TokenCollector>("pyxasm");
  xasm_tc->collect(*PP.get(), cached, {"q"}, ss);
  std::cout << "heres the test\n";
  std::cout << ss.str() << "\n";
  const std::string expectedCodeGen =
      R"#(auto ctrl_qubits = q.head(q.size()-1); 
auto last_qubit = q.tail(); 
Z::ctrl(parent_kernel, ctrl_qubits, last_qubit);
X::ctrl(parent_kernel, q.head(q.size()-1), q.tail());
auto r = q.extract_range(0,bitPrecision); 
auto slice1 = q.extract_range({0, 3}); 
auto slice2 = q.extract_range({0, 5, 2}); 
)#";
  EXPECT_EQ(expectedCodeGen, ss.str());
}

int main(int argc, char **argv) {
  std::string xacc_config_install_dir = std::string(XACC_INSTALL_DIR);
  std::string qcor_root = std::string(QCOR_INSTALL_DIR);