Commit 6105e4ae authored by Mccaskey, Alex's avatar Mccaskey, Alex
Browse files

[WIP] work on ctrl / adjoint in python qjit kernels


Signed-off-by: Mccaskey, Alex's avatarAlex McCaskey <mccaskeyaj@ornl.gov>
parent 3c86792b
......@@ -29,6 +29,29 @@ class pyxasm_visitor : public pyxasmBaseVisitor {
antlrcpp::Any visitAtom_expr(
pyxasmParser::Atom_exprContext *context) override {
// Handle kernel::ctrl(...), kernel::adjoint(...)
if (!context->trailer().empty() && context->trailer()[0]->getText() == ".ctrl") {
std::cout << "HELLO: " << context->getText() << "\n";
std::cout << context->trailer()[0]->getText() << "\n";
std::cout << context->atom()->getText() << "\n";
std::cout << context->trailer()[1]->getText() << "\n";
std::cout << context->trailer()[1]->arglist() << "\n";
auto arg_list = context->trailer()[1]->arglist();
std::stringstream ss;
ss << context->atom()->getText() << "::ctrl(parent_kernel";
for (int i = 0; i < arg_list->argument().size(); i++) {
ss << ", " << arg_list->argument(i)->getText();
}
ss << ");\n";
std::cout << "HELLO SS: " << ss.str() << "\n";
result.first = ss.str();
return 0;
}
if (context->atom()->NAME() != nullptr) {
auto inst_name = context->atom()->NAME()->getText();
......
......@@ -10,6 +10,7 @@ from collections import defaultdict
List = typing.List
PauliOperator = xacc.quantum.PauliOperator
def X(idx):
return xacc.quantum.PauliOperator({idx: 'X'}, 1.0)
......@@ -34,7 +35,8 @@ class KernelGraph(object):
self.kernel_name_list.append(kernelName)
self.kernel_idx_dep_map[self.V] = []
for dep_ker_name in depList:
self.kernel_idx_dep_map[self.V].append(self.kernel_name_list.index(dep_ker_name))
self.kernel_idx_dep_map[self.V].append(
self.kernel_name_list.index(dep_ker_name))
self.V += 1
def addEdge(self, u, v):
......@@ -86,6 +88,7 @@ class KernelGraph(object):
else:
result_dep.append(dep_name)
class qjit(object):
"""
The qjit class serves a python function decorator that enables
......@@ -186,11 +189,13 @@ class qjit(object):
if moduleAlias != importedModules[moduleAlias]:
aliasModuleStr = moduleAlias + '.'
originalModuleStr = importedModules[moduleAlias] + '.'
fbody_src = fbody_src.replace(aliasModuleStr, originalModuleStr)
fbody_src = fbody_src.replace(
aliasModuleStr, originalModuleStr)
# Create the qcor quantum kernel function src for QJIT and the Clang syntax handler
self.src = '__qpu__ void '+self.function.__name__ + \
'('+cpp_arg_str+') {\nusing qcor::pyxasm;\n' + globalDeclStr + '\n' + fbody_src +"}\n"
'('+cpp_arg_str+') {\nusing qcor::pyxasm;\n' + \
globalDeclStr + '\n' + fbody_src + "}\n"
# Handle nested kernels:
dependency = []
......@@ -201,8 +206,11 @@ class qjit(object):
if re.search(r"\b" + re.escape(kernelCall), self.src):
dependency.append(kernelName)
self.__kernels__graph.addKernelDependency(self.function.__name__, dependency)
sorted_kernel_dep = self.__kernels__graph.getSortedDependency(self.function.__name__)
self.__kernels__graph.addKernelDependency(
self.function.__name__, dependency)
sorted_kernel_dep = self.__kernels__graph.getSortedDependency(
self.function.__name__)
# Run the QJIT compile step to store function pointers internally
self._qjit.internal_python_jit_compile(self.src, sorted_kernel_dep)
......@@ -300,9 +308,9 @@ class qjit(object):
"""
return self.extract_composite(*args).nInstructions()
# def ctrl(self, *args):
def ctrl(self, *args):
print('This is an internal API call and will be translated to C++ via the QJIT.\nIt can only be called from within another quantum kernel.')
exit(1)
def __call__(self, *args):
"""
......
......@@ -220,7 +220,6 @@ class TestSimpleKernelJIT(unittest.TestCase):
self.assertEqual(comp.getInstruction(i).name(), "Measure")
def test_iqft_kernel(self):
import numpy as np
@qjit
def iqft(q : qreg, startIdx : int, nbQubits : int):
for i in range(nbQubits/2):
......@@ -254,61 +253,49 @@ class TestSimpleKernelJIT(unittest.TestCase):
self.assertEqual(comp.getInstruction(i).name(), "CPhase")
self.assertEqual(comp.getInstruction(16).name(), "H")
# def test_ctrl_kernel(self):
# @qjit
# def qft(q : qreg, startIdx : int, nbQubits : int): # with swap
# for i in range(nbQubits - 1, -1, -1):
# shiftedBitIdx = i + startIdx
# H(q[shiftedBitIdx])
# for j in range(i-1, -1, -1):
# theta = np.pi / 2**(i-j)
# tIdx = j + i
# CPhase(q[shiftedBitIdx], q[tIdx], theta)
# swapCount = 0 if shouldSwap == 0 else 1
# for i in range(nbQubits/2):
# Swap(q[startIdx+i], q[startIdx+nbQubits-i-1])
# @qjit
# def iqft(q : qreg, startIdx : int, nbQubits : int):
# for i in range(nbQubits/2):
# Swap(q[startIdx + i], q[startIdx + nbQubits - i - 1])
# for i in range(nbQubits-1):
# H(q[startIdx+i])
# j = i +1
# for y in range(i, -1, -1):
# theta = -np.pi / 2**(j-y)
# CPhase(q[startIdx+j], q[startIdx + y], theta)
# H(q[startIdx+nbQubits-1])
# @qjit
# def oracle(q : qreg):
# bit = q.size()-1
# T(q[bit])
# def qpe(q : qreg):
# nq = q.size()
# for i in range(q.size()-1):
# H(q[i])
# bitPrecision = nq-1
# for i in range(bitPrecision):
# nbCalls = 1 << i
# for j in range(nbCalls):
# ctrl_bit = i
# oracle.ctrl(ctrl_bit, q)
# iqft(q, 0, bitPrecision)
# for i in range(bitPrecision):
# Measure(q[i])
# q = qalloc(4)
# qpe(q)
# print(q.counts())
def test_ctrl_kernel(self):
set_qpu('qpp', {'shots':1024})
@qjit
def iqft(q : qreg, startIdx : int, nbQubits : int):
for i in range(nbQubits/2):
Swap(q[startIdx + i], q[startIdx + nbQubits - i - 1])
for i in range(nbQubits-1):
H(q[startIdx+i])
j = i +1
for y in range(i, -1, -1):
theta = -MY_PI / 2**(j-y)
CPhase(q[startIdx+j], q[startIdx + y], theta)
H(q[startIdx+nbQubits-1])
@qjit
def oracle(q : qreg):
bit = q.size()-1
T(q[bit])
@qjit
def qpe(q : qreg):
nq = q.size()
for i in range(q.size()-1):
H(q[i])
bitPrecision = nq-1
for i in range(bitPrecision):
nbCalls = 1 << i
for j in range(nbCalls):
ctrl_bit = i
oracle.ctrl(ctrl_bit, q)
for i in range(bitPrecision):
Measure(q[i])
q = qalloc(4)
qpe(q)
print(q.counts())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment