Commit 6105e4ae authored by Mccaskey, Alex's avatar Mccaskey, Alex
Browse files

[WIP] work on ctrl / adjoint in python qjit kernels


Signed-off-by: Mccaskey, Alex's avatarAlex McCaskey <mccaskeyaj@ornl.gov>
parent 3c86792b
...@@ -29,6 +29,29 @@ class pyxasm_visitor : public pyxasmBaseVisitor { ...@@ -29,6 +29,29 @@ class pyxasm_visitor : public pyxasmBaseVisitor {
antlrcpp::Any visitAtom_expr( antlrcpp::Any visitAtom_expr(
pyxasmParser::Atom_exprContext *context) override { pyxasmParser::Atom_exprContext *context) override {
// Handle kernel::ctrl(...), kernel::adjoint(...)
if (!context->trailer().empty() && context->trailer()[0]->getText() == ".ctrl") {
std::cout << "HELLO: " << context->getText() << "\n";
std::cout << context->trailer()[0]->getText() << "\n";
std::cout << context->atom()->getText() << "\n";
std::cout << context->trailer()[1]->getText() << "\n";
std::cout << context->trailer()[1]->arglist() << "\n";
auto arg_list = context->trailer()[1]->arglist();
std::stringstream ss;
ss << context->atom()->getText() << "::ctrl(parent_kernel";
for (int i = 0; i < arg_list->argument().size(); i++) {
ss << ", " << arg_list->argument(i)->getText();
}
ss << ");\n";
std::cout << "HELLO SS: " << ss.str() << "\n";
result.first = ss.str();
return 0;
}
if (context->atom()->NAME() != nullptr) { if (context->atom()->NAME() != nullptr) {
auto inst_name = context->atom()->NAME()->getText(); auto inst_name = context->atom()->NAME()->getText();
......
...@@ -5,11 +5,12 @@ import inspect ...@@ -5,11 +5,12 @@ import inspect
from typing import List from typing import List
import typing import typing
import re import re
from collections import defaultdict from collections import defaultdict
List = typing.List List = typing.List
PauliOperator = xacc.quantum.PauliOperator PauliOperator = xacc.quantum.PauliOperator
def X(idx): def X(idx):
return xacc.quantum.PauliOperator({idx: 'X'}, 1.0) return xacc.quantum.PauliOperator({idx: 'X'}, 1.0)
...@@ -21,12 +22,12 @@ def Y(idx): ...@@ -21,12 +22,12 @@ def Y(idx):
def Z(idx): def Z(idx):
return xacc.quantum.PauliOperator({idx: 'Z'}, 1.0) return xacc.quantum.PauliOperator({idx: 'Z'}, 1.0)
# Simple graph class to help resolve kernel dependency (via topological sort) # Simple graph class to help resolve kernel dependency (via topological sort)
class KernelGraph(object): class KernelGraph(object):
def __init__(self): def __init__(self):
self.graph = defaultdict(list) self.graph = defaultdict(list)
self.V = 0 self.V = 0
self.kernel_idx_dep_map = {} self.kernel_idx_dep_map = {}
self.kernel_name_list = [] self.kernel_name_list = []
...@@ -34,42 +35,43 @@ class KernelGraph(object): ...@@ -34,42 +35,43 @@ class KernelGraph(object):
self.kernel_name_list.append(kernelName) self.kernel_name_list.append(kernelName)
self.kernel_idx_dep_map[self.V] = [] self.kernel_idx_dep_map[self.V] = []
for dep_ker_name in depList: for dep_ker_name in depList:
self.kernel_idx_dep_map[self.V].append(self.kernel_name_list.index(dep_ker_name)) self.kernel_idx_dep_map[self.V].append(
self.kernel_name_list.index(dep_ker_name))
self.V += 1 self.V += 1
def addEdge(self, u, v): def addEdge(self, u, v):
self.graph[u].append(v) self.graph[u].append(v)
# Topological Sort. # Topological Sort.
def topologicalSort(self): def topologicalSort(self):
self.graph = defaultdict(list) self.graph = defaultdict(list)
for sub_ker_idx in self.kernel_idx_dep_map: for sub_ker_idx in self.kernel_idx_dep_map:
for dep_sub_idx in self.kernel_idx_dep_map[sub_ker_idx]: for dep_sub_idx in self.kernel_idx_dep_map[sub_ker_idx]:
self.addEdge(dep_sub_idx, sub_ker_idx) self.addEdge(dep_sub_idx, sub_ker_idx)
in_degree = [0]*(self.V) in_degree = [0]*(self.V)
for i in self.graph: for i in self.graph:
for j in self.graph[i]: for j in self.graph[i]:
in_degree[j] += 1 in_degree[j] += 1
queue = [] queue = []
for i in range(self.V): for i in range(self.V):
if in_degree[i] == 0: if in_degree[i] == 0:
queue.append(i) queue.append(i)
cnt = 0 cnt = 0
top_order = [] top_order = []
while queue: while queue:
u = queue.pop(0) u = queue.pop(0)
top_order.append(u) top_order.append(u)
for i in self.graph[u]: for i in self.graph[u]:
in_degree[i] -= 1 in_degree[i] -= 1
if in_degree[i] == 0: if in_degree[i] == 0:
queue.append(i) queue.append(i)
cnt += 1 cnt += 1
sortedDep = [] sortedDep = []
for sorted_dep_idx in top_order: for sorted_dep_idx in top_order:
sortedDep.append(self.kernel_name_list[sorted_dep_idx]) sortedDep.append(self.kernel_name_list[sorted_dep_idx])
return sortedDep return sortedDep
def getSortedDependency(self, kernelName): def getSortedDependency(self, kernelName):
...@@ -77,7 +79,7 @@ class KernelGraph(object): ...@@ -77,7 +79,7 @@ class KernelGraph(object):
# No dependency # No dependency
if len(self.kernel_idx_dep_map[kernel_idx]) == 0: if len(self.kernel_idx_dep_map[kernel_idx]) == 0:
return [] return []
sorted_dep = self.topologicalSort() sorted_dep = self.topologicalSort()
result_dep = [] result_dep = []
for dep_name in sorted_dep: for dep_name in sorted_dep:
...@@ -86,6 +88,7 @@ class KernelGraph(object): ...@@ -86,6 +88,7 @@ class KernelGraph(object):
else: else:
result_dep.append(dep_name) result_dep.append(dep_name)
class qjit(object): class qjit(object):
""" """
The qjit class serves a python function decorator that enables The qjit class serves a python function decorator that enables
...@@ -126,8 +129,8 @@ class qjit(object): ...@@ -126,8 +129,8 @@ class qjit(object):
self.kwargs = kwargs self.kwargs = kwargs
self.function = function self.function = function
self.allowed_type_cpp_map = {'<class \'_pyqcor.qreg\'>': 'qreg', self.allowed_type_cpp_map = {'<class \'_pyqcor.qreg\'>': 'qreg',
'<class \'float\'>': 'double', 'typing.List[float]': 'std::vector<double>', '<class \'float\'>': 'double', 'typing.List[float]': 'std::vector<double>',
'<class \'int\'>': 'int', '<class \'int\'>': 'int',
'<class \'_pyxacc.quantum.PauliOperator\'>': 'qcor::PauliOperator'} '<class \'_pyxacc.quantum.PauliOperator\'>': 'qcor::PauliOperator'}
self.__dict__.update(kwargs) self.__dict__.update(kwargs)
...@@ -173,25 +176,27 @@ class qjit(object): ...@@ -173,25 +176,27 @@ class qjit(object):
# Only support float atm # Only support float atm
if (isinstance(globalVars[key], float)): if (isinstance(globalVars[key], float)):
globalVarDecl.append(key + " = " + str(globalVars[key])) globalVarDecl.append(key + " = " + str(globalVars[key]))
# Inject these global declarations into the function body. # Inject these global declarations into the function body.
separator = "\n" separator = "\n"
globalDeclStr = separator.join(globalVarDecl) globalDeclStr = separator.join(globalVarDecl)
# Handle common modules like numpy or math # Handle common modules like numpy or math
# e.g. if seeing `import numpy as np`, we'll have <'np' -> 'numpy'> in the importedModules dict. # e.g. if seeing `import numpy as np`, we'll have <'np' -> 'numpy'> in the importedModules dict.
# We'll replace any module alias by its original name, # We'll replace any module alias by its original name,
# i.e. 'np.pi' -> 'numpy.pi', etc. # i.e. 'np.pi' -> 'numpy.pi', etc.
for moduleAlias in importedModules: for moduleAlias in importedModules:
if moduleAlias != importedModules[moduleAlias]: if moduleAlias != importedModules[moduleAlias]:
aliasModuleStr = moduleAlias + '.' aliasModuleStr = moduleAlias + '.'
originalModuleStr = importedModules[moduleAlias] + '.' originalModuleStr = importedModules[moduleAlias] + '.'
fbody_src = fbody_src.replace(aliasModuleStr, originalModuleStr) fbody_src = fbody_src.replace(
aliasModuleStr, originalModuleStr)
# Create the qcor quantum kernel function src for QJIT and the Clang syntax handler # Create the qcor quantum kernel function src for QJIT and the Clang syntax handler
self.src = '__qpu__ void '+self.function.__name__ + \ self.src = '__qpu__ void '+self.function.__name__ + \
'('+cpp_arg_str+') {\nusing qcor::pyxasm;\n' + globalDeclStr + '\n' + fbody_src +"}\n" '('+cpp_arg_str+') {\nusing qcor::pyxasm;\n' + \
globalDeclStr + '\n' + fbody_src + "}\n"
# Handle nested kernels: # Handle nested kernels:
dependency = [] dependency = []
for kernelName in self.__compiled__kernels: for kernelName in self.__compiled__kernels:
...@@ -200,10 +205,13 @@ class qjit(object): ...@@ -200,10 +205,13 @@ class qjit(object):
# pattern: "<white space> kernel(" # pattern: "<white space> kernel("
if re.search(r"\b" + re.escape(kernelCall), self.src): if re.search(r"\b" + re.escape(kernelCall), self.src):
dependency.append(kernelName) dependency.append(kernelName)
self.__kernels__graph.addKernelDependency(self.function.__name__, dependency)
sorted_kernel_dep = self.__kernels__graph.getSortedDependency(self.function.__name__) self.__kernels__graph.addKernelDependency(
self.function.__name__, dependency)
sorted_kernel_dep = self.__kernels__graph.getSortedDependency(
self.function.__name__)
# Run the QJIT compile step to store function pointers internally # Run the QJIT compile step to store function pointers internally
self._qjit.internal_python_jit_compile(self.src, sorted_kernel_dep) self._qjit.internal_python_jit_compile(self.src, sorted_kernel_dep)
self._qjit.write_cache() self._qjit.write_cache()
...@@ -212,8 +220,8 @@ class qjit(object): ...@@ -212,8 +220,8 @@ class qjit(object):
# Static list of all kernels compiled # Static list of all kernels compiled
__compiled__kernels = [] __compiled__kernels = []
__kernels__graph = KernelGraph() __kernels__graph = KernelGraph()
def get_internal_src(self): def get_internal_src(self):
"""Return the C++ / embedded python DSL function code that will be passed to QJIT """Return the C++ / embedded python DSL function code that will be passed to QJIT
and the clang syntax handler. This function is primarily to be used for developer purposes. """ and the clang syntax handler. This function is primarily to be used for developer purposes. """
...@@ -293,16 +301,16 @@ class qjit(object): ...@@ -293,16 +301,16 @@ class qjit(object):
Print the QJIT kernel as a QASM-like string Print the QJIT kernel as a QASM-like string
""" """
print(self.extract_composite(*args).toString()) print(self.extract_composite(*args).toString())
def n_instructions(self, *args): def n_instructions(self, *args):
""" """
Return the number of quantum instructions in this kernel. Return the number of quantum instructions in this kernel.
""" """
return self.extract_composite(*args).nInstructions() return self.extract_composite(*args).nInstructions()
# def ctrl(self, *args): def ctrl(self, *args):
print('This is an internal API call and will be translated to C++ via the QJIT.\nIt can only be called from within another quantum kernel.')
exit(1)
def __call__(self, *args): def __call__(self, *args):
""" """
......
...@@ -220,7 +220,6 @@ class TestSimpleKernelJIT(unittest.TestCase): ...@@ -220,7 +220,6 @@ class TestSimpleKernelJIT(unittest.TestCase):
self.assertEqual(comp.getInstruction(i).name(), "Measure") self.assertEqual(comp.getInstruction(i).name(), "Measure")
def test_iqft_kernel(self): def test_iqft_kernel(self):
import numpy as np
@qjit @qjit
def iqft(q : qreg, startIdx : int, nbQubits : int): def iqft(q : qreg, startIdx : int, nbQubits : int):
for i in range(nbQubits/2): for i in range(nbQubits/2):
...@@ -254,61 +253,49 @@ class TestSimpleKernelJIT(unittest.TestCase): ...@@ -254,61 +253,49 @@ class TestSimpleKernelJIT(unittest.TestCase):
self.assertEqual(comp.getInstruction(i).name(), "CPhase") self.assertEqual(comp.getInstruction(i).name(), "CPhase")
self.assertEqual(comp.getInstruction(16).name(), "H") self.assertEqual(comp.getInstruction(16).name(), "H")
# def test_ctrl_kernel(self): def test_ctrl_kernel(self):
# @qjit
# def qft(q : qreg, startIdx : int, nbQubits : int): # with swap
# for i in range(nbQubits - 1, -1, -1):
# shiftedBitIdx = i + startIdx
# H(q[shiftedBitIdx])
# for j in range(i-1, -1, -1):
# theta = np.pi / 2**(i-j)
# tIdx = j + i
# CPhase(q[shiftedBitIdx], q[tIdx], theta)
# swapCount = 0 if shouldSwap == 0 else 1
# for i in range(nbQubits/2):
# Swap(q[startIdx+i], q[startIdx+nbQubits-i-1])
# @qjit set_qpu('qpp', {'shots':1024})
# def iqft(q : qreg, startIdx : int, nbQubits : int):
# for i in range(nbQubits/2): @qjit
# Swap(q[startIdx + i], q[startIdx + nbQubits - i - 1]) def iqft(q : qreg, startIdx : int, nbQubits : int):
for i in range(nbQubits/2):
Swap(q[startIdx + i], q[startIdx + nbQubits - i - 1])
# for i in range(nbQubits-1): for i in range(nbQubits-1):
# H(q[startIdx+i]) H(q[startIdx+i])
# j = i +1 j = i +1
# for y in range(i, -1, -1): for y in range(i, -1, -1):
# theta = -np.pi / 2**(j-y) theta = -MY_PI / 2**(j-y)
# CPhase(q[startIdx+j], q[startIdx + y], theta) CPhase(q[startIdx+j], q[startIdx + y], theta)
# H(q[startIdx+nbQubits-1]) H(q[startIdx+nbQubits-1])
# @qjit @qjit
# def oracle(q : qreg): def oracle(q : qreg):
# bit = q.size()-1 bit = q.size()-1
# T(q[bit]) T(q[bit])
# def qpe(q : qreg): @qjit
# nq = q.size() def qpe(q : qreg):
nq = q.size()
# for i in range(q.size()-1): for i in range(q.size()-1):
# H(q[i]) H(q[i])
# bitPrecision = nq-1 bitPrecision = nq-1
# for i in range(bitPrecision): for i in range(bitPrecision):
# nbCalls = 1 << i nbCalls = 1 << i
# for j in range(nbCalls): for j in range(nbCalls):
# ctrl_bit = i ctrl_bit = i
# oracle.ctrl(ctrl_bit, q) oracle.ctrl(ctrl_bit, q)
# iqft(q, 0, bitPrecision) for i in range(bitPrecision):
# for i in range(bitPrecision): Measure(q[i])
# Measure(q[i])
# q = qalloc(4) q = qalloc(4)
# qpe(q) qpe(q)
# print(q.counts()) print(q.counts())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment