Commit eaf85ba0 authored by Nguyen, Thien Minh's avatar Nguyen, Thien Minh
Browse files

Handle JIT reverse dependency due to kernel callable



Add a mechanism to inject the dependency. This is suboptimal but worked.

Signed-off-by: default avatarThien Nguyen <nguyentm@ornl.gov>
parent ca7d0916
Loading
Loading
Loading
Loading
+27 −2
Original line number Diff line number Diff line
@@ -103,7 +103,7 @@ class KernelGraph(object):
        self.kernel_idx_dep_map = {}
        self.kernel_name_list = []

    def addKernelDependency(self, kernelName, depList):
    def createKernelDependency(self, kernelName, depList):
        self.kernel_name_list.append(kernelName)
        self.kernel_idx_dep_map[self.V] = []
        for dep_ker_name in depList:
@@ -111,6 +111,10 @@ class KernelGraph(object):
                self.kernel_name_list.index(dep_ker_name))
        self.V += 1

    def addKernelDependency(self, kernelName, newDep):
        self.kernel_idx_dep_map[self.kernel_name_list.index(kernelName)].append(
                self.kernel_name_list.index(newDep))

    def addEdge(self, u, v):
        self.graph[u].append(v)

@@ -485,7 +489,7 @@ class qjit(object):
            if re.search(r"\b" + re.escape(kernelCall) + '|' + re.escape(kernelAdjCall) + '|' + re.escape(kernelCtrlCall), self.src):
                dependency.append(kernelName)

        self.__kernels__graph.addKernelDependency(
        self.__kernels__graph.createKernelDependency(
            self.function.__name__, dependency)
        self.sorted_kernel_dep = self.__kernels__graph.getSortedDependency(
            self.function.__name__)
@@ -660,6 +664,27 @@ class qjit(object):
                    exit(1)
                
                callable_qjit = args_dict[arg_name]
                
                # Handle runtime dependency:
                # The QJIT arg. was not *known* until invocation,
                # hence, we recompile the this jit kernel taking into account 
                # the KernelSignature argument.
                # TODO: perhaps an optimization that we can make is to
                # skip *eager* compilation for those kernels that have 
                # KernelSignature arguments.
                if callable_qjit.kernel_name() not in self.sorted_kernel_dep:
                    # print('New kernel:', callable_qjit.kernel_name())
                    # IMPORTANT: we cannot release a QJIT object till shut-down.
                    QJIT_OBJ_CACHE.append(self._qjit) 
                    # Create a new QJIT
                    self._qjit = QJIT()
                    # Add a kernel dependency
                    self.__kernels__graph.addKernelDependency(self.function.__name__, callable_qjit.kernel_name())
                    self.sorted_kernel_dep = self.__kernels__graph.getSortedDependency(self.function.__name__)
                    # Recompile:
                    self._qjit.internal_python_jit_compile(self.src, self.sorted_kernel_dep, self.extra_cpp_code, extra_headers)
                
                # This should always be successful.
                fn_ptr = self._qjit.get_kernel_function_ptr(callable_qjit.kernel_name())
                if fn_ptr == 0:
                    print('Failed to retrieve JIT-compiled function pointer for qjit kernel {}.'.format(callable_qjit.kernel_name()))
+24 −16
Original line number Diff line number Diff line
@@ -5,30 +5,38 @@ import unittest
from qcor import *

class TestKernelJIT(unittest.TestCase):
    def test_grover(self):
    def test_kernel_signature(self):
        set_qpu('qpp', {'shots':1024})
        
        @qjit
        def test_kernel(q: qreg, call_var1: KernelSignature(qreg, int, float), call_var2: KernelSignature(qreg, int, float)):
            call_var1(q, 0, 1.0)
            call_var1(q, 1, 2.0)
            call_var2(q, 0, 1.0)
            call_var2(q, 1, 2.0)

        # These kernels are unknown to test_kernel 
        @qjit
        def rx_kernel(q: qreg, idx: int, theta: float):
            Rx(q[idx], theta)

        @qjit
        def test_kernel(q: qreg, call_var: KernelSignature(qreg, int, float)):
            call_var(q, 0, 1.0)
            call_var(q, 1, 2.0)
            # TODO: currently, we don't have the ability to inject
            # new dependency, hence must use rx_kernel here to 
            # pull rx_kernel in.
            rx_kernel(q, 2, 3.0)
        def ry_kernel(q: qreg, idx: int, theta: float):
            Ry(q[idx], theta)

        q = qalloc(3)
        test_kernel(q, rx_kernel)
        comp = test_kernel.extract_composite(q, rx_kernel)
        q = qalloc(2)
        comp = test_kernel.extract_composite(q, rx_kernel, ry_kernel)
        print(comp)
        self.assertEqual(comp.nInstructions(), 3)   
        for i in range(3):
            self.assertEqual(comp.getInstruction(i).name(), "Rx") 
            self.assertAlmostEqual((float)(comp.getInstruction(i).getParameter(0)), i + 1.0)
        self.assertEqual(comp.nInstructions(), 4)   
        counter = 0
        for i in range(2):
            self.assertEqual(comp.getInstruction(counter).name(), "Rx") 
            self.assertAlmostEqual((float)(comp.getInstruction(counter).getParameter(0)), i + 1.0)
            counter+=1
        for i in range(2):
            self.assertEqual(comp.getInstruction(counter).name(), "Ry") 
            self.assertAlmostEqual((float)(comp.getInstruction(counter).getParameter(0)), i + 1.0)
            counter+=1

if __name__ == '__main__':
  unittest.main()
 No newline at end of file