Unverified Commit cec804e2 authored by Mccaskey, Alex's avatar Mccaskey, Alex Committed by GitHub
Browse files

Merge pull request #45 from tnguyen-ornl/tnguyen/pyxasm

Work on PyXASM syntax handler
parents 0083c54d c0d98372
Loading
Loading
Loading
Loading
Loading
+29 −14
Original line number Diff line number Diff line
@@ -55,9 +55,8 @@ void PyXasmTokenCollector::collect(clang::Preprocessor &PP,
  std::vector<std::pair<std::string, int>> lines;
  std::string line = "";
  auto current_line_number = sm.getSpellingLineNumber(Toks[0].getLocation());
  line += PP.getSpelling(Toks[0]);
  int last_col_number = 0;
  for (int i = 1; i < Toks.size(); i++) {
  for (int i = 0; i < Toks.size(); i++) {
    // std::cout << PP.getSpelling(Toks[i]) << "\n";
    auto location = Toks[i].getLocation();
    auto col_number = sm.getSpellingColumnNumber(location);
@@ -90,17 +89,29 @@ void PyXasmTokenCollector::collect(clang::Preprocessor &PP,
  using namespace antlr4;

  int previous_col = lines[0].second;
  bool is_in_for_loop = false;
  int line_counter = 0;
  // Tracking the scope of for loops by their indent
  std::stack<int> for_loop_indent;
  for (const auto &line : lines) {
    // std::cout << "processing line " << line_counter << " of " << lines.size()
    //           << ": " << line.first << ", " << line.second << std::boolalpha
    //           << ", " << is_in_for_loop << "\n";

    pyxasm_visitor visitor;
    //           << ", " << !for_loop_indent.empty() << "\n";

    pyxasm_visitor visitor(bufferNames);
    // Should we close a 'for' scope after this statement
    // If > 0, indicate the number of for blocks to be closed.
    int close_for_scopes = 0;
    // If the stack is not empty and this line changed column to an outside
    // scope:
    while (!for_loop_indent.empty() && line.second < for_loop_indent.top()) {
      // Pop the stack and flag to close the scope afterward
      for_loop_indent.pop();
      close_for_scopes++;
    }

    // Enter a new for loop -> push to the stack
    if (line.first.find("for ") != std::string::npos) {
      is_in_for_loop = true;
      for_loop_indent.push(line.second);
    }

    // is_in_for_loop = line.first.find("for ") != std::string::npos &&
@@ -128,17 +139,21 @@ void PyXasmTokenCollector::collect(clang::Preprocessor &PP,
      ss << visitor.result.first;
    }

    if ((is_in_for_loop && line.second < previous_col) ||
        (is_in_for_loop && line_counter == lines.size() - 1)) {
      // we are now not in a for loop...
      is_in_for_loop = false;
    if (close_for_scopes > 0) {
      // std::cout << "Close " << close_for_scopes << " for scopes.\n";
      // need to close out the c++ or loop
      for (int i = 0; i < close_for_scopes; ++i) {
        ss << "}\n";
      }

    }
    previous_col = line.second;
    line_counter++;
  }

  // If there are open for scope blocks here,
  // i.e. for loops at the end of the function body.
  while (!for_loop_indent.empty()) {
    for_loop_indent.pop();
    ss << "}\n";
  }
}
}  // namespace qcor
 No newline at end of file
+106 −31
Original line number Diff line number Diff line
@@ -17,16 +17,18 @@ using pyxasm_result_type =
class pyxasm_visitor : public pyxasmBaseVisitor {
protected:
  std::shared_ptr<xacc::IRProvider> provider;
  // List of buffers in the *context* of this XASM visitor
  std::vector<std::string> bufferNames;

public:
  pyxasm_visitor()
      : provider(xacc::getIRProvider("quantum")) {}
  pyxasm_visitor(const std::vector<std::string> &buffers = {})
      : provider(xacc::getIRProvider("quantum")), bufferNames(buffers) {}
  pyxasm_result_type result;

  bool in_for_loop = false;

  antlrcpp::Any visitAtom_expr(
      pyxasmParser::Atom_exprContext *context) override {
  antlrcpp::Any
  visitAtom_expr(pyxasmParser::Atom_exprContext *context) override {
    if (context->atom()->NAME() != nullptr) {
      auto inst_name = context->atom()->NAME()->getText();

@@ -67,9 +69,9 @@ class pyxasm_visitor : public pyxasmBaseVisitor {
              auto found_bracket = bit_expr_str.find_first_of("[");
              if (found_bracket != std::string::npos) {
                auto buffer_name = bit_expr_str.substr(0, found_bracket);
                auto bit_idx_expr = bit_expr_str.substr(
                    found_bracket + 1,
                    bit_expr_str.length() - found_bracket - 2);
                auto bit_idx_expr = bit_expr_str.substr(found_bracket + 1,
                                                        bit_expr_str.length() -
                                                            found_bracket - 2);
                buffer_names.push_back(buffer_name);
                inst->setBitExpression(i, bit_idx_expr);
              } else {
@@ -81,14 +83,68 @@ class pyxasm_visitor : public pyxasmBaseVisitor {
            // Get the parameter expressions
            int counter = 0;
            for (int i = required_bits; i < atom_n_args; i++) {
              inst->setParameter(
                  counter,
                  context->trailer()[0]->arglist()->argument()[i]->getText());
              inst->setParameter(counter,
                                 replacePythonConstants(context->trailer()[0]
                                                            ->arglist()
                                                            ->argument()[i]
                                                            ->getText()));
              counter++;
            }
          }
        }
          result.second = inst;
        } else {
          // Composite instructions, e.g. exp_i_theta
          if (inst_name == "exp_i_theta") {
            // Expected 3 params:
            if (context->trailer()[0]->arglist()->argument().size() != 3) {
              xacc::error(
                  "Invalid number of arguments for the 'exp_i_theta' "
                  "instruction. Expected 3, got " +
                  std::to_string(
                      context->trailer()[0]->arglist()->argument().size()) +
                  ". Please check your input.");
            }

            std::stringstream ss;
            // Delegate to the QRT call directly.
            ss << "quantum::exp("
               << context->trailer()[0]->arglist()->argument(0)->getText()
               << ", "
               << context->trailer()[0]->arglist()->argument(1)->getText()
               << ", "
               << context->trailer()[0]->arglist()->argument(2)->getText()
               << ");\n";
            result.first = ss.str();
          } else {
            xacc::error("Composite instruction '" + inst_name +
                        "' is not currently supported.");
          }
        }
      } else {
        // This kernel *callable* is not an intrinsic instruction, just
        // reassemble the call:
        // Check that the *first* argument is a *qreg* in the current context of
        // *this* kernel.
        if (!context->trailer().empty() &&
            !context->trailer()[0]->arglist()->argument().empty() &&
            xacc::container::contains(
                bufferNames,
                context->trailer()[0]->arglist()->argument(0)->getText())) {
          std::stringstream ss;
          // Use the kernel call with a parent kernel arg.
          ss << inst_name << "(parent_kernel, ";
          // TODO: We potentially need to handle *inline* expressions in the
          // function call.
          const auto &argList = context->trailer()[0]->arglist()->argument();
          for (size_t i = 0; i < argList.size(); ++i) {
            ss << argList[i]->getText();
            if (i != argList.size() - 1) {
              ss << ", ";
            }
          }
          ss << ");\n";
          result.first = ss.str();
        }
      }
    }
    return 0;
@@ -96,28 +152,47 @@ class pyxasm_visitor : public pyxasmBaseVisitor {

  antlrcpp::Any visitFor_stmt(pyxasmParser::For_stmtContext *context) override {
    auto counter_expr = context->exprlist()->expr()[0];

    if (context->testlist()->test()[0]->getText().find("range") !=
        std::string::npos) {
      auto range_str = context->testlist()->test()[0]->getText();
      auto found_paren = range_str.find_first_of("(");
      auto range_contents = range_str.substr(
          found_paren + 1, range_str.length() - found_paren - 2);

    auto iter_container = context->testlist()->test()[0]->getText();
    // Rewrite:
    // Python: "for <var> in <expr>:"
    // C++: for (auto& var: <expr>) {}
    // Note: we add range(int) as a C++ function to support this common pattern.
    std::stringstream ss;
      ss << "for (int " << counter_expr->getText() << " = 0; "
         << counter_expr->getText() << " < " << range_contents << "; ++"
         << counter_expr->getText() << " ) {\n";

    ss << "for (auto &" << counter_expr->getText() << " : " << iter_container
       << ") {\n";
    result.first = ss.str();
    in_for_loop = true;
    return 0;
  }

  antlrcpp::Any visitExpr_stmt(pyxasmParser::Expr_stmtContext *ctx) override {
    if (ctx->ASSIGN().size() == 1 && ctx->testlist_star_expr().size() == 2) {
      // Handle simple assignment: a = expr
      std::stringstream ss;
      const std::string lhs = ctx->testlist_star_expr(0)->getText();
      const std::string rhs = ctx->testlist_star_expr(1)->getText();
      ss << "auto " << lhs << " = " << rhs << "; \n";
      result.first = ss.str();
      return 0;
    } else {
      xacc::error(
          "QCOR PyXasm can only handle 'for VAR in range(QREG.size())' at the "
          "moment.");
      return visitChildren(ctx);
    }
  }

    return 0;
private:
  // Replaces common Python constants, e.g. 'math.pi' or 'numpy.pi'.
  // Note: the library names have been resolved to their original names.
  std::string replacePythonConstants(const std::string &in_pyExpr) const {
    // List of all keywords to be replaced
    const std::map<std::string, std::string> REPLACE_MAP{{"math.pi", "M_PI"},
                                                         {"numpy.pi", "M_PI"}};
    std::string newSrc = in_pyExpr;
    for (const auto &[key, value] : REPLACE_MAP) {
      const auto pos = newSrc.find(key);
      if (pos != std::string::npos) {
        newSrc.replace(pos, key.length(), value);
      }
    }
    return newSrc;
  }
};
 No newline at end of file
+1 −1
Original line number Diff line number Diff line
@@ -31,7 +31,7 @@ TEST(PyXASMTokenCollectorTester, checkSimple) {

  EXPECT_EQ(R"#(quantum::h(qb[0]);
quantum::cnot(qb[0], qb[1]);
for (int i = 0; i < qb.size(); ++i ) {
for (auto &i : range(qb.size())) {
quantum::x(qb[i]);
quantum::x(qb[i]);
quantum::mz(qb[i]);
+2 −2
Original line number Diff line number Diff line
@@ -156,7 +156,7 @@ TEST(TokenCollectorTester, checkPyXasm) {
  std::cout << results << "\n";
  EXPECT_EQ(R"#(quantum::h(qb[0]);
quantum::cnot(qb[0], qb[1]);
for (int i = 0; i < qb.size(); ++i ) {
for (auto &i : range(qb.size())) {
quantum::x(qb[i]);
quantum::x(qb[i]);
quantum::mz(qb[i]);
+40 −0
Original line number Diff line number Diff line
# Run this from the command line like this
#
# python3 exp_i_theta.py -shots 100

from qcor import qjit, qalloc, qreg

# To create QCOR quantum kernels in Python one 
# simply creates a Python function, writes Pythonic, 
# XASM-like quantum code, and annotates the kernel 
# to indicate it is meant for QCOR just in time compilation

# NOTE Programmers must type annotate their function arguments

# Define a XASM kernel
@qjit
def exp_circuit(q : qreg, t0: float, t1: float):
    exponent_op1 = X(0) * Y(1) - Y(0) * X(1)
    exponent_op2 = X(0) * Z(1) * Y(2) -  X(2) * Z(1) * Y(0)
    X(q[0])
    exp_i_theta(q, t0, exponent_op1)
    exp_i_theta(q, t1, exponent_op2)
    
    for i in range(q.size()):
        Measure(q[i])

# Allocate 3 qubits
q = qalloc(3)

# Run the  experiment with some random angles
theta1 = 1.234
theta2 = 2.345

# Examine the circuit QASM
comp = exp_circuit.extract_composite(q, theta1, theta2)
print(comp.toString())

# Execute
exp_circuit(q, theta1, theta2)
# Print the results
q.print()
 No newline at end of file
Loading