Commit 7ca241e8 authored by Nguyen, Thien Minh's avatar Nguyen, Thien Minh
Browse files

Work on support KernelSignature arguments in QJIT



- Using Python Callable type annotation to declare KernelSignature

- JIT function pointer is added in place of KernelSignature argument in the HetMap.

- Code-gen to unpack the function ptr and construct the KernelSignature.

Signed-off-by: default avatarThien Nguyen <nguyentm@ornl.gov>
parent 6595a0be
Loading
Loading
Loading
Loading
+19 −0
Original line number Diff line number Diff line
@@ -399,6 +399,25 @@ void QCORSyntaxHandler::GetReplacement(
        // We just pass this copied var to the ctor
        // where it expects a reference type.
        arg_ctor_list.emplace_back(new_var_name);
      } else if (program_arg_types[i].rfind("KernelSignature", 0) == 0) {
        // This is a KernelSignature argument.
        // The one in HetMap is the function pointer represented as an integer.
        const std::string new_var_name =
            "__temp_kernel_ptr_var__" + std::to_string(var_counter++);
        // Retrieve the function pointer from the HetMap
        // ref_type_copy_decl_ss << "std::cout << args.getString(\""
        //                       << program_parameters[i] << "\").c_str() << std::endl;\n";
        ref_type_copy_decl_ss << "void* " << new_var_name << " = "
                              << "(void *) strtoull(args.getString(\""
                              << program_parameters[i] << "\").c_str(), nullptr, 16);\n";
        // ref_type_copy_decl_ss << "std::cout << " << new_var_name << " << std::endl;\n";
        // Construct the KernelSignature
        const std::string kernel_signature_var_name =
            "__temp_kernel_signature_var__" + std::to_string(var_counter++);
        ref_type_copy_decl_ss << program_arg_types[i] << " "
                              << kernel_signature_var_name << "("
                              << new_var_name << ");\n";
        arg_ctor_list.emplace_back(kernel_signature_var_name);
      } else {
        // Otherwise, just unpack the arg inline in the ctor call.
        std::stringstream ss;
+7 −1
Original line number Diff line number Diff line
@@ -596,7 +596,13 @@ PYBIND11_MODULE(_pyqcor, m) {
               }
             }
             return visitor.getMat();
           });
           })
      .def(
          "get_kernel_function_ptr",
          [](qcor::QJIT &qjit, const std::string &kernel_name) {
            return qjit.get_kernel_function_ptr(kernel_name);
          },
          "");

  py::class_<qcor::ObjectiveFunction, std::shared_ptr<qcor::ObjectiveFunction>>(
      m, "ObjectiveFunction", "")
+40 −1
Original line number Diff line number Diff line
@@ -16,6 +16,7 @@ from collections import defaultdict
List = typing.List
Tuple = typing.Tuple
MethodType = types.MethodType
Callable = typing.Callable

# Static cache of all Python QJIT objects that have been created.
# There seems to be a bug when a Python interpreter tried to create a new QJIT
@@ -396,6 +397,23 @@ class qjit(object):
                cpp_arg_str += ',' + \
                    'int& ' + arg
                continue
            if str(_type).startswith('typing.Callable'):
                cpp_type_str = 'KernelSignature<'
                for i in range(len(_type.__args__) - 1):
                    # print("input type:", _type.__args__[i])
                    arg_type = _type.__args__[i]
                    if str(arg_type) not in self.allowed_type_cpp_map:
                        print('Error, this quantum kernel arg type is not allowed: ', str(_type))
                        exit(1)
                    cpp_type_str += self.allowed_type_cpp_map[str(arg_type)]
                    cpp_type_str += ','
                
                cpp_type_str = cpp_type_str[:-1]
                cpp_type_str += '>'
                # print("cpp type", cpp_type_str)
                cpp_arg_str += ',' + cpp_type_str + ' ' + arg
                continue

            if str(_type) not in self.allowed_type_cpp_map:
                print('Error, this quantum kernel arg type is not allowed: ', str(_type))
                exit(1)
@@ -631,6 +649,27 @@ class qjit(object):
        args_dict = {}
        for i, arg_name in enumerate(self.arg_names):
            args_dict[arg_name] = list(args)[i]
            print(arg_name)
            print(self.type_annotations[arg_name])
            arg_type_str = str(self.type_annotations[arg_name])
            if arg_type_str.startswith('typing.Callable'):
                print("callable:", arg_name)
                print("arg:", type(args_dict[arg_name]))
                # the arg must be a qjit
                if not isinstance(args_dict[arg_name], qjit):
                    print('Invalid argument type for {}. A quantum kernel (qjit) is expected.'.format(arg_name))
                    exit(1)
                
                callable_qjit = args_dict[arg_name]
                fn_ptr = hex(self._qjit.get_kernel_function_ptr(callable_qjit.kernel_name()))
                print("Fn ptr:", fn_ptr)
                if fn_ptr == 0:
                    print('Failed to retrieve JIT-compiled function pointer for qjit kernel {}.'.format(callable_qjit.kernel_name()))
                    exit(1)
                # Replace the argument (in the dict) with the function pointer
                # qjit is a pure-Python object, hence cannot be used by native QCOR.
                args_dict[arg_name] = fn_ptr
                print(type(args_dict[arg_name]))
        
        # Invoke the JITed function
        self._qjit.invoke(self.function.__name__, args_dict)
+4 −6
Original line number Diff line number Diff line
@@ -22,19 +22,17 @@ class TestKernelJIT(unittest.TestCase):
                Z.ctrl(q[0: q.size() - 1], q[q.size() - 1])
            
        @qjit
        def run_grover(q: qreg, iterations: int):
        def run_grover(q: qreg, oracle_var: Callable[[qreg], None], iterations: int):
            H(q)
            #Iteratively apply the oracle then reflect
            for i in range(iterations):
                oracle_fn(q)
                oracle_var(q)
                reflect_about_uniform(q)
            # Measure all qubits
            Measure(q)

        q = qalloc(3)
        # comp = run_grover.extract_composite(q, 1)
        # print(comp)
        run_grover(q, 1)
        run_grover(q, oracle_fn, 1)
        q.print()
        counts = q.counts()
        print(counts)
+8 −0
Original line number Diff line number Diff line
@@ -80,6 +80,14 @@ class QJIT {
    void (*kernel_functor)(Args...) = (void (*)(Args...))f_ptr;
    return kernel_functor;
  }

  // The type of kernel functions: 
  enum class KernelType { Regular, HetMapArg, HetMapArgWithParent };
  // Return kernel function pointer (as an integer)
  // Returns 0 if the kernel doesn't exist.
  std::uint64_t
  get_kernel_function_ptr(const std::string &kernelName,
                       KernelType subType = KernelType::Regular) const;
};

}  // namespace qcor
 No newline at end of file
Loading