Unverified Commit 6bae7628 authored by Mccaskey, Alex's avatar Mccaskey, Alex Committed by GitHub
Browse files

Merge pull request #118 from tnguyen-ornl/tnguyen/pyxasm-update

Update PyXASM
parents 6de35cf2 eaf85ba0
Loading
Loading
Loading
Loading
Loading
+19 −0
Original line number Diff line number Diff line
@@ -399,6 +399,25 @@ void QCORSyntaxHandler::GetReplacement(
        // We just pass this copied var to the ctor
        // where it expects a reference type.
        arg_ctor_list.emplace_back(new_var_name);
      } else if (program_arg_types[i].rfind("KernelSignature", 0) == 0) {
        // This is a KernelSignature argument.
        // The one in HetMap is the function pointer represented as a hex string.
        const std::string new_var_name =
            "__temp_kernel_ptr_var__" + std::to_string(var_counter++);
        // Retrieve the function pointer from the HetMap
        // ref_type_copy_decl_ss << "std::cout << args.getString(\""
        //                       << program_parameters[i] << "\").c_str() << std::endl;\n";
        ref_type_copy_decl_ss << "void* " << new_var_name << " = "
                              << "(void *) strtoull(args.getString(\""
                              << program_parameters[i] << "\").c_str(), nullptr, 16);\n";
        // ref_type_copy_decl_ss << "std::cout << " << new_var_name << " << std::endl;\n";
        // Construct the KernelSignature
        const std::string kernel_signature_var_name =
            "__temp_kernel_signature_var__" + std::to_string(var_counter++);
        ref_type_copy_decl_ss << program_arg_types[i] << " "
                              << kernel_signature_var_name << "("
                              << new_var_name << ");\n";
        arg_ctor_list.emplace_back(kernel_signature_var_name);
      } else {
        // Otherwise, just unpack the arg inline in the ctor call.
        std::stringstream ss;
+194 −26
Original line number Diff line number Diff line
@@ -33,9 +33,136 @@ class pyxasm_visitor : public pyxasmBaseVisitor {
  // New var declared (auto type) after visiting this node.
  std::string new_var;
  bool in_for_loop = false;
  // Var to keep track of sub-node rewrite:
  // e.g., traverse down the AST recursively.
  std::stringstream sub_node_translation;
  bool is_processing_sub_expr = false;

  antlrcpp::Any visitAtom_expr(
      pyxasmParser::Atom_exprContext *context) override {
    // std::cout << "Atom_exprContext: " << context->getText() << "\n";
    // Strategy:
    // At the top level, we analyze the trailer to determine the 
    // list of function call arguments.
    // Then, traverse down the arg. node to see if there is a potential rewrite rules
    // e.g. for arrays (as testlist_comp nodes)
    // Otherwise, just get the argument text as is.
    /*
    atom_expr: (AWAIT)? atom trailer*;
    atom: ('(' (yield_expr|testlist_comp)? ')' |
       '[' (testlist_comp)? ']' |
       '{' (dictorsetmaker)? '}' |
       NAME | NUMBER | STRING+ | '...' | 'None' | 'True' | 'False');
    */
    // Only processes these for sub-expressesions, 
    // e.g. re-entries to this function
    if (is_processing_sub_expr) {
      if (context->atom() && context->atom()->OPEN_BRACK() &&
          context->atom()->CLOSE_BRACK() && context->atom()->testlist_comp()) {
        // Array type expression:
        // std::cout << "Array atom expression: "
        //           << context->atom()->testlist_comp()->getText() << "\n";
        // Use braces
        sub_node_translation << "{";
        bool firstElProcessed = false;
        for (auto &testNode : context->atom()->testlist_comp()->test()) {
          // std::cout << "Array elem: " << testNode->getText() << "\n";
          // Add comma if needed (there is a previous element)
          if (firstElProcessed) {
            sub_node_translation << ", ";
          }
          sub_node_translation << testNode->getText();
          firstElProcessed = true;
        }
        sub_node_translation << "}";
        return 0;
      }

      // We don't have a re-write rule for this one (py::dict)
      if (context->atom() && context->atom()->OPEN_BRACE() &&
          context->atom()->CLOSE_BRACE() && context->atom()->dictorsetmaker()) {
        // Dict:
        // std::cout << "Dict atom expression: "
        //           << context->atom()->dictorsetmaker()->getText() << "\n";
        // TODO:
        return 0;
      }

      if (context->atom() && !context->atom()->STRING().empty()) {
        // Strings:
        for (auto &strNode : context->atom()->STRING()) {
          std::string cppStrLiteral = strNode->getText();
          // Handle Python single-quotes
          if (cppStrLiteral.front() == '\'' && cppStrLiteral.back() == '\'') {
            cppStrLiteral.front() = '"';
            cppStrLiteral.back() = '"';
          }
          sub_node_translation << cppStrLiteral;
          // std::cout << "String expression: " << strNode->getText() << " --> "
          //           << cppStrLiteral << "\n";
        }
        return 0;
      }

      const auto isSliceOp =
          [](pyxasmParser::Atom_exprContext *atom_expr_context) -> bool {
        if (atom_expr_context->trailer().size() == 1) {
          auto subscriptlist = atom_expr_context->trailer(0)->subscriptlist();
          if (subscriptlist && subscriptlist->subscript().size() == 1) {
            auto subscript = subscriptlist->subscript(0);
            const auto nbTestTerms = subscript->test().size();
            // Multiple test terms (separated by ':')
            return (nbTestTerms > 1);
          }
        }

        return false;
      };

      // Handle slicing operations (multiple array subscriptions separated by
      // ':') on a qreg.
      if (context->atom() &&
          xacc::container::contains(bufferNames, context->atom()->getText()) &&
          isSliceOp(context)) {
        // std::cout << "Slice op: " << context->getText() << "\n";
        sub_node_translation << context->atom()->getText()
                             << ".extract_range({";
        auto subscripts =
            context->trailer(0)->subscriptlist()->subscript(0)->test();
        assert(subscripts.size() > 1);
        std::vector<std::string> subscriptTerms;
        for (auto &test : subscripts) {
          subscriptTerms.emplace_back(test->getText());
        }

        auto sliceOp =
            context->trailer(0)->subscriptlist()->subscript(0)->sliceop();
        if (sliceOp && sliceOp->test()) {
          subscriptTerms.emplace_back(sliceOp->test()->getText());
        }
        assert(subscriptTerms.size() == 2 || subscriptTerms.size() == 3);

        for (int i = 0; i < subscriptTerms.size(); ++i) {
          // Need to cast to prevent compiler errors,
          // e.g. when using q.size() which returns an int.
          sub_node_translation << "static_cast<size_t>(" << subscriptTerms[i]
                               << ")";
          if (i != subscriptTerms.size() - 1) {
            sub_node_translation << ", ";
          }
        }

        sub_node_translation << "})";

        // convert the slice op to initializer list:
        // std::cout << "Slice Convert: " << context->getText() << " --> "
        //           << sub_node_translation.str() << "\n";
        return 0;
      }

      return 0;
    }

    // Handle kernel::ctrl(...), kernel::adjoint(...)
    if (!context->trailer().empty() &&
        (context->trailer()[0]->getText() == ".ctrl" ||
@@ -54,7 +181,7 @@ class pyxasm_visitor : public pyxasmBaseVisitor {
      ss << context->atom()->getText() << "::" << methodName
         << "(parent_kernel";
      for (int i = 0; i < arg_list->argument().size(); i++) {
        ss << ", " << arg_list->argument(i)->getText();
        ss << ", " << rewriteFunctionArgument(*(arg_list->argument(i)));
      }
      ss << ");\n";

@@ -97,7 +224,7 @@ class pyxasm_visitor : public pyxasmBaseVisitor {
            std::vector<std::string> buffer_names;
            for (int i = 0; i < required_bits; i++) {
              auto bit_expr = context->trailer()[0]->arglist()->argument()[i];
              auto bit_expr_str = bit_expr->getText();
              auto bit_expr_str = rewriteFunctionArgument(*bit_expr);

              auto found_bracket = bit_expr_str.find_first_of("[");
              if (found_bracket != std::string::npos) {
@@ -210,8 +337,24 @@ class pyxasm_visitor : public pyxasmBaseVisitor {
            // A classical call-like expression: i.e. not a kernel call:
            // Just output it *as-is* to the C++ stream.
            // We can hook more sophisticated code-gen here if required.
            // std::cout << "Callable: " << context->getText() << "\n";
            std::stringstream ss;

            if (context->trailer()[0]->arglist() &&
                !context->trailer()[0]->arglist()->argument().empty()) {
              const auto &argList =
                  context->trailer()[0]->arglist()->argument();
              ss << inst_name << "(";
              for (size_t i = 0; i < argList.size(); ++i) {                
                ss << rewriteFunctionArgument(*(argList[i]));                
                if (i != argList.size() - 1) {
                  ss << ", ";
                }
              }
              ss << ");\n";
            } else {
              ss << context->getText() << ";\n";
            }
            result.first = ss.str();
          }
        }
@@ -250,7 +393,7 @@ class pyxasm_visitor : public pyxasmBaseVisitor {
      // Handle simple assignment: a = expr
      std::stringstream ss;
      const std::string lhs = ctx->testlist_star_expr(0)->getText();
      const std::string rhs = replacePythonConstants(
      std::string rhs = replacePythonConstants(
          replaceMeasureAssignment(ctx->testlist_star_expr(1)->getText()));

      if (lhs.find(",") != std::string::npos) {
@@ -269,32 +412,27 @@ class pyxasm_visitor : public pyxasmBaseVisitor {
          }
        }
      } else {
        // Handle Python list assignment:
        // i.e. lhs = [a, b, c] => { a, b, c }
        // NOTE: we only support simple lists, i.e. no nested.
        const auto transformListAssignmentIfAny =
            [](const std::string &in_expr) -> std::string {
          const auto whitespace = " ";
          // Trim leading and trailing spaces:
          const auto strBegin = in_expr.find_first_not_of(whitespace);
          const auto strEnd = in_expr.find_last_not_of(whitespace);
          const auto strRange = strEnd - strBegin + 1;
          const auto trim_expr = in_expr.substr(strBegin, strRange);

          if (trim_expr.front() == '[' && trim_expr.back() == ']') {
            return "{" + trim_expr.substr(1, trim_expr.size() - 2) + "}";
          }

          // Returns the original expression:
          return in_expr;
        };
        // Strategy: try to traverse the rhs to see if there is a possible rewrite;
        // Otherwise, use the text as is.
        is_processing_sub_expr = true;
        // clear the sub_node_translation  
        sub_node_translation.str(std::string());

        // visit arg sub-node:
        visitChildren(ctx->testlist_star_expr(1));

        // Check if there is a rewrite:
        if (!sub_node_translation.str().empty()) {
          // Update RHS
          rhs = replacePythonConstants(
              replaceMeasureAssignment(sub_node_translation.str()));
        }

        if (xacc::container::contains(declared_var_names, lhs)) {
          ss << lhs << " = " << transformListAssignmentIfAny(rhs) << "; \n";
          ss << lhs << " = " << rhs << "; \n";
        } else {
          // New variable: need to add *auto*
          ss << "auto " << lhs << " = " << transformListAssignmentIfAny(rhs)
             << "; \n";
          ss << "auto " << lhs << " = " << rhs << "; \n";
          new_var = lhs;
        }
      }
@@ -396,4 +534,34 @@ class pyxasm_visitor : public pyxasmBaseVisitor {
      return in_expr;
    }
  }

  // A helper to rewrite function argument by traversing the node to see
  // if there is a potential rewrite.
  // Use case: inline expressions
  // e.g. X(q[0:3])
  // slicing of the qreg 'q' then call the broadcast X op.
  // i.e., we need to rewrite the arg to q.extract_range({0, 3}).
  std::string
  rewriteFunctionArgument(pyxasmParser::ArgumentContext &in_argContext) {
    // Strategy: try to traverse the argument context to see if there is a
    // possible rewrite; i.e. it may be another atom_expression that we have a
    // handler for. Otherwise, use the text as is.
    // We need this flag to prevent parsing quantum instructions as sub-expressions.
    // e.g. QCOR operators (X, Y, Z) in an observable definition shouldn't be 
    // processed as instructions.
    is_processing_sub_expr = true;
    // clear the sub_node_translation
    sub_node_translation.str(std::string());

    // visit arg sub-node:
    visitChildren(&in_argContext);

    // Check if there is a rewrite:
    if (!sub_node_translation.str().empty()) {
      // Update RHS
      return sub_node_translation.str();
    }
    // Returns the string as is
    return in_argContext.getText();
  }
};
 No newline at end of file
+185 −0
Original line number Diff line number Diff line
@@ -75,6 +75,191 @@ quantum::x(qb[i]);
  EXPECT_EQ(expectedCodeGen, ss.str());
}

TEST(PyXASMTokenCollectorTester, checkPythonList) {
  LexerHelper helper;

  auto [tokens, PP] = helper.Lex(R"(
    # inline initializer list
    apply_X_at_idx.ctrl([q[1], q[2]], q[0])
    # array var assignement
    array_val = [q[1], q[2]]
)");

  clang::CachedTokens cached;
  for (auto &t : tokens) {
    cached.push_back(t);
  }

  std::stringstream ss;
  auto xasm_tc = xacc::getService<qcor::TokenCollector>("pyxasm");
  xasm_tc->collect(*PP.get(), cached, {"q"}, ss);
  std::cout << "heres the test\n";
  std::cout << ss.str() << "\n";
  const std::string expectedCodeGen =
      R"#(apply_X_at_idx::ctrl(parent_kernel, {q[1], q[2]}, q[0]);
auto array_val = {q[1], q[2]}; 
)#";
  EXPECT_EQ(expectedCodeGen, ss.str());
}

TEST(PyXASMTokenCollectorTester, checkStringLiteral) {
  LexerHelper helper;

  auto [tokens, PP] = helper.Lex(R"(
    # Cpp style strings
    print("hello", 1, "world")
    # Python style
    print('howdy', 1, 'abc')
)");

  clang::CachedTokens cached;
  for (auto &t : tokens) {
    cached.push_back(t);
  }

  std::stringstream ss;
  auto xasm_tc = xacc::getService<qcor::TokenCollector>("pyxasm");
  xasm_tc->collect(*PP.get(), cached, {}, ss);
  std::cout << "heres the test\n";
  std::cout << ss.str() << "\n";
  const std::string expectedCodeGen =
      R"#(print("hello", 1, "world");
print("howdy", 1, "abc");
)#";
  EXPECT_EQ(expectedCodeGen, ss.str());
}

TEST(PyXASMTokenCollectorTester, checkQregMethods) {
  LexerHelper helper;
  auto [tokens, PP] = helper.Lex(R"(
    ctrl_qubits = q.head(q.size()-1)
    last_qubit = q.tail()
    Z.ctrl(ctrl_qubits, last_qubit)
    
    # inline
    X.ctrl(q.head(q.size()-1), q.tail())

    # range:
    # API
    r = q.extract_range(0, bitPrecision)
    # Python style
    slice1 = q[0:3]
    # step size
    slice2 = q[0:5:2]
)");

  clang::CachedTokens cached;
  for (auto &t : tokens) {
    cached.push_back(t);
  }

  std::stringstream ss;
  auto xasm_tc = xacc::getService<qcor::TokenCollector>("pyxasm");
  xasm_tc->collect(*PP.get(), cached, {"q"}, ss);
  std::cout << "heres the test\n";
  std::cout << ss.str() << "\n";
  const std::string expectedCodeGen =
      R"#(auto ctrl_qubits = q.head(q.size()-1); 
auto last_qubit = q.tail(); 
Z::ctrl(parent_kernel, ctrl_qubits, last_qubit);
X::ctrl(parent_kernel, q.head(q.size()-1), q.tail());
auto r = q.extract_range(0,bitPrecision); 
auto slice1 = q.extract_range({static_cast<size_t>(0), static_cast<size_t>(3)}); 
auto slice2 = q.extract_range({static_cast<size_t>(0), static_cast<size_t>(5), static_cast<size_t>(2)}); 
)#";
  EXPECT_EQ(expectedCodeGen, ss.str());
}

TEST(PyXASMTokenCollectorTester, checkBroadCastWithSlice) {
  LexerHelper helper;
  auto [tokens, PP] = helper.Lex(R"(
    X(q.head(q.size()-1))
    X(q[0])
    X(q)
    X(q[0:2])
    X(q[0:5:2])
    Measure(q.head(q.size()-1))
    Measure(q[0])
    Measure(q)
    Measure(q[0:2])
    Measure(q[0:5:2])
)");

  clang::CachedTokens cached;
  for (auto &t : tokens) {
    cached.push_back(t);
  }

  std::stringstream ss;
  auto xasm_tc = xacc::getService<qcor::TokenCollector>("pyxasm");
  xasm_tc->collect(*PP.get(), cached, {"q"}, ss);
  std::cout << "heres the test\n";
  std::cout << ss.str() << "\n";
  const std::string expectedCodeGen =
      R"#(quantum::x(q.head(q.size()-1));
quantum::x(q[0]);
quantum::x(q);
quantum::x(q.extract_range({static_cast<size_t>(0), static_cast<size_t>(2)}));
quantum::x(q.extract_range({static_cast<size_t>(0), static_cast<size_t>(5), static_cast<size_t>(2)}));
quantum::mz(q.head(q.size()-1));
quantum::mz(q[0]);
quantum::mz(q);
quantum::mz(q.extract_range({static_cast<size_t>(0), static_cast<size_t>(2)}));
quantum::mz(q.extract_range({static_cast<size_t>(0), static_cast<size_t>(5), static_cast<size_t>(2)}));
)#";
  EXPECT_EQ(expectedCodeGen, ss.str());
}

TEST(PyXASMTokenCollectorTester, checkQcorOperators) {
  LexerHelper helper;
  auto [tokens, PP] = helper.Lex(R"(
    exponent_op = X(0) * Y(1) - Y(0) * X(1)
    exp_i_theta(q, theta, exponent_op)
)");

  clang::CachedTokens cached;
  for (auto &t : tokens) {
    cached.push_back(t);
  }

  std::stringstream ss;
  auto xasm_tc = xacc::getService<qcor::TokenCollector>("pyxasm");
  xasm_tc->collect(*PP.get(), cached, {"q"}, ss);
  std::cout << "heres the test\n";
  std::cout << ss.str() << "\n";
  const std::string expectedCodeGen =
      R"#(auto exponent_op = X(0)*Y(1)-Y(0)*X(1); 
quantum::exp(q, theta, exponent_op);
)#";
  EXPECT_EQ(expectedCodeGen, ss.str());
}

TEST(PyXASMTokenCollectorTester, checkCommonMath) {
  LexerHelper helper;
  auto [tokens, PP] = helper.Lex(R"(
    out_parity = oneCount - 2 * (oneCount / 2)
    # Power
    index = 2**n 
)");

  clang::CachedTokens cached;
  for (auto &t : tokens) {
    cached.push_back(t);
  }

  std::stringstream ss;
  auto xasm_tc = xacc::getService<qcor::TokenCollector>("pyxasm");
  xasm_tc->collect(*PP.get(), cached, {""}, ss);
  std::cout << "heres the test\n";
  std::cout << ss.str() << "\n";
  const std::string expectedCodeGen =
      R"#(auto out_parity = oneCount-2*(oneCount/2); 
auto index = std::pow(2, n); 
)#";
  EXPECT_EQ(expectedCodeGen, ss.str());
}


int main(int argc, char **argv) {
  std::string xacc_config_install_dir = std::string(XACC_INSTALL_DIR);
  std::string qcor_root = std::string(QCOR_INSTALL_DIR);
+7 −1
Original line number Diff line number Diff line
@@ -596,7 +596,13 @@ PYBIND11_MODULE(_pyqcor, m) {
               }
             }
             return visitor.getMat();
           });
           })
      .def(
          "get_kernel_function_ptr",
          [](qcor::QJIT &qjit, const std::string &kernel_name) {
            return qjit.get_kernel_function_ptr(kernel_name);
          },
          "");

  py::class_<qcor::ObjectiveFunction, std::shared_ptr<qcor::ObjectiveFunction>>(
      m, "ObjectiveFunction", "")
+82 −13
Original line number Diff line number Diff line
@@ -16,6 +16,14 @@ from collections import defaultdict
List = typing.List
Tuple = typing.Tuple
MethodType = types.MethodType
Callable = typing.Callable

# KernelSignature type annotation:
# Usage: annotate an function argument as a KernelSignature by:
# varName: KernelSignature(qreg, ...)
# Kernel always returns void (None)
def KernelSignature(*args):
    return Callable[list(args), None]

# Static cache of all Python QJIT objects that have been created.
# There seems to be a bug when a Python interpreter tried to create a new QJIT
@@ -95,7 +103,7 @@ class KernelGraph(object):
        self.kernel_idx_dep_map = {}
        self.kernel_name_list = []

    def addKernelDependency(self, kernelName, depList):
    def createKernelDependency(self, kernelName, depList):
        self.kernel_name_list.append(kernelName)
        self.kernel_idx_dep_map[self.V] = []
        for dep_ker_name in depList:
@@ -103,6 +111,10 @@ class KernelGraph(object):
                self.kernel_name_list.index(dep_ker_name))
        self.V += 1

    def addKernelDependency(self, kernelName, newDep):
        self.kernel_idx_dep_map[self.kernel_name_list.index(kernelName)].append(
                self.kernel_name_list.index(newDep))

    def addEdge(self, u, v):
        self.graph[u].append(v)

@@ -396,6 +408,23 @@ class qjit(object):
                cpp_arg_str += ',' + \
                    'int& ' + arg
                continue
            if str(_type).startswith('typing.Callable'):
                cpp_type_str = 'KernelSignature<'
                for i in range(len(_type.__args__) - 1):
                    # print("input type:", _type.__args__[i])
                    arg_type = _type.__args__[i]
                    if str(arg_type) not in self.allowed_type_cpp_map:
                        print('Error, this quantum kernel arg type is not allowed: ', str(_type))
                        exit(1)
                    cpp_type_str += self.allowed_type_cpp_map[str(arg_type)]
                    cpp_type_str += ','
                
                cpp_type_str = cpp_type_str[:-1]
                cpp_type_str += '>'
                # print("cpp type", cpp_type_str)
                cpp_arg_str += ',' + cpp_type_str + ' ' + arg
                continue

            if str(_type) not in self.allowed_type_cpp_map:
                print('Error, this quantum kernel arg type is not allowed: ', str(_type))
                exit(1)
