Commit 5c4651f2 authored by Nguyen, Thien Minh's avatar Nguyen, Thien Minh
Browse files

First pass to support nested JIT kernels



Kernel dependency is injected to the generated source before JIT compilation.

Signed-off-by: default avatarThien Nguyen <nguyentm@ornl.gov>
parent 957d90b7
Loading
Loading
Loading
Loading
+20 −0
Original line number Diff line number Diff line
@@ -117,6 +117,26 @@ public:
                        "' is not currently supported.");
          }
        }
      } else {
        // This *callable* is not an intrinsic instruction, just reassemble the
        // call:
        // TODO: validate that this is a *previously-defined* kernel (i.e. not a
        // classical function call)
        if (!context->trailer().empty()) {
          std::stringstream ss;
          ss << inst_name << "(parent_kernel, ";
          // TODO: We potentially need to handle *inline* expressions in the
          // function call.
          const auto &argList = context->trailer()[0]->arglist()->argument();
          for (size_t i = 0; i < argList.size(); ++i) {
            ss << argList[i]->getText();
            if (i != argList.size() - 1) {
              ss << ", ";
            }
          }
          ss << ");\n";
          result.first = ss.str();
        }
      }
    }
    return 0;
+37 −0
Original line number Diff line number Diff line
# Run this from the command line like this
#
# python3 multiple_kernels.py -shots 100

from qcor import qjit, qalloc, qreg

# To create QCOR quantum kernels in Python one 
# simply creates a Python function, writes Pythonic, 
# XASM-like quantum code, and annotates the kernel 
# to indicate it is meant for QCOR just in time compilation

# NOTE Programmers must type annotate their function arguments

@qjit
def measure_all_qubits(q : qreg):
    for i in range(q.size()):
        Measure(q[i])

# Define a Bell kernel
@qjit
def bell_test(q : qreg):
    H(q[0])
    CX(q[0], q[1])
    # Call other kernels
    measure_all_qubits(q)

# Allocate 2 qubits
q = qalloc(2)

# Inspect the IR
comp = bell_test.extract_composite(q)
print(comp.toString())

# Run the bell experiment
bell_test(q)
# Print the results
q.print()
 No newline at end of file
+3 −2
Original line number Diff line number Diff line
@@ -287,9 +287,10 @@ PYBIND11_MODULE(_pyqcor, m) {
      .def("jit_compile", &qcor::QJIT::jit_compile, "")
      .def(
          "internal_python_jit_compile",
          [](qcor::QJIT &qjit, const std::string src) {
          [](qcor::QJIT &qjit, const std::string src,
             const std::vector<std::string> &dependency = {}) {
            bool turn_on_hetmap_kernel_ctor = true;
            qjit.jit_compile(src, turn_on_hetmap_kernel_ctor);
            qjit.jit_compile(src, turn_on_hetmap_kernel_ctor, dependency);
          },
          "")
      .def("run_syntax_handler", &qcor::QJIT::run_syntax_handler, "")
+15 −2
Original line number Diff line number Diff line
@@ -4,6 +4,7 @@ import sys
import inspect
from typing import List
import typing
import re

List = typing.List

@@ -124,12 +125,24 @@ class qjit(object):
        self.src = '__qpu__ void '+self.function.__name__ + \
            '('+cpp_arg_str+') {\nusing qcor::pyxasm;\n' + globalDeclStr + '\n' + fbody_src +"}\n"
        
        # Handle nested kernels:
        dependency = []
        for kernelName in self.__compiled__kernels:
            kernelCall = kernelName + '('
            # Check that this kernel *calls* a previously-compiled kernel:
            # pattern: "<white space> kernel("
            if re.search(r"\b" + re.escape(kernelCall), self.src):
                dependency.append(kernelName)
        
        # Run the QJIT compile step to store function pointers internally
        self._qjit.internal_python_jit_compile(self.src)
        self._qjit.internal_python_jit_compile(self.src, dependency)
        self._qjit.write_cache()

        self.__compiled__kernels.append(self.function.__name__)
        return

    # Static list of all kernels compiled
    __compiled__kernels = []

    def get_internal_src(self):
        """Return the C++ / embedded python DSL function code that will be passed to QJIT
        and the clang syntax handler. This function is primarily to be used for developer purposes. """
+2 −1
Original line number Diff line number Diff line
@@ -45,7 +45,8 @@ class QJIT {
      const std::string &quantum_kernel_src,
      const bool add_het_map_kernel_ctor = false);
  void jit_compile(const std::string &quantum_kernel_src,
                   const bool add_het_map_kernel_ctor = false);
                   const bool add_het_map_kernel_ctor = false,
                   const std::vector<std::string> &kernel_dependency = {});

  void write_cache();
  
Loading