Commit c957b3c3 authored by Nguyen, Thien Minh's avatar Nguyen, Thien Minh
Browse files

Work on PyXASM improvements

Handle Python list/array (testlist_comp) node: map [] to {}

Related to https://github.com/ORNL-QCI/qcor/issues/112



Signed-off-by: default avatarThien Nguyen <nguyentm@ornl.gov>
parent 02e27bd8
Loading
Loading
Loading
Loading
+111 −21
Original line number Diff line number Diff line
@@ -33,9 +33,63 @@ class pyxasm_visitor : public pyxasmBaseVisitor {
  // New var declared (auto type) after visiting this node.
  std::string new_var;
  bool in_for_loop = false;
  // Var to keep track of sub-node rewrite:
  // e.g., traverse down the AST recursively.
  std::stringstream sub_node_translation;

  antlrcpp::Any visitAtom_expr(
      pyxasmParser::Atom_exprContext *context) override {
    std::cout << "Atom_exprContext: " << context->getText() << "\n";
    // Strategy:
    // At the top level, we analyze the trailer to determine the 
    // list of function call arguments.
    // Then, traverse down the arg. node to see if there is a potential rewrite rules
    // e.g. for arrays (as testlist_comp nodes)
    // Otherwise, just get the argument text as is.
    /*
    atom_expr: (AWAIT)? atom trailer*;
    atom: ('(' (yield_expr|testlist_comp)? ')' |
       '[' (testlist_comp)? ']' |
       '{' (dictorsetmaker)? '}' |
       NAME | NUMBER | STRING+ | '...' | 'None' | 'True' | 'False');
    */

    if (context->atom() && context->atom()->testlist_comp()) {
      // Array type expression:
      std::cout << "Array atom expression: "
                << context->atom()->testlist_comp()->getText() << "\n";
      // Use braces
      sub_node_translation << "{";
      bool firstElProcessed = false;
      for (auto &testNode : context->atom()->testlist_comp()->test()) {
        std::cout << "Array elem: " << testNode->getText() << "\n";
        // Add comma if needed (there is a previous element)
        if (firstElProcessed) {
          sub_node_translation << ", ";
        }
        sub_node_translation << testNode->getText();
        firstElProcessed = true;
      }
      sub_node_translation << "}";
      return 0;
    }

    if (context->atom() && context->atom()->dictorsetmaker()) {
      // Dict:
      std::cout << "Dict atom expression: "
                << context->atom()->dictorsetmaker()->getText() << "\n";
      // TODO:
      return 0;
    }

    if (context->atom() && !context->atom()->STRING().empty()) {
      // Strings:
      for (auto &strNode : context->atom()->STRING()) {
        std::cout << "String expression: " << strNode->getText() << "\n";
      }
      return 0;
    }

    // Handle kernel::ctrl(...), kernel::adjoint(...)
    if (!context->trailer().empty() &&
        (context->trailer()[0]->getText() == ".ctrl" ||
@@ -54,8 +108,29 @@ class pyxasm_visitor : public pyxasmBaseVisitor {
      ss << context->atom()->getText() << "::" << methodName
         << "(parent_kernel";
      for (int i = 0; i < arg_list->argument().size(); i++) {
        // Strategy:
        // Traverse down the tree to see if the there is a potential translation:
        // i.e. it will populate sub_node_translation stream.
        // Otherwise, output the argument *as-is*

        // clear the sub_node_translation  
        sub_node_translation.str(std::string());

        // visit arg sub-node:
        visitChildren(arg_list->argument(i));

        // Check if there is a rewrite:
        if (!sub_node_translation.str().empty()) {
          const auto arg_new_str = sub_node_translation.str();
          std::cout << arg_list->argument(i)->getText() << " --> " << arg_new_str << "\n";
          sub_node_translation.str(std::string());
          ss << ", " << arg_new_str;
        }
        else {
          // Use the arg as is:
          ss << ", " << arg_list->argument(i)->getText();
        }
      }
      ss << ");\n";

      // std::cout << "HELLO SS: " << ss.str() << "\n";
@@ -250,7 +325,7 @@ class pyxasm_visitor : public pyxasmBaseVisitor {
      // Handle simple assignment: a = expr
      std::stringstream ss;
      const std::string lhs = ctx->testlist_star_expr(0)->getText();
      const std::string rhs = replacePythonConstants(
      std::string rhs = replacePythonConstants(
          replaceMeasureAssignment(ctx->testlist_star_expr(1)->getText()));

      if (lhs.find(",") != std::string::npos) {
@@ -269,32 +344,47 @@ class pyxasm_visitor : public pyxasmBaseVisitor {
          }
        }
      } else {
        // Handle Python list assignment:
        // i.e. lhs = [a, b, c] => { a, b, c }
        // NOTE: we only support simple lists, i.e. no nested.
        const auto transformListAssignmentIfAny =
            [](const std::string &in_expr) -> std::string {
          const auto whitespace = " ";
          // Trim leading and trailing spaces:
          const auto strBegin = in_expr.find_first_not_of(whitespace);
          const auto strEnd = in_expr.find_last_not_of(whitespace);
          const auto strRange = strEnd - strBegin + 1;
          const auto trim_expr = in_expr.substr(strBegin, strRange);
        // Strategy: try to traverse the rhs to see if there is a possible rewrite;
        // Otherwise, use the text as is.
        
        // clear the sub_node_translation  
        sub_node_translation.str(std::string());

        // visit arg sub-node:
        visitChildren(ctx->testlist_star_expr(1));

          if (trim_expr.front() == '[' && trim_expr.back() == ']') {
            return "{" + trim_expr.substr(1, trim_expr.size() - 2) + "}";
        // Check if there is a rewrite:
        if (!sub_node_translation.str().empty()) {
          // Update RHS
          rhs = replacePythonConstants(
              replaceMeasureAssignment(sub_node_translation.str()));
        }

          // Returns the original expression:
          return in_expr;
        };
        // Handle Python list assignment:
        // i.e. lhs = [a, b, c] => { a, b, c }
        // NOTE: we only support simple lists, i.e. no nested.
        // const auto transformListAssignmentIfAny =
        //     [](const std::string &in_expr) -> std::string {
        //   const auto whitespace = " ";
        //   // Trim leading and trailing spaces:
        //   const auto strBegin = in_expr.find_first_not_of(whitespace);
        //   const auto strEnd = in_expr.find_last_not_of(whitespace);
        //   const auto strRange = strEnd - strBegin + 1;
        //   const auto trim_expr = in_expr.substr(strBegin, strRange);

        //   if (trim_expr.front() == '[' && trim_expr.back() == ']') {
        //     return "{" + trim_expr.substr(1, trim_expr.size() - 2) + "}";
        //   }

        //   // Returns the original expression:
        //   return in_expr;
        // };

        if (xacc::container::contains(declared_var_names, lhs)) {
          ss << lhs << " = " << transformListAssignmentIfAny(rhs) << "; \n";
          ss << lhs << " = " << rhs << "; \n";
        } else {
          // New variable: need to add *auto*
          ss << "auto " << lhs << " = " << transformListAssignmentIfAny(rhs)
             << "; \n";
          ss << "auto " << lhs << " = " << rhs << "; \n";
          new_var = lhs;
        }
      }
+27 −0
Original line number Diff line number Diff line
@@ -75,6 +75,33 @@ quantum::x(qb[i]);
  EXPECT_EQ(expectedCodeGen, ss.str());
}

TEST(PyXASMTokenCollectorTester, checkPythonList) {
  LexerHelper helper;

  auto [tokens, PP] = helper.Lex(R"(
    # inline initializer list
    apply_X_at_idx.ctrl([q[1], q[2]], q[0])
    # array var assignement
    array_val = [q[1], q[2]]
)");

  clang::CachedTokens cached;
  for (auto &t : tokens) {
    cached.push_back(t);
  }

  std::stringstream ss;
  auto xasm_tc = xacc::getService<qcor::TokenCollector>("pyxasm");
  xasm_tc->collect(*PP.get(), cached, {"qb"}, ss);
  std::cout << "heres the test\n";
  std::cout << ss.str() << "\n";
  const std::string expectedCodeGen =
      R"#(apply_X_at_idx::ctrl(parent_kernel, {q[1], q[2]}, q[0]);
auto array_val = {q[1], q[2]}; 
)#";
  EXPECT_EQ(expectedCodeGen, ss.str());
}

int main(int argc, char **argv) {
  std::string xacc_config_install_dir = std::string(XACC_INSTALL_DIR);
  std::string qcor_root = std::string(QCOR_INSTALL_DIR);