Commit 6105e4ae authored by Mccaskey, Alex's avatar Mccaskey, Alex
Browse files

[WIP] work on ctrl / adjoint in python qjit kernels


Signed-off-by: Mccaskey, Alex's avatarAlex McCaskey <mccaskeyaj@ornl.gov>
parent 3c86792b
......@@ -29,6 +29,29 @@ class pyxasm_visitor : public pyxasmBaseVisitor {
antlrcpp::Any visitAtom_expr(
pyxasmParser::Atom_exprContext *context) override {
// Handle kernel::ctrl(...), kernel::adjoint(...)
if (!context->trailer().empty() && context->trailer()[0]->getText() == ".ctrl") {
std::cout << "HELLO: " << context->getText() << "\n";
std::cout << context->trailer()[0]->getText() << "\n";
std::cout << context->atom()->getText() << "\n";
std::cout << context->trailer()[1]->getText() << "\n";
std::cout << context->trailer()[1]->arglist() << "\n";
auto arg_list = context->trailer()[1]->arglist();
std::stringstream ss;
ss << context->atom()->getText() << "::ctrl(parent_kernel";
for (int i = 0; i < arg_list->argument().size(); i++) {
ss << ", " << arg_list->argument(i)->getText();
}
ss << ");\n";
std::cout << "HELLO SS: " << ss.str() << "\n";
result.first = ss.str();
return 0;
}
if (context->atom()->NAME() != nullptr) {
auto inst_name = context->atom()->NAME()->getText();
......
......@@ -5,11 +5,12 @@ import inspect
from typing import List
import typing
import re
from collections import defaultdict
from collections import defaultdict
List = typing.List
PauliOperator = xacc.quantum.PauliOperator
def X(idx):
return xacc.quantum.PauliOperator({idx: 'X'}, 1.0)
......@@ -21,12 +22,12 @@ def Y(idx):
def Z(idx):
return xacc.quantum.PauliOperator({idx: 'Z'}, 1.0)
# Simple graph class to help resolve kernel dependency (via topological sort)
class KernelGraph(object):
def __init__(self):
self.graph = defaultdict(list)
self.V = 0
class KernelGraph(object):
def __init__(self):
self.graph = defaultdict(list)
self.V = 0
self.kernel_idx_dep_map = {}
self.kernel_name_list = []
......@@ -34,42 +35,43 @@ class KernelGraph(object):
self.kernel_name_list.append(kernelName)
self.kernel_idx_dep_map[self.V] = []
for dep_ker_name in depList:
self.kernel_idx_dep_map[self.V].append(self.kernel_name_list.index(dep_ker_name))
self.kernel_idx_dep_map[self.V].append(
self.kernel_name_list.index(dep_ker_name))
self.V += 1
def addEdge(self, u, v):
self.graph[u].append(v)
# Topological Sort.
def topologicalSort(self):
self.graph = defaultdict(list)
def addEdge(self, u, v):
self.graph[u].append(v)
# Topological Sort.
def topologicalSort(self):
self.graph = defaultdict(list)
for sub_ker_idx in self.kernel_idx_dep_map:
for dep_sub_idx in self.kernel_idx_dep_map[sub_ker_idx]:
self.addEdge(dep_sub_idx, sub_ker_idx)
in_degree = [0]*(self.V)
for i in self.graph:
for j in self.graph[i]:
self.addEdge(dep_sub_idx, sub_ker_idx)
in_degree = [0]*(self.V)
for i in self.graph:
for j in self.graph[i]:
in_degree[j] += 1
queue = []
for i in range(self.V):
if in_degree[i] == 0:
queue.append(i)
queue = []
for i in range(self.V):
if in_degree[i] == 0:
queue.append(i)
cnt = 0
top_order = []
while queue:
u = queue.pop(0)
top_order.append(u)
for i in self.graph[u]:
top_order = []
while queue:
u = queue.pop(0)
top_order.append(u)
for i in self.graph[u]:
in_degree[i] -= 1
if in_degree[i] == 0:
queue.append(i)
if in_degree[i] == 0:
queue.append(i)
cnt += 1
sortedDep = []
for sorted_dep_idx in top_order:
sortedDep.append(self.kernel_name_list[sorted_dep_idx])
sortedDep.append(self.kernel_name_list[sorted_dep_idx])
return sortedDep
def getSortedDependency(self, kernelName):
......@@ -77,7 +79,7 @@ class KernelGraph(object):
# No dependency
if len(self.kernel_idx_dep_map[kernel_idx]) == 0:
return []
sorted_dep = self.topologicalSort()
result_dep = []
for dep_name in sorted_dep:
......@@ -86,6 +88,7 @@ class KernelGraph(object):
else:
result_dep.append(dep_name)
class qjit(object):
"""
The qjit class serves a python function decorator that enables
......@@ -126,8 +129,8 @@ class qjit(object):
self.kwargs = kwargs
self.function = function
self.allowed_type_cpp_map = {'<class \'_pyqcor.qreg\'>': 'qreg',
'<class \'float\'>': 'double', 'typing.List[float]': 'std::vector<double>',
'<class \'int\'>': 'int',
'<class \'float\'>': 'double', 'typing.List[float]': 'std::vector<double>',
'<class \'int\'>': 'int',
'<class \'_pyxacc.quantum.PauliOperator\'>': 'qcor::PauliOperator'}
self.__dict__.update(kwargs)
......@@ -173,25 +176,27 @@ class qjit(object):
# Only support float atm
if (isinstance(globalVars[key], float)):
globalVarDecl.append(key + " = " + str(globalVars[key]))
# Inject these global declarations into the function body.
separator = "\n"
globalDeclStr = separator.join(globalVarDecl)
# Handle common modules like numpy or math
# e.g. if seeing `import numpy as np`, we'll have <'np' -> 'numpy'> in the importedModules dict.
# e.g. if seeing `import numpy as np`, we'll have <'np' -> 'numpy'> in the importedModules dict.
# We'll replace any module alias by its original name,
# i.e. 'np.pi' -> 'numpy.pi', etc.
for moduleAlias in importedModules:
if moduleAlias != importedModules[moduleAlias]:
aliasModuleStr = moduleAlias + '.'
originalModuleStr = importedModules[moduleAlias] + '.'
fbody_src = fbody_src.replace(aliasModuleStr, originalModuleStr)
fbody_src = fbody_src.replace(
aliasModuleStr, originalModuleStr)
# Create the qcor quantum kernel function src for QJIT and the Clang syntax handler
self.src = '__qpu__ void '+self.function.__name__ + \
'('+cpp_arg_str+') {\nusing qcor::pyxasm;\n' + globalDeclStr + '\n' + fbody_src +"}\n"
'('+cpp_arg_str+') {\nusing qcor::pyxasm;\n' + \
globalDeclStr + '\n' + fbody_src + "}\n"
# Handle nested kernels:
dependency = []
for kernelName in self.__compiled__kernels:
......@@ -200,10 +205,13 @@ class qjit(object):
# pattern: "<white space> kernel("
if re.search(r"\b" + re.escape(kernelCall), self.src):
dependency.append(kernelName)
self.__kernels__graph.addKernelDependency(self.function.__name__, dependency)
sorted_kernel_dep = self.__kernels__graph.getSortedDependency(self.function.__name__)
self.__kernels__graph.addKernelDependency(
self.function.__name__, dependency)
sorted_kernel_dep = self.__kernels__graph.getSortedDependency(
self.function.__name__)
# Run the QJIT compile step to store function pointers internally
self._qjit.internal_python_jit_compile(self.src, sorted_kernel_dep)
self._qjit.write_cache()
......@@ -212,8 +220,8 @@ class qjit(object):
# Static list of all kernels compiled
__compiled__kernels = []
__kernels__graph = KernelGraph()
__kernels__graph = KernelGraph()
def get_internal_src(self):
"""Return the C++ / embedded python DSL function code that will be passed to QJIT
and the clang syntax handler. This function is primarily to be used for developer purposes. """
......@@ -293,16 +301,16 @@ class qjit(object):
Print the QJIT kernel as a QASM-like string
"""
print(self.extract_composite(*args).toString())
def n_instructions(self, *args):
"""
Return the number of quantum instructions in this kernel.
"""
return self.extract_composite(*args).nInstructions()
# def ctrl(self, *args):
def ctrl(self, *args):
print('This is an internal API call and will be translated to C++ via the QJIT.\nIt can only be called from within another quantum kernel.')
exit(1)
def __call__(self, *args):
"""
......
......@@ -220,7 +220,6 @@ class TestSimpleKernelJIT(unittest.TestCase):
self.assertEqual(comp.getInstruction(i).name(), "Measure")
def test_iqft_kernel(self):
import numpy as np
@qjit
def iqft(q : qreg, startIdx : int, nbQubits : int):
for i in range(nbQubits/2):
......@@ -254,61 +253,49 @@ class TestSimpleKernelJIT(unittest.TestCase):
self.assertEqual(comp.getInstruction(i).name(), "CPhase")
self.assertEqual(comp.getInstruction(16).name(), "H")
# def test_ctrl_kernel(self):
# @qjit
# def qft(q : qreg, startIdx : int, nbQubits : int): # with swap
# for i in range(nbQubits - 1, -1, -1):
# shiftedBitIdx = i + startIdx
# H(q[shiftedBitIdx])
# for j in range(i-1, -1, -1):
# theta = np.pi / 2**(i-j)
# tIdx = j + i
# CPhase(q[shiftedBitIdx], q[tIdx], theta)
# swapCount = 0 if shouldSwap == 0 else 1
# for i in range(nbQubits/2):
# Swap(q[startIdx+i], q[startIdx+nbQubits-i-1])
def test_ctrl_kernel(self):
# @qjit
# def iqft(q : qreg, startIdx : int, nbQubits : int):
# for i in range(nbQubits/2):
# Swap(q[startIdx + i], q[startIdx + nbQubits - i - 1])
set_qpu('qpp', {'shots':1024})
@qjit
def iqft(q : qreg, startIdx : int, nbQubits : int):
for i in range(nbQubits/2):
Swap(q[startIdx + i], q[startIdx + nbQubits - i - 1])
# for i in range(nbQubits-1):
# H(q[startIdx+i])
# j = i +1
# for y in range(i, -1, -1):
# theta = -np.pi / 2**(j-y)
# CPhase(q[startIdx+j], q[startIdx + y], theta)
for i in range(nbQubits-1):
H(q[startIdx+i])
j = i +1
for y in range(i, -1, -1):
theta = -MY_PI / 2**(j-y)
CPhase(q[startIdx+j], q[startIdx + y], theta)
# H(q[startIdx+nbQubits-1])
H(q[startIdx+nbQubits-1])
# @qjit
# def oracle(q : qreg):
# bit = q.size()-1
# T(q[bit])
@qjit
def oracle(q : qreg):
bit = q.size()-1
T(q[bit])
# def qpe(q : qreg):
# nq = q.size()
@qjit
def qpe(q : qreg):
nq = q.size()
# for i in range(q.size()-1):
# H(q[i])
for i in range(q.size()-1):
H(q[i])
# bitPrecision = nq-1
# for i in range(bitPrecision):
# nbCalls = 1 << i
# for j in range(nbCalls):
# ctrl_bit = i
# oracle.ctrl(ctrl_bit, q)
bitPrecision = nq-1
for i in range(bitPrecision):
nbCalls = 1 << i
for j in range(nbCalls):
ctrl_bit = i
oracle.ctrl(ctrl_bit, q)
# iqft(q, 0, bitPrecision)
# for i in range(bitPrecision):
# Measure(q[i])
for i in range(bitPrecision):
Measure(q[i])
# q = qalloc(4)
# qpe(q)
# print(q.counts())
q = qalloc(4)
qpe(q)
print(q.counts())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment