Commit 3c86792b authored by Mccaskey, Alex's avatar Mccaskey, Alex
Browse files

started on supporting python qjit kernel ctrl, adjoint, etc. implement 2**x in...


started on supporting python qjit kernel ctrl, adjoint, etc. implement 2**x in pyxasm visitor, added qft example to tests
Signed-off-by: Mccaskey, Alex's avatarAlex McCaskey <mccaskeyaj@ornl.gov>
parent e2d51688
Pipeline #123812 passed with stages
in 68 minutes and 40 seconds
......@@ -15,20 +15,20 @@ using pyxasm_result_type =
std::pair<std::string, std::shared_ptr<xacc::Instruction>>;
class pyxasm_visitor : public pyxasmBaseVisitor {
protected:
protected:
std::shared_ptr<xacc::IRProvider> provider;
// List of buffers in the *context* of this XASM visitor
std::vector<std::string> bufferNames;
public:
public:
pyxasm_visitor(const std::vector<std::string> &buffers = {})
: provider(xacc::getIRProvider("quantum")), bufferNames(buffers) {}
pyxasm_result_type result;
bool in_for_loop = false;
antlrcpp::Any
visitAtom_expr(pyxasmParser::Atom_exprContext *context) override {
antlrcpp::Any visitAtom_expr(
pyxasmParser::Atom_exprContext *context) override {
if (context->atom()->NAME() != nullptr) {
auto inst_name = context->atom()->NAME()->getText();
......@@ -69,9 +69,9 @@ public:
auto found_bracket = bit_expr_str.find_first_of("[");
if (found_bracket != std::string::npos) {
auto buffer_name = bit_expr_str.substr(0, found_bracket);
auto bit_idx_expr = bit_expr_str.substr(found_bracket + 1,
bit_expr_str.length() -
found_bracket - 2);
auto bit_idx_expr = bit_expr_str.substr(
found_bracket + 1,
bit_expr_str.length() - found_bracket - 2);
buffer_names.push_back(buffer_name);
inst->setBitExpression(i, bit_idx_expr);
} else {
......@@ -173,13 +173,43 @@ public:
const std::string rhs = ctx->testlist_star_expr(1)->getText();
ss << "auto " << lhs << " = " << rhs << "; \n";
result.first = ss.str();
return 0;
if (rhs.find("**") != std::string::npos) {
// keep processing
return visitChildren(ctx);
} else {
return 0;
}
} else {
return visitChildren(ctx);
}
}
private:
antlrcpp::Any visitPower(pyxasmParser::PowerContext *context) override {
if (context->getText().find("**") != std::string::npos &&
context->factor() != nullptr) {
// Here we handle x**y from parent assignment expression
auto replaceAll = [](std::string &s, const std::string &search,
const std::string &replace) {
for (std::size_t pos = 0;; pos += replace.length()) {
// Locate the substring to replace
pos = s.find(search, pos);
if (pos == std::string::npos) break;
// Replace by erasing and inserting
s.erase(pos, search.length());
s.insert(pos, replace);
}
};
auto factor = context->factor();
auto atom_expr = context->atom_expr();
std::string s =
"std::pow(" + atom_expr->getText() + ", " + factor->getText() + ")";
replaceAll(result.first, context->getText(), s);
return 0;
}
return visitChildren(context);
}
private:
// Replaces common Python constants, e.g. 'math.pi' or 'numpy.pi'.
// Note: the library names have been resolved to their original names.
std::string replacePythonConstants(const std::string &in_pyExpr) const {
......
......@@ -288,6 +288,22 @@ class qjit(object):
staq = xacc.getCompiler('staq')
return staq.translate(kernel)
def print_kernel(self, *args):
"""
Print the QJIT kernel as a QASM-like string
"""
print(self.extract_composite(*args).toString())
def n_instructions(self, *args):
"""
Return the number of quantum instructions in this kernel.
"""
return self.extract_composite(*args).nInstructions()
# def ctrl(self, *args):
def __call__(self, *args):
"""
Execute the decorated quantum kernel. This will directly
......
......@@ -217,7 +217,100 @@ class TestSimpleKernelJIT(unittest.TestCase):
for i in range(0, q.size() * len(list1) * len(list2)):
self.assertEqual(comp.getInstruction(i).name(), "Rx")
for i in range(q.size() * len(list1) * len(list2), comp.nInstructions()):
self.assertEqual(comp.getInstruction(i).name(), "Measure")
self.assertEqual(comp.getInstruction(i).name(), "Measure")
def test_iqft_kernel(self):
import numpy as np
@qjit
def iqft(q : qreg, startIdx : int, nbQubits : int):
for i in range(nbQubits/2):
Swap(q[startIdx + i], q[startIdx + nbQubits - i - 1])
for i in range(nbQubits-1):
H(q[startIdx+i])
j = i +1
for y in range(i, -1, -1):
theta = -MY_PI / 2**(j-y)
CPhase(q[startIdx+j], q[startIdx + y], theta)
H(q[startIdx+nbQubits-1])
q = qalloc(5)
comp = iqft.extract_composite(q, 0, 5)
print(comp.toString())
self.assertEqual(comp.nInstructions(), 17)
self.assertEqual(comp.getInstruction(0).name(), "Swap")
self.assertEqual(comp.getInstruction(1).name(), "Swap")
self.assertEqual(comp.getInstruction(2).name(), "H")
self.assertEqual(comp.getInstruction(3).name(), "CPhase")
self.assertEqual(comp.getInstruction(4).name(), "H")
for i in range(5, 7):
self.assertEqual(comp.getInstruction(i).name(), "CPhase")
self.assertEqual(comp.getInstruction(7).name(), "H")
for i in range(8, 11):
self.assertEqual(comp.getInstruction(i).name(), "CPhase")
self.assertEqual(comp.getInstruction(11).name(), "H")
for i in range(12, 16):
self.assertEqual(comp.getInstruction(i).name(), "CPhase")
self.assertEqual(comp.getInstruction(16).name(), "H")
# def test_ctrl_kernel(self):
# @qjit
# def qft(q : qreg, startIdx : int, nbQubits : int): # with swap
# for i in range(nbQubits - 1, -1, -1):
# shiftedBitIdx = i + startIdx
# H(q[shiftedBitIdx])
# for j in range(i-1, -1, -1):
# theta = np.pi / 2**(i-j)
# tIdx = j + i
# CPhase(q[shiftedBitIdx], q[tIdx], theta)
# swapCount = 0 if shouldSwap == 0 else 1
# for i in range(nbQubits/2):
# Swap(q[startIdx+i], q[startIdx+nbQubits-i-1])
# @qjit
# def iqft(q : qreg, startIdx : int, nbQubits : int):
# for i in range(nbQubits/2):
# Swap(q[startIdx + i], q[startIdx + nbQubits - i - 1])
# for i in range(nbQubits-1):
# H(q[startIdx+i])
# j = i +1
# for y in range(i, -1, -1):
# theta = -np.pi / 2**(j-y)
# CPhase(q[startIdx+j], q[startIdx + y], theta)
# H(q[startIdx+nbQubits-1])
# @qjit
# def oracle(q : qreg):
# bit = q.size()-1
# T(q[bit])
# def qpe(q : qreg):
# nq = q.size()
# for i in range(q.size()-1):
# H(q[i])
# bitPrecision = nq-1
# for i in range(bitPrecision):
# nbCalls = 1 << i
# for j in range(nbCalls):
# ctrl_bit = i
# oracle.ctrl(ctrl_bit, q)
# iqft(q, 0, bitPrecision)
# for i in range(bitPrecision):
# Measure(q[i])
# q = qalloc(4)
# qpe(q)
# print(q.counts())
if __name__ == '__main__':
unittest.main()
\ No newline at end of file
......@@ -47,6 +47,10 @@ using Handle = std::future<ResultsBuffer>;
// Sync up a Handle
ResultsBuffer sync(Handle &handle);
// Indicate we have an error with the given message.
// This should abort execution
void error(const std::string &msg);
template <typename T> std::vector<T> linspace(T a, T b, size_t N) {
T h = (b - a) / static_cast<T>(N - 1);
std::vector<T> xs(N);
......@@ -63,6 +67,25 @@ inline std::vector<int> range(int N) {
return vec;
}
inline std::vector<int> range(int start, int stop, int step) {
if (step == 0) {
error("step for range must be non-zero.");
}
int i = start;
std::vector<int> vec;
while ((step > 0) ? (i < stop) : (i > stop)) {
vec.push_back(i);
i+=step;
}
return vec;
}
inline std::vector<int> range(int start, int stop) {
return range(start, stop, 1);
}
// Get size() of any types that have size() implemented.
template <typename T> int len(const T &countable) { return countable.size(); }
template <typename T> int len(T &countable) { return countable.size(); }
......@@ -272,8 +295,5 @@ bool get_verbose();
// Set the shots for a given quantum kernel execution
void set_shots(const int shots);
// Indicate we have an error with the given message.
// This should abort execution
void error(const std::string &msg);
} // namespace qcor
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment