Loading handlers/token_collector/pyxasm/pyxasm_visitor.hpp +32 −24 Original line number Diff line number Diff line Loading @@ -135,7 +135,9 @@ class pyxasm_visitor : public pyxasmBaseVisitor { assert(subscriptTerms.size() == 2 || subscriptTerms.size() == 3); for (int i = 0; i < subscriptTerms.size(); ++i) { sub_node_translation << subscriptTerms[i]; // Need to cast to prevent compiler errors, // e.g. when using q.size() which returns an int. sub_node_translation << "static_cast<size_t>(" << subscriptTerms[i] << ")"; if (i != subscriptTerms.size() - 1) { sub_node_translation << ", "; } Loading Loading @@ -167,28 +169,7 @@ class pyxasm_visitor : public pyxasmBaseVisitor { ss << context->atom()->getText() << "::" << methodName << "(parent_kernel"; for (int i = 0; i < arg_list->argument().size(); i++) { // Strategy: // Traverse down the tree to see if the there is a potential translation: // i.e. it will populate sub_node_translation stream. // Otherwise, output the argument *as-is* // clear the sub_node_translation sub_node_translation.str(std::string()); // visit arg sub-node: visitChildren(arg_list->argument(i)); // Check if there is a rewrite: if (!sub_node_translation.str().empty()) { const auto arg_new_str = sub_node_translation.str(); std::cout << arg_list->argument(i)->getText() << " --> " << arg_new_str << "\n"; sub_node_translation.str(std::string()); ss << ", " << arg_new_str; } else { // Use the arg as is: ss << ", " << arg_list->argument(i)->getText(); } ss << ", " << rewriteFunctionArgument(*(arg_list->argument(i))); } ss << ");\n"; Loading Loading @@ -231,7 +212,7 @@ class pyxasm_visitor : public pyxasmBaseVisitor { std::vector<std::string> buffer_names; for (int i = 0; i < required_bits; i++) { auto bit_expr = context->trailer()[0]->arglist()->argument()[i]; auto bit_expr_str = bit_expr->getText(); auto bit_expr_str = rewriteFunctionArgument(*bit_expr); auto found_bracket = bit_expr_str.find_first_of("["); if (found_bracket != std::string::npos) { Loading Loading @@ -556,4 +537,31 @@ class pyxasm_visitor : public pyxasmBaseVisitor { return in_expr; } } // A helper to rewrite function argument by traversing the node to see // if there is a potential rewrite. // Use case: inline expressions // e.g. X(q[0:3]) // slicing of the qreg 'q' then call the broadcast X op. // i.e., we need to rewrite the arg to q.extract_range({0, 3}). std::string rewriteFunctionArgument(pyxasmParser::ArgumentContext &in_argContext) { // Strategy: try to traverse the argument context to see if there is a // possible rewrite; i.e. it may be another atom_expression that we have a // handler for. Otherwise, use the text as is. // clear the sub_node_translation sub_node_translation.str(std::string()); // visit arg sub-node: visitChildren(&in_argContext); // Check if there is a rewrite: if (!sub_node_translation.str().empty()) { // Update RHS return sub_node_translation.str(); } // Returns the string as is return in_argContext.getText(); } }; No newline at end of file handlers/token_collector/pyxasm/tests/PyXASMTokenCollectorTester.cpp +40 −0 Original line number Diff line number Diff line Loading @@ -170,6 +170,46 @@ auto slice2 = q.extract_range({0, 5, 2}); EXPECT_EQ(expectedCodeGen, ss.str()); } TEST(PyXASMTokenCollectorTester, checkBroadCastWithSlice) { LexerHelper helper; auto [tokens, PP] = helper.Lex(R"( X(q.head(q.size()-1)) X(q[0]) X(q) X(q[0:2]) X(q[0:5:2]) Measure(q.head(q.size()-1)) Measure(q[0]) Measure(q) Measure(q[0:2]) Measure(q[0:5:2]) )"); clang::CachedTokens cached; for (auto &t : tokens) { cached.push_back(t); } std::stringstream ss; auto xasm_tc = xacc::getService<qcor::TokenCollector>("pyxasm"); xasm_tc->collect(*PP.get(), cached, {"q"}, ss); std::cout << "heres the test\n"; std::cout << ss.str() << "\n"; const std::string expectedCodeGen = R"#(quantum::x(q.head(q.size()-1)); quantum::x(q[0]); quantum::x(q); quantum::x(q.extract_range({0, 2})); quantum::x(q.extract_range({0, 5, 2})); quantum::mz(q.head(q.size()-1)); quantum::mz(q[0]); quantum::mz(q); quantum::mz(q.extract_range({0, 2})); quantum::mz(q.extract_range({0, 5, 2})); )#"; EXPECT_EQ(expectedCodeGen, ss.str()); } int main(int argc, char **argv) { std::string xacc_config_install_dir = std::string(XACC_INSTALL_DIR); std::string qcor_root = std::string(QCOR_INSTALL_DIR); Loading python/tests/test_jit_multi_ctrl.py +86 −0 Original line number Diff line number Diff line Loading @@ -33,5 +33,91 @@ class TestKernelJIT(unittest.TestCase): # q0: 1 --> 0 self.assertTrue('0111' in counts) def test_qreg_head_tail(self): set_qpu('qpp', {'shots':1024}) @qjit def test_cccx_qreg(q : qreg): # Broadcast X(q) # 3 control bits ctrl_qubits = q.tail(q.size() - 1) first_qubit = q.head() X.ctrl(ctrl_qubits, first_qubit) # # Broadcast Measure(q) q = qalloc(4) comp = test_cccx_qreg.extract_composite(q) print(comp) # Run experiment test_cccx_qreg(q) # Print the results q.print() counts = q.counts() print(counts) self.assertEqual(len(counts), 1) # q0: 1 --> 0 self.assertTrue('0111' in counts) def test_qreg_slicing(self): set_qpu('qpp', {'shots':1024}) @qjit def test_cccx_qreg_slice(q : qreg): # Broadcast X(q) # 3 control bits: # q[0], q[1], q[2] ctrl_qubits = q[0:3] last_qubit = q.tail() X.ctrl(ctrl_qubits, last_qubit) # Broadcast Measure(q) q = qalloc(4) comp = test_cccx_qreg_slice.extract_composite(q) print(comp) # Run experiment test_cccx_qreg_slice(q) # Print the results q.print() counts = q.counts() print(counts) self.assertEqual(len(counts), 1) # q3: 1 --> 0 self.assertTrue('1110' in counts) def test_qreg_slicing_inline(self): set_qpu('qpp', {'shots':1024}) @qjit def test_cccx_qreg_slice_inline(q : qreg): # Broadcast via a slice X(q) # Control with slicing inline X.ctrl(q[0:3], q.tail()) # Broadcast Measure(q) q = qalloc(4) comp = test_cccx_qreg_slice_inline.extract_composite(q) print(comp) # Run experiment test_cccx_qreg_slice_inline(q) # Print the results q.print() counts = q.counts() print(counts) self.assertEqual(len(counts), 1) # q3: 1 --> 0 self.assertTrue('1110' in counts) if __name__ == '__main__': unittest.main() No newline at end of file Loading
handlers/token_collector/pyxasm/pyxasm_visitor.hpp +32 −24 Original line number Diff line number Diff line Loading @@ -135,7 +135,9 @@ class pyxasm_visitor : public pyxasmBaseVisitor { assert(subscriptTerms.size() == 2 || subscriptTerms.size() == 3); for (int i = 0; i < subscriptTerms.size(); ++i) { sub_node_translation << subscriptTerms[i]; // Need to cast to prevent compiler errors, // e.g. when using q.size() which returns an int. sub_node_translation << "static_cast<size_t>(" << subscriptTerms[i] << ")"; if (i != subscriptTerms.size() - 1) { sub_node_translation << ", "; } Loading Loading @@ -167,28 +169,7 @@ class pyxasm_visitor : public pyxasmBaseVisitor { ss << context->atom()->getText() << "::" << methodName << "(parent_kernel"; for (int i = 0; i < arg_list->argument().size(); i++) { // Strategy: // Traverse down the tree to see if the there is a potential translation: // i.e. it will populate sub_node_translation stream. // Otherwise, output the argument *as-is* // clear the sub_node_translation sub_node_translation.str(std::string()); // visit arg sub-node: visitChildren(arg_list->argument(i)); // Check if there is a rewrite: if (!sub_node_translation.str().empty()) { const auto arg_new_str = sub_node_translation.str(); std::cout << arg_list->argument(i)->getText() << " --> " << arg_new_str << "\n"; sub_node_translation.str(std::string()); ss << ", " << arg_new_str; } else { // Use the arg as is: ss << ", " << arg_list->argument(i)->getText(); } ss << ", " << rewriteFunctionArgument(*(arg_list->argument(i))); } ss << ");\n"; Loading Loading @@ -231,7 +212,7 @@ class pyxasm_visitor : public pyxasmBaseVisitor { std::vector<std::string> buffer_names; for (int i = 0; i < required_bits; i++) { auto bit_expr = context->trailer()[0]->arglist()->argument()[i]; auto bit_expr_str = bit_expr->getText(); auto bit_expr_str = rewriteFunctionArgument(*bit_expr); auto found_bracket = bit_expr_str.find_first_of("["); if (found_bracket != std::string::npos) { Loading Loading @@ -556,4 +537,31 @@ class pyxasm_visitor : public pyxasmBaseVisitor { return in_expr; } } // A helper to rewrite function argument by traversing the node to see // if there is a potential rewrite. // Use case: inline expressions // e.g. X(q[0:3]) // slicing of the qreg 'q' then call the broadcast X op. // i.e., we need to rewrite the arg to q.extract_range({0, 3}). std::string rewriteFunctionArgument(pyxasmParser::ArgumentContext &in_argContext) { // Strategy: try to traverse the argument context to see if there is a // possible rewrite; i.e. it may be another atom_expression that we have a // handler for. Otherwise, use the text as is. // clear the sub_node_translation sub_node_translation.str(std::string()); // visit arg sub-node: visitChildren(&in_argContext); // Check if there is a rewrite: if (!sub_node_translation.str().empty()) { // Update RHS return sub_node_translation.str(); } // Returns the string as is return in_argContext.getText(); } }; No newline at end of file
handlers/token_collector/pyxasm/tests/PyXASMTokenCollectorTester.cpp +40 −0 Original line number Diff line number Diff line Loading @@ -170,6 +170,46 @@ auto slice2 = q.extract_range({0, 5, 2}); EXPECT_EQ(expectedCodeGen, ss.str()); } TEST(PyXASMTokenCollectorTester, checkBroadCastWithSlice) { LexerHelper helper; auto [tokens, PP] = helper.Lex(R"( X(q.head(q.size()-1)) X(q[0]) X(q) X(q[0:2]) X(q[0:5:2]) Measure(q.head(q.size()-1)) Measure(q[0]) Measure(q) Measure(q[0:2]) Measure(q[0:5:2]) )"); clang::CachedTokens cached; for (auto &t : tokens) { cached.push_back(t); } std::stringstream ss; auto xasm_tc = xacc::getService<qcor::TokenCollector>("pyxasm"); xasm_tc->collect(*PP.get(), cached, {"q"}, ss); std::cout << "heres the test\n"; std::cout << ss.str() << "\n"; const std::string expectedCodeGen = R"#(quantum::x(q.head(q.size()-1)); quantum::x(q[0]); quantum::x(q); quantum::x(q.extract_range({0, 2})); quantum::x(q.extract_range({0, 5, 2})); quantum::mz(q.head(q.size()-1)); quantum::mz(q[0]); quantum::mz(q); quantum::mz(q.extract_range({0, 2})); quantum::mz(q.extract_range({0, 5, 2})); )#"; EXPECT_EQ(expectedCodeGen, ss.str()); } int main(int argc, char **argv) { std::string xacc_config_install_dir = std::string(XACC_INSTALL_DIR); std::string qcor_root = std::string(QCOR_INSTALL_DIR); Loading
python/tests/test_jit_multi_ctrl.py +86 −0 Original line number Diff line number Diff line Loading @@ -33,5 +33,91 @@ class TestKernelJIT(unittest.TestCase): # q0: 1 --> 0 self.assertTrue('0111' in counts) def test_qreg_head_tail(self): set_qpu('qpp', {'shots':1024}) @qjit def test_cccx_qreg(q : qreg): # Broadcast X(q) # 3 control bits ctrl_qubits = q.tail(q.size() - 1) first_qubit = q.head() X.ctrl(ctrl_qubits, first_qubit) # # Broadcast Measure(q) q = qalloc(4) comp = test_cccx_qreg.extract_composite(q) print(comp) # Run experiment test_cccx_qreg(q) # Print the results q.print() counts = q.counts() print(counts) self.assertEqual(len(counts), 1) # q0: 1 --> 0 self.assertTrue('0111' in counts) def test_qreg_slicing(self): set_qpu('qpp', {'shots':1024}) @qjit def test_cccx_qreg_slice(q : qreg): # Broadcast X(q) # 3 control bits: # q[0], q[1], q[2] ctrl_qubits = q[0:3] last_qubit = q.tail() X.ctrl(ctrl_qubits, last_qubit) # Broadcast Measure(q) q = qalloc(4) comp = test_cccx_qreg_slice.extract_composite(q) print(comp) # Run experiment test_cccx_qreg_slice(q) # Print the results q.print() counts = q.counts() print(counts) self.assertEqual(len(counts), 1) # q3: 1 --> 0 self.assertTrue('1110' in counts) def test_qreg_slicing_inline(self): set_qpu('qpp', {'shots':1024}) @qjit def test_cccx_qreg_slice_inline(q : qreg): # Broadcast via a slice X(q) # Control with slicing inline X.ctrl(q[0:3], q.tail()) # Broadcast Measure(q) q = qalloc(4) comp = test_cccx_qreg_slice_inline.extract_composite(q) print(comp) # Run experiment test_cccx_qreg_slice_inline(q) # Print the results q.print() counts = q.counts() print(counts) self.assertEqual(len(counts), 1) # q3: 1 --> 0 self.assertTrue('1110' in counts) if __name__ == '__main__': unittest.main() No newline at end of file