Loading handlers/qcor_syntax_handler.cpp +19 −0 Original line number Diff line number Diff line Loading @@ -399,6 +399,25 @@ void QCORSyntaxHandler::GetReplacement( // We just pass this copied var to the ctor // where it expects a reference type. arg_ctor_list.emplace_back(new_var_name); } else if (program_arg_types[i].rfind("KernelSignature", 0) == 0) { // This is a KernelSignature argument. // The one in HetMap is the function pointer represented as a hex string. const std::string new_var_name = "__temp_kernel_ptr_var__" + std::to_string(var_counter++); // Retrieve the function pointer from the HetMap // ref_type_copy_decl_ss << "std::cout << args.getString(\"" // << program_parameters[i] << "\").c_str() << std::endl;\n"; ref_type_copy_decl_ss << "void* " << new_var_name << " = " << "(void *) strtoull(args.getString(\"" << program_parameters[i] << "\").c_str(), nullptr, 16);\n"; // ref_type_copy_decl_ss << "std::cout << " << new_var_name << " << std::endl;\n"; // Construct the KernelSignature const std::string kernel_signature_var_name = "__temp_kernel_signature_var__" + std::to_string(var_counter++); ref_type_copy_decl_ss << program_arg_types[i] << " " << kernel_signature_var_name << "(" << new_var_name << ");\n"; arg_ctor_list.emplace_back(kernel_signature_var_name); } else { // Otherwise, just unpack the arg inline in the ctor call. std::stringstream ss; Loading handlers/token_collector/pyxasm/pyxasm_visitor.hpp +194 −26 Original line number Diff line number Diff line Loading @@ -33,9 +33,136 @@ class pyxasm_visitor : public pyxasmBaseVisitor { // New var declared (auto type) after visiting this node. std::string new_var; bool in_for_loop = false; // Var to keep track of sub-node rewrite: // e.g., traverse down the AST recursively. std::stringstream sub_node_translation; bool is_processing_sub_expr = false; antlrcpp::Any visitAtom_expr( pyxasmParser::Atom_exprContext *context) override { // std::cout << "Atom_exprContext: " << context->getText() << "\n"; // Strategy: // At the top level, we analyze the trailer to determine the // list of function call arguments. // Then, traverse down the arg. node to see if there is a potential rewrite rules // e.g. for arrays (as testlist_comp nodes) // Otherwise, just get the argument text as is. /* atom_expr: (AWAIT)? atom trailer*; atom: ('(' (yield_expr|testlist_comp)? ')' | '[' (testlist_comp)? ']' | '{' (dictorsetmaker)? '}' | NAME | NUMBER | STRING+ | '...' | 'None' | 'True' | 'False'); */ // Only processes these for sub-expressesions, // e.g. re-entries to this function if (is_processing_sub_expr) { if (context->atom() && context->atom()->OPEN_BRACK() && context->atom()->CLOSE_BRACK() && context->atom()->testlist_comp()) { // Array type expression: // std::cout << "Array atom expression: " // << context->atom()->testlist_comp()->getText() << "\n"; // Use braces sub_node_translation << "{"; bool firstElProcessed = false; for (auto &testNode : context->atom()->testlist_comp()->test()) { // std::cout << "Array elem: " << testNode->getText() << "\n"; // Add comma if needed (there is a previous element) if (firstElProcessed) { sub_node_translation << ", "; } sub_node_translation << testNode->getText(); firstElProcessed = true; } sub_node_translation << "}"; return 0; } // We don't have a re-write rule for this one (py::dict) if (context->atom() && context->atom()->OPEN_BRACE() && context->atom()->CLOSE_BRACE() && context->atom()->dictorsetmaker()) { // Dict: // std::cout << "Dict atom expression: " // << context->atom()->dictorsetmaker()->getText() << "\n"; // TODO: return 0; } if (context->atom() && !context->atom()->STRING().empty()) { // Strings: for (auto &strNode : context->atom()->STRING()) { std::string cppStrLiteral = strNode->getText(); // Handle Python single-quotes if (cppStrLiteral.front() == '\'' && cppStrLiteral.back() == '\'') { cppStrLiteral.front() = '"'; cppStrLiteral.back() = '"'; } sub_node_translation << cppStrLiteral; // std::cout << "String expression: " << strNode->getText() << " --> " // << cppStrLiteral << "\n"; } return 0; } const auto isSliceOp = [](pyxasmParser::Atom_exprContext *atom_expr_context) -> bool { if (atom_expr_context->trailer().size() == 1) { auto subscriptlist = atom_expr_context->trailer(0)->subscriptlist(); if (subscriptlist && subscriptlist->subscript().size() == 1) { auto subscript = subscriptlist->subscript(0); const auto nbTestTerms = subscript->test().size(); // Multiple test terms (separated by ':') return (nbTestTerms > 1); } } return false; }; // Handle slicing operations (multiple array subscriptions separated by // ':') on a qreg. if (context->atom() && xacc::container::contains(bufferNames, context->atom()->getText()) && isSliceOp(context)) { // std::cout << "Slice op: " << context->getText() << "\n"; sub_node_translation << context->atom()->getText() << ".extract_range({"; auto subscripts = context->trailer(0)->subscriptlist()->subscript(0)->test(); assert(subscripts.size() > 1); std::vector<std::string> subscriptTerms; for (auto &test : subscripts) { subscriptTerms.emplace_back(test->getText()); } auto sliceOp = context->trailer(0)->subscriptlist()->subscript(0)->sliceop(); if (sliceOp && sliceOp->test()) { subscriptTerms.emplace_back(sliceOp->test()->getText()); } assert(subscriptTerms.size() == 2 || subscriptTerms.size() == 3); for (int i = 0; i < subscriptTerms.size(); ++i) { // Need to cast to prevent compiler errors, // e.g. when using q.size() which returns an int. sub_node_translation << "static_cast<size_t>(" << subscriptTerms[i] << ")"; if (i != subscriptTerms.size() - 1) { sub_node_translation << ", "; } } sub_node_translation << "})"; // convert the slice op to initializer list: // std::cout << "Slice Convert: " << context->getText() << " --> " // << sub_node_translation.str() << "\n"; return 0; } return 0; } // Handle kernel::ctrl(...), kernel::adjoint(...) if (!context->trailer().empty() && (context->trailer()[0]->getText() == ".ctrl" || Loading @@ -54,7 +181,7 @@ class pyxasm_visitor : public pyxasmBaseVisitor { ss << context->atom()->getText() << "::" << methodName << "(parent_kernel"; for (int i = 0; i < arg_list->argument().size(); i++) { ss << ", " << arg_list->argument(i)->getText(); ss << ", " << rewriteFunctionArgument(*(arg_list->argument(i))); } ss << ");\n"; Loading Loading @@ -97,7 +224,7 @@ class pyxasm_visitor : public pyxasmBaseVisitor { std::vector<std::string> buffer_names; for (int i = 0; i < required_bits; i++) { auto bit_expr = context->trailer()[0]->arglist()->argument()[i]; auto bit_expr_str = bit_expr->getText(); auto bit_expr_str = rewriteFunctionArgument(*bit_expr); auto found_bracket = bit_expr_str.find_first_of("["); if (found_bracket != std::string::npos) { Loading Loading @@ -210,8 +337,24 @@ class pyxasm_visitor : public pyxasmBaseVisitor { // A classical call-like expression: i.e. not a kernel call: // Just output it *as-is* to the C++ stream. // We can hook more sophisticated code-gen here if required. // std::cout << "Callable: " << context->getText() << "\n"; std::stringstream ss; if (context->trailer()[0]->arglist() && !context->trailer()[0]->arglist()->argument().empty()) { const auto &argList = context->trailer()[0]->arglist()->argument(); ss << inst_name << "("; for (size_t i = 0; i < argList.size(); ++i) { ss << rewriteFunctionArgument(*(argList[i])); if (i != argList.size() - 1) { ss << ", "; } } ss << ");\n"; } else { ss << context->getText() << ";\n"; } result.first = ss.str(); } } Loading Loading @@ -250,7 +393,7 @@ class pyxasm_visitor : public pyxasmBaseVisitor { // Handle simple assignment: a = expr std::stringstream ss; const std::string lhs = ctx->testlist_star_expr(0)->getText(); const std::string rhs = replacePythonConstants( std::string rhs = replacePythonConstants( replaceMeasureAssignment(ctx->testlist_star_expr(1)->getText())); if (lhs.find(",") != std::string::npos) { Loading @@ -269,32 +412,27 @@ class pyxasm_visitor : public pyxasmBaseVisitor { } } } else { // Handle Python list assignment: // i.e. lhs = [a, b, c] => { a, b, c } // NOTE: we only support simple lists, i.e. no nested. const auto transformListAssignmentIfAny = [](const std::string &in_expr) -> std::string { const auto whitespace = " "; // Trim leading and trailing spaces: const auto strBegin = in_expr.find_first_not_of(whitespace); const auto strEnd = in_expr.find_last_not_of(whitespace); const auto strRange = strEnd - strBegin + 1; const auto trim_expr = in_expr.substr(strBegin, strRange); if (trim_expr.front() == '[' && trim_expr.back() == ']') { return "{" + trim_expr.substr(1, trim_expr.size() - 2) + "}"; } // Returns the original expression: return in_expr; }; // Strategy: try to traverse the rhs to see if there is a possible rewrite; // Otherwise, use the text as is. is_processing_sub_expr = true; // clear the sub_node_translation sub_node_translation.str(std::string()); // visit arg sub-node: visitChildren(ctx->testlist_star_expr(1)); // Check if there is a rewrite: if (!sub_node_translation.str().empty()) { // Update RHS rhs = replacePythonConstants( replaceMeasureAssignment(sub_node_translation.str())); } if (xacc::container::contains(declared_var_names, lhs)) { ss << lhs << " = " << transformListAssignmentIfAny(rhs) << "; \n"; ss << lhs << " = " << rhs << "; \n"; } else { // New variable: need to add *auto* ss << "auto " << lhs << " = " << transformListAssignmentIfAny(rhs) << "; \n"; ss << "auto " << lhs << " = " << rhs << "; \n"; new_var = lhs; } } Loading Loading @@ -396,4 +534,34 @@ class pyxasm_visitor : public pyxasmBaseVisitor { return in_expr; } } // A helper to rewrite function argument by traversing the node to see // if there is a potential rewrite. // Use case: inline expressions // e.g. X(q[0:3]) // slicing of the qreg 'q' then call the broadcast X op. // i.e., we need to rewrite the arg to q.extract_range({0, 3}). std::string rewriteFunctionArgument(pyxasmParser::ArgumentContext &in_argContext) { // Strategy: try to traverse the argument context to see if there is a // possible rewrite; i.e. it may be another atom_expression that we have a // handler for. Otherwise, use the text as is. // We need this flag to prevent parsing quantum instructions as sub-expressions. // e.g. QCOR operators (X, Y, Z) in an observable definition shouldn't be // processed as instructions. is_processing_sub_expr = true; // clear the sub_node_translation sub_node_translation.str(std::string()); // visit arg sub-node: visitChildren(&in_argContext); // Check if there is a rewrite: if (!sub_node_translation.str().empty()) { // Update RHS return sub_node_translation.str(); } // Returns the string as is return in_argContext.getText(); } }; No newline at end of file handlers/token_collector/pyxasm/tests/PyXASMTokenCollectorTester.cpp +185 −0 Original line number Diff line number Diff line Loading @@ -75,6 +75,191 @@ quantum::x(qb[i]); EXPECT_EQ(expectedCodeGen, ss.str()); } TEST(PyXASMTokenCollectorTester, checkPythonList) { LexerHelper helper; auto [tokens, PP] = helper.Lex(R"( # inline initializer list apply_X_at_idx.ctrl([q[1], q[2]], q[0]) # array var assignement array_val = [q[1], q[2]] )"); clang::CachedTokens cached; for (auto &t : tokens) { cached.push_back(t); } std::stringstream ss; auto xasm_tc = xacc::getService<qcor::TokenCollector>("pyxasm"); xasm_tc->collect(*PP.get(), cached, {"q"}, ss); std::cout << "heres the test\n"; std::cout << ss.str() << "\n"; const std::string expectedCodeGen = R"#(apply_X_at_idx::ctrl(parent_kernel, {q[1], q[2]}, q[0]); auto array_val = {q[1], q[2]}; )#"; EXPECT_EQ(expectedCodeGen, ss.str()); } TEST(PyXASMTokenCollectorTester, checkStringLiteral) { LexerHelper helper; auto [tokens, PP] = helper.Lex(R"( # Cpp style strings print("hello", 1, "world") # Python style print('howdy', 1, 'abc') )"); clang::CachedTokens cached; for (auto &t : tokens) { cached.push_back(t); } std::stringstream ss; auto xasm_tc = xacc::getService<qcor::TokenCollector>("pyxasm"); xasm_tc->collect(*PP.get(), cached, {}, ss); std::cout << "heres the test\n"; std::cout << ss.str() << "\n"; const std::string expectedCodeGen = R"#(print("hello", 1, "world"); print("howdy", 1, "abc"); )#"; EXPECT_EQ(expectedCodeGen, ss.str()); } TEST(PyXASMTokenCollectorTester, checkQregMethods) { LexerHelper helper; auto [tokens, PP] = helper.Lex(R"( ctrl_qubits = q.head(q.size()-1) last_qubit = q.tail() Z.ctrl(ctrl_qubits, last_qubit) # inline X.ctrl(q.head(q.size()-1), q.tail()) # range: # API r = q.extract_range(0, bitPrecision) # Python style slice1 = q[0:3] # step size slice2 = q[0:5:2] )"); clang::CachedTokens cached; for (auto &t : tokens) { cached.push_back(t); } std::stringstream ss; auto xasm_tc = xacc::getService<qcor::TokenCollector>("pyxasm"); xasm_tc->collect(*PP.get(), cached, {"q"}, ss); std::cout << "heres the test\n"; std::cout << ss.str() << "\n"; const std::string expectedCodeGen = R"#(auto ctrl_qubits = q.head(q.size()-1); auto last_qubit = q.tail(); Z::ctrl(parent_kernel, ctrl_qubits, last_qubit); X::ctrl(parent_kernel, q.head(q.size()-1), q.tail()); auto r = q.extract_range(0,bitPrecision); auto slice1 = q.extract_range({static_cast<size_t>(0), static_cast<size_t>(3)}); auto slice2 = q.extract_range({static_cast<size_t>(0), static_cast<size_t>(5), static_cast<size_t>(2)}); )#"; EXPECT_EQ(expectedCodeGen, ss.str()); } TEST(PyXASMTokenCollectorTester, checkBroadCastWithSlice) { LexerHelper helper; auto [tokens, PP] = helper.Lex(R"( X(q.head(q.size()-1)) X(q[0]) X(q) X(q[0:2]) X(q[0:5:2]) Measure(q.head(q.size()-1)) Measure(q[0]) Measure(q) Measure(q[0:2]) Measure(q[0:5:2]) )"); clang::CachedTokens cached; for (auto &t : tokens) { cached.push_back(t); } std::stringstream ss; auto xasm_tc = xacc::getService<qcor::TokenCollector>("pyxasm"); xasm_tc->collect(*PP.get(), cached, {"q"}, ss); std::cout << "heres the test\n"; std::cout << ss.str() << "\n"; const std::string expectedCodeGen = R"#(quantum::x(q.head(q.size()-1)); quantum::x(q[0]); quantum::x(q); quantum::x(q.extract_range({static_cast<size_t>(0), static_cast<size_t>(2)})); quantum::x(q.extract_range({static_cast<size_t>(0), static_cast<size_t>(5), static_cast<size_t>(2)})); quantum::mz(q.head(q.size()-1)); quantum::mz(q[0]); quantum::mz(q); quantum::mz(q.extract_range({static_cast<size_t>(0), static_cast<size_t>(2)})); quantum::mz(q.extract_range({static_cast<size_t>(0), static_cast<size_t>(5), static_cast<size_t>(2)})); )#"; EXPECT_EQ(expectedCodeGen, ss.str()); } TEST(PyXASMTokenCollectorTester, checkQcorOperators) { LexerHelper helper; auto [tokens, PP] = helper.Lex(R"( exponent_op = X(0) * Y(1) - Y(0) * X(1) exp_i_theta(q, theta, exponent_op) )"); clang::CachedTokens cached; for (auto &t : tokens) { cached.push_back(t); } std::stringstream ss; auto xasm_tc = xacc::getService<qcor::TokenCollector>("pyxasm"); xasm_tc->collect(*PP.get(), cached, {"q"}, ss); std::cout << "heres the test\n"; std::cout << ss.str() << "\n"; const std::string expectedCodeGen = R"#(auto exponent_op = X(0)*Y(1)-Y(0)*X(1); quantum::exp(q, theta, exponent_op); )#"; EXPECT_EQ(expectedCodeGen, ss.str()); } TEST(PyXASMTokenCollectorTester, checkCommonMath) { LexerHelper helper; auto [tokens, PP] = helper.Lex(R"( out_parity = oneCount - 2 * (oneCount / 2) # Power index = 2**n )"); clang::CachedTokens cached; for (auto &t : tokens) { cached.push_back(t); } std::stringstream ss; auto xasm_tc = xacc::getService<qcor::TokenCollector>("pyxasm"); xasm_tc->collect(*PP.get(), cached, {""}, ss); std::cout << "heres the test\n"; std::cout << ss.str() << "\n"; const std::string expectedCodeGen = R"#(auto out_parity = oneCount-2*(oneCount/2); auto index = std::pow(2, n); )#"; EXPECT_EQ(expectedCodeGen, ss.str()); } int main(int argc, char **argv) { std::string xacc_config_install_dir = std::string(XACC_INSTALL_DIR); std::string qcor_root = std::string(QCOR_INSTALL_DIR); Loading python/py-qcor.cpp +7 −1 Original line number Diff line number Diff line Loading @@ -596,7 +596,13 @@ PYBIND11_MODULE(_pyqcor, m) { } } return visitor.getMat(); }); }) .def( "get_kernel_function_ptr", [](qcor::QJIT &qjit, const std::string &kernel_name) { return qjit.get_kernel_function_ptr(kernel_name); }, ""); py::class_<qcor::ObjectiveFunction, std::shared_ptr<qcor::ObjectiveFunction>>( m, "ObjectiveFunction", "") Loading python/qcor.in.py +82 −13 Original line number Diff line number Diff line Loading @@ -16,6 +16,14 @@ from collections import defaultdict List = typing.List Tuple = typing.Tuple MethodType = types.MethodType Callable = typing.Callable # KernelSignature type annotation: # Usage: annotate an function argument as a KernelSignature by: # varName: KernelSignature(qreg, ...) # Kernel always returns void (None) def KernelSignature(*args): return Callable[list(args), None] # Static cache of all Python QJIT objects that have been created. # There seems to be a bug when a Python interpreter tried to create a new QJIT Loading Loading @@ -95,7 +103,7 @@ class KernelGraph(object): self.kernel_idx_dep_map = {} self.kernel_name_list = [] def addKernelDependency(self, kernelName, depList): def createKernelDependency(self, kernelName, depList): self.kernel_name_list.append(kernelName) self.kernel_idx_dep_map[self.V] = [] for dep_ker_name in depList: Loading @@ -103,6 +111,10 @@ class KernelGraph(object): self.kernel_name_list.index(dep_ker_name)) self.V += 1 def addKernelDependency(self, kernelName, newDep): self.kernel_idx_dep_map[self.kernel_name_list.index(kernelName)].append( self.kernel_name_list.index(newDep)) def addEdge(self, u, v): self.graph[u].append(v) Loading Loading @@ -396,6 +408,23 @@ class qjit(object): cpp_arg_str += ',' + \ 'int& ' + arg continue if str(_type).startswith('typing.Callable'): cpp_type_str = 'KernelSignature<' for i in range(len(_type.__args__) - 1): # print("input type:", _type.__args__[i]) arg_type = _type.__args__[i] if str(arg_type) not in self.allowed_type_cpp_map: print('Error, this quantum kernel arg type is not allowed: ', str(_type)) exit(1) cpp_type_str += self.allowed_type_cpp_map[str(arg_type)] cpp_type_str += ',' cpp_type_str = cpp_type_str[:-1] cpp_type_str += '>' # print("cpp type", cpp_type_str) cpp_arg_str += ',' + cpp_type_str + ' ' + arg continue if str(_type) not in self.allowed_type_cpp_map: print('Error, this quantum kernel arg type is not allowed: ', str(_type)) exit(1) Loading Loading @@ -460,7 +489,7 @@ class qjit(object): if re.search(r"\b" + re.escape(kernelCall) + '|' + re.escape(kernelAdjCall) + '|' + re.escape(kernelCtrlCall), self.src): dependency.append(kernelName) self.__kernels__graph.addKernelDependency( self.__kernels__graph.createKernelDependency( self.function.__name__, dependency) self.sorted_kernel_dep = self.__kernels__graph.getSortedDependency( self.function.__name__) Loading Loading @@ -554,9 +583,7 @@ class qjit(object): """ assert len(args) == len(self.arg_names), "Cannot create CompositeInstruction, you did not provided the correct kernel arguments." # Create a dictionary for the function arguments args_dict = {} for i, arg_name in enumerate(self.arg_names): args_dict[arg_name] = list(args)[i] args_dict = self.construct_arg_dict(*args) return self._qjit.extract_composite(self.function.__name__, args_dict) def observe(self, observable, *args): Loading Loading @@ -590,9 +617,7 @@ class qjit(object): return self.extract_composite(*args).nInstructions() def as_unitary_matrix(self, *args): args_dict = {} for i, arg_name in enumerate(self.arg_names): args_dict[arg_name] = list(args)[i] args_dict = self.construct_arg_dict(*args) return self._qjit.internal_as_unitary(self.function.__name__, args_dict) def ctrl(self, *args): Loading Loading @@ -622,16 +647,60 @@ class qjit(object): def qir(self, *args, **kwargs): return llvm_ir(*args, **kwargs) def __call__(self, *args): """ Execute the decorated quantum kernel. This will directly invoke the corresponding LLVM JITed function pointer. """ # Helper to construct the arg_dict (HetMap) # e.g. perform any additional type conversion if required. def construct_arg_dict(self, *args): # Create a dictionary for the function arguments args_dict = {} for i, arg_name in enumerate(self.arg_names): args_dict[arg_name] = list(args)[i] arg_type_str = str(self.type_annotations[arg_name]) if arg_type_str.startswith('typing.Callable'): # print("callable:", arg_name) # print("arg:", type(args_dict[arg_name])) # the arg must be a qjit if not isinstance(args_dict[arg_name], qjit): print('Invalid argument type for {}. A quantum kernel (qjit) is expected.'.format(arg_name)) exit(1) callable_qjit = args_dict[arg_name] # Handle runtime dependency: # The QJIT arg. was not *known* until invocation, # hence, we recompile the this jit kernel taking into account # the KernelSignature argument. # TODO: perhaps an optimization that we can make is to # skip *eager* compilation for those kernels that have # KernelSignature arguments. if callable_qjit.kernel_name() not in self.sorted_kernel_dep: # print('New kernel:', callable_qjit.kernel_name()) # IMPORTANT: we cannot release a QJIT object till shut-down. QJIT_OBJ_CACHE.append(self._qjit) # Create a new QJIT self._qjit = QJIT() # Add a kernel dependency self.__kernels__graph.addKernelDependency(self.function.__name__, callable_qjit.kernel_name()) self.sorted_kernel_dep = self.__kernels__graph.getSortedDependency(self.function.__name__) # Recompile: self._qjit.internal_python_jit_compile(self.src, self.sorted_kernel_dep, self.extra_cpp_code, extra_headers) # This should always be successful. fn_ptr = self._qjit.get_kernel_function_ptr(callable_qjit.kernel_name()) if fn_ptr == 0: print('Failed to retrieve JIT-compiled function pointer for qjit kernel {}.'.format(callable_qjit.kernel_name())) exit(1) # Replace the argument (in the dict) with the function pointer # qjit is a pure-Python object, hence cannot be used by native QCOR. args_dict[arg_name] = hex(fn_ptr) return args_dict def __call__(self, *args): """ Execute the decorated quantum kernel. This will directly invoke the corresponding LLVM JITed function pointer. """ args_dict = self.construct_arg_dict(*args) # Invoke the JITed function self._qjit.invoke(self.function.__name__, args_dict) Loading Loading
handlers/qcor_syntax_handler.cpp +19 −0 Original line number Diff line number Diff line Loading @@ -399,6 +399,25 @@ void QCORSyntaxHandler::GetReplacement( // We just pass this copied var to the ctor // where it expects a reference type. arg_ctor_list.emplace_back(new_var_name); } else if (program_arg_types[i].rfind("KernelSignature", 0) == 0) { // This is a KernelSignature argument. // The one in HetMap is the function pointer represented as a hex string. const std::string new_var_name = "__temp_kernel_ptr_var__" + std::to_string(var_counter++); // Retrieve the function pointer from the HetMap // ref_type_copy_decl_ss << "std::cout << args.getString(\"" // << program_parameters[i] << "\").c_str() << std::endl;\n"; ref_type_copy_decl_ss << "void* " << new_var_name << " = " << "(void *) strtoull(args.getString(\"" << program_parameters[i] << "\").c_str(), nullptr, 16);\n"; // ref_type_copy_decl_ss << "std::cout << " << new_var_name << " << std::endl;\n"; // Construct the KernelSignature const std::string kernel_signature_var_name = "__temp_kernel_signature_var__" + std::to_string(var_counter++); ref_type_copy_decl_ss << program_arg_types[i] << " " << kernel_signature_var_name << "(" << new_var_name << ");\n"; arg_ctor_list.emplace_back(kernel_signature_var_name); } else { // Otherwise, just unpack the arg inline in the ctor call. std::stringstream ss; Loading
handlers/token_collector/pyxasm/pyxasm_visitor.hpp +194 −26 Original line number Diff line number Diff line Loading @@ -33,9 +33,136 @@ class pyxasm_visitor : public pyxasmBaseVisitor { // New var declared (auto type) after visiting this node. std::string new_var; bool in_for_loop = false; // Var to keep track of sub-node rewrite: // e.g., traverse down the AST recursively. std::stringstream sub_node_translation; bool is_processing_sub_expr = false; antlrcpp::Any visitAtom_expr( pyxasmParser::Atom_exprContext *context) override { // std::cout << "Atom_exprContext: " << context->getText() << "\n"; // Strategy: // At the top level, we analyze the trailer to determine the // list of function call arguments. // Then, traverse down the arg. node to see if there is a potential rewrite rules // e.g. for arrays (as testlist_comp nodes) // Otherwise, just get the argument text as is. /* atom_expr: (AWAIT)? atom trailer*; atom: ('(' (yield_expr|testlist_comp)? ')' | '[' (testlist_comp)? ']' | '{' (dictorsetmaker)? '}' | NAME | NUMBER | STRING+ | '...' | 'None' | 'True' | 'False'); */ // Only processes these for sub-expressesions, // e.g. re-entries to this function if (is_processing_sub_expr) { if (context->atom() && context->atom()->OPEN_BRACK() && context->atom()->CLOSE_BRACK() && context->atom()->testlist_comp()) { // Array type expression: // std::cout << "Array atom expression: " // << context->atom()->testlist_comp()->getText() << "\n"; // Use braces sub_node_translation << "{"; bool firstElProcessed = false; for (auto &testNode : context->atom()->testlist_comp()->test()) { // std::cout << "Array elem: " << testNode->getText() << "\n"; // Add comma if needed (there is a previous element) if (firstElProcessed) { sub_node_translation << ", "; } sub_node_translation << testNode->getText(); firstElProcessed = true; } sub_node_translation << "}"; return 0; } // We don't have a re-write rule for this one (py::dict) if (context->atom() && context->atom()->OPEN_BRACE() && context->atom()->CLOSE_BRACE() && context->atom()->dictorsetmaker()) { // Dict: // std::cout << "Dict atom expression: " // << context->atom()->dictorsetmaker()->getText() << "\n"; // TODO: return 0; } if (context->atom() && !context->atom()->STRING().empty()) { // Strings: for (auto &strNode : context->atom()->STRING()) { std::string cppStrLiteral = strNode->getText(); // Handle Python single-quotes if (cppStrLiteral.front() == '\'' && cppStrLiteral.back() == '\'') { cppStrLiteral.front() = '"'; cppStrLiteral.back() = '"'; } sub_node_translation << cppStrLiteral; // std::cout << "String expression: " << strNode->getText() << " --> " // << cppStrLiteral << "\n"; } return 0; } const auto isSliceOp = [](pyxasmParser::Atom_exprContext *atom_expr_context) -> bool { if (atom_expr_context->trailer().size() == 1) { auto subscriptlist = atom_expr_context->trailer(0)->subscriptlist(); if (subscriptlist && subscriptlist->subscript().size() == 1) { auto subscript = subscriptlist->subscript(0); const auto nbTestTerms = subscript->test().size(); // Multiple test terms (separated by ':') return (nbTestTerms > 1); } } return false; }; // Handle slicing operations (multiple array subscriptions separated by // ':') on a qreg. if (context->atom() && xacc::container::contains(bufferNames, context->atom()->getText()) && isSliceOp(context)) { // std::cout << "Slice op: " << context->getText() << "\n"; sub_node_translation << context->atom()->getText() << ".extract_range({"; auto subscripts = context->trailer(0)->subscriptlist()->subscript(0)->test(); assert(subscripts.size() > 1); std::vector<std::string> subscriptTerms; for (auto &test : subscripts) { subscriptTerms.emplace_back(test->getText()); } auto sliceOp = context->trailer(0)->subscriptlist()->subscript(0)->sliceop(); if (sliceOp && sliceOp->test()) { subscriptTerms.emplace_back(sliceOp->test()->getText()); } assert(subscriptTerms.size() == 2 || subscriptTerms.size() == 3); for (int i = 0; i < subscriptTerms.size(); ++i) { // Need to cast to prevent compiler errors, // e.g. when using q.size() which returns an int. sub_node_translation << "static_cast<size_t>(" << subscriptTerms[i] << ")"; if (i != subscriptTerms.size() - 1) { sub_node_translation << ", "; } } sub_node_translation << "})"; // convert the slice op to initializer list: // std::cout << "Slice Convert: " << context->getText() << " --> " // << sub_node_translation.str() << "\n"; return 0; } return 0; } // Handle kernel::ctrl(...), kernel::adjoint(...) if (!context->trailer().empty() && (context->trailer()[0]->getText() == ".ctrl" || Loading @@ -54,7 +181,7 @@ class pyxasm_visitor : public pyxasmBaseVisitor { ss << context->atom()->getText() << "::" << methodName << "(parent_kernel"; for (int i = 0; i < arg_list->argument().size(); i++) { ss << ", " << arg_list->argument(i)->getText(); ss << ", " << rewriteFunctionArgument(*(arg_list->argument(i))); } ss << ");\n"; Loading Loading @@ -97,7 +224,7 @@ class pyxasm_visitor : public pyxasmBaseVisitor { std::vector<std::string> buffer_names; for (int i = 0; i < required_bits; i++) { auto bit_expr = context->trailer()[0]->arglist()->argument()[i]; auto bit_expr_str = bit_expr->getText(); auto bit_expr_str = rewriteFunctionArgument(*bit_expr); auto found_bracket = bit_expr_str.find_first_of("["); if (found_bracket != std::string::npos) { Loading Loading @@ -210,8 +337,24 @@ class pyxasm_visitor : public pyxasmBaseVisitor { // A classical call-like expression: i.e. not a kernel call: // Just output it *as-is* to the C++ stream. // We can hook more sophisticated code-gen here if required. // std::cout << "Callable: " << context->getText() << "\n"; std::stringstream ss; if (context->trailer()[0]->arglist() && !context->trailer()[0]->arglist()->argument().empty()) { const auto &argList = context->trailer()[0]->arglist()->argument(); ss << inst_name << "("; for (size_t i = 0; i < argList.size(); ++i) { ss << rewriteFunctionArgument(*(argList[i])); if (i != argList.size() - 1) { ss << ", "; } } ss << ");\n"; } else { ss << context->getText() << ";\n"; } result.first = ss.str(); } } Loading Loading @@ -250,7 +393,7 @@ class pyxasm_visitor : public pyxasmBaseVisitor { // Handle simple assignment: a = expr std::stringstream ss; const std::string lhs = ctx->testlist_star_expr(0)->getText(); const std::string rhs = replacePythonConstants( std::string rhs = replacePythonConstants( replaceMeasureAssignment(ctx->testlist_star_expr(1)->getText())); if (lhs.find(",") != std::string::npos) { Loading @@ -269,32 +412,27 @@ class pyxasm_visitor : public pyxasmBaseVisitor { } } } else { // Handle Python list assignment: // i.e. lhs = [a, b, c] => { a, b, c } // NOTE: we only support simple lists, i.e. no nested. const auto transformListAssignmentIfAny = [](const std::string &in_expr) -> std::string { const auto whitespace = " "; // Trim leading and trailing spaces: const auto strBegin = in_expr.find_first_not_of(whitespace); const auto strEnd = in_expr.find_last_not_of(whitespace); const auto strRange = strEnd - strBegin + 1; const auto trim_expr = in_expr.substr(strBegin, strRange); if (trim_expr.front() == '[' && trim_expr.back() == ']') { return "{" + trim_expr.substr(1, trim_expr.size() - 2) + "}"; } // Returns the original expression: return in_expr; }; // Strategy: try to traverse the rhs to see if there is a possible rewrite; // Otherwise, use the text as is. is_processing_sub_expr = true; // clear the sub_node_translation sub_node_translation.str(std::string()); // visit arg sub-node: visitChildren(ctx->testlist_star_expr(1)); // Check if there is a rewrite: if (!sub_node_translation.str().empty()) { // Update RHS rhs = replacePythonConstants( replaceMeasureAssignment(sub_node_translation.str())); } if (xacc::container::contains(declared_var_names, lhs)) { ss << lhs << " = " << transformListAssignmentIfAny(rhs) << "; \n"; ss << lhs << " = " << rhs << "; \n"; } else { // New variable: need to add *auto* ss << "auto " << lhs << " = " << transformListAssignmentIfAny(rhs) << "; \n"; ss << "auto " << lhs << " = " << rhs << "; \n"; new_var = lhs; } } Loading Loading @@ -396,4 +534,34 @@ class pyxasm_visitor : public pyxasmBaseVisitor { return in_expr; } } // A helper to rewrite function argument by traversing the node to see // if there is a potential rewrite. // Use case: inline expressions // e.g. X(q[0:3]) // slicing of the qreg 'q' then call the broadcast X op. // i.e., we need to rewrite the arg to q.extract_range({0, 3}). std::string rewriteFunctionArgument(pyxasmParser::ArgumentContext &in_argContext) { // Strategy: try to traverse the argument context to see if there is a // possible rewrite; i.e. it may be another atom_expression that we have a // handler for. Otherwise, use the text as is. // We need this flag to prevent parsing quantum instructions as sub-expressions. // e.g. QCOR operators (X, Y, Z) in an observable definition shouldn't be // processed as instructions. is_processing_sub_expr = true; // clear the sub_node_translation sub_node_translation.str(std::string()); // visit arg sub-node: visitChildren(&in_argContext); // Check if there is a rewrite: if (!sub_node_translation.str().empty()) { // Update RHS return sub_node_translation.str(); } // Returns the string as is return in_argContext.getText(); } }; No newline at end of file
handlers/token_collector/pyxasm/tests/PyXASMTokenCollectorTester.cpp +185 −0 Original line number Diff line number Diff line Loading @@ -75,6 +75,191 @@ quantum::x(qb[i]); EXPECT_EQ(expectedCodeGen, ss.str()); } TEST(PyXASMTokenCollectorTester, checkPythonList) { LexerHelper helper; auto [tokens, PP] = helper.Lex(R"( # inline initializer list apply_X_at_idx.ctrl([q[1], q[2]], q[0]) # array var assignement array_val = [q[1], q[2]] )"); clang::CachedTokens cached; for (auto &t : tokens) { cached.push_back(t); } std::stringstream ss; auto xasm_tc = xacc::getService<qcor::TokenCollector>("pyxasm"); xasm_tc->collect(*PP.get(), cached, {"q"}, ss); std::cout << "heres the test\n"; std::cout << ss.str() << "\n"; const std::string expectedCodeGen = R"#(apply_X_at_idx::ctrl(parent_kernel, {q[1], q[2]}, q[0]); auto array_val = {q[1], q[2]}; )#"; EXPECT_EQ(expectedCodeGen, ss.str()); } TEST(PyXASMTokenCollectorTester, checkStringLiteral) { LexerHelper helper; auto [tokens, PP] = helper.Lex(R"( # Cpp style strings print("hello", 1, "world") # Python style print('howdy', 1, 'abc') )"); clang::CachedTokens cached; for (auto &t : tokens) { cached.push_back(t); } std::stringstream ss; auto xasm_tc = xacc::getService<qcor::TokenCollector>("pyxasm"); xasm_tc->collect(*PP.get(), cached, {}, ss); std::cout << "heres the test\n"; std::cout << ss.str() << "\n"; const std::string expectedCodeGen = R"#(print("hello", 1, "world"); print("howdy", 1, "abc"); )#"; EXPECT_EQ(expectedCodeGen, ss.str()); } TEST(PyXASMTokenCollectorTester, checkQregMethods) { LexerHelper helper; auto [tokens, PP] = helper.Lex(R"( ctrl_qubits = q.head(q.size()-1) last_qubit = q.tail() Z.ctrl(ctrl_qubits, last_qubit) # inline X.ctrl(q.head(q.size()-1), q.tail()) # range: # API r = q.extract_range(0, bitPrecision) # Python style slice1 = q[0:3] # step size slice2 = q[0:5:2] )"); clang::CachedTokens cached; for (auto &t : tokens) { cached.push_back(t); } std::stringstream ss; auto xasm_tc = xacc::getService<qcor::TokenCollector>("pyxasm"); xasm_tc->collect(*PP.get(), cached, {"q"}, ss); std::cout << "heres the test\n"; std::cout << ss.str() << "\n"; const std::string expectedCodeGen = R"#(auto ctrl_qubits = q.head(q.size()-1); auto last_qubit = q.tail(); Z::ctrl(parent_kernel, ctrl_qubits, last_qubit); X::ctrl(parent_kernel, q.head(q.size()-1), q.tail()); auto r = q.extract_range(0,bitPrecision); auto slice1 = q.extract_range({static_cast<size_t>(0), static_cast<size_t>(3)}); auto slice2 = q.extract_range({static_cast<size_t>(0), static_cast<size_t>(5), static_cast<size_t>(2)}); )#"; EXPECT_EQ(expectedCodeGen, ss.str()); } TEST(PyXASMTokenCollectorTester, checkBroadCastWithSlice) { LexerHelper helper; auto [tokens, PP] = helper.Lex(R"( X(q.head(q.size()-1)) X(q[0]) X(q) X(q[0:2]) X(q[0:5:2]) Measure(q.head(q.size()-1)) Measure(q[0]) Measure(q) Measure(q[0:2]) Measure(q[0:5:2]) )"); clang::CachedTokens cached; for (auto &t : tokens) { cached.push_back(t); } std::stringstream ss; auto xasm_tc = xacc::getService<qcor::TokenCollector>("pyxasm"); xasm_tc->collect(*PP.get(), cached, {"q"}, ss); std::cout << "heres the test\n"; std::cout << ss.str() << "\n"; const std::string expectedCodeGen = R"#(quantum::x(q.head(q.size()-1)); quantum::x(q[0]); quantum::x(q); quantum::x(q.extract_range({static_cast<size_t>(0), static_cast<size_t>(2)})); quantum::x(q.extract_range({static_cast<size_t>(0), static_cast<size_t>(5), static_cast<size_t>(2)})); quantum::mz(q.head(q.size()-1)); quantum::mz(q[0]); quantum::mz(q); quantum::mz(q.extract_range({static_cast<size_t>(0), static_cast<size_t>(2)})); quantum::mz(q.extract_range({static_cast<size_t>(0), static_cast<size_t>(5), static_cast<size_t>(2)})); )#"; EXPECT_EQ(expectedCodeGen, ss.str()); } TEST(PyXASMTokenCollectorTester, checkQcorOperators) { LexerHelper helper; auto [tokens, PP] = helper.Lex(R"( exponent_op = X(0) * Y(1) - Y(0) * X(1) exp_i_theta(q, theta, exponent_op) )"); clang::CachedTokens cached; for (auto &t : tokens) { cached.push_back(t); } std::stringstream ss; auto xasm_tc = xacc::getService<qcor::TokenCollector>("pyxasm"); xasm_tc->collect(*PP.get(), cached, {"q"}, ss); std::cout << "heres the test\n"; std::cout << ss.str() << "\n"; const std::string expectedCodeGen = R"#(auto exponent_op = X(0)*Y(1)-Y(0)*X(1); quantum::exp(q, theta, exponent_op); )#"; EXPECT_EQ(expectedCodeGen, ss.str()); } TEST(PyXASMTokenCollectorTester, checkCommonMath) { LexerHelper helper; auto [tokens, PP] = helper.Lex(R"( out_parity = oneCount - 2 * (oneCount / 2) # Power index = 2**n )"); clang::CachedTokens cached; for (auto &t : tokens) { cached.push_back(t); } std::stringstream ss; auto xasm_tc = xacc::getService<qcor::TokenCollector>("pyxasm"); xasm_tc->collect(*PP.get(), cached, {""}, ss); std::cout << "heres the test\n"; std::cout << ss.str() << "\n"; const std::string expectedCodeGen = R"#(auto out_parity = oneCount-2*(oneCount/2); auto index = std::pow(2, n); )#"; EXPECT_EQ(expectedCodeGen, ss.str()); } int main(int argc, char **argv) { std::string xacc_config_install_dir = std::string(XACC_INSTALL_DIR); std::string qcor_root = std::string(QCOR_INSTALL_DIR); Loading
python/py-qcor.cpp +7 −1 Original line number Diff line number Diff line Loading @@ -596,7 +596,13 @@ PYBIND11_MODULE(_pyqcor, m) { } } return visitor.getMat(); }); }) .def( "get_kernel_function_ptr", [](qcor::QJIT &qjit, const std::string &kernel_name) { return qjit.get_kernel_function_ptr(kernel_name); }, ""); py::class_<qcor::ObjectiveFunction, std::shared_ptr<qcor::ObjectiveFunction>>( m, "ObjectiveFunction", "") Loading
python/qcor.in.py +82 −13 Original line number Diff line number Diff line Loading @@ -16,6 +16,14 @@ from collections import defaultdict List = typing.List Tuple = typing.Tuple MethodType = types.MethodType Callable = typing.Callable # KernelSignature type annotation: # Usage: annotate an function argument as a KernelSignature by: # varName: KernelSignature(qreg, ...) # Kernel always returns void (None) def KernelSignature(*args): return Callable[list(args), None] # Static cache of all Python QJIT objects that have been created. # There seems to be a bug when a Python interpreter tried to create a new QJIT Loading Loading @@ -95,7 +103,7 @@ class KernelGraph(object): self.kernel_idx_dep_map = {} self.kernel_name_list = [] def addKernelDependency(self, kernelName, depList): def createKernelDependency(self, kernelName, depList): self.kernel_name_list.append(kernelName) self.kernel_idx_dep_map[self.V] = [] for dep_ker_name in depList: Loading @@ -103,6 +111,10 @@ class KernelGraph(object): self.kernel_name_list.index(dep_ker_name)) self.V += 1 def addKernelDependency(self, kernelName, newDep): self.kernel_idx_dep_map[self.kernel_name_list.index(kernelName)].append( self.kernel_name_list.index(newDep)) def addEdge(self, u, v): self.graph[u].append(v) Loading Loading @@ -396,6 +408,23 @@ class qjit(object): cpp_arg_str += ',' + \ 'int& ' + arg continue if str(_type).startswith('typing.Callable'): cpp_type_str = 'KernelSignature<' for i in range(len(_type.__args__) - 1): # print("input type:", _type.__args__[i]) arg_type = _type.__args__[i] if str(arg_type) not in self.allowed_type_cpp_map: print('Error, this quantum kernel arg type is not allowed: ', str(_type)) exit(1) cpp_type_str += self.allowed_type_cpp_map[str(arg_type)] cpp_type_str += ',' cpp_type_str = cpp_type_str[:-1] cpp_type_str += '>' # print("cpp type", cpp_type_str) cpp_arg_str += ',' + cpp_type_str + ' ' + arg continue if str(_type) not in self.allowed_type_cpp_map: print('Error, this quantum kernel arg type is not allowed: ', str(_type)) exit(1) Loading Loading @@ -460,7 +489,7 @@ class qjit(object): if re.search(r"\b" + re.escape(kernelCall) + '|' + re.escape(kernelAdjCall) + '|' + re.escape(kernelCtrlCall), self.src): dependency.append(kernelName) self.__kernels__graph.addKernelDependency( self.__kernels__graph.createKernelDependency( self.function.__name__, dependency) self.sorted_kernel_dep = self.__kernels__graph.getSortedDependency( self.function.__name__) Loading Loading @@ -554,9 +583,7 @@ class qjit(object): """ assert len(args) == len(self.arg_names), "Cannot create CompositeInstruction, you did not provided the correct kernel arguments." # Create a dictionary for the function arguments args_dict = {} for i, arg_name in enumerate(self.arg_names): args_dict[arg_name] = list(args)[i] args_dict = self.construct_arg_dict(*args) return self._qjit.extract_composite(self.function.__name__, args_dict) def observe(self, observable, *args): Loading Loading @@ -590,9 +617,7 @@ class qjit(object): return self.extract_composite(*args).nInstructions() def as_unitary_matrix(self, *args): args_dict = {} for i, arg_name in enumerate(self.arg_names): args_dict[arg_name] = list(args)[i] args_dict = self.construct_arg_dict(*args) return self._qjit.internal_as_unitary(self.function.__name__, args_dict) def ctrl(self, *args): Loading Loading @@ -622,16 +647,60 @@ class qjit(object): def qir(self, *args, **kwargs): return llvm_ir(*args, **kwargs) def __call__(self, *args): """ Execute the decorated quantum kernel. This will directly invoke the corresponding LLVM JITed function pointer. """ # Helper to construct the arg_dict (HetMap) # e.g. perform any additional type conversion if required. def construct_arg_dict(self, *args): # Create a dictionary for the function arguments args_dict = {} for i, arg_name in enumerate(self.arg_names): args_dict[arg_name] = list(args)[i] arg_type_str = str(self.type_annotations[arg_name]) if arg_type_str.startswith('typing.Callable'): # print("callable:", arg_name) # print("arg:", type(args_dict[arg_name])) # the arg must be a qjit if not isinstance(args_dict[arg_name], qjit): print('Invalid argument type for {}. A quantum kernel (qjit) is expected.'.format(arg_name)) exit(1) callable_qjit = args_dict[arg_name] # Handle runtime dependency: # The QJIT arg. was not *known* until invocation, # hence, we recompile the this jit kernel taking into account # the KernelSignature argument. # TODO: perhaps an optimization that we can make is to # skip *eager* compilation for those kernels that have # KernelSignature arguments. if callable_qjit.kernel_name() not in self.sorted_kernel_dep: # print('New kernel:', callable_qjit.kernel_name()) # IMPORTANT: we cannot release a QJIT object till shut-down. QJIT_OBJ_CACHE.append(self._qjit) # Create a new QJIT self._qjit = QJIT() # Add a kernel dependency self.__kernels__graph.addKernelDependency(self.function.__name__, callable_qjit.kernel_name()) self.sorted_kernel_dep = self.__kernels__graph.getSortedDependency(self.function.__name__) # Recompile: self._qjit.internal_python_jit_compile(self.src, self.sorted_kernel_dep, self.extra_cpp_code, extra_headers) # This should always be successful. fn_ptr = self._qjit.get_kernel_function_ptr(callable_qjit.kernel_name()) if fn_ptr == 0: print('Failed to retrieve JIT-compiled function pointer for qjit kernel {}.'.format(callable_qjit.kernel_name())) exit(1) # Replace the argument (in the dict) with the function pointer # qjit is a pure-Python object, hence cannot be used by native QCOR. args_dict[arg_name] = hex(fn_ptr) return args_dict def __call__(self, *args): """ Execute the decorated quantum kernel. This will directly invoke the corresponding LLVM JITed function pointer. """ args_dict = self.construct_arg_dict(*args) # Invoke the JITed function self._qjit.invoke(self.function.__name__, args_dict) Loading