Commit b934adc2 authored by Nguyen, Thien Minh's avatar Nguyen, Thien Minh
Browse files

Handle more complex kernel nesting



Flatten and sort the dependency + add a test.

Signed-off-by: Nguyen, Thien Minh's avatarThien Nguyen <nguyentm@ornl.gov>
parent ae1a6c2b
......@@ -5,6 +5,7 @@ import inspect
from typing import List
import typing
import re
from collections import defaultdict
List = typing.List
......@@ -20,6 +21,70 @@ def Y(idx):
def Z(idx):
return xacc.quantum.PauliOperator({idx: 'Z'}, 1.0)
# Simple graph class to help resolve kernel dependency (via topological sort)
class KernelGraph(object):
def __init__(self):
self.graph = defaultdict(list)
self.V = 0
self.kernel_idx_dep_map = {}
self.kernel_name_list = []
def addKernelDependency(self, kernelName, depList):
self.kernel_name_list.append(kernelName)
self.kernel_idx_dep_map[self.V] = []
for dep_ker_name in depList:
self.kernel_idx_dep_map[self.V].append(self.kernel_name_list.index(dep_ker_name))
self.V += 1
def addEdge(self, u, v):
self.graph[u].append(v)
# Topological Sort.
def topologicalSort(self):
self.graph = defaultdict(list)
for sub_ker_idx in self.kernel_idx_dep_map:
for dep_sub_idx in self.kernel_idx_dep_map[sub_ker_idx]:
self.addEdge(dep_sub_idx, sub_ker_idx)
in_degree = [0]*(self.V)
for i in self.graph:
for j in self.graph[i]:
in_degree[j] += 1
queue = []
for i in range(self.V):
if in_degree[i] == 0:
queue.append(i)
cnt = 0
top_order = []
while queue:
u = queue.pop(0)
top_order.append(u)
for i in self.graph[u]:
in_degree[i] -= 1
if in_degree[i] == 0:
queue.append(i)
cnt += 1
sortedDep = []
for sorted_dep_idx in top_order:
sortedDep.append(self.kernel_name_list[sorted_dep_idx])
return sortedDep
def getSortedDependency(self, kernelName):
kernel_idx = self.kernel_name_list.index(kernelName)
# No dependency
if len(self.kernel_idx_dep_map[kernel_idx]) == 0:
return []
sorted_dep = self.topologicalSort()
result_dep = []
for dep_name in sorted_dep:
if dep_name == kernelName:
return result_dep
else:
result_dep.append(dep_name)
class qjit(object):
"""
......@@ -134,15 +199,19 @@ class qjit(object):
if re.search(r"\b" + re.escape(kernelCall), self.src):
dependency.append(kernelName)
self.__kernels__graph.addKernelDependency(self.function.__name__, dependency)
sorted_kernel_dep = self.__kernels__graph.getSortedDependency(self.function.__name__)
# Run the QJIT compile step to store function pointers internally
self._qjit.internal_python_jit_compile(self.src, dependency)
self._qjit.internal_python_jit_compile(self.src, sorted_kernel_dep)
self._qjit.write_cache()
self.__compiled__kernels.append(self.function.__name__)
return
# Static list of all kernels compiled
__compiled__kernels = []
__kernels__graph = KernelGraph()
def get_internal_src(self):
"""Return the C++ / embedded python DSL function code that will be passed to QJIT
and the clang syntax handler. This function is primarily to be used for developer purposes. """
......
......@@ -165,5 +165,38 @@ class TestSimpleKernelJIT(unittest.TestCase):
for i in range(10, 15):
self.assertEqual(comp.getInstruction(i).name(), "Measure")
# Make sure that multi-level dependency can be resolved.
def test_nested_kernels(self):
@qjit
def apply_cnot_fwd(q : qreg):
for i in range(q.size() - 1):
CX(q[i], q[i + 1])
@qjit
def make_bell(q : qreg):
H(q[0])
apply_cnot_fwd(q)
@qjit
def measure_all_bits(q : qreg):
for i in range(q.size()):
Measure(q[i])
@qjit
def bell_expr(q : qreg):
# dep: apply_cnot_fwd -> make_bell -> bell_expr
make_bell(q)
measure_all_bits(q)
q = qalloc(5)
comp = bell_expr.extract_composite(q)
# 1 H, 4 CNOT, 5 Measure
self.assertEqual(comp.nInstructions(), 1 + 4 + 5)
self.assertEqual(comp.getInstruction(0).name(), "H")
for i in range(1, 5):
self.assertEqual(comp.getInstruction(i).name(), "CNOT")
for i in range(5, 10):
self.assertEqual(comp.getInstruction(i).name(), "Measure")
if __name__ == '__main__':
unittest.main()
\ No newline at end of file
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment