Commit b934adc2 authored by Nguyen, Thien Minh's avatar Nguyen, Thien Minh
Browse files

Handle more complex kernel nesting



Flatten and sort the dependency + add a test.

Signed-off-by: default avatarThien Nguyen <nguyentm@ornl.gov>
parent ae1a6c2b
Loading
Loading
Loading
Loading
+71 −2
Original line number Diff line number Diff line
@@ -5,6 +5,7 @@ import inspect
from typing import List
import typing
import re
from collections import defaultdict 

List = typing.List

@@ -21,6 +22,70 @@ def Z(idx):
    return xacc.quantum.PauliOperator({idx: 'Z'}, 1.0)

  
# Simple graph class to help resolve kernel dependency (via topological sort)
class KernelGraph(object): 
    def __init__(self): 
        self.graph = defaultdict(list) 
        self.V = 0   
        self.kernel_idx_dep_map = {}
        self.kernel_name_list = []

    def addKernelDependency(self, kernelName, depList):
        self.kernel_name_list.append(kernelName)
        self.kernel_idx_dep_map[self.V] = []
        for dep_ker_name in depList:
            self.kernel_idx_dep_map[self.V].append(self.kernel_name_list.index(dep_ker_name))
        self.V += 1
    
    def addEdge(self, u, v): 
        self.graph[u].append(v) 
  
    # Topological Sort.  
    def topologicalSort(self): 
        self.graph = defaultdict(list) 
        for sub_ker_idx in self.kernel_idx_dep_map:
            for dep_sub_idx in self.kernel_idx_dep_map[sub_ker_idx]:
               self.addEdge(dep_sub_idx, sub_ker_idx)
        
        in_degree = [0]*(self.V) 
        for i in self.graph: 
            for j in self.graph[i]: 
                in_degree[j] += 1
        
        queue = [] 
        for i in range(self.V): 
            if in_degree[i] == 0: 
                queue.append(i)   
        cnt = 0
        top_order = [] 
        while queue: 
            u = queue.pop(0) 
            top_order.append(u) 
            for i in self.graph[u]: 
                in_degree[i] -= 1
                if in_degree[i] == 0: 
                    queue.append(i) 
            cnt += 1
        
        sortedDep = []
        for sorted_dep_idx in top_order:
            sortedDep.append(self.kernel_name_list[sorted_dep_idx]) 
        return sortedDep

    def getSortedDependency(self, kernelName):
        kernel_idx = self.kernel_name_list.index(kernelName)
        # No dependency
        if len(self.kernel_idx_dep_map[kernel_idx]) == 0:
            return []
        
        sorted_dep = self.topologicalSort()
        result_dep = []
        for dep_name in sorted_dep:
            if dep_name == kernelName:
                return result_dep
            else:
                result_dep.append(dep_name)

class qjit(object):
    """
    The qjit class serves a python function decorator that enables 
@@ -134,14 +199,18 @@ class qjit(object):
            if re.search(r"\b" + re.escape(kernelCall), self.src):
                dependency.append(kernelName)
        
        self.__kernels__graph.addKernelDependency(self.function.__name__, dependency)
        sorted_kernel_dep = self.__kernels__graph.getSortedDependency(self.function.__name__)
        
        # Run the QJIT compile step to store function pointers internally
        self._qjit.internal_python_jit_compile(self.src, dependency)
        self._qjit.internal_python_jit_compile(self.src, sorted_kernel_dep)
        self._qjit.write_cache()
        self.__compiled__kernels.append(self.function.__name__)
        return

    # Static list of all kernels compiled
    __compiled__kernels = []
    __kernels__graph = KernelGraph() 
    
    def get_internal_src(self):
        """Return the C++ / embedded python DSL function code that will be passed to QJIT
+33 −0
Original line number Diff line number Diff line
@@ -165,5 +165,38 @@ class TestSimpleKernelJIT(unittest.TestCase):
        for i in range(10, 15):
            self.assertEqual(comp.getInstruction(i).name(), "Measure") 

    # Make sure that multi-level dependency can be resolved.
    def test_nested_kernels(self):
        @qjit
        def apply_cnot_fwd(q : qreg):
            for i in range(q.size() - 1):
                CX(q[i], q[i + 1])
        
        @qjit
        def make_bell(q : qreg):
            H(q[0])
            apply_cnot_fwd(q)

        @qjit
        def measure_all_bits(q : qreg):
            for i in range(q.size()):
                Measure(q[i])

        @qjit
        def bell_expr(q : qreg):
           # dep: apply_cnot_fwd -> make_bell -> bell_expr
           make_bell(q)
           measure_all_bits(q) 
        
        q = qalloc(5)
        comp = bell_expr.extract_composite(q)
        # 1 H, 4 CNOT, 5 Measure
        self.assertEqual(comp.nInstructions(), 1 + 4 + 5)   
        self.assertEqual(comp.getInstruction(0).name(), "H") 
        for i in range(1, 5):
            self.assertEqual(comp.getInstruction(i).name(), "CNOT") 
        for i in range(5, 10):
            self.assertEqual(comp.getInstruction(i).name(), "Measure") 

if __name__ == '__main__':
  unittest.main()
 No newline at end of file