Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in
Toggle navigation
Menu
Open sidebar
ORNL Quantum Computing Institute
qcor
Commits
6105e4ae
Commit
6105e4ae
authored
Nov 03, 2020
by
Mccaskey, Alex
Browse files
[WIP] work on ctrl / adjoint in python qjit kernels
Signed-off-by:
Alex McCaskey
<
mccaskeyaj@ornl.gov
>
parent
3c86792b
Changes
3
Hide whitespace changes
Inline
Side-by-side
handlers/token_collector/pyxasm/pyxasm_visitor.hpp
View file @
6105e4ae
...
@@ -29,6 +29,29 @@ class pyxasm_visitor : public pyxasmBaseVisitor {
...
@@ -29,6 +29,29 @@ class pyxasm_visitor : public pyxasmBaseVisitor {
antlrcpp
::
Any
visitAtom_expr
(
antlrcpp
::
Any
visitAtom_expr
(
pyxasmParser
::
Atom_exprContext
*
context
)
override
{
pyxasmParser
::
Atom_exprContext
*
context
)
override
{
// Handle kernel::ctrl(...), kernel::adjoint(...)
if
(
!
context
->
trailer
().
empty
()
&&
context
->
trailer
()[
0
]
->
getText
()
==
".ctrl"
)
{
std
::
cout
<<
"HELLO: "
<<
context
->
getText
()
<<
"
\n
"
;
std
::
cout
<<
context
->
trailer
()[
0
]
->
getText
()
<<
"
\n
"
;
std
::
cout
<<
context
->
atom
()
->
getText
()
<<
"
\n
"
;
std
::
cout
<<
context
->
trailer
()[
1
]
->
getText
()
<<
"
\n
"
;
std
::
cout
<<
context
->
trailer
()[
1
]
->
arglist
()
<<
"
\n
"
;
auto
arg_list
=
context
->
trailer
()[
1
]
->
arglist
();
std
::
stringstream
ss
;
ss
<<
context
->
atom
()
->
getText
()
<<
"::ctrl(parent_kernel"
;
for
(
int
i
=
0
;
i
<
arg_list
->
argument
().
size
();
i
++
)
{
ss
<<
", "
<<
arg_list
->
argument
(
i
)
->
getText
();
}
ss
<<
");
\n
"
;
std
::
cout
<<
"HELLO SS: "
<<
ss
.
str
()
<<
"
\n
"
;
result
.
first
=
ss
.
str
();
return
0
;
}
if
(
context
->
atom
()
->
NAME
()
!=
nullptr
)
{
if
(
context
->
atom
()
->
NAME
()
!=
nullptr
)
{
auto
inst_name
=
context
->
atom
()
->
NAME
()
->
getText
();
auto
inst_name
=
context
->
atom
()
->
NAME
()
->
getText
();
...
...
python/qcor.py
View file @
6105e4ae
...
@@ -5,11 +5,12 @@ import inspect
...
@@ -5,11 +5,12 @@ import inspect
from
typing
import
List
from
typing
import
List
import
typing
import
typing
import
re
import
re
from
collections
import
defaultdict
from
collections
import
defaultdict
List
=
typing
.
List
List
=
typing
.
List
PauliOperator
=
xacc
.
quantum
.
PauliOperator
PauliOperator
=
xacc
.
quantum
.
PauliOperator
def
X
(
idx
):
def
X
(
idx
):
return
xacc
.
quantum
.
PauliOperator
({
idx
:
'X'
},
1.0
)
return
xacc
.
quantum
.
PauliOperator
({
idx
:
'X'
},
1.0
)
...
@@ -21,12 +22,12 @@ def Y(idx):
...
@@ -21,12 +22,12 @@ def Y(idx):
def
Z
(
idx
):
def
Z
(
idx
):
return
xacc
.
quantum
.
PauliOperator
({
idx
:
'Z'
},
1.0
)
return
xacc
.
quantum
.
PauliOperator
({
idx
:
'Z'
},
1.0
)
# Simple graph class to help resolve kernel dependency (via topological sort)
# Simple graph class to help resolve kernel dependency (via topological sort)
class
KernelGraph
(
object
):
class
KernelGraph
(
object
):
def
__init__
(
self
):
def
__init__
(
self
):
self
.
graph
=
defaultdict
(
list
)
self
.
graph
=
defaultdict
(
list
)
self
.
V
=
0
self
.
V
=
0
self
.
kernel_idx_dep_map
=
{}
self
.
kernel_idx_dep_map
=
{}
self
.
kernel_name_list
=
[]
self
.
kernel_name_list
=
[]
...
@@ -34,42 +35,43 @@ class KernelGraph(object):
...
@@ -34,42 +35,43 @@ class KernelGraph(object):
self
.
kernel_name_list
.
append
(
kernelName
)
self
.
kernel_name_list
.
append
(
kernelName
)
self
.
kernel_idx_dep_map
[
self
.
V
]
=
[]
self
.
kernel_idx_dep_map
[
self
.
V
]
=
[]
for
dep_ker_name
in
depList
:
for
dep_ker_name
in
depList
:
self
.
kernel_idx_dep_map
[
self
.
V
].
append
(
self
.
kernel_name_list
.
index
(
dep_ker_name
))
self
.
kernel_idx_dep_map
[
self
.
V
].
append
(
self
.
kernel_name_list
.
index
(
dep_ker_name
))
self
.
V
+=
1
self
.
V
+=
1
def
addEdge
(
self
,
u
,
v
):
def
addEdge
(
self
,
u
,
v
):
self
.
graph
[
u
].
append
(
v
)
self
.
graph
[
u
].
append
(
v
)
# Topological Sort.
# Topological Sort.
def
topologicalSort
(
self
):
def
topologicalSort
(
self
):
self
.
graph
=
defaultdict
(
list
)
self
.
graph
=
defaultdict
(
list
)
for
sub_ker_idx
in
self
.
kernel_idx_dep_map
:
for
sub_ker_idx
in
self
.
kernel_idx_dep_map
:
for
dep_sub_idx
in
self
.
kernel_idx_dep_map
[
sub_ker_idx
]:
for
dep_sub_idx
in
self
.
kernel_idx_dep_map
[
sub_ker_idx
]:
self
.
addEdge
(
dep_sub_idx
,
sub_ker_idx
)
self
.
addEdge
(
dep_sub_idx
,
sub_ker_idx
)
in_degree
=
[
0
]
*
(
self
.
V
)
in_degree
=
[
0
]
*
(
self
.
V
)
for
i
in
self
.
graph
:
for
i
in
self
.
graph
:
for
j
in
self
.
graph
[
i
]:
for
j
in
self
.
graph
[
i
]:
in_degree
[
j
]
+=
1
in_degree
[
j
]
+=
1
queue
=
[]
queue
=
[]
for
i
in
range
(
self
.
V
):
for
i
in
range
(
self
.
V
):
if
in_degree
[
i
]
==
0
:
if
in_degree
[
i
]
==
0
:
queue
.
append
(
i
)
queue
.
append
(
i
)
cnt
=
0
cnt
=
0
top_order
=
[]
top_order
=
[]
while
queue
:
while
queue
:
u
=
queue
.
pop
(
0
)
u
=
queue
.
pop
(
0
)
top_order
.
append
(
u
)
top_order
.
append
(
u
)
for
i
in
self
.
graph
[
u
]:
for
i
in
self
.
graph
[
u
]:
in_degree
[
i
]
-=
1
in_degree
[
i
]
-=
1
if
in_degree
[
i
]
==
0
:
if
in_degree
[
i
]
==
0
:
queue
.
append
(
i
)
queue
.
append
(
i
)
cnt
+=
1
cnt
+=
1
sortedDep
=
[]
sortedDep
=
[]
for
sorted_dep_idx
in
top_order
:
for
sorted_dep_idx
in
top_order
:
sortedDep
.
append
(
self
.
kernel_name_list
[
sorted_dep_idx
])
sortedDep
.
append
(
self
.
kernel_name_list
[
sorted_dep_idx
])
return
sortedDep
return
sortedDep
def
getSortedDependency
(
self
,
kernelName
):
def
getSortedDependency
(
self
,
kernelName
):
...
@@ -77,7 +79,7 @@ class KernelGraph(object):
...
@@ -77,7 +79,7 @@ class KernelGraph(object):
# No dependency
# No dependency
if
len
(
self
.
kernel_idx_dep_map
[
kernel_idx
])
==
0
:
if
len
(
self
.
kernel_idx_dep_map
[
kernel_idx
])
==
0
:
return
[]
return
[]
sorted_dep
=
self
.
topologicalSort
()
sorted_dep
=
self
.
topologicalSort
()
result_dep
=
[]
result_dep
=
[]
for
dep_name
in
sorted_dep
:
for
dep_name
in
sorted_dep
:
...
@@ -86,6 +88,7 @@ class KernelGraph(object):
...
@@ -86,6 +88,7 @@ class KernelGraph(object):
else
:
else
:
result_dep
.
append
(
dep_name
)
result_dep
.
append
(
dep_name
)
class
qjit
(
object
):
class
qjit
(
object
):
"""
"""
The qjit class serves a python function decorator that enables
The qjit class serves a python function decorator that enables
...
@@ -126,8 +129,8 @@ class qjit(object):
...
@@ -126,8 +129,8 @@ class qjit(object):
self
.
kwargs
=
kwargs
self
.
kwargs
=
kwargs
self
.
function
=
function
self
.
function
=
function
self
.
allowed_type_cpp_map
=
{
'<class
\'
_pyqcor.qreg
\'
>'
:
'qreg'
,
self
.
allowed_type_cpp_map
=
{
'<class
\'
_pyqcor.qreg
\'
>'
:
'qreg'
,
'<class
\'
float
\'
>'
:
'double'
,
'typing.List[float]'
:
'std::vector<double>'
,
'<class
\'
float
\'
>'
:
'double'
,
'typing.List[float]'
:
'std::vector<double>'
,
'<class
\'
int
\'
>'
:
'int'
,
'<class
\'
int
\'
>'
:
'int'
,
'<class
\'
_pyxacc.quantum.PauliOperator
\'
>'
:
'qcor::PauliOperator'
}
'<class
\'
_pyxacc.quantum.PauliOperator
\'
>'
:
'qcor::PauliOperator'
}
self
.
__dict__
.
update
(
kwargs
)
self
.
__dict__
.
update
(
kwargs
)
...
@@ -173,25 +176,27 @@ class qjit(object):
...
@@ -173,25 +176,27 @@ class qjit(object):
# Only support float atm
# Only support float atm
if
(
isinstance
(
globalVars
[
key
],
float
)):
if
(
isinstance
(
globalVars
[
key
],
float
)):
globalVarDecl
.
append
(
key
+
" = "
+
str
(
globalVars
[
key
]))
globalVarDecl
.
append
(
key
+
" = "
+
str
(
globalVars
[
key
]))
# Inject these global declarations into the function body.
# Inject these global declarations into the function body.
separator
=
"
\n
"
separator
=
"
\n
"
globalDeclStr
=
separator
.
join
(
globalVarDecl
)
globalDeclStr
=
separator
.
join
(
globalVarDecl
)
# Handle common modules like numpy or math
# Handle common modules like numpy or math
# e.g. if seeing `import numpy as np`, we'll have <'np' -> 'numpy'> in the importedModules dict.
# e.g. if seeing `import numpy as np`, we'll have <'np' -> 'numpy'> in the importedModules dict.
# We'll replace any module alias by its original name,
# We'll replace any module alias by its original name,
# i.e. 'np.pi' -> 'numpy.pi', etc.
# i.e. 'np.pi' -> 'numpy.pi', etc.
for
moduleAlias
in
importedModules
:
for
moduleAlias
in
importedModules
:
if
moduleAlias
!=
importedModules
[
moduleAlias
]:
if
moduleAlias
!=
importedModules
[
moduleAlias
]:
aliasModuleStr
=
moduleAlias
+
'.'
aliasModuleStr
=
moduleAlias
+
'.'
originalModuleStr
=
importedModules
[
moduleAlias
]
+
'.'
originalModuleStr
=
importedModules
[
moduleAlias
]
+
'.'
fbody_src
=
fbody_src
.
replace
(
aliasModuleStr
,
originalModuleStr
)
fbody_src
=
fbody_src
.
replace
(
aliasModuleStr
,
originalModuleStr
)
# Create the qcor quantum kernel function src for QJIT and the Clang syntax handler
# Create the qcor quantum kernel function src for QJIT and the Clang syntax handler
self
.
src
=
'__qpu__ void '
+
self
.
function
.
__name__
+
\
self
.
src
=
'__qpu__ void '
+
self
.
function
.
__name__
+
\
'('
+
cpp_arg_str
+
') {
\n
using qcor::pyxasm;
\n
'
+
globalDeclStr
+
'
\n
'
+
fbody_src
+
"}
\n
"
'('
+
cpp_arg_str
+
') {
\n
using qcor::pyxasm;
\n
'
+
\
globalDeclStr
+
'
\n
'
+
fbody_src
+
"}
\n
"
# Handle nested kernels:
# Handle nested kernels:
dependency
=
[]
dependency
=
[]
for
kernelName
in
self
.
__compiled__kernels
:
for
kernelName
in
self
.
__compiled__kernels
:
...
@@ -200,10 +205,13 @@ class qjit(object):
...
@@ -200,10 +205,13 @@ class qjit(object):
# pattern: "<white space> kernel("
# pattern: "<white space> kernel("
if
re
.
search
(
r
"\b"
+
re
.
escape
(
kernelCall
),
self
.
src
):
if
re
.
search
(
r
"\b"
+
re
.
escape
(
kernelCall
),
self
.
src
):
dependency
.
append
(
kernelName
)
dependency
.
append
(
kernelName
)
self
.
__kernels__graph
.
addKernelDependency
(
self
.
function
.
__name__
,
dependency
)
sorted_kernel_dep
=
self
.
__kernels__graph
.
getSortedDependency
(
self
.
function
.
__name__
)
self
.
__kernels__graph
.
addKernelDependency
(
self
.
function
.
__name__
,
dependency
)
sorted_kernel_dep
=
self
.
__kernels__graph
.
getSortedDependency
(
self
.
function
.
__name__
)
# Run the QJIT compile step to store function pointers internally
# Run the QJIT compile step to store function pointers internally
self
.
_qjit
.
internal_python_jit_compile
(
self
.
src
,
sorted_kernel_dep
)
self
.
_qjit
.
internal_python_jit_compile
(
self
.
src
,
sorted_kernel_dep
)
self
.
_qjit
.
write_cache
()
self
.
_qjit
.
write_cache
()
...
@@ -212,8 +220,8 @@ class qjit(object):
...
@@ -212,8 +220,8 @@ class qjit(object):
# Static list of all kernels compiled
# Static list of all kernels compiled
__compiled__kernels
=
[]
__compiled__kernels
=
[]
__kernels__graph
=
KernelGraph
()
__kernels__graph
=
KernelGraph
()
def
get_internal_src
(
self
):
def
get_internal_src
(
self
):
"""Return the C++ / embedded python DSL function code that will be passed to QJIT
"""Return the C++ / embedded python DSL function code that will be passed to QJIT
and the clang syntax handler. This function is primarily to be used for developer purposes. """
and the clang syntax handler. This function is primarily to be used for developer purposes. """
...
@@ -293,16 +301,16 @@ class qjit(object):
...
@@ -293,16 +301,16 @@ class qjit(object):
Print the QJIT kernel as a QASM-like string
Print the QJIT kernel as a QASM-like string
"""
"""
print
(
self
.
extract_composite
(
*
args
).
toString
())
print
(
self
.
extract_composite
(
*
args
).
toString
())
def
n_instructions
(
self
,
*
args
):
def
n_instructions
(
self
,
*
args
):
"""
"""
Return the number of quantum instructions in this kernel.
Return the number of quantum instructions in this kernel.
"""
"""
return
self
.
extract_composite
(
*
args
).
nInstructions
()
return
self
.
extract_composite
(
*
args
).
nInstructions
()
#
def ctrl(self, *args):
def
ctrl
(
self
,
*
args
):
print
(
'This is an internal API call and will be translated to C++ via the QJIT.
\n
It can only be called from within another quantum kernel.'
)
exit
(
1
)
def
__call__
(
self
,
*
args
):
def
__call__
(
self
,
*
args
):
"""
"""
...
...
python/tests/test_kernel_jit.py
View file @
6105e4ae
...
@@ -220,7 +220,6 @@ class TestSimpleKernelJIT(unittest.TestCase):
...
@@ -220,7 +220,6 @@ class TestSimpleKernelJIT(unittest.TestCase):
self
.
assertEqual
(
comp
.
getInstruction
(
i
).
name
(),
"Measure"
)
self
.
assertEqual
(
comp
.
getInstruction
(
i
).
name
(),
"Measure"
)
def
test_iqft_kernel
(
self
):
def
test_iqft_kernel
(
self
):
import
numpy
as
np
@
qjit
@
qjit
def
iqft
(
q
:
qreg
,
startIdx
:
int
,
nbQubits
:
int
):
def
iqft
(
q
:
qreg
,
startIdx
:
int
,
nbQubits
:
int
):
for
i
in
range
(
nbQubits
/
2
):
for
i
in
range
(
nbQubits
/
2
):
...
@@ -254,61 +253,49 @@ class TestSimpleKernelJIT(unittest.TestCase):
...
@@ -254,61 +253,49 @@ class TestSimpleKernelJIT(unittest.TestCase):
self
.
assertEqual
(
comp
.
getInstruction
(
i
).
name
(),
"CPhase"
)
self
.
assertEqual
(
comp
.
getInstruction
(
i
).
name
(),
"CPhase"
)
self
.
assertEqual
(
comp
.
getInstruction
(
16
).
name
(),
"H"
)
self
.
assertEqual
(
comp
.
getInstruction
(
16
).
name
(),
"H"
)
# def test_ctrl_kernel(self):
def
test_ctrl_kernel
(
self
):
# @qjit
# def qft(q : qreg, startIdx : int, nbQubits : int): # with swap
# for i in range(nbQubits - 1, -1, -1):
# shiftedBitIdx = i + startIdx
# H(q[shiftedBitIdx])
# for j in range(i-1, -1, -1):
# theta = np.pi / 2**(i-j)
# tIdx = j + i
# CPhase(q[shiftedBitIdx], q[tIdx], theta)
# swapCount = 0 if shouldSwap == 0 else 1
# for i in range(nbQubits/2):
# Swap(q[startIdx+i], q[startIdx+nbQubits-i-1])
# @qjit
set_qpu
(
'qpp'
,
{
'shots'
:
1024
})
# def iqft(q : qreg, startIdx : int, nbQubits : int):
# for i in range(nbQubits/2):
@
qjit
# Swap(q[startIdx + i], q[startIdx + nbQubits - i - 1])
def
iqft
(
q
:
qreg
,
startIdx
:
int
,
nbQubits
:
int
):
for
i
in
range
(
nbQubits
/
2
):
Swap
(
q
[
startIdx
+
i
],
q
[
startIdx
+
nbQubits
-
i
-
1
])
#
for i in range(nbQubits-1):
for
i
in
range
(
nbQubits
-
1
):
#
H(q[startIdx+i])
H
(
q
[
startIdx
+
i
])
#
j = i +1
j
=
i
+
1
#
for y in range(i, -1, -1):
for
y
in
range
(
i
,
-
1
,
-
1
):
#
theta = -
np.pi
/ 2**(j-y)
theta
=
-
MY_PI
/
2
**
(
j
-
y
)
#
CPhase(q[startIdx+j], q[startIdx + y], theta)
CPhase
(
q
[
startIdx
+
j
],
q
[
startIdx
+
y
],
theta
)
#
H(q[startIdx+nbQubits-1])
H
(
q
[
startIdx
+
nbQubits
-
1
])
#
@qjit
@
qjit
#
def oracle(q : qreg):
def
oracle
(
q
:
qreg
):
#
bit = q.size()-1
bit
=
q
.
size
()
-
1
#
T(q[bit])
T
(
q
[
bit
])
# def qpe(q : qreg):
@
qjit
# nq = q.size()
def
qpe
(
q
:
qreg
):
nq
=
q
.
size
()
#
for i in range(q.size()-1):
for
i
in
range
(
q
.
size
()
-
1
):
#
H(q[i])
H
(
q
[
i
])
#
bitPrecision = nq-1
bitPrecision
=
nq
-
1
#
for i in range(bitPrecision):
for
i
in
range
(
bitPrecision
):
#
nbCalls = 1 << i
nbCalls
=
1
<<
i
#
for j in range(nbCalls):
for
j
in
range
(
nbCalls
):
#
ctrl_bit = i
ctrl_bit
=
i
#
oracle.ctrl(ctrl_bit, q)
oracle
.
ctrl
(
ctrl_bit
,
q
)
# iqft(q, 0, bitPrecision)
for
i
in
range
(
bitPrecision
):
# for i in range(bitPrecision):
Measure
(
q
[
i
])
# Measure(q[i])
#
q = qalloc(4)
q
=
qalloc
(
4
)
#
qpe(q)
qpe
(
q
)
#
print(q.counts())
print
(
q
.
counts
())
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment