Commit 9ad2d539 authored by Nguyen, Thien Minh's avatar Nguyen, Thien Minh
Browse files

qrt broadcast API to take qreg by value (copies)



This allows passing a qreg slice directly to broadcast instructions.

Refactor the PyXASM token collector impl.

Signed-off-by: Nguyen, Thien Minh's avatarThien Nguyen <nguyentm@ornl.gov>
parent a6938257
......@@ -36,6 +36,7 @@ class pyxasm_visitor : public pyxasmBaseVisitor {
// Var to keep track of sub-node rewrite:
// e.g., traverse down the AST recursively.
std::stringstream sub_node_translation;
bool is_processing_sub_expr = false;
antlrcpp::Any visitAtom_expr(
pyxasmParser::Atom_exprContext *context) override {
......@@ -53,101 +54,109 @@ class pyxasm_visitor : public pyxasmBaseVisitor {
'{' (dictorsetmaker)? '}' |
NAME | NUMBER | STRING+ | '...' | 'None' | 'True' | 'False');
*/
if (context->atom() && context->atom()->testlist_comp()) {
// Array type expression:
std::cout << "Array atom expression: "
<< context->atom()->testlist_comp()->getText() << "\n";
// Use braces
sub_node_translation << "{";
bool firstElProcessed = false;
for (auto &testNode : context->atom()->testlist_comp()->test()) {
std::cout << "Array elem: " << testNode->getText() << "\n";
// Add comma if needed (there is a previous element)
if (firstElProcessed) {
sub_node_translation << ", ";
// Only processes these for sub-expressesions,
// e.g. re-entries to this function
if (is_processing_sub_expr) {
if (context->atom() && context->atom()->testlist_comp()) {
// Array type expression:
std::cout << "Array atom expression: "
<< context->atom()->testlist_comp()->getText() << "\n";
// Use braces
sub_node_translation << "{";
bool firstElProcessed = false;
for (auto &testNode : context->atom()->testlist_comp()->test()) {
std::cout << "Array elem: " << testNode->getText() << "\n";
// Add comma if needed (there is a previous element)
if (firstElProcessed) {
sub_node_translation << ", ";
}
sub_node_translation << testNode->getText();
firstElProcessed = true;
}
sub_node_translation << testNode->getText();
firstElProcessed = true;
sub_node_translation << "}";
return 0;
}
sub_node_translation << "}";
return 0;
}
if (context->atom() && context->atom()->dictorsetmaker()) {
// Dict:
std::cout << "Dict atom expression: "
<< context->atom()->dictorsetmaker()->getText() << "\n";
// TODO:
return 0;
}
if (context->atom() && context->atom()->dictorsetmaker()) {
// Dict:
std::cout << "Dict atom expression: "
<< context->atom()->dictorsetmaker()->getText() << "\n";
// TODO:
return 0;
}
if (context->atom() && !context->atom()->STRING().empty()) {
// Strings:
for (auto &strNode : context->atom()->STRING()) {
std::string cppStrLiteral = strNode->getText();
// Handle Python single-quotes
if (cppStrLiteral.front() == '\'' && cppStrLiteral.back() == '\'') {
cppStrLiteral.front() = '"';
cppStrLiteral.back() = '"';
if (context->atom() && !context->atom()->STRING().empty()) {
// Strings:
for (auto &strNode : context->atom()->STRING()) {
std::string cppStrLiteral = strNode->getText();
// Handle Python single-quotes
if (cppStrLiteral.front() == '\'' && cppStrLiteral.back() == '\'') {
cppStrLiteral.front() = '"';
cppStrLiteral.back() = '"';
}
sub_node_translation << cppStrLiteral;
std::cout << "String expression: " << strNode->getText() << " --> "
<< cppStrLiteral << "\n";
}
sub_node_translation << cppStrLiteral;
std::cout << "String expression: " << strNode->getText() << " --> " << cppStrLiteral << "\n";
return 0;
}
return 0;
}
const auto isSliceOp =
[](pyxasmParser::Atom_exprContext *atom_expr_context) -> bool {
if (atom_expr_context->trailer().size() == 1) {
auto subscriptlist = atom_expr_context->trailer(0)->subscriptlist();
if (subscriptlist && subscriptlist->subscript().size() == 1) {
auto subscript = subscriptlist->subscript(0);
const auto nbTestTerms = subscript->test().size();
// Multiple test terms (separated by ':')
return (nbTestTerms > 1);
const auto isSliceOp =
[](pyxasmParser::Atom_exprContext *atom_expr_context) -> bool {
if (atom_expr_context->trailer().size() == 1) {
auto subscriptlist = atom_expr_context->trailer(0)->subscriptlist();
if (subscriptlist && subscriptlist->subscript().size() == 1) {
auto subscript = subscriptlist->subscript(0);
const auto nbTestTerms = subscript->test().size();
// Multiple test terms (separated by ':')
return (nbTestTerms > 1);
}
}
}
return false;
};
// Handle slicing operations (multiple array subscriptions separated by ':')
// on a qreg.
if (context->atom() &&
xacc::container::contains(bufferNames, context->atom()->getText()) &&
isSliceOp(context)) {
std::cout << "Slice op: " << context->getText() << "\n";
sub_node_translation << context->atom()->getText() << ".extract_range({";
auto subscripts =
context->trailer(0)->subscriptlist()->subscript(0)->test();
assert(subscripts.size() > 1);
std::vector<std::string> subscriptTerms;
for (auto &test : subscripts) {
subscriptTerms.emplace_back(test->getText());
}
return false;
};
auto sliceOp =
context->trailer(0)->subscriptlist()->subscript(0)->sliceop();
if (sliceOp && sliceOp->test()) {
subscriptTerms.emplace_back(sliceOp->test()->getText());
}
assert(subscriptTerms.size() == 2 || subscriptTerms.size() == 3);
for (int i = 0; i < subscriptTerms.size(); ++i) {
// Need to cast to prevent compiler errors,
// e.g. when using q.size() which returns an int.
sub_node_translation << "static_cast<size_t>(" << subscriptTerms[i] << ")";
if (i != subscriptTerms.size() - 1) {
sub_node_translation << ", ";
// Handle slicing operations (multiple array subscriptions separated by
// ':') on a qreg.
if (context->atom() &&
xacc::container::contains(bufferNames, context->atom()->getText()) &&
isSliceOp(context)) {
std::cout << "Slice op: " << context->getText() << "\n";
sub_node_translation << context->atom()->getText()
<< ".extract_range({";
auto subscripts =
context->trailer(0)->subscriptlist()->subscript(0)->test();
assert(subscripts.size() > 1);
std::vector<std::string> subscriptTerms;
for (auto &test : subscripts) {
subscriptTerms.emplace_back(test->getText());
}
}
sub_node_translation << "})";
auto sliceOp =
context->trailer(0)->subscriptlist()->subscript(0)->sliceop();
if (sliceOp && sliceOp->test()) {
subscriptTerms.emplace_back(sliceOp->test()->getText());
}
assert(subscriptTerms.size() == 2 || subscriptTerms.size() == 3);
for (int i = 0; i < subscriptTerms.size(); ++i) {
// Need to cast to prevent compiler errors,
// e.g. when using q.size() which returns an int.
sub_node_translation << "static_cast<size_t>(" << subscriptTerms[i]
<< ")";
if (i != subscriptTerms.size() - 1) {
sub_node_translation << ", ";
}
}
sub_node_translation << "})";
// convert the slice op to initializer list:
std::cout << "Slice Convert: " << context->getText() << " --> "
<< sub_node_translation.str() << "\n";
return 0;
}
// convert the slice op to initializer list:
std::cout << "Slice Convert: " << context->getText() << " --> "
<< sub_node_translation.str() << "\n";
return 0;
}
......@@ -334,22 +343,7 @@ class pyxasm_visitor : public pyxasmBaseVisitor {
context->trailer()[0]->arglist()->argument();
ss << inst_name << "(";
for (size_t i = 0; i < argList.size(); ++i) {
// Find rewrite for arguments
sub_node_translation.str(std::string());
// visit arg sub-node:
visitChildren(argList[i]);
// Check if there is a rewrite:
if (!sub_node_translation.str().empty()) {
const auto arg_new_str = sub_node_translation.str();
std::cout << argList[i]->getText() << " --> " << arg_new_str << "\n";
sub_node_translation.str(std::string());
ss << arg_new_str;
}
else {
// Use the arg as is:
ss << argList[i]->getText();
}
ss << rewriteFunctionArgument(*(argList[i]));
if (i != argList.size() - 1) {
ss << ", ";
}
......@@ -417,7 +411,7 @@ class pyxasm_visitor : public pyxasmBaseVisitor {
} else {
// Strategy: try to traverse the rhs to see if there is a possible rewrite;
// Otherwise, use the text as is.
is_processing_sub_expr = true;
// clear the sub_node_translation
sub_node_translation.str(std::string());
......@@ -549,7 +543,10 @@ class pyxasm_visitor : public pyxasmBaseVisitor {
// Strategy: try to traverse the argument context to see if there is a
// possible rewrite; i.e. it may be another atom_expression that we have a
// handler for. Otherwise, use the text as is.
// We need this flag to prevent parsing quantum instructions as sub-expressions.
// e.g. QCOR operators (X, Y, Z) in an observable definition shouldn't be
// processed as instructions.
is_processing_sub_expr = true;
// clear the sub_node_translation
sub_node_translation.str(std::string());
......
......@@ -164,8 +164,8 @@ auto last_qubit = q.tail();
Z::ctrl(parent_kernel, ctrl_qubits, last_qubit);
X::ctrl(parent_kernel, q.head(q.size()-1), q.tail());
auto r = q.extract_range(0,bitPrecision);
auto slice1 = q.extract_range({0, 3});
auto slice2 = q.extract_range({0, 5, 2});
auto slice1 = q.extract_range({static_cast<size_t>(0), static_cast<size_t>(3)});
auto slice2 = q.extract_range({static_cast<size_t>(0), static_cast<size_t>(5), static_cast<size_t>(2)});
)#";
EXPECT_EQ(expectedCodeGen, ss.str());
}
......@@ -199,13 +199,37 @@ TEST(PyXASMTokenCollectorTester, checkBroadCastWithSlice) {
R"#(quantum::x(q.head(q.size()-1));
quantum::x(q[0]);
quantum::x(q);
quantum::x(q.extract_range({0, 2}));
quantum::x(q.extract_range({0, 5, 2}));
quantum::x(q.extract_range({static_cast<size_t>(0), static_cast<size_t>(2)}));
quantum::x(q.extract_range({static_cast<size_t>(0), static_cast<size_t>(5), static_cast<size_t>(2)}));
quantum::mz(q.head(q.size()-1));
quantum::mz(q[0]);
quantum::mz(q);
quantum::mz(q.extract_range({0, 2}));
quantum::mz(q.extract_range({0, 5, 2}));
quantum::mz(q.extract_range({static_cast<size_t>(0), static_cast<size_t>(2)}));
quantum::mz(q.extract_range({static_cast<size_t>(0), static_cast<size_t>(5), static_cast<size_t>(2)}));
)#";
EXPECT_EQ(expectedCodeGen, ss.str());
}
TEST(PyXASMTokenCollectorTester, checkQcorOperators) {
LexerHelper helper;
auto [tokens, PP] = helper.Lex(R"(
exponent_op = X(0) * Y(1) - Y(0) * X(1)
exp_i_theta(q, theta, exponent_op)
)");
clang::CachedTokens cached;
for (auto &t : tokens) {
cached.push_back(t);
}
std::stringstream ss;
auto xasm_tc = xacc::getService<qcor::TokenCollector>("pyxasm");
xasm_tc->collect(*PP.get(), cached, {"q"}, ss);
std::cout << "heres the test\n";
std::cout << ss.str() << "\n";
const std::string expectedCodeGen =
R"#(auto exponent_op = X(0)*Y(1)-Y(0)*X(1);
quantum::exp(q, theta, exponent_op);
)#";
EXPECT_EQ(expectedCodeGen, ss.str());
}
......
......@@ -179,5 +179,38 @@ class TestKernelJIT(unittest.TestCase):
self.assertEqual(comp2.getInstruction(1).name(), "Z")
self.assertEqual(comp3.getInstruction(1).name(), "T")
def test_instBroadCast(self):
set_qpu('qpp', {'shots':1024})
@qjit
def broadCastTest(q : qreg):
# Simple broadcast
X(q)
# broadcast by slice
Z(q[0:q.size()])
# Even qubits
Y(q[0:q.size():2])
q = qalloc(6)
comp = broadCastTest.extract_composite(q)
counter = 0
for i in range(q.size()):
self.assertEqual(comp.getInstruction(counter).name(), "X")
self.assertEqual(comp.getInstruction(counter).bits()[0], i)
counter += 1
for i in range(q.size()):
self.assertEqual(comp.getInstruction(counter).name(), "Z")
self.assertEqual(comp.getInstruction(counter).bits()[0], i)
counter += 1
for i in range(0, q.size(), 2):
self.assertEqual(comp.getInstruction(counter).name(), "Y")
self.assertEqual(comp.getInstruction(counter).bits()[0], i)
counter += 1
self.assertEqual(comp.nInstructions(), counter)
if __name__ == '__main__':
unittest.main()
\ No newline at end of file
......@@ -225,80 +225,80 @@ void persistBitstring(xacc::AcceleratorBuffer *buffer) {
}
}
void h(qreg &q) {
void h(qreg q) {
for (int i = 0; i < q.size(); i++) {
h(q[i]);
}
}
void x(qreg &q) {
void x(qreg q) {
for (int i = 0; i < q.size(); i++) {
x(q[i]);
}
}
void y(qreg &q) {
void y(qreg q) {
for (int i = 0; i < q.size(); i++) {
y(q[i]);
}
}
void z(qreg &q) {
void z(qreg q) {
for (int i = 0; i < q.size(); i++) {
z(q[i]);
}
}
void t(qreg &q) {
void t(qreg q) {
for (int i = 0; i < q.size(); i++) {
t(q[i]);
}
}
void tdg(qreg &q) {
void tdg(qreg q) {
for (int i = 0; i < q.size(); i++) {
tdg(q[i]);
}
}
void s(qreg &q) {
void s(qreg q) {
for (int i = 0; i < q.size(); i++) {
s(q[i]);
}
}
void sdg(qreg &q) {
void sdg(qreg q) {
for (int i = 0; i < q.size(); i++) {
sdg(q[i]);
}
}
void mz(qreg &q) {
void mz(qreg q) {
for (int i = 0; i < q.size(); i++) {
mz(q[i]);
}
}
void rx(qreg &q, const double theta) {
void rx(qreg q, const double theta) {
for (int i = 0; i < q.size(); i++) {
rx(q[i], theta);
}
}
void ry(qreg &q, const double theta) {
void ry(qreg q, const double theta) {
for (int i = 0; i < q.size(); i++) {
ry(q[i], theta);
}
}
void rz(qreg &q, const double theta) {
void rz(qreg q, const double theta) {
for (int i = 0; i < q.size(); i++) {
rz(q[i], theta);
}
}
// U1(theta) gate
void u1(qreg &q, const double theta) {
void u1(qreg q, const double theta) {
for (int i = 0; i < q.size(); i++) {
u1(q[i], theta);
}
}
void u3(qreg &q, const double theta, const double phi, const double lambda) {
void u3(qreg q, const double theta, const double phi, const double lambda) {
for (int i = 0; i < q.size(); i++) {
u3(q[i], theta, phi, lambda);
}
}
void reset(qreg &q) {
void reset(qreg q) {
for (int i = 0; i < q.size(); i++) {
reset(q[i]);
}
......
......@@ -125,15 +125,15 @@ void sdg(const qubit &qidx);
void reset(const qubit &qidx);
// broadcast across qreg
void h(qreg &q);
void x(qreg &q);
void y(qreg &q);
void z(qreg &q);
void t(qreg &q);
void tdg(qreg &q);
void s(qreg &q);
void sdg(qreg &q);
void reset(qreg &qidx);
void h(qreg q);
void x(qreg q);
void y(qreg q);
void z(qreg q);
void t(qreg q);
void tdg(qreg q);
void s(qreg q);
void sdg(qreg q);
void reset(qreg qidx);
// Common single-qubit, parameterized instructions
void rx(const qubit &qidx, const double theta);
......@@ -145,17 +145,17 @@ void u3(const qubit &qidx, const double theta, const double phi,
const double lambda);
// broadcast rotations across qubits
void rx(qreg &qidx, const double theta);
void ry(qreg &qidx, const double theta);
void rz(qreg &qidx, const double theta);
void rx(qreg qidx, const double theta);
void ry(qreg qidx, const double theta);
void rz(qreg qidx, const double theta);
// U1(theta) gate
void u1(qreg &qidx, const double theta);
void u3(qreg &qidx, const double theta, const double phi,
void u1(qreg qidx, const double theta);
void u3(qreg qidx, const double theta, const double phi,
const double lambda);
// Measure-Z and broadcast mz
bool mz(const qubit &qidx);
void mz(qreg &q);
void mz(qreg q);
// Common two-qubit gates.
void cnot(const qubit &src_idx, const qubit &tgt_idx);
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment