Commit 11c18dce authored by Nguyen, Thien Minh's avatar Nguyen, Thien Minh
Browse files

More fixes for KernelSignature support in PyXASM:



- QJIT needs a more robust function signature parsing: in addition to commas, it needs to recognize template argument type, like Type<a,b,c>

- Python args_dict construction to act for invocation (__call__) as well as other helper methods.

- Comment out debug logging

Signed-off-by: default avatarThien Nguyen <nguyentm@ornl.gov>
parent 119c96a5
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -401,7 +401,7 @@ void QCORSyntaxHandler::GetReplacement(
        arg_ctor_list.emplace_back(new_var_name);
      } else if (program_arg_types[i].rfind("KernelSignature", 0) == 0) {
        // This is a KernelSignature argument.
        // The one in HetMap is the function pointer represented as an integer.
        // The one in HetMap is the function pointer represented as a hex string.
        const std::string new_var_name =
            "__temp_kernel_ptr_var__" + std::to_string(var_counter++);
        // Retrieve the function pointer from the HetMap
+12 −12
Original line number Diff line number Diff line
@@ -40,7 +40,7 @@ class pyxasm_visitor : public pyxasmBaseVisitor {

  antlrcpp::Any visitAtom_expr(
      pyxasmParser::Atom_exprContext *context) override {
    std::cout << "Atom_exprContext: " << context->getText() << "\n";
    // std::cout << "Atom_exprContext: " << context->getText() << "\n";
    // Strategy:
    // At the top level, we analyze the trailer to determine the 
    // list of function call arguments.
@@ -60,13 +60,13 @@ class pyxasm_visitor : public pyxasmBaseVisitor {
      if (context->atom() && context->atom()->OPEN_BRACK() &&
          context->atom()->CLOSE_BRACK() && context->atom()->testlist_comp()) {
        // Array type expression:
        std::cout << "Array atom expression: "
                  << context->atom()->testlist_comp()->getText() << "\n";
        // std::cout << "Array atom expression: "
        //           << context->atom()->testlist_comp()->getText() << "\n";
        // Use braces
        sub_node_translation << "{";
        bool firstElProcessed = false;
        for (auto &testNode : context->atom()->testlist_comp()->test()) {
          std::cout << "Array elem: " << testNode->getText() << "\n";
          // std::cout << "Array elem: " << testNode->getText() << "\n";
          // Add comma if needed (there is a previous element)
          if (firstElProcessed) {
            sub_node_translation << ", ";
@@ -82,8 +82,8 @@ class pyxasm_visitor : public pyxasmBaseVisitor {
      if (context->atom() && context->atom()->OPEN_BRACE() &&
          context->atom()->CLOSE_BRACE() && context->atom()->dictorsetmaker()) {
        // Dict:
        std::cout << "Dict atom expression: "
                  << context->atom()->dictorsetmaker()->getText() << "\n";
        // std::cout << "Dict atom expression: "
        //           << context->atom()->dictorsetmaker()->getText() << "\n";
        // TODO:
        return 0;
      }
@@ -98,8 +98,8 @@ class pyxasm_visitor : public pyxasmBaseVisitor {
            cppStrLiteral.back() = '"';
          }
          sub_node_translation << cppStrLiteral;
          std::cout << "String expression: " << strNode->getText() << " --> "
                    << cppStrLiteral << "\n";
          // std::cout << "String expression: " << strNode->getText() << " --> "
          //           << cppStrLiteral << "\n";
        }
        return 0;
      }
@@ -124,7 +124,7 @@ class pyxasm_visitor : public pyxasmBaseVisitor {
      if (context->atom() &&
          xacc::container::contains(bufferNames, context->atom()->getText()) &&
          isSliceOp(context)) {
        std::cout << "Slice op: " << context->getText() << "\n";
        // std::cout << "Slice op: " << context->getText() << "\n";
        sub_node_translation << context->atom()->getText()
                             << ".extract_range({";
        auto subscripts =
@@ -155,8 +155,8 @@ class pyxasm_visitor : public pyxasmBaseVisitor {
        sub_node_translation << "})";

        // convert the slice op to initializer list:
        std::cout << "Slice Convert: " << context->getText() << " --> "
                  << sub_node_translation.str() << "\n";
        // std::cout << "Slice Convert: " << context->getText() << " --> "
        //           << sub_node_translation.str() << "\n";
        return 0;
      }

@@ -337,7 +337,7 @@ class pyxasm_visitor : public pyxasmBaseVisitor {
            // A classical call-like expression: i.e. not a kernel call:
            // Just output it *as-is* to the C++ stream.
            // We can hook more sophisticated code-gen here if required.
            std::cout << "Callable: " << context->getText() << "\n";
            // std::cout << "Callable: " << context->getText() << "\n";
            std::stringstream ss;

            if (context->trailer()[0]->arglist() &&
+17 −15
Original line number Diff line number Diff line
@@ -579,9 +579,7 @@ class qjit(object):
        """
        assert len(args) == len(self.arg_names), "Cannot create CompositeInstruction, you did not provided the correct kernel arguments."
        # Create a dictionary for the function arguments
        args_dict = {}
        for i, arg_name in enumerate(self.arg_names):
            args_dict[arg_name] = list(args)[i]
        args_dict = self.construct_arg_dict(*args)
        return self._qjit.extract_composite(self.function.__name__, args_dict)

    def observe(self, observable, *args):
@@ -615,9 +613,7 @@ class qjit(object):
        return self.extract_composite(*args).nInstructions()
    
    def as_unitary_matrix(self, *args):
        args_dict = {}
        for i, arg_name in enumerate(self.arg_names):
            args_dict[arg_name] = list(args)[i]
        args_dict = self.construct_arg_dict(*args)
        return self._qjit.internal_as_unitary(self.function.__name__, args_dict)
    
    def ctrl(self, *args):
@@ -647,11 +643,9 @@ class qjit(object):
    def qir(self, *args, **kwargs):
        return llvm_ir(*args, **kwargs)

    def __call__(self, *args):
        """
        Execute the decorated quantum kernel. This will directly 
        invoke the corresponding LLVM JITed function pointer. 
        """
    # Helper to construct the arg_dict (HetMap)
    # e.g. perform any additional type conversion if required.
    def construct_arg_dict(self, *args):
        # Create a dictionary for the function arguments
        args_dict = {}
        for i, arg_name in enumerate(self.arg_names):
@@ -666,16 +660,24 @@ class qjit(object):
                    exit(1)
                
                callable_qjit = args_dict[arg_name]
                fn_ptr = hex(self._qjit.get_kernel_function_ptr(callable_qjit.kernel_name()))
                fn_ptr = self._qjit.get_kernel_function_ptr(callable_qjit.kernel_name())
                if fn_ptr == 0:
                    print('Failed to retrieve JIT-compiled function pointer for qjit kernel {}.'.format(callable_qjit.kernel_name()))
                    exit(1)
                # Replace the argument (in the dict) with the function pointer
                # qjit is a pure-Python object, hence cannot be used by native QCOR.
                args_dict[arg_name] = fn_ptr
                args_dict[arg_name] = hex(fn_ptr)
            
        return args_dict

    def __call__(self, *args):
        """
        Execute the decorated quantum kernel. This will directly 
        invoke the corresponding LLVM JITed function pointer. 
        """
        arg_dict = self.construct_arg_dict(*args)
        # Invoke the JITed function
        self._qjit.invoke(self.function.__name__, args_dict)
        self._qjit.invoke(self.function.__name__, arg_dict)

        # Update any *by-ref* arguments: annotated with the custom type: FLOAT_REF, INT_REF, etc.
        # If there are *pass-by-ref* variables:
+7 −1
Original line number Diff line number Diff line
@@ -67,4 +67,10 @@ add_test (NAME qcor_python_jit_grover
set_tests_properties(qcor_python_jit_grover
  PROPERTIES ENVIRONMENT "PYTHONPATH=${CMAKE_INSTALL_PREFIX}:$ENV{PYTHONPATH}")

add_test (NAME qcor_python_jit_kernel_signature
  COMMAND ${Python_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/test_jit_kernel_signature.py
  WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
)
set_tests_properties(qcor_python_jit_kernel_signature
  PROPERTIES ENVIRONMENT "PYTHONPATH=${CMAKE_INSTALL_PREFIX}:$ENV{PYTHONPATH}")
 
 No newline at end of file
+34 −0
Original line number Diff line number Diff line
import faulthandler
faulthandler.enable()

import unittest
from qcor import *

class TestKernelJIT(unittest.TestCase):
    def test_grover(self):
        set_qpu('qpp', {'shots':1024})
        
        @qjit
        def rx_kernel(q: qreg, idx: int, theta: float):
            Rx(q[idx], theta)
        
        @qjit
        def test_kernel(q: qreg, call_var: KernelSignature(qreg, int, float)):
            call_var(q, 0, 1.0)
            call_var(q, 1, 2.0)
            # TODO: currently, we don't have the ability to inject
            # new dependency, hence must use rx_kernel here to 
            # pull rx_kernel in.
            rx_kernel(q, 2, 3.0)

        q = qalloc(3)
        test_kernel(q, rx_kernel)
        comp = test_kernel.extract_composite(q, rx_kernel)
        print(comp)
        self.assertEqual(comp.nInstructions(), 3)   
        for i in range(3):
            self.assertEqual(comp.getInstruction(i).name(), "Rx") 
            self.assertAlmostEqual((float)(comp.getInstruction(i).getParameter(0)), i + 1.0)

if __name__ == '__main__':
  unittest.main()
 No newline at end of file
Loading