@@ -460,7 +489,7 @@ class qjit(object):
            if re.search(r"\b" + re.escape(kernelCall) + '|' + re.escape(kernelAdjCall) + '|' + re.escape(kernelCtrlCall), self.src):
                dependency.append(kernelName)

        self.__kernels__graph.addKernelDependency(
        self.__kernels__graph.createKernelDependency(
            self.function.__name__, dependency)
        self.sorted_kernel_dep = self.__kernels__graph.getSortedDependency(
            self.function.__name__)
@@ -554,9 +583,7 @@ class qjit(object):
        """
        assert len(args) == len(self.arg_names), "Cannot create CompositeInstruction, you did not provided the correct kernel arguments."
        # Create a dictionary for the function arguments
        args_dict = {}
        for i, arg_name in enumerate(self.arg_names):
            args_dict[arg_name] = list(args)[i]
        args_dict = self.construct_arg_dict(*args)
        return self._qjit.extract_composite(self.function.__name__, args_dict)

    def observe(self, observable, *args):
@@ -590,9 +617,7 @@ class qjit(object):
        return self.extract_composite(*args).nInstructions()
    
    def as_unitary_matrix(self, *args):
        args_dict = {}
        for i, arg_name in enumerate(self.arg_names):
            args_dict[arg_name] = list(args)[i]
        args_dict = self.construct_arg_dict(*args)
        return self._qjit.internal_as_unitary(self.function.__name__, args_dict)
    
    def ctrl(self, *args):
@@ -622,16 +647,60 @@ class qjit(object):
    def qir(self, *args, **kwargs):
        return llvm_ir(*args, **kwargs)

    def __call__(self, *args):
        """
        Execute the decorated quantum kernel. This will directly 
        invoke the corresponding LLVM JITed function pointer. 
        """
    # Helper to construct the arg_dict (HetMap)
    # e.g. perform any additional type conversion if required.
    def construct_arg_dict(self, *args):
        # Create a dictionary for the function arguments
        args_dict = {}
        for i, arg_name in enumerate(self.arg_names):
            args_dict[arg_name] = list(args)[i]
            arg_type_str = str(self.type_annotations[arg_name])
            if arg_type_str.startswith('typing.Callable'):
                # print("callable:", arg_name)
                # print("arg:", type(args_dict[arg_name]))
                # the arg must be a qjit
                if not isinstance(args_dict[arg_name], qjit):
                    print('Invalid argument type for {}. A quantum kernel (qjit) is expected.'.format(arg_name))
                    exit(1)
                
                callable_qjit = args_dict[arg_name]
                
                # Handle runtime dependency:
                # The QJIT arg. was not *known* until invocation,
                # hence, we recompile the this jit kernel taking into account 
                # the KernelSignature argument.
                # TODO: perhaps an optimization that we can make is to
                # skip *eager* compilation for those kernels that have 
                # KernelSignature arguments.
                if callable_qjit.kernel_name() not in self.sorted_kernel_dep:
                    # print('New kernel:', callable_qjit.kernel_name())
                    # IMPORTANT: we cannot release a QJIT object till shut-down.
                    QJIT_OBJ_CACHE.append(self._qjit) 
                    # Create a new QJIT
                    self._qjit = QJIT()
                    # Add a kernel dependency
                    self.__kernels__graph.addKernelDependency(self.function.__name__, callable_qjit.kernel_name())
                    self.sorted_kernel_dep = self.__kernels__graph.getSortedDependency(self.function.__name__)
                    # Recompile:
                    self._qjit.internal_python_jit_compile(self.src, self.sorted_kernel_dep, self.extra_cpp_code, extra_headers)
                
                # This should always be successful.
                fn_ptr = self._qjit.get_kernel_function_ptr(callable_qjit.kernel_name())
                if fn_ptr == 0:
                    print('Failed to retrieve JIT-compiled function pointer for qjit kernel {}.'.format(callable_qjit.kernel_name()))
                    exit(1)
                # Replace the argument (in the dict) with the function pointer
                # qjit is a pure-Python object, hence cannot be used by native QCOR.
                args_dict[arg_name] = hex(fn_ptr)
            
        return args_dict

    def __call__(self, *args):
        """
        Execute the decorated quantum kernel. This will directly 
        invoke the corresponding LLVM JITed function pointer. 
        """
        args_dict = self.construct_arg_dict(*args)
        # Invoke the JITed function
        self._qjit.invoke(self.function.__name__, args_dict)

Loading