Commit 6105e4ae authored by Mccaskey, Alex's avatar Mccaskey, Alex
Browse files

[WIP] work on ctrl / adjoint in python qjit kernels



Signed-off-by: Mccaskey, Alex's avatarAlex McCaskey <mccaskeyaj@ornl.gov>
parent 3c86792b
...@@ -29,6 +29,29 @@ class pyxasm_visitor : public pyxasmBaseVisitor { ...@@ -29,6 +29,29 @@ class pyxasm_visitor : public pyxasmBaseVisitor {
antlrcpp::Any visitAtom_expr( antlrcpp::Any visitAtom_expr(
pyxasmParser::Atom_exprContext *context) override { pyxasmParser::Atom_exprContext *context) override {
// Handle kernel::ctrl(...), kernel::adjoint(...)
if (!context->trailer().empty() && context->trailer()[0]->getText() == ".ctrl") {
std::cout << "HELLO: " << context->getText() << "\n";
std::cout << context->trailer()[0]->getText() << "\n";
std::cout << context->atom()->getText() << "\n";
std::cout << context->trailer()[1]->getText() << "\n";
std::cout << context->trailer()[1]->arglist() << "\n";
auto arg_list = context->trailer()[1]->arglist();
std::stringstream ss;
ss << context->atom()->getText() << "::ctrl(parent_kernel";
for (int i = 0; i < arg_list->argument().size(); i++) {
ss << ", " << arg_list->argument(i)->getText();
}
ss << ");\n";
std::cout << "HELLO SS: " << ss.str() << "\n";
result.first = ss.str();
return 0;
}
if (context->atom()->NAME() != nullptr) { if (context->atom()->NAME() != nullptr) {
auto inst_name = context->atom()->NAME()->getText(); auto inst_name = context->atom()->NAME()->getText();
......
...@@ -10,6 +10,7 @@ from collections import defaultdict ...@@ -10,6 +10,7 @@ from collections import defaultdict
List = typing.List List = typing.List
PauliOperator = xacc.quantum.PauliOperator PauliOperator = xacc.quantum.PauliOperator
def X(idx): def X(idx):
return xacc.quantum.PauliOperator({idx: 'X'}, 1.0) return xacc.quantum.PauliOperator({idx: 'X'}, 1.0)
...@@ -34,7 +35,8 @@ class KernelGraph(object): ...@@ -34,7 +35,8 @@ class KernelGraph(object):
self.kernel_name_list.append(kernelName) self.kernel_name_list.append(kernelName)
self.kernel_idx_dep_map[self.V] = [] self.kernel_idx_dep_map[self.V] = []
for dep_ker_name in depList: for dep_ker_name in depList:
self.kernel_idx_dep_map[self.V].append(self.kernel_name_list.index(dep_ker_name)) self.kernel_idx_dep_map[self.V].append(
self.kernel_name_list.index(dep_ker_name))
self.V += 1 self.V += 1
def addEdge(self, u, v): def addEdge(self, u, v):
...@@ -86,6 +88,7 @@ class KernelGraph(object): ...@@ -86,6 +88,7 @@ class KernelGraph(object):
else: else:
result_dep.append(dep_name) result_dep.append(dep_name)
class qjit(object): class qjit(object):
""" """
The qjit class serves a python function decorator that enables The qjit class serves a python function decorator that enables
...@@ -186,11 +189,13 @@ class qjit(object): ...@@ -186,11 +189,13 @@ class qjit(object):
if moduleAlias != importedModules[moduleAlias]: if moduleAlias != importedModules[moduleAlias]:
aliasModuleStr = moduleAlias + '.' aliasModuleStr = moduleAlias + '.'
originalModuleStr = importedModules[moduleAlias] + '.' originalModuleStr = importedModules[moduleAlias] + '.'
fbody_src = fbody_src.replace(aliasModuleStr, originalModuleStr) fbody_src = fbody_src.replace(
aliasModuleStr, originalModuleStr)
# Create the qcor quantum kernel function src for QJIT and the Clang syntax handler # Create the qcor quantum kernel function src for QJIT and the Clang syntax handler
self.src = '__qpu__ void '+self.function.__name__ + \ self.src = '__qpu__ void '+self.function.__name__ + \
'('+cpp_arg_str+') {\nusing qcor::pyxasm;\n' + globalDeclStr + '\n' + fbody_src +"}\n" '('+cpp_arg_str+') {\nusing qcor::pyxasm;\n' + \
globalDeclStr + '\n' + fbody_src + "}\n"
# Handle nested kernels: # Handle nested kernels:
dependency = [] dependency = []
...@@ -201,8 +206,11 @@ class qjit(object): ...@@ -201,8 +206,11 @@ class qjit(object):
if re.search(r"\b" + re.escape(kernelCall), self.src): if re.search(r"\b" + re.escape(kernelCall), self.src):
dependency.append(kernelName) dependency.append(kernelName)
self.__kernels__graph.addKernelDependency(self.function.__name__, dependency)
sorted_kernel_dep = self.__kernels__graph.getSortedDependency(self.function.__name__) self.__kernels__graph.addKernelDependency(
self.function.__name__, dependency)
sorted_kernel_dep = self.__kernels__graph.getSortedDependency(
self.function.__name__)
# Run the QJIT compile step to store function pointers internally # Run the QJIT compile step to store function pointers internally
self._qjit.internal_python_jit_compile(self.src, sorted_kernel_dep) self._qjit.internal_python_jit_compile(self.src, sorted_kernel_dep)
...@@ -300,9 +308,9 @@ class qjit(object): ...@@ -300,9 +308,9 @@ class qjit(object):
""" """
return self.extract_composite(*args).nInstructions() return self.extract_composite(*args).nInstructions()
# def ctrl(self, *args): def ctrl(self, *args):
print('This is an internal API call and will be translated to C++ via the QJIT.\nIt can only be called from within another quantum kernel.')
exit(1)
def __call__(self, *args): def __call__(self, *args):
""" """
......
...@@ -220,7 +220,6 @@ class TestSimpleKernelJIT(unittest.TestCase): ...@@ -220,7 +220,6 @@ class TestSimpleKernelJIT(unittest.TestCase):
self.assertEqual(comp.getInstruction(i).name(), "Measure") self.assertEqual(comp.getInstruction(i).name(), "Measure")
def test_iqft_kernel(self): def test_iqft_kernel(self):
import numpy as np
@qjit @qjit
def iqft(q : qreg, startIdx : int, nbQubits : int): def iqft(q : qreg, startIdx : int, nbQubits : int):
for i in range(nbQubits/2): for i in range(nbQubits/2):
...@@ -254,61 +253,49 @@ class TestSimpleKernelJIT(unittest.TestCase): ...@@ -254,61 +253,49 @@ class TestSimpleKernelJIT(unittest.TestCase):
self.assertEqual(comp.getInstruction(i).name(), "CPhase") self.assertEqual(comp.getInstruction(i).name(), "CPhase")
self.assertEqual(comp.getInstruction(16).name(), "H") self.assertEqual(comp.getInstruction(16).name(), "H")
# def test_ctrl_kernel(self): def test_ctrl_kernel(self):
# @qjit
# def qft(q : qreg, startIdx : int, nbQubits : int): # with swap set_qpu('qpp', {'shots':1024})
# for i in range(nbQubits - 1, -1, -1):
# shiftedBitIdx = i + startIdx @qjit
# H(q[shiftedBitIdx]) def iqft(q : qreg, startIdx : int, nbQubits : int):
for i in range(nbQubits/2):
# for j in range(i-1, -1, -1): Swap(q[startIdx + i], q[startIdx + nbQubits - i - 1])
# theta = np.pi / 2**(i-j)
# tIdx = j + i for i in range(nbQubits-1):
# CPhase(q[shiftedBitIdx], q[tIdx], theta) H(q[startIdx+i])
j = i +1
# swapCount = 0 if shouldSwap == 0 else 1 for y in range(i, -1, -1):
# for i in range(nbQubits/2): theta = -MY_PI / 2**(j-y)
# Swap(q[startIdx+i], q[startIdx+nbQubits-i-1]) CPhase(q[startIdx+j], q[startIdx + y], theta)
# @qjit H(q[startIdx+nbQubits-1])
# def iqft(q : qreg, startIdx : int, nbQubits : int):
# for i in range(nbQubits/2): @qjit
# Swap(q[startIdx + i], q[startIdx + nbQubits - i - 1]) def oracle(q : qreg):
bit = q.size()-1
# for i in range(nbQubits-1): T(q[bit])
# H(q[startIdx+i])
# j = i +1 @qjit
# for y in range(i, -1, -1): def qpe(q : qreg):
# theta = -np.pi / 2**(j-y) nq = q.size()
# CPhase(q[startIdx+j], q[startIdx + y], theta)
for i in range(q.size()-1):
# H(q[startIdx+nbQubits-1]) H(q[i])
# @qjit bitPrecision = nq-1
# def oracle(q : qreg): for i in range(bitPrecision):
# bit = q.size()-1 nbCalls = 1 << i
# T(q[bit]) for j in range(nbCalls):
ctrl_bit = i
# def qpe(q : qreg): oracle.ctrl(ctrl_bit, q)
# nq = q.size()
for i in range(bitPrecision):
# for i in range(q.size()-1): Measure(q[i])
# H(q[i])
q = qalloc(4)
# bitPrecision = nq-1 qpe(q)
# for i in range(bitPrecision): print(q.counts())
# nbCalls = 1 << i
# for j in range(nbCalls):
# ctrl_bit = i
# oracle.ctrl(ctrl_bit, q)
# iqft(q, 0, bitPrecision)
# for i in range(bitPrecision):
# Measure(q[i])
# q = qalloc(4)
# qpe(q)
# print(q.counts())
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment