Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in
Toggle navigation
Menu
Open sidebar
ORNL Quantum Computing Institute
qcor
Commits
6105e4ae
Commit
6105e4ae
authored
Nov 03, 2020
by
Mccaskey, Alex
Browse files
[WIP] work on ctrl / adjoint in python qjit kernels
Signed-off-by:
Alex McCaskey
<
mccaskeyaj@ornl.gov
>
parent
3c86792b
Changes
3
Hide whitespace changes
Inline
Side-by-side
handlers/token_collector/pyxasm/pyxasm_visitor.hpp
View file @
6105e4ae
...
...
@@ -29,6 +29,29 @@ class pyxasm_visitor : public pyxasmBaseVisitor {
antlrcpp
::
Any
visitAtom_expr
(
pyxasmParser
::
Atom_exprContext
*
context
)
override
{
// Handle kernel::ctrl(...), kernel::adjoint(...)
if
(
!
context
->
trailer
().
empty
()
&&
context
->
trailer
()[
0
]
->
getText
()
==
".ctrl"
)
{
std
::
cout
<<
"HELLO: "
<<
context
->
getText
()
<<
"
\n
"
;
std
::
cout
<<
context
->
trailer
()[
0
]
->
getText
()
<<
"
\n
"
;
std
::
cout
<<
context
->
atom
()
->
getText
()
<<
"
\n
"
;
std
::
cout
<<
context
->
trailer
()[
1
]
->
getText
()
<<
"
\n
"
;
std
::
cout
<<
context
->
trailer
()[
1
]
->
arglist
()
<<
"
\n
"
;
auto
arg_list
=
context
->
trailer
()[
1
]
->
arglist
();
std
::
stringstream
ss
;
ss
<<
context
->
atom
()
->
getText
()
<<
"::ctrl(parent_kernel"
;
for
(
int
i
=
0
;
i
<
arg_list
->
argument
().
size
();
i
++
)
{
ss
<<
", "
<<
arg_list
->
argument
(
i
)
->
getText
();
}
ss
<<
");
\n
"
;
std
::
cout
<<
"HELLO SS: "
<<
ss
.
str
()
<<
"
\n
"
;
result
.
first
=
ss
.
str
();
return
0
;
}
if
(
context
->
atom
()
->
NAME
()
!=
nullptr
)
{
auto
inst_name
=
context
->
atom
()
->
NAME
()
->
getText
();
...
...
python/qcor.py
View file @
6105e4ae
...
...
@@ -5,11 +5,12 @@ import inspect
from
typing
import
List
import
typing
import
re
from
collections
import
defaultdict
from
collections
import
defaultdict
List
=
typing
.
List
PauliOperator
=
xacc
.
quantum
.
PauliOperator
def
X
(
idx
):
return
xacc
.
quantum
.
PauliOperator
({
idx
:
'X'
},
1.0
)
...
...
@@ -21,12 +22,12 @@ def Y(idx):
def
Z
(
idx
):
return
xacc
.
quantum
.
PauliOperator
({
idx
:
'Z'
},
1.0
)
# Simple graph class to help resolve kernel dependency (via topological sort)
class
KernelGraph
(
object
):
def
__init__
(
self
):
self
.
graph
=
defaultdict
(
list
)
self
.
V
=
0
class
KernelGraph
(
object
):
def
__init__
(
self
):
self
.
graph
=
defaultdict
(
list
)
self
.
V
=
0
self
.
kernel_idx_dep_map
=
{}
self
.
kernel_name_list
=
[]
...
...
@@ -34,42 +35,43 @@ class KernelGraph(object):
self
.
kernel_name_list
.
append
(
kernelName
)
self
.
kernel_idx_dep_map
[
self
.
V
]
=
[]
for
dep_ker_name
in
depList
:
self
.
kernel_idx_dep_map
[
self
.
V
].
append
(
self
.
kernel_name_list
.
index
(
dep_ker_name
))
self
.
kernel_idx_dep_map
[
self
.
V
].
append
(
self
.
kernel_name_list
.
index
(
dep_ker_name
))
self
.
V
+=
1
def
addEdge
(
self
,
u
,
v
):
self
.
graph
[
u
].
append
(
v
)
# Topological Sort.
def
topologicalSort
(
self
):
self
.
graph
=
defaultdict
(
list
)
def
addEdge
(
self
,
u
,
v
):
self
.
graph
[
u
].
append
(
v
)
# Topological Sort.
def
topologicalSort
(
self
):
self
.
graph
=
defaultdict
(
list
)
for
sub_ker_idx
in
self
.
kernel_idx_dep_map
:
for
dep_sub_idx
in
self
.
kernel_idx_dep_map
[
sub_ker_idx
]:
self
.
addEdge
(
dep_sub_idx
,
sub_ker_idx
)
in_degree
=
[
0
]
*
(
self
.
V
)
for
i
in
self
.
graph
:
for
j
in
self
.
graph
[
i
]:
self
.
addEdge
(
dep_sub_idx
,
sub_ker_idx
)
in_degree
=
[
0
]
*
(
self
.
V
)
for
i
in
self
.
graph
:
for
j
in
self
.
graph
[
i
]:
in_degree
[
j
]
+=
1
queue
=
[]
for
i
in
range
(
self
.
V
):
if
in_degree
[
i
]
==
0
:
queue
.
append
(
i
)
queue
=
[]
for
i
in
range
(
self
.
V
):
if
in_degree
[
i
]
==
0
:
queue
.
append
(
i
)
cnt
=
0
top_order
=
[]
while
queue
:
u
=
queue
.
pop
(
0
)
top_order
.
append
(
u
)
for
i
in
self
.
graph
[
u
]:
top_order
=
[]
while
queue
:
u
=
queue
.
pop
(
0
)
top_order
.
append
(
u
)
for
i
in
self
.
graph
[
u
]:
in_degree
[
i
]
-=
1
if
in_degree
[
i
]
==
0
:
queue
.
append
(
i
)
if
in_degree
[
i
]
==
0
:
queue
.
append
(
i
)
cnt
+=
1
sortedDep
=
[]
for
sorted_dep_idx
in
top_order
:
sortedDep
.
append
(
self
.
kernel_name_list
[
sorted_dep_idx
])
sortedDep
.
append
(
self
.
kernel_name_list
[
sorted_dep_idx
])
return
sortedDep
def
getSortedDependency
(
self
,
kernelName
):
...
...
@@ -77,7 +79,7 @@ class KernelGraph(object):
# No dependency
if
len
(
self
.
kernel_idx_dep_map
[
kernel_idx
])
==
0
:
return
[]
sorted_dep
=
self
.
topologicalSort
()
result_dep
=
[]
for
dep_name
in
sorted_dep
:
...
...
@@ -86,6 +88,7 @@ class KernelGraph(object):
else
:
result_dep
.
append
(
dep_name
)
class
qjit
(
object
):
"""
The qjit class serves a python function decorator that enables
...
...
@@ -126,8 +129,8 @@ class qjit(object):
self
.
kwargs
=
kwargs
self
.
function
=
function
self
.
allowed_type_cpp_map
=
{
'<class
\'
_pyqcor.qreg
\'
>'
:
'qreg'
,
'<class
\'
float
\'
>'
:
'double'
,
'typing.List[float]'
:
'std::vector<double>'
,
'<class
\'
int
\'
>'
:
'int'
,
'<class
\'
float
\'
>'
:
'double'
,
'typing.List[float]'
:
'std::vector<double>'
,
'<class
\'
int
\'
>'
:
'int'
,
'<class
\'
_pyxacc.quantum.PauliOperator
\'
>'
:
'qcor::PauliOperator'
}
self
.
__dict__
.
update
(
kwargs
)
...
...
@@ -173,25 +176,27 @@ class qjit(object):
# Only support float atm
if
(
isinstance
(
globalVars
[
key
],
float
)):
globalVarDecl
.
append
(
key
+
" = "
+
str
(
globalVars
[
key
]))
# Inject these global declarations into the function body.
separator
=
"
\n
"
globalDeclStr
=
separator
.
join
(
globalVarDecl
)
# Handle common modules like numpy or math
# e.g. if seeing `import numpy as np`, we'll have <'np' -> 'numpy'> in the importedModules dict.
# e.g. if seeing `import numpy as np`, we'll have <'np' -> 'numpy'> in the importedModules dict.
# We'll replace any module alias by its original name,
# i.e. 'np.pi' -> 'numpy.pi', etc.
for
moduleAlias
in
importedModules
:
if
moduleAlias
!=
importedModules
[
moduleAlias
]:
aliasModuleStr
=
moduleAlias
+
'.'
originalModuleStr
=
importedModules
[
moduleAlias
]
+
'.'
fbody_src
=
fbody_src
.
replace
(
aliasModuleStr
,
originalModuleStr
)
fbody_src
=
fbody_src
.
replace
(
aliasModuleStr
,
originalModuleStr
)
# Create the qcor quantum kernel function src for QJIT and the Clang syntax handler
self
.
src
=
'__qpu__ void '
+
self
.
function
.
__name__
+
\
'('
+
cpp_arg_str
+
') {
\n
using qcor::pyxasm;
\n
'
+
globalDeclStr
+
'
\n
'
+
fbody_src
+
"}
\n
"
'('
+
cpp_arg_str
+
') {
\n
using qcor::pyxasm;
\n
'
+
\
globalDeclStr
+
'
\n
'
+
fbody_src
+
"}
\n
"
# Handle nested kernels:
dependency
=
[]
for
kernelName
in
self
.
__compiled__kernels
:
...
...
@@ -200,10 +205,13 @@ class qjit(object):
# pattern: "<white space> kernel("
if
re
.
search
(
r
"\b"
+
re
.
escape
(
kernelCall
),
self
.
src
):
dependency
.
append
(
kernelName
)
self
.
__kernels__graph
.
addKernelDependency
(
self
.
function
.
__name__
,
dependency
)
sorted_kernel_dep
=
self
.
__kernels__graph
.
getSortedDependency
(
self
.
function
.
__name__
)
self
.
__kernels__graph
.
addKernelDependency
(
self
.
function
.
__name__
,
dependency
)
sorted_kernel_dep
=
self
.
__kernels__graph
.
getSortedDependency
(
self
.
function
.
__name__
)
# Run the QJIT compile step to store function pointers internally
self
.
_qjit
.
internal_python_jit_compile
(
self
.
src
,
sorted_kernel_dep
)
self
.
_qjit
.
write_cache
()
...
...
@@ -212,8 +220,8 @@ class qjit(object):
# Static list of all kernels compiled
__compiled__kernels
=
[]
__kernels__graph
=
KernelGraph
()
__kernels__graph
=
KernelGraph
()
def
get_internal_src
(
self
):
"""Return the C++ / embedded python DSL function code that will be passed to QJIT
and the clang syntax handler. This function is primarily to be used for developer purposes. """
...
...
@@ -293,16 +301,16 @@ class qjit(object):
Print the QJIT kernel as a QASM-like string
"""
print
(
self
.
extract_composite
(
*
args
).
toString
())
def
n_instructions
(
self
,
*
args
):
"""
Return the number of quantum instructions in this kernel.
"""
return
self
.
extract_composite
(
*
args
).
nInstructions
()
#
def ctrl(self, *args):
def
ctrl
(
self
,
*
args
):
print
(
'This is an internal API call and will be translated to C++ via the QJIT.
\n
It can only be called from within another quantum kernel.'
)
exit
(
1
)
def
__call__
(
self
,
*
args
):
"""
...
...
python/tests/test_kernel_jit.py
View file @
6105e4ae
...
...
@@ -220,7 +220,6 @@ class TestSimpleKernelJIT(unittest.TestCase):
self
.
assertEqual
(
comp
.
getInstruction
(
i
).
name
(),
"Measure"
)
def
test_iqft_kernel
(
self
):
import
numpy
as
np
@
qjit
def
iqft
(
q
:
qreg
,
startIdx
:
int
,
nbQubits
:
int
):
for
i
in
range
(
nbQubits
/
2
):
...
...
@@ -254,61 +253,49 @@ class TestSimpleKernelJIT(unittest.TestCase):
self
.
assertEqual
(
comp
.
getInstruction
(
i
).
name
(),
"CPhase"
)
self
.
assertEqual
(
comp
.
getInstruction
(
16
).
name
(),
"H"
)
# def test_ctrl_kernel(self):
# @qjit
# def qft(q : qreg, startIdx : int, nbQubits : int): # with swap
# for i in range(nbQubits - 1, -1, -1):
# shiftedBitIdx = i + startIdx
# H(q[shiftedBitIdx])
# for j in range(i-1, -1, -1):
# theta = np.pi / 2**(i-j)
# tIdx = j + i
# CPhase(q[shiftedBitIdx], q[tIdx], theta)
# swapCount = 0 if shouldSwap == 0 else 1
# for i in range(nbQubits/2):
# Swap(q[startIdx+i], q[startIdx+nbQubits-i-1])
def
test_ctrl_kernel
(
self
):
# @qjit
# def iqft(q : qreg, startIdx : int, nbQubits : int):
# for i in range(nbQubits/2):
# Swap(q[startIdx + i], q[startIdx + nbQubits - i - 1])
set_qpu
(
'qpp'
,
{
'shots'
:
1024
})
@
qjit
def
iqft
(
q
:
qreg
,
startIdx
:
int
,
nbQubits
:
int
):
for
i
in
range
(
nbQubits
/
2
):
Swap
(
q
[
startIdx
+
i
],
q
[
startIdx
+
nbQubits
-
i
-
1
])
#
for i in range(nbQubits-1):
#
H(q[startIdx+i])
#
j = i +1
#
for y in range(i, -1, -1):
#
theta = -
np.pi
/ 2**(j-y)
#
CPhase(q[startIdx+j], q[startIdx + y], theta)
for
i
in
range
(
nbQubits
-
1
):
H
(
q
[
startIdx
+
i
])
j
=
i
+
1
for
y
in
range
(
i
,
-
1
,
-
1
):
theta
=
-
MY_PI
/
2
**
(
j
-
y
)
CPhase
(
q
[
startIdx
+
j
],
q
[
startIdx
+
y
],
theta
)
#
H(q[startIdx+nbQubits-1])
H
(
q
[
startIdx
+
nbQubits
-
1
])
#
@qjit
#
def oracle(q : qreg):
#
bit = q.size()-1
#
T(q[bit])
@
qjit
def
oracle
(
q
:
qreg
):
bit
=
q
.
size
()
-
1
T
(
q
[
bit
])
# def qpe(q : qreg):
# nq = q.size()
@
qjit
def
qpe
(
q
:
qreg
):
nq
=
q
.
size
()
#
for i in range(q.size()-1):
#
H(q[i])
for
i
in
range
(
q
.
size
()
-
1
):
H
(
q
[
i
])
#
bitPrecision = nq-1
#
for i in range(bitPrecision):
#
nbCalls = 1 << i
#
for j in range(nbCalls):
#
ctrl_bit = i
#
oracle.ctrl(ctrl_bit, q)
bitPrecision
=
nq
-
1
for
i
in
range
(
bitPrecision
):
nbCalls
=
1
<<
i
for
j
in
range
(
nbCalls
):
ctrl_bit
=
i
oracle
.
ctrl
(
ctrl_bit
,
q
)
# iqft(q, 0, bitPrecision)
# for i in range(bitPrecision):
# Measure(q[i])
for
i
in
range
(
bitPrecision
):
Measure
(
q
[
i
])
#
q = qalloc(4)
#
qpe(q)
#
print(q.counts())
q
=
qalloc
(
4
)
qpe
(
q
)
print
(
q
.
counts
())
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment