Loading handlers/token_collector/pyxasm/pyxasm_visitor.hpp +97 −100 Original line number Diff line number Diff line Loading @@ -36,6 +36,7 @@ class pyxasm_visitor : public pyxasmBaseVisitor { // Var to keep track of sub-node rewrite: // e.g., traverse down the AST recursively. std::stringstream sub_node_translation; bool is_processing_sub_expr = false; antlrcpp::Any visitAtom_expr( pyxasmParser::Atom_exprContext *context) override { Loading @@ -53,7 +54,9 @@ class pyxasm_visitor : public pyxasmBaseVisitor { '{' (dictorsetmaker)? '}' | NAME | NUMBER | STRING+ | '...' | 'None' | 'True' | 'False'); */ // Only processes these for sub-expressesions, // e.g. re-entries to this function if (is_processing_sub_expr) { if (context->atom() && context->atom()->testlist_comp()) { // Array type expression: std::cout << "Array atom expression: " Loading Loading @@ -92,7 +95,8 @@ class pyxasm_visitor : public pyxasmBaseVisitor { cppStrLiteral.back() = '"'; } sub_node_translation << cppStrLiteral; std::cout << "String expression: " << strNode->getText() << " --> " << cppStrLiteral << "\n"; std::cout << "String expression: " << strNode->getText() << " --> " << cppStrLiteral << "\n"; } return 0; } Loading @@ -112,13 +116,14 @@ class pyxasm_visitor : public pyxasmBaseVisitor { return false; }; // Handle slicing operations (multiple array subscriptions separated by ':') // on a qreg. // Handle slicing operations (multiple array subscriptions separated by // ':') on a qreg. if (context->atom() && xacc::container::contains(bufferNames, context->atom()->getText()) && isSliceOp(context)) { std::cout << "Slice op: " << context->getText() << "\n"; sub_node_translation << context->atom()->getText() << ".extract_range({"; sub_node_translation << context->atom()->getText() << ".extract_range({"; auto subscripts = context->trailer(0)->subscriptlist()->subscript(0)->test(); assert(subscripts.size() > 1); Loading @@ -137,7 +142,8 @@ class pyxasm_visitor : public pyxasmBaseVisitor { for (int i = 0; i < subscriptTerms.size(); ++i) { // Need to cast to prevent compiler errors, // e.g. when using q.size() which returns an int. sub_node_translation << "static_cast<size_t>(" << subscriptTerms[i] << ")"; sub_node_translation << "static_cast<size_t>(" << subscriptTerms[i] << ")"; if (i != subscriptTerms.size() - 1) { sub_node_translation << ", "; } Loading @@ -151,6 +157,9 @@ class pyxasm_visitor : public pyxasmBaseVisitor { return 0; } return 0; } // Handle kernel::ctrl(...), kernel::adjoint(...) if (!context->trailer().empty() && (context->trailer()[0]->getText() == ".ctrl" || Loading Loading @@ -334,22 +343,7 @@ class pyxasm_visitor : public pyxasmBaseVisitor { context->trailer()[0]->arglist()->argument(); ss << inst_name << "("; for (size_t i = 0; i < argList.size(); ++i) { // Find rewrite for arguments sub_node_translation.str(std::string()); // visit arg sub-node: visitChildren(argList[i]); // Check if there is a rewrite: if (!sub_node_translation.str().empty()) { const auto arg_new_str = sub_node_translation.str(); std::cout << argList[i]->getText() << " --> " << arg_new_str << "\n"; sub_node_translation.str(std::string()); ss << arg_new_str; } else { // Use the arg as is: ss << argList[i]->getText(); } ss << rewriteFunctionArgument(*(argList[i])); if (i != argList.size() - 1) { ss << ", "; } Loading Loading @@ -417,7 +411,7 @@ class pyxasm_visitor : public pyxasmBaseVisitor { } else { // Strategy: try to traverse the rhs to see if there is a possible rewrite; // Otherwise, use the text as is. is_processing_sub_expr = true; // clear the sub_node_translation sub_node_translation.str(std::string()); Loading Loading @@ -549,7 +543,10 @@ class pyxasm_visitor : public pyxasmBaseVisitor { // Strategy: try to traverse the argument context to see if there is a // possible rewrite; i.e. it may be another atom_expression that we have a // handler for. Otherwise, use the text as is. // We need this flag to prevent parsing quantum instructions as sub-expressions. // e.g. QCOR operators (X, Y, Z) in an observable definition shouldn't be // processed as instructions. is_processing_sub_expr = true; // clear the sub_node_translation sub_node_translation.str(std::string()); Loading handlers/token_collector/pyxasm/tests/PyXASMTokenCollectorTester.cpp +30 −6 Original line number Diff line number Diff line Loading @@ -164,8 +164,8 @@ auto last_qubit = q.tail(); Z::ctrl(parent_kernel, ctrl_qubits, last_qubit); X::ctrl(parent_kernel, q.head(q.size()-1), q.tail()); auto r = q.extract_range(0,bitPrecision); auto slice1 = q.extract_range({0, 3}); auto slice2 = q.extract_range({0, 5, 2}); auto slice1 = q.extract_range({static_cast<size_t>(0), static_cast<size_t>(3)}); auto slice2 = q.extract_range({static_cast<size_t>(0), static_cast<size_t>(5), static_cast<size_t>(2)}); )#"; EXPECT_EQ(expectedCodeGen, ss.str()); } Loading Loading @@ -199,13 +199,37 @@ TEST(PyXASMTokenCollectorTester, checkBroadCastWithSlice) { R"#(quantum::x(q.head(q.size()-1)); quantum::x(q[0]); quantum::x(q); quantum::x(q.extract_range({0, 2})); quantum::x(q.extract_range({0, 5, 2})); quantum::x(q.extract_range({static_cast<size_t>(0), static_cast<size_t>(2)})); quantum::x(q.extract_range({static_cast<size_t>(0), static_cast<size_t>(5), static_cast<size_t>(2)})); quantum::mz(q.head(q.size()-1)); quantum::mz(q[0]); quantum::mz(q); quantum::mz(q.extract_range({0, 2})); quantum::mz(q.extract_range({0, 5, 2})); quantum::mz(q.extract_range({static_cast<size_t>(0), static_cast<size_t>(2)})); quantum::mz(q.extract_range({static_cast<size_t>(0), static_cast<size_t>(5), static_cast<size_t>(2)})); )#"; EXPECT_EQ(expectedCodeGen, ss.str()); } TEST(PyXASMTokenCollectorTester, checkQcorOperators) { LexerHelper helper; auto [tokens, PP] = helper.Lex(R"( exponent_op = X(0) * Y(1) - Y(0) * X(1) exp_i_theta(q, theta, exponent_op) )"); clang::CachedTokens cached; for (auto &t : tokens) { cached.push_back(t); } std::stringstream ss; auto xasm_tc = xacc::getService<qcor::TokenCollector>("pyxasm"); xasm_tc->collect(*PP.get(), cached, {"q"}, ss); std::cout << "heres the test\n"; std::cout << ss.str() << "\n"; const std::string expectedCodeGen = R"#(auto exponent_op = X(0)*Y(1)-Y(0)*X(1); quantum::exp(q, theta, exponent_op); )#"; EXPECT_EQ(expectedCodeGen, ss.str()); } Loading python/tests/test_jit_simple.py +33 −0 Original line number Diff line number Diff line Loading @@ -179,5 +179,38 @@ class TestKernelJIT(unittest.TestCase): self.assertEqual(comp2.getInstruction(1).name(), "Z") self.assertEqual(comp3.getInstruction(1).name(), "T") def test_instBroadCast(self): set_qpu('qpp', {'shots':1024}) @qjit def broadCastTest(q : qreg): # Simple broadcast X(q) # broadcast by slice Z(q[0:q.size()]) # Even qubits Y(q[0:q.size():2]) q = qalloc(6) comp = broadCastTest.extract_composite(q) counter = 0 for i in range(q.size()): self.assertEqual(comp.getInstruction(counter).name(), "X") self.assertEqual(comp.getInstruction(counter).bits()[0], i) counter += 1 for i in range(q.size()): self.assertEqual(comp.getInstruction(counter).name(), "Z") self.assertEqual(comp.getInstruction(counter).bits()[0], i) counter += 1 for i in range(0, q.size(), 2): self.assertEqual(comp.getInstruction(counter).name(), "Y") self.assertEqual(comp.getInstruction(counter).bits()[0], i) counter += 1 self.assertEqual(comp.nInstructions(), counter) if __name__ == '__main__': unittest.main() No newline at end of file runtime/qrt/qrt.cpp +15 −15 Original line number Diff line number Diff line Loading @@ -225,80 +225,80 @@ void persistBitstring(xacc::AcceleratorBuffer *buffer) { } } void h(qreg &q) { void h(qreg q) { for (int i = 0; i < q.size(); i++) { h(q[i]); } } void x(qreg &q) { void x(qreg q) { for (int i = 0; i < q.size(); i++) { x(q[i]); } } void y(qreg &q) { void y(qreg q) { for (int i = 0; i < q.size(); i++) { y(q[i]); } } void z(qreg &q) { void z(qreg q) { for (int i = 0; i < q.size(); i++) { z(q[i]); } } void t(qreg &q) { void t(qreg q) { for (int i = 0; i < q.size(); i++) { t(q[i]); } } void tdg(qreg &q) { void tdg(qreg q) { for (int i = 0; i < q.size(); i++) { tdg(q[i]); } } void s(qreg &q) { void s(qreg q) { for (int i = 0; i < q.size(); i++) { s(q[i]); } } void sdg(qreg &q) { void sdg(qreg q) { for (int i = 0; i < q.size(); i++) { sdg(q[i]); } } void mz(qreg &q) { void mz(qreg q) { for (int i = 0; i < q.size(); i++) { mz(q[i]); } } void rx(qreg &q, const double theta) { void rx(qreg q, const double theta) { for (int i = 0; i < q.size(); i++) { rx(q[i], theta); } } void ry(qreg &q, const double theta) { void ry(qreg q, const double theta) { for (int i = 0; i < q.size(); i++) { ry(q[i], theta); } } void rz(qreg &q, const double theta) { void rz(qreg q, const double theta) { for (int i = 0; i < q.size(); i++) { rz(q[i], theta); } } // U1(theta) gate void u1(qreg &q, const double theta) { void u1(qreg q, const double theta) { for (int i = 0; i < q.size(); i++) { u1(q[i], theta); } } void u3(qreg &q, const double theta, const double phi, const double lambda) { void u3(qreg q, const double theta, const double phi, const double lambda) { for (int i = 0; i < q.size(); i++) { u3(q[i], theta, phi, lambda); } } void reset(qreg &q) { void reset(qreg q) { for (int i = 0; i < q.size(); i++) { reset(q[i]); } Loading runtime/qrt/qrt.hpp +15 −15 Original line number Diff line number Diff line Loading @@ -125,15 +125,15 @@ void sdg(const qubit &qidx); void reset(const qubit &qidx); // broadcast across qreg void h(qreg &q); void x(qreg &q); void y(qreg &q); void z(qreg &q); void t(qreg &q); void tdg(qreg &q); void s(qreg &q); void sdg(qreg &q); void reset(qreg &qidx); void h(qreg q); void x(qreg q); void y(qreg q); void z(qreg q); void t(qreg q); void tdg(qreg q); void s(qreg q); void sdg(qreg q); void reset(qreg qidx); // Common single-qubit, parameterized instructions void rx(const qubit &qidx, const double theta); Loading @@ -145,17 +145,17 @@ void u3(const qubit &qidx, const double theta, const double phi, const double lambda); // broadcast rotations across qubits void rx(qreg &qidx, const double theta); void ry(qreg &qidx, const double theta); void rz(qreg &qidx, const double theta); void rx(qreg qidx, const double theta); void ry(qreg qidx, const double theta); void rz(qreg qidx, const double theta); // U1(theta) gate void u1(qreg &qidx, const double theta); void u3(qreg &qidx, const double theta, const double phi, void u1(qreg qidx, const double theta); void u3(qreg qidx, const double theta, const double phi, const double lambda); // Measure-Z and broadcast mz bool mz(const qubit &qidx); void mz(qreg &q); void mz(qreg q); // Common two-qubit gates. void cnot(const qubit &src_idx, const qubit &tgt_idx); Loading Loading
handlers/token_collector/pyxasm/pyxasm_visitor.hpp +97 −100 Original line number Diff line number Diff line Loading @@ -36,6 +36,7 @@ class pyxasm_visitor : public pyxasmBaseVisitor { // Var to keep track of sub-node rewrite: // e.g., traverse down the AST recursively. std::stringstream sub_node_translation; bool is_processing_sub_expr = false; antlrcpp::Any visitAtom_expr( pyxasmParser::Atom_exprContext *context) override { Loading @@ -53,7 +54,9 @@ class pyxasm_visitor : public pyxasmBaseVisitor { '{' (dictorsetmaker)? '}' | NAME | NUMBER | STRING+ | '...' | 'None' | 'True' | 'False'); */ // Only processes these for sub-expressesions, // e.g. re-entries to this function if (is_processing_sub_expr) { if (context->atom() && context->atom()->testlist_comp()) { // Array type expression: std::cout << "Array atom expression: " Loading Loading @@ -92,7 +95,8 @@ class pyxasm_visitor : public pyxasmBaseVisitor { cppStrLiteral.back() = '"'; } sub_node_translation << cppStrLiteral; std::cout << "String expression: " << strNode->getText() << " --> " << cppStrLiteral << "\n"; std::cout << "String expression: " << strNode->getText() << " --> " << cppStrLiteral << "\n"; } return 0; } Loading @@ -112,13 +116,14 @@ class pyxasm_visitor : public pyxasmBaseVisitor { return false; }; // Handle slicing operations (multiple array subscriptions separated by ':') // on a qreg. // Handle slicing operations (multiple array subscriptions separated by // ':') on a qreg. if (context->atom() && xacc::container::contains(bufferNames, context->atom()->getText()) && isSliceOp(context)) { std::cout << "Slice op: " << context->getText() << "\n"; sub_node_translation << context->atom()->getText() << ".extract_range({"; sub_node_translation << context->atom()->getText() << ".extract_range({"; auto subscripts = context->trailer(0)->subscriptlist()->subscript(0)->test(); assert(subscripts.size() > 1); Loading @@ -137,7 +142,8 @@ class pyxasm_visitor : public pyxasmBaseVisitor { for (int i = 0; i < subscriptTerms.size(); ++i) { // Need to cast to prevent compiler errors, // e.g. when using q.size() which returns an int. sub_node_translation << "static_cast<size_t>(" << subscriptTerms[i] << ")"; sub_node_translation << "static_cast<size_t>(" << subscriptTerms[i] << ")"; if (i != subscriptTerms.size() - 1) { sub_node_translation << ", "; } Loading @@ -151,6 +157,9 @@ class pyxasm_visitor : public pyxasmBaseVisitor { return 0; } return 0; } // Handle kernel::ctrl(...), kernel::adjoint(...) if (!context->trailer().empty() && (context->trailer()[0]->getText() == ".ctrl" || Loading Loading @@ -334,22 +343,7 @@ class pyxasm_visitor : public pyxasmBaseVisitor { context->trailer()[0]->arglist()->argument(); ss << inst_name << "("; for (size_t i = 0; i < argList.size(); ++i) { // Find rewrite for arguments sub_node_translation.str(std::string()); // visit arg sub-node: visitChildren(argList[i]); // Check if there is a rewrite: if (!sub_node_translation.str().empty()) { const auto arg_new_str = sub_node_translation.str(); std::cout << argList[i]->getText() << " --> " << arg_new_str << "\n"; sub_node_translation.str(std::string()); ss << arg_new_str; } else { // Use the arg as is: ss << argList[i]->getText(); } ss << rewriteFunctionArgument(*(argList[i])); if (i != argList.size() - 1) { ss << ", "; } Loading Loading @@ -417,7 +411,7 @@ class pyxasm_visitor : public pyxasmBaseVisitor { } else { // Strategy: try to traverse the rhs to see if there is a possible rewrite; // Otherwise, use the text as is. is_processing_sub_expr = true; // clear the sub_node_translation sub_node_translation.str(std::string()); Loading Loading @@ -549,7 +543,10 @@ class pyxasm_visitor : public pyxasmBaseVisitor { // Strategy: try to traverse the argument context to see if there is a // possible rewrite; i.e. it may be another atom_expression that we have a // handler for. Otherwise, use the text as is. // We need this flag to prevent parsing quantum instructions as sub-expressions. // e.g. QCOR operators (X, Y, Z) in an observable definition shouldn't be // processed as instructions. is_processing_sub_expr = true; // clear the sub_node_translation sub_node_translation.str(std::string()); Loading
handlers/token_collector/pyxasm/tests/PyXASMTokenCollectorTester.cpp +30 −6 Original line number Diff line number Diff line Loading @@ -164,8 +164,8 @@ auto last_qubit = q.tail(); Z::ctrl(parent_kernel, ctrl_qubits, last_qubit); X::ctrl(parent_kernel, q.head(q.size()-1), q.tail()); auto r = q.extract_range(0,bitPrecision); auto slice1 = q.extract_range({0, 3}); auto slice2 = q.extract_range({0, 5, 2}); auto slice1 = q.extract_range({static_cast<size_t>(0), static_cast<size_t>(3)}); auto slice2 = q.extract_range({static_cast<size_t>(0), static_cast<size_t>(5), static_cast<size_t>(2)}); )#"; EXPECT_EQ(expectedCodeGen, ss.str()); } Loading Loading @@ -199,13 +199,37 @@ TEST(PyXASMTokenCollectorTester, checkBroadCastWithSlice) { R"#(quantum::x(q.head(q.size()-1)); quantum::x(q[0]); quantum::x(q); quantum::x(q.extract_range({0, 2})); quantum::x(q.extract_range({0, 5, 2})); quantum::x(q.extract_range({static_cast<size_t>(0), static_cast<size_t>(2)})); quantum::x(q.extract_range({static_cast<size_t>(0), static_cast<size_t>(5), static_cast<size_t>(2)})); quantum::mz(q.head(q.size()-1)); quantum::mz(q[0]); quantum::mz(q); quantum::mz(q.extract_range({0, 2})); quantum::mz(q.extract_range({0, 5, 2})); quantum::mz(q.extract_range({static_cast<size_t>(0), static_cast<size_t>(2)})); quantum::mz(q.extract_range({static_cast<size_t>(0), static_cast<size_t>(5), static_cast<size_t>(2)})); )#"; EXPECT_EQ(expectedCodeGen, ss.str()); } TEST(PyXASMTokenCollectorTester, checkQcorOperators) { LexerHelper helper; auto [tokens, PP] = helper.Lex(R"( exponent_op = X(0) * Y(1) - Y(0) * X(1) exp_i_theta(q, theta, exponent_op) )"); clang::CachedTokens cached; for (auto &t : tokens) { cached.push_back(t); } std::stringstream ss; auto xasm_tc = xacc::getService<qcor::TokenCollector>("pyxasm"); xasm_tc->collect(*PP.get(), cached, {"q"}, ss); std::cout << "heres the test\n"; std::cout << ss.str() << "\n"; const std::string expectedCodeGen = R"#(auto exponent_op = X(0)*Y(1)-Y(0)*X(1); quantum::exp(q, theta, exponent_op); )#"; EXPECT_EQ(expectedCodeGen, ss.str()); } Loading
python/tests/test_jit_simple.py +33 −0 Original line number Diff line number Diff line Loading @@ -179,5 +179,38 @@ class TestKernelJIT(unittest.TestCase): self.assertEqual(comp2.getInstruction(1).name(), "Z") self.assertEqual(comp3.getInstruction(1).name(), "T") def test_instBroadCast(self): set_qpu('qpp', {'shots':1024}) @qjit def broadCastTest(q : qreg): # Simple broadcast X(q) # broadcast by slice Z(q[0:q.size()]) # Even qubits Y(q[0:q.size():2]) q = qalloc(6) comp = broadCastTest.extract_composite(q) counter = 0 for i in range(q.size()): self.assertEqual(comp.getInstruction(counter).name(), "X") self.assertEqual(comp.getInstruction(counter).bits()[0], i) counter += 1 for i in range(q.size()): self.assertEqual(comp.getInstruction(counter).name(), "Z") self.assertEqual(comp.getInstruction(counter).bits()[0], i) counter += 1 for i in range(0, q.size(), 2): self.assertEqual(comp.getInstruction(counter).name(), "Y") self.assertEqual(comp.getInstruction(counter).bits()[0], i) counter += 1 self.assertEqual(comp.nInstructions(), counter) if __name__ == '__main__': unittest.main() No newline at end of file
runtime/qrt/qrt.cpp +15 −15 Original line number Diff line number Diff line Loading @@ -225,80 +225,80 @@ void persistBitstring(xacc::AcceleratorBuffer *buffer) { } } void h(qreg &q) { void h(qreg q) { for (int i = 0; i < q.size(); i++) { h(q[i]); } } void x(qreg &q) { void x(qreg q) { for (int i = 0; i < q.size(); i++) { x(q[i]); } } void y(qreg &q) { void y(qreg q) { for (int i = 0; i < q.size(); i++) { y(q[i]); } } void z(qreg &q) { void z(qreg q) { for (int i = 0; i < q.size(); i++) { z(q[i]); } } void t(qreg &q) { void t(qreg q) { for (int i = 0; i < q.size(); i++) { t(q[i]); } } void tdg(qreg &q) { void tdg(qreg q) { for (int i = 0; i < q.size(); i++) { tdg(q[i]); } } void s(qreg &q) { void s(qreg q) { for (int i = 0; i < q.size(); i++) { s(q[i]); } } void sdg(qreg &q) { void sdg(qreg q) { for (int i = 0; i < q.size(); i++) { sdg(q[i]); } } void mz(qreg &q) { void mz(qreg q) { for (int i = 0; i < q.size(); i++) { mz(q[i]); } } void rx(qreg &q, const double theta) { void rx(qreg q, const double theta) { for (int i = 0; i < q.size(); i++) { rx(q[i], theta); } } void ry(qreg &q, const double theta) { void ry(qreg q, const double theta) { for (int i = 0; i < q.size(); i++) { ry(q[i], theta); } } void rz(qreg &q, const double theta) { void rz(qreg q, const double theta) { for (int i = 0; i < q.size(); i++) { rz(q[i], theta); } } // U1(theta) gate void u1(qreg &q, const double theta) { void u1(qreg q, const double theta) { for (int i = 0; i < q.size(); i++) { u1(q[i], theta); } } void u3(qreg &q, const double theta, const double phi, const double lambda) { void u3(qreg q, const double theta, const double phi, const double lambda) { for (int i = 0; i < q.size(); i++) { u3(q[i], theta, phi, lambda); } } void reset(qreg &q) { void reset(qreg q) { for (int i = 0; i < q.size(); i++) { reset(q[i]); } Loading
runtime/qrt/qrt.hpp +15 −15 Original line number Diff line number Diff line Loading @@ -125,15 +125,15 @@ void sdg(const qubit &qidx); void reset(const qubit &qidx); // broadcast across qreg void h(qreg &q); void x(qreg &q); void y(qreg &q); void z(qreg &q); void t(qreg &q); void tdg(qreg &q); void s(qreg &q); void sdg(qreg &q); void reset(qreg &qidx); void h(qreg q); void x(qreg q); void y(qreg q); void z(qreg q); void t(qreg q); void tdg(qreg q); void s(qreg q); void sdg(qreg q); void reset(qreg qidx); // Common single-qubit, parameterized instructions void rx(const qubit &qidx, const double theta); Loading @@ -145,17 +145,17 @@ void u3(const qubit &qidx, const double theta, const double phi, const double lambda); // broadcast rotations across qubits void rx(qreg &qidx, const double theta); void ry(qreg &qidx, const double theta); void rz(qreg &qidx, const double theta); void rx(qreg qidx, const double theta); void ry(qreg qidx, const double theta); void rz(qreg qidx, const double theta); // U1(theta) gate void u1(qreg &qidx, const double theta); void u3(qreg &qidx, const double theta, const double phi, void u1(qreg qidx, const double theta); void u3(qreg qidx, const double theta, const double phi, const double lambda); // Measure-Z and broadcast mz bool mz(const qubit &qidx); void mz(qreg &q); void mz(qreg q); // Common two-qubit gates. void cnot(const qubit &src_idx, const qubit &tgt_idx); Loading