Commit 759df083 authored by Mccaskey, Alex's avatar Mccaskey, Alex
Browse files

support List[Tuple[int,int]] in python kernel args

parent 747ff814
Loading
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -27,7 +27,7 @@ add_library(${LIBRARY_NAME} SHARED ${SRC})
target_compile_options(${LIBRARY_NAME} PRIVATE "-Wno-attributes")
target_include_directories(
  ${LIBRARY_NAME}
  PUBLIC . .. ${XACC_ROOT}/include/antlr4-runtime generated ${CLANG_INCLUDE_DIRS})
  PUBLIC . .. ${CMAKE_SOURCE_DIR}/runtime/utils ${XACC_ROOT}/include/antlr4-runtime generated ${CLANG_INCLUDE_DIRS})

target_link_libraries(${LIBRARY_NAME} PUBLIC ${CLANG_LIBS} ${LLVM_LIBS} ${ANTLR_LIB} qrt xacc::xacc)

+32 −14
Original line number Diff line number Diff line
@@ -5,6 +5,7 @@
#include "pyxasmBaseVisitor.h"
#include "qrt.hpp"
#include "xacc.hpp"
#include "qcor_utils.hpp"

using namespace pyxasm;

@@ -25,7 +26,8 @@ class pyxasm_visitor : public pyxasmBaseVisitor {
 public:
  pyxasm_visitor(const std::vector<std::string> &buffers = {},
                 const std::vector<std::string> &local_var_names = {})
      : provider(xacc::getIRProvider("quantum")), bufferNames(buffers),
      : provider(xacc::getIRProvider("quantum")),
        bufferNames(buffers),
        declared_var_names(local_var_names) {}
  pyxasm_result_type result;
  // New var declared (auto type) after visiting this node.
@@ -34,7 +36,6 @@ public:

  antlrcpp::Any visitAtom_expr(
      pyxasmParser::Atom_exprContext *context) override {
      
    // Handle kernel::ctrl(...), kernel::adjoint(...)
    if (!context->trailer().empty() &&
        (context->trailer()[0]->getText() == ".ctrl" ||
@@ -199,8 +200,7 @@ public:
          }
          ss << ");\n";
          result.first = ss.str();
        }
        else {
        } else {
          if (!context->trailer().empty()) {
            // A classical call-like expression: i.e. not a kernel call:
            // Just output it *as-is* to the C++ stream.
@@ -232,12 +232,29 @@ public:

  antlrcpp::Any visitExpr_stmt(pyxasmParser::Expr_stmtContext *ctx) override {
    if (ctx->ASSIGN().size() == 1 && ctx->testlist_star_expr().size() == 2) {

      // Handle simple assignment: a = expr
      std::stringstream ss;
      const std::string lhs = ctx->testlist_star_expr(0)->getText();
      const std::string rhs = replacePythonConstants(
          replaceMeasureAssignment(ctx->testlist_star_expr(1)->getText()));

      if (lhs.find(",") != std::string::npos) {
        // this is
        // var1, var2, ... = some_tuple_thing
        // We only support var1, var2 = ... for now
        // where ... is a pair-like object
        std::vector<std::string> suffix{".first", ".second"};
        auto vars = xacc::split(lhs, ',');
        for (auto [i, var] : qcor::enumerate(vars)) {
          if (xacc::container::contains(declared_var_names, var)) {
            ss << var << " = " << rhs << suffix[i] << ";\n";
          } else {
            ss << "auto " << var << " = " << rhs << suffix[i] << ";\n";
            new_var = lhs;
          }
        }
      } else {
        if (xacc::container::contains(declared_var_names, lhs)) {
          ss << lhs << " = " << rhs << "; \n";
        } else {
@@ -245,6 +262,7 @@ public:
          ss << "auto " << lhs << " = " << rhs << "; \n";
          new_var = lhs;
        }
      }

      result.first = ss.str();
      if (rhs.find("**") != std::string::npos) {
@@ -283,8 +301,8 @@ public:
    return visitChildren(context);
  }

  virtual antlrcpp::Any
  visitIf_stmt(pyxasmParser::If_stmtContext *ctx) override {
  virtual antlrcpp::Any visitIf_stmt(
      pyxasmParser::If_stmtContext *ctx) override {
    // Only support single clause atm
    if (ctx->test().size() == 1) {
      std::stringstream ss;
+1 −1
Original line number Diff line number Diff line
@@ -54,7 +54,7 @@ namespace {
// Here we enumerate them as a Variant
using AllowedKernelArgTypes =
    xacc::Variant<bool, int, double, std::string, xacc::internal_compiler::qreg,
                  std::vector<double>, std::vector<int>, qcor::PauliOperator>;
                  std::vector<double>, std::vector<int>, qcor::PauliOperator, qcor::PairList<int>>;

// We will take as input a mapping of arg variable names to the argument itself.
using KernelArgDict = std::map<std::string, AllowedKernelArgTypes>;
+4 −1
Original line number Diff line number Diff line
@@ -14,6 +14,8 @@ import itertools
from collections import defaultdict

List = typing.List
Tuple = typing.Tuple

PauliOperator = xacc.quantum.PauliOperator
FLOAT_REF = typing.NewType('value', float)
INT_REF = typing.NewType('value', int)
@@ -156,7 +158,8 @@ class qjit(object):
        self.allowed_type_cpp_map = {'<class \'_pyqcor.qreg\'>': 'qreg',
                                     '<class \'float\'>': 'double', 'typing.List[float]': 'std::vector<double>',
                                     '<class \'int\'>': 'int', 'typing.List[int]': 'std::vector<int>',
                                     '<class \'_pyxacc.quantum.PauliOperator\'>': 'qcor::PauliOperator'}
                                     '<class \'_pyxacc.quantum.PauliOperator\'>': 'qcor::PauliOperator',
                                     'typing.List[typing.Tuple[int, int]]': 'PairList<int>'}
        self.__dict__.update(kwargs)

        # Create the qcor just in time engine
+2 −2
Original line number Diff line number Diff line
@@ -88,8 +88,8 @@ std::shared_ptr<Observable> createOperator(const std::string &name,

std::shared_ptr<Observable> operatorTransform(const std::string &type,
                                              qcor::Observable &op) {
  // return xacc::getService<xacc::ObservableTransform>(type)->transform(
  //     xacc::as_shared_ptr(*&op));
  return xacc::getService<xacc::ObservableTransform>(type)->transform(
      xacc::as_shared_ptr(&op));
}
std::shared_ptr<Observable> operatorTransform(const std::string &type,
                                              std::shared_ptr<Observable> op) {
Loading