Loading handlers/token_collector/pyxasm/pyxasm_visitor.hpp +9 −1 Original line number Diff line number Diff line Loading @@ -178,7 +178,15 @@ class pyxasm_visitor : public pyxasmBaseVisitor { std::stringstream ss; // Remove the first '.' character const std::string methodName = context->trailer()[0]->getText().substr(1); ss << context->atom()->getText() << "::" << methodName // If this is a *variable*, then using '.' for control/adjoint. // Otherwise, use '::' (global scope kernel names) const std::string separator = (xacc::container::contains(declared_var_names, context->atom()->getText())) ? "." : "::"; ss << context->atom()->getText() << separator << methodName << "(parent_kernel"; for (int i = 0; i < arg_list->argument().size(); i++) { ss << ", " << rewriteFunctionArgument(*(arg_list->argument(i))); Loading handlers/token_collector/pyxasm/tests/PyXASMTokenCollectorTester.cpp +27 −0 Original line number Diff line number Diff line Loading @@ -259,6 +259,33 @@ auto index = std::pow(2, n); EXPECT_EQ(expectedCodeGen, ss.str()); } TEST(PyXASMTokenCollectorTester, checkKernelSignature) { LexerHelper helper; auto [tokens, PP] = helper.Lex(R"( # fake local var creation # (should be from the function args if compiling full source.) callable = createCallable(a,b,c) callable.ctrl([q[1], q[2]], q[0]) callable.adjoint(q) )"); clang::CachedTokens cached; for (auto &t : tokens) { cached.push_back(t); } std::stringstream ss; auto xasm_tc = xacc::getService<qcor::TokenCollector>("pyxasm"); xasm_tc->collect(*PP.get(), cached, {"q"}, ss); std::cout << "heres the test\n"; std::cout << ss.str() << "\n"; const std::string code_gen_str = ss.str(); // Rewrite to '.' EXPECT_TRUE(code_gen_str.find("callable.ctrl(parent_kernel, {q[1], q[2]}, q[0]);") != std::string::npos); EXPECT_TRUE(code_gen_str.find("callable.adjoint(parent_kernel, q);") != std::string::npos); } int main(int argc, char **argv) { std::string xacc_config_install_dir = std::string(XACC_INSTALL_DIR); Loading python/tests/test_jit_kernel_signature.py +33 −0 Original line number Diff line number Diff line Loading @@ -38,5 +38,38 @@ class TestKernelJIT(unittest.TestCase): self.assertAlmostEqual((float)(comp.getInstruction(counter).getParameter(0)), i + 1.0) counter+=1 def test_kernel_signature_ctrl_adj(self): set_qpu('qpp', {'shots':1024}) @qjit def test_kernel1(q: qreg, call_var1: KernelSignature(qreg, int, float), call_var2: KernelSignature(qubit)): call_var1.adjoint(q, 0, 1.0) call_var1.adjoint(q, 1, 2.0) call_var2.ctrl(q[1], q[0]) # These kernels are unknown to test_kernel @qjit def rz_kernel(q: qreg, idx: int, theta: float): Rz(q[idx], theta) @qjit def x_kernel(q: qubit): X(q) q = qalloc(2) comp = test_kernel1.extract_composite(q, rz_kernel, x_kernel) print(comp) self.assertEqual(comp.nInstructions(), 3) counter = 0 for i in range(2): self.assertEqual(comp.getInstruction(counter).name(), "Rz") # Minus due to adjoint self.assertAlmostEqual((float)(comp.getInstruction(counter).getParameter(0)), -(i + 1.0)) counter+=1 self.assertEqual(comp.getInstruction(2).name(), "CNOT") self.assertEqual(comp.getInstruction(2).bits()[0], 1) self.assertEqual(comp.getInstruction(2).bits()[1], 0) if __name__ == '__main__': unittest.main() No newline at end of file Loading
handlers/token_collector/pyxasm/pyxasm_visitor.hpp +9 −1 Original line number Diff line number Diff line Loading @@ -178,7 +178,15 @@ class pyxasm_visitor : public pyxasmBaseVisitor { std::stringstream ss; // Remove the first '.' character const std::string methodName = context->trailer()[0]->getText().substr(1); ss << context->atom()->getText() << "::" << methodName // If this is a *variable*, then using '.' for control/adjoint. // Otherwise, use '::' (global scope kernel names) const std::string separator = (xacc::container::contains(declared_var_names, context->atom()->getText())) ? "." : "::"; ss << context->atom()->getText() << separator << methodName << "(parent_kernel"; for (int i = 0; i < arg_list->argument().size(); i++) { ss << ", " << rewriteFunctionArgument(*(arg_list->argument(i))); Loading
handlers/token_collector/pyxasm/tests/PyXASMTokenCollectorTester.cpp +27 −0 Original line number Diff line number Diff line Loading @@ -259,6 +259,33 @@ auto index = std::pow(2, n); EXPECT_EQ(expectedCodeGen, ss.str()); } TEST(PyXASMTokenCollectorTester, checkKernelSignature) { LexerHelper helper; auto [tokens, PP] = helper.Lex(R"( # fake local var creation # (should be from the function args if compiling full source.) callable = createCallable(a,b,c) callable.ctrl([q[1], q[2]], q[0]) callable.adjoint(q) )"); clang::CachedTokens cached; for (auto &t : tokens) { cached.push_back(t); } std::stringstream ss; auto xasm_tc = xacc::getService<qcor::TokenCollector>("pyxasm"); xasm_tc->collect(*PP.get(), cached, {"q"}, ss); std::cout << "heres the test\n"; std::cout << ss.str() << "\n"; const std::string code_gen_str = ss.str(); // Rewrite to '.' EXPECT_TRUE(code_gen_str.find("callable.ctrl(parent_kernel, {q[1], q[2]}, q[0]);") != std::string::npos); EXPECT_TRUE(code_gen_str.find("callable.adjoint(parent_kernel, q);") != std::string::npos); } int main(int argc, char **argv) { std::string xacc_config_install_dir = std::string(XACC_INSTALL_DIR); Loading
python/tests/test_jit_kernel_signature.py +33 −0 Original line number Diff line number Diff line Loading @@ -38,5 +38,38 @@ class TestKernelJIT(unittest.TestCase): self.assertAlmostEqual((float)(comp.getInstruction(counter).getParameter(0)), i + 1.0) counter+=1 def test_kernel_signature_ctrl_adj(self): set_qpu('qpp', {'shots':1024}) @qjit def test_kernel1(q: qreg, call_var1: KernelSignature(qreg, int, float), call_var2: KernelSignature(qubit)): call_var1.adjoint(q, 0, 1.0) call_var1.adjoint(q, 1, 2.0) call_var2.ctrl(q[1], q[0]) # These kernels are unknown to test_kernel @qjit def rz_kernel(q: qreg, idx: int, theta: float): Rz(q[idx], theta) @qjit def x_kernel(q: qubit): X(q) q = qalloc(2) comp = test_kernel1.extract_composite(q, rz_kernel, x_kernel) print(comp) self.assertEqual(comp.nInstructions(), 3) counter = 0 for i in range(2): self.assertEqual(comp.getInstruction(counter).name(), "Rz") # Minus due to adjoint self.assertAlmostEqual((float)(comp.getInstruction(counter).getParameter(0)), -(i + 1.0)) counter+=1 self.assertEqual(comp.getInstruction(2).name(), "CNOT") self.assertEqual(comp.getInstruction(2).bits()[0], 1) self.assertEqual(comp.getInstruction(2).bits()[1], 0) if __name__ == '__main__': unittest.main() No newline at end of file