Commit 498d73a4 authored by Mccaskey, Alex's avatar Mccaskey, Alex
Browse files

update python bindings to support qreg extraction and qubit kernel args

parent 40834d0a
Loading
Loading
Loading
Loading
+14 −9
Original line number Diff line number Diff line
@@ -108,7 +108,9 @@ class pyxasm_visitor : public pyxasmBaseVisitor {
                buffer_names.push_back(buffer_name);
                inst->setBitExpression(i, bit_idx_expr);
              } else {
                xacc::error("Must provide qreg[IDX] and not just qreg.");
                // Indicate this is a qubit(-1) or a qreg(-2)
                inst->setBitExpression(-1, bit_expr_str);
                buffer_names.push_back(bit_expr_str);
              }
            }
            inst->setBufferNames(buffer_names);
@@ -154,8 +156,11 @@ class pyxasm_visitor : public pyxasmBaseVisitor {
          // Note: these circuits (except exp_i_theta) don't have QRT
          // equivalents.
          // Condition: first argument is a qubit register
          else if (!context->trailer()[0]->arglist()->argument().empty() &&
                   xacc::container::contains(bufferNames, context->trailer()[0]
          else if (xacc::container::contains(
                       ::quantum::kernels_in_translation_unit, inst_name) ||
                   !context->trailer()[0]->arglist()->argument().empty() &&
                       xacc::container::contains(bufferNames,
                                                 context->trailer()[0]
                                                     ->arglist()
                                                     ->argument(0)
                                                     ->getText())) {
@@ -234,8 +239,7 @@ class pyxasm_visitor : public pyxasmBaseVisitor {
    }

    std::stringstream ss;
    ss << "for (auto " << counter_expr << " : " << iter_container
       << ") {\n";
    ss << "for (auto " << counter_expr << " : " << iter_container << ") {\n";
    result.first = ss.str();
    in_for_loop = true;
    return 0;
@@ -289,7 +293,8 @@ class pyxasm_visitor : public pyxasmBaseVisitor {
          ss << lhs << " = " << transformListAssignmentIfAny(rhs) << "; \n";
        } else {
          // New variable: need to add *auto*
          ss << "auto " << lhs << " = " << transformListAssignmentIfAny(rhs) << "; \n";
          ss << "auto " << lhs << " = " << transformListAssignmentIfAny(rhs)
             << "; \n";
          new_var = lhs;
        }
      }
+6 −0
Original line number Diff line number Diff line
@@ -508,10 +508,16 @@ PYBIND11_MODULE(_pyqcor, m) {
      py::arg("placement_name"), "Set the placement strategy.");

  m.def("qalloc", &::qalloc, py::return_value_policy::reference, "");
  py::class_<xacc::internal_compiler::qubit>(m, "qubit", "");
  py::class_<xacc::internal_compiler::qreg>(m, "qreg", "")
      .def("size", &xacc::internal_compiler::qreg::size, "")
      .def("print", &xacc::internal_compiler::qreg::print, "")
      .def("counts", &xacc::internal_compiler::qreg::counts, "")
      .def("extract_range", [](xacc::internal_compiler::qreg& q, std::size_t start, std::size_t end){
        std::vector<std::size_t> r{start, end};
        return q.extract_range(r);
      }, "")
      // .def("extract_qubits", &xacc::internal_compiler::qreg::extract_qubits, "")
      .def("exp_val_z", &xacc::internal_compiler::qreg::exp_val_z, "")
      .def("results", [](xacc::internal_compiler::qreg& q){
        auto buffer = q.results_shared();
+3 −0
Original line number Diff line number Diff line
@@ -34,6 +34,7 @@ FLOAT_REF = typing.NewType('value', float)
INT_REF = typing.NewType('value', int)

typing_to_simple_map = {'<class \'_pyqcor.qreg\'>': 'qreg',
                            '<class \'_pyqcor.qubit\'>': 'qubit',
                            '<class \'float\'>': 'float', 'typing.List[float]': 'List[float]',
                            '<class \'int\'>': 'int', 'typing.List[int]': 'List[int]',
                            '<class \'_pyxacc.quantum.PauliOperator\'>': 'PauliOperator',
@@ -192,6 +193,7 @@ class qjit(object):
        self.kwargs = kwargs
        self.function = function
        self.allowed_type_cpp_map = {'<class \'_pyqcor.qreg\'>': 'qreg',
                                     '<class \'_pyqcor.qubit\'>': 'qubit',
                                     '<class \'float\'>': 'double', 'typing.List[float]': 'std::vector<double>',
                                     '<class \'int\'>': 'int', 'typing.List[int]': 'std::vector<int>',
                                     '<class \'_pyxacc.quantum.PauliOperator\'>': 'qcor::PauliOperator',
@@ -463,6 +465,7 @@ class qjit(object):
        self.sorted_kernel_dep = self.__kernels__graph.getSortedDependency(
            self.function.__name__)

        # print(self.src)
        # Run the QJIT compile step to store function pointers internally
        self._qjit.internal_python_jit_compile(
            self.src, self.sorted_kernel_dep, self.extra_cpp_code, extra_headers)
+2 −0
Original line number Diff line number Diff line
@@ -190,6 +190,8 @@ const std::pair<std::string, std::string> QJIT::run_syntax_handler(
    auto arg_var = split(arg, ' ');
    if (arg_var[0] == "qreg") {
      bufferNames.push_back(arg_var[1]);
    } else if (arg_var[0] == "qubit") {
      bufferNames.push_back(arg_var[1]);
    }
    arg_types.push_back(arg_var[0]);
    arg_vars.push_back(arg_var[1]);