Unverified Commit 2a4ae844 authored by Mccaskey, Alex's avatar Mccaskey, Alex Committed by GitHub
Browse files

Merge pull request #55 from tnguyen-ornl/tnguyen/pyxasm-ftqc

Work on support ftqc in python qjit
parents ce893a98 5e3a1671
Loading
Loading
Loading
Loading
Loading
+118 −41
Original line number Diff line number Diff line
@@ -80,7 +80,7 @@ void QCORSyntaxHandler::GetReplacement(

  // Get Tokens as a string, rewrite code
  // with XACC api calls
  qcor::append_kernel(kernel_name);
  qcor::append_kernel(kernel_name, program_arg_types, program_parameters);

  auto new_src = qcor::run_token_collector(PP, Toks, bufferNames);

@@ -200,44 +200,44 @@ void QCORSyntaxHandler::GetReplacement(
  }
  OS << ") {}\n";

  if (add_het_map_ctor) {
    // Third constructor, give us a way to provide a HeterogeneousMap of
    // arguments, this is used for Pythonic QJIT...
    // KERNEL_NAME(HeterogeneousMap args);
    OS << kernel_name << "(HeterogeneousMap& args): QuantumKernel<"
       << kernel_name << ", " << program_arg_types[0];
    for (int i = 1; i < program_arg_types.size(); i++) {
      OS << ", " << program_arg_types[i];
    }
    OS << "> (args.get<" << program_arg_types[0] << ">(\""
       << program_parameters[0] << "\")";
    for (int i = 1; i < program_parameters.size(); i++) {
      OS << ", "
         << "args.get<" << program_arg_types[i] << ">(\""
         << program_parameters[i] << "\")";
    }
    OS << ") {}\n";

    // Forth constructor, give us a way to provide a HeterogeneousMap of
    // arguments, and set a parent kernel - this is also used for Pythonic
    // QJIT... KERNEL_NAME(std::shared_ptr<CompositeInstruction> parent,
    // HeterogeneousMap args);
    OS << kernel_name
       << "(std::shared_ptr<CompositeInstruction> parent, HeterogeneousMap& "
          "args): QuantumKernel<"
       << kernel_name << ", " << program_arg_types[0];
    for (int i = 1; i < program_arg_types.size(); i++) {
      OS << ", " << program_arg_types[i];
    }
    OS << "> (parent, args.get<" << program_arg_types[0] << ">(\""
       << program_parameters[0] << "\")";
    for (int i = 1; i < program_parameters.size(); i++) {
      OS << ", "
         << "args.get<" << program_arg_types[i] << ">(\""
         << program_parameters[i] << "\")";
    }
    OS << ") {}\n";
  }
  // if (add_het_map_ctor) {
  //   // Third constructor, give us a way to provide a HeterogeneousMap of
  //   // arguments, this is used for Pythonic QJIT...
  //   // KERNEL_NAME(HeterogeneousMap args);
  //   OS << kernel_name << "(HeterogeneousMap& args): QuantumKernel<"
  //      << kernel_name << ", " << program_arg_types[0];
  //   for (int i = 1; i < program_arg_types.size(); i++) {
  //     OS << ", " << program_arg_types[i];
  //   }
  //   OS << "> (args.get<" << program_arg_types[0] << ">(\""
  //      << program_parameters[0] << "\")";
  //   for (int i = 1; i < program_parameters.size(); i++) {
  //     OS << ", "
  //        << "args.get<" << program_arg_types[i] << ">(\""
  //        << program_parameters[i] << "\")";
  //   }
  //   OS << ") {}\n";

  //   // Forth constructor, give us a way to provide a HeterogeneousMap of
  //   // arguments, and set a parent kernel - this is also used for Pythonic
  //   // QJIT... KERNEL_NAME(std::shared_ptr<CompositeInstruction> parent,
  //   // HeterogeneousMap args);
  //   OS << kernel_name
  //      << "(std::shared_ptr<CompositeInstruction> parent, HeterogeneousMap& "
  //         "args): QuantumKernel<"
  //      << kernel_name << ", " << program_arg_types[0];
  //   for (int i = 1; i < program_arg_types.size(); i++) {
  //     OS << ", " << program_arg_types[i];
  //   }
  //   OS << "> (parent, args.get<" << program_arg_types[0] << ">(\""
  //      << program_parameters[0] << "\")";
  //   for (int i = 1; i < program_parameters.size(); i++) {
  //     OS << ", "
  //        << "args.get<" << program_arg_types[i] << ">(\""
  //        << program_parameters[i] << "\")";
  //   }
  //   OS << ") {}\n";
  // }

  // Destructor definition
  OS << "virtual ~" << kernel_name << "() {\n";
@@ -331,16 +331,93 @@ void QCORSyntaxHandler::GetReplacement(
  OS << "}\n";

  if (add_het_map_ctor) {
    // Remove "&" from type string before getting the Python variables in the HetMap.
    // Note: HetMap can't store references.
    const auto remove_ref_arg_type = [](const std::string &org_arg_type) -> std::string {
      // We intentially only support a very limited set of pass-by-ref types
      // from the HetMap.
      // Only do: double& and int&
      if (org_arg_type == "double&") {
        return "double";
      }
      if (org_arg_type == "int&") {
        return "int";
      }
      // Keep the type string.
      return org_arg_type;
    };

    // Strategy: we unpack the args in the HetMap and call
    // the appropriate ctor overload.
    
    // For reference ctor params (e.g. double& and int&),
    // we create a local variable to copy the arg from the HetMap
    // before passing to the ctor.
    // We have a special machanism to handle *pass-by-reference*
    // in the Python side.
    // Non-reference types will just use inline `args.get<T>(key)` to unpack
    // the arguments.

    // List of resolved argument strings for ctor calls.
    std::vector<std::string> arg_ctor_list;
    // Code to copy *ref* type arguments from the HetMap.
    // This *must* be injected before the ctor call.
    std::stringstream ref_type_copy_decl_ss;
    int var_counter = 0;
    // Only handle non-qreg args
    for (int i = 1; i < program_parameters.size(); i++) {
      // If this is a *supported* ref types: double&, int&, etc. 
      if (remove_ref_arg_type(program_arg_types[i]) != program_arg_types[i]) {
        // Generate a temp var
        const std::string new_var_name = "__temp_var__" + std::to_string(var_counter++);
        // Copy the var from HetMap to the temp var
        ref_type_copy_decl_ss << remove_ref_arg_type(program_arg_types[i]) << " "<< new_var_name << " = " << "args.get<" << remove_ref_arg_type(program_arg_types[i]) << ">(\""
         << program_parameters[i] << "\");\n";
        
        // We just pass this copied var to the ctor 
        // where it expects a reference type.
        arg_ctor_list.emplace_back(new_var_name); 
      }
      else {
        // Otherwise, just unpack the arg inline in the ctor call.
        std::stringstream ss;
        ss << "args.get<" << program_arg_types[i] << ">(\""<< program_parameters[i] << "\")";
        arg_ctor_list.emplace_back(ss.str());
      }
    }

    // Add the HeterogeneousMap args function overload
    OS << "void " << kernel_name
       << "__with_hetmap_args(HeterogeneousMap& args) {\n";
    OS << "class " << kernel_name << " __ker__temp__(args);\n";
    // First, inject any copying statements required to unpack *ref* types.
    OS << ref_type_copy_decl_ss.str();
    // CTor call
    OS << "class " << kernel_name << " __ker__temp__(";
    // First arg: qreg
    OS << "args.get<" << program_arg_types[0] << ">(\""
       << program_parameters[0] << "\")";
    // The rest: either inline unpacking or temp var names (ref type)
    for (const auto &arg_str: arg_ctor_list) {
      OS << ", " << arg_str;
    }
    OS << ");\n";
    OS << "}\n";

    OS << "void " << kernel_name
       << "__with_parent_and_hetmap_args(std::shared_ptr<CompositeInstruction> parent, "
          "HeterogeneousMap& args) {\n";
    OS << "class " << kernel_name << " __ker__temp__(parent, args);\n";
    OS << ref_type_copy_decl_ss.str();
    // CTor call with parent kernel
    OS << "class " << kernel_name << " __ker__temp__(parent, ";
    // Second arg: qreg
    OS << "args.get<" << program_arg_types[0] << ">(\""
       << program_parameters[0] << "\")";
    // The rest: either inline unpacking or temp var names (ref type)
    for (const auto &arg_str: arg_ctor_list) {
      OS << ", " << arg_str;
    }
    OS << ");\n";   
    // The rest: either inline unpacking or temp var names (ref type)
    OS << "}\n";
  }
  auto s = OS.str();
+84 −21
Original line number Diff line number Diff line
@@ -80,6 +80,16 @@ void PyXasmTokenCollector::collect(clang::Preprocessor &PP,
      line += PP.getSpelling(Toks[i]);
      line += " ";
    }

    // If statement:
    // Note: Python has an "elif" token, which doesn't have a C++ equiv.
    if (Toks[i].is(clang::tok::TokenKind::kw_if) ||
        PP.getSpelling(Toks[i]) == "elif") {
      line += " ";
      i += 1;
      line += PP.getSpelling(Toks[i]);
    }

    last_col_number = col_number;
  }

@@ -90,34 +100,83 @@ void PyXasmTokenCollector::collect(clang::Preprocessor &PP,

  int previous_col = lines[0].second;
  int line_counter = 0;
  // Tracking the scope of for loops by their indent
  std::stack<int> for_loop_indent;
  // Add all the kernel args to the list of *known* arguments.
  // i.e. when we see an assignment expression where this arg. is the LHS,
  // we don't add *auto * to the codegen.
  std::vector<std::string> local_vars = [&]() -> std::vector<std::string> {
    if (::quantum::kernels_in_translation_unit.empty()) {
      return {};
    }
    const std::string kernel_name =
        ::quantum::kernels_in_translation_unit.back();
    const auto &[arg_types, arg_names] =
        ::quantum::kernel_signatures_in_translation_unit[kernel_name];
    return arg_names;
  }();
  // Tracking the Python scopes by the indent of code blocks
  std::stack<int> scope_block_indent;
  for (const auto &line : lines) {
    // std::cout << "processing line " << line_counter << " of " << lines.size()
    //           << ": " << line.first << ", " << line.second << std::boolalpha
    //           << ", " << !for_loop_indent.empty() << "\n";
    //           << ", " << !scope_block_indent.empty() << "\n";

    pyxasm_visitor visitor(bufferNames);
    // Should we close a 'for' scope after this statement
    pyxasm_visitor visitor(bufferNames, local_vars);
    // Should we close a 'for'/'if' scope after this statement
    // If > 0, indicate the number of for blocks to be closed.
    int close_for_scopes = 0;
    int nb_closing_scopes = 0;
    // If the stack is not empty and this line changed column to an outside
    // scope:
    while (!for_loop_indent.empty() && line.second < for_loop_indent.top()) {
    while (!scope_block_indent.empty() &&
           line.second < scope_block_indent.top()) {
      // Pop the stack and flag to close the scope afterward
      for_loop_indent.pop();
      close_for_scopes++;
      scope_block_indent.pop();
      nb_closing_scopes++;
    }

    // Enter a new for loop -> push to the stack
    if (line.first.find("for ") != std::string::npos) {
      for_loop_indent.push(line.second);
    std::string lineText = line.first;
    // Enter a new for scope block (for/if/etc.) -> push to the stack
    // Note: we rewrite Python if .. elif .. else as follows:
    // Python:
    // if (cond1):
    //   code1
    // elif (cond2):
    //   code2
    // else:
    //   code3
    // ===============
    // C++:
    // if (cond1) {
    //   code1
    // }
    // else if (cond2) {
    //   code2
    // }
    // else {
    //   code3
    // }

    if (line.first.find("for ") != std::string::npos ||
        // Starts with 'if'
        line.first.rfind("if ", 0) == 0) {
      scope_block_indent.push(line.second);
    } else if (line.first == "else:") {
      ss << "else {\n";
      scope_block_indent.push(line.second);
    }
    // Starts with 'elif'
    else if (line.first.rfind("elif ", 0) == 0) {
      // Rewrite it to
      // else if () { }
      ss << "else ";
      scope_block_indent.push(line.second);
      // Remove the first two characters ("el")
      // hence this line will be parsed as an idependent C++ if block:
      lineText.erase(0, 2);
    }

    // is_in_for_loop = line.first.find("for ") != std::string::npos &&
    // line.second >= previous_col;

    ANTLRInputStream input(line.first);
    ANTLRInputStream input(lineText);
    pyxasmLexer lexer(&input);
    CommonTokenStream tokens(&lexer);
    pyxasmParser parser(&tokens);
@@ -139,20 +198,24 @@ void PyXasmTokenCollector::collect(clang::Preprocessor &PP,
      ss << visitor.result.first;
    }

    if (close_for_scopes > 0) {
      // std::cout << "Close " << close_for_scopes << " for scopes.\n";
    if (nb_closing_scopes > 0) {
      // std::cout << "Close " << nb_closing_scopes << " for scopes.\n";
      // need to close out the c++ or loop
      for (int i = 0; i < close_for_scopes; ++i) {
      for (int i = 0; i < nb_closing_scopes; ++i) {
        ss << "}\n";
      }
    }
    previous_col = line.second;
    line_counter++;
    if (!visitor.new_var.empty()) {
      // A new local variable was declared, add to the tracking list.
      local_vars.emplace_back(visitor.new_var);
    }
  }
  // If there are open for scope blocks here,
  // i.e. for loops at the end of the function body.
  while (!for_loop_indent.empty()) {
    for_loop_indent.pop();
  // If there are open scope blocks here,
  // e.g. for loops at the end of the function body.
  while (!scope_block_indent.empty()) {
    scope_block_indent.pop();
    ss << "}\n";
  }
}
+75 −7
Original line number Diff line number Diff line
@@ -19,12 +19,17 @@ class pyxasm_visitor : public pyxasmBaseVisitor {
  std::shared_ptr<xacc::IRProvider> provider;
  // List of buffers in the *context* of this XASM visitor
  std::vector<std::string> bufferNames;
  // List of *declared* variables
  std::vector<std::string> declared_var_names;

public:
  pyxasm_visitor(const std::vector<std::string> &buffers = {})
      : provider(xacc::getIRProvider("quantum")), bufferNames(buffers) {}
  pyxasm_visitor(const std::vector<std::string> &buffers = {},
                 const std::vector<std::string> &local_var_names = {})
      : provider(xacc::getIRProvider("quantum")), bufferNames(buffers),
        declared_var_names(local_var_names) {}
  pyxasm_result_type result;

  // New var declared (auto type) after visiting this node.
  std::string new_var;
  bool in_for_loop = false;

  antlrcpp::Any visitAtom_expr(
@@ -175,7 +180,7 @@ class pyxasm_visitor : public pyxasmBaseVisitor {
        // reassemble the call:
        // Check that the *first* argument is a *qreg* in the current context of
        // *this* kernel.
        if (!context->trailer().empty() &&
        if (!context->trailer().empty() && context->trailer()[0]->arglist() &&
            !context->trailer()[0]->arglist()->argument().empty() &&
            xacc::container::contains(
                bufferNames,
@@ -195,6 +200,16 @@ class pyxasm_visitor : public pyxasmBaseVisitor {
          ss << ");\n";
          result.first = ss.str();
        }
        else {
          if (!context->trailer().empty()) {
            // A classical call-like expression: i.e. not a kernel call:
            // Just output it *as-is* to the C++ stream.
            // We can hook more sophisticated code-gen here if required.
            std::stringstream ss;
            ss << context->getText() << ";\n";
            result.first = ss.str();
          }
        }
      }
    }
    return 0;
@@ -220,8 +235,17 @@ class pyxasm_visitor : public pyxasmBaseVisitor {
      // Handle simple assignment: a = expr
      std::stringstream ss;
      const std::string lhs = ctx->testlist_star_expr(0)->getText();
      const std::string rhs = ctx->testlist_star_expr(1)->getText();
      const std::string rhs = replacePythonConstants(
          replaceMeasureAssignment(ctx->testlist_star_expr(1)->getText()));
      
      if (xacc::container::contains(declared_var_names, lhs)) {
        ss << lhs << " = " << rhs << "; \n";
      } else {
        // New variable: need to add *auto*
        ss << "auto " << lhs << " = " << rhs << "; \n";
        new_var = lhs;
      }
      
      result.first = ss.str();
      if (rhs.find("**") != std::string::npos) {
        // keep processing
@@ -259,6 +283,21 @@ class pyxasm_visitor : public pyxasmBaseVisitor {
    return visitChildren(context);
  }

  virtual antlrcpp::Any
  visitIf_stmt(pyxasmParser::If_stmtContext *ctx) override {
    // Only support single clause atm
    if (ctx->test().size() == 1) {
      std::stringstream ss;
      ss << "if ("
         << replacePythonConstants(
                replaceMeasureAssignment(ctx->test(0)->getText()))
         << ") {\n";
      result.first = ss.str();
      return 0;
    }
    return visitChildren(ctx);
  }

 private:
  // Replaces common Python constants, e.g. 'math.pi' or 'numpy.pi'.
  // Note: the library names have been resolved to their original names.
@@ -275,4 +314,33 @@ class pyxasm_visitor : public pyxasmBaseVisitor {
    }
    return newSrc;
  }

  // Assignment of Measure results -> variable or in if conditional statements
  std::string replaceMeasureAssignment(const std::string &in_expr) const {
    if (in_expr.find("Measure") != std::string::npos) {
      // Found measure in an if statement instruction.
      const auto replaceMeasureInst = [](std::string &s,
                                         const std::string &search,
                                         const std::string &replace) {
        for (size_t pos = 0;; pos += replace.length()) {
          pos = s.find(search, pos);
          if (pos == std::string::npos) {
            break;
          }
          if (!isspace(s[pos + search.length()]) &&
              (s[pos + search.length()] != '(')) {
            continue;
          }
          s.erase(pos, search.length());
          s.insert(pos, replace);
        }
      };

      std::string result = in_expr;
      replaceMeasureInst(result, "Measure", "quantum::mz");
      return result;
    } else {
      return in_expr;
    }
  }
};
 No newline at end of file
+33 −0
Original line number Diff line number Diff line
@@ -40,6 +40,39 @@ quantum::mz(qb[i]);
            ss.str());
}

TEST(PyXASMTokenCollectorTester, checkIf) {
  LexerHelper helper;

  auto [tokens, PP] = helper.Lex(R"(
    H(qb[0])
    CX(qb[0],qb[1])
    for i in range(qb.size()):
      if Measure(qb[i]):
        X(qb[i])
)");

  clang::CachedTokens cached;
  for (auto &t : tokens) {
    cached.push_back(t);
  }

  std::stringstream ss;
  auto xasm_tc = xacc::getService<qcor::TokenCollector>("pyxasm");
  xasm_tc->collect(*PP.get(), cached, {"qb"}, ss);
  std::cout << "heres the test\n";
  std::cout << ss.str() << "\n";
  const std::string expectedCodeGen =
      R"#(quantum::h(qb[0]);
quantum::cnot(qb[0], qb[1]);
for (auto &i : range(qb.size())) {
if (quantum::mz(qb[i])) {
quantum::x(qb[i]);
}
}
)#";
  EXPECT_EQ(expectedCodeGen, ss.str());
}

int main(int argc, char **argv) {
  xacc::Initialize();
  ::testing::InitGoogleTest(&argc, argv);
+5 −1
Original line number Diff line number Diff line
@@ -16,8 +16,12 @@
#include "qcor_config.hpp"

namespace qcor {
void append_kernel(const std::string name) {
void append_kernel(const std::string name,
                   const std::vector<std::string> &program_arg_types,
                   const std::vector<std::string> &program_parameters) {
  ::quantum::kernels_in_translation_unit.push_back(name);
  ::quantum::kernel_signatures_in_translation_unit[name] =
      std::make_pair(program_arg_types, program_parameters);
}

void set_verbose(bool verbose) { xacc::set_verbose(verbose); }
Loading