Commit 119c96a5 authored by Nguyen, Thien Minh's avatar Nguyen, Thien Minh
Browse files

Syntactic sugar for PyXASM Kernel Signature and code clean-up



Signed-off-by: default avatarThien Nguyen <nguyentm@ornl.gov>
parent 03fbdf0d
Loading
Loading
Loading
Loading
+9 −6
Original line number Diff line number Diff line
@@ -18,6 +18,13 @@ Tuple = typing.Tuple
MethodType = types.MethodType
Callable = typing.Callable

# KernelSignature type annotation:
# Usage: annotate an function argument as a KernelSignature by:
# varName: KernelSignature(qreg, ...)
# Kernel always returns void (None)
def KernelSignature(*args):
    return Callable[list(args), None]

# Static cache of all Python QJIT objects that have been created.
# There seems to be a bug when a Python interpreter tried to create a new QJIT
# *after* a previous QJIT is destroyed.
@@ -649,12 +656,10 @@ class qjit(object):
        args_dict = {}
        for i, arg_name in enumerate(self.arg_names):
            args_dict[arg_name] = list(args)[i]
            print(arg_name)
            print(self.type_annotations[arg_name])
            arg_type_str = str(self.type_annotations[arg_name])
            if arg_type_str.startswith('typing.Callable'):
                print("callable:", arg_name)
                print("arg:", type(args_dict[arg_name]))
                # print("callable:", arg_name)
                # print("arg:", type(args_dict[arg_name]))
                # the arg must be a qjit
                if not isinstance(args_dict[arg_name], qjit):
                    print('Invalid argument type for {}. A quantum kernel (qjit) is expected.'.format(arg_name))
@@ -662,14 +667,12 @@ class qjit(object):
                
                callable_qjit = args_dict[arg_name]
                fn_ptr = hex(self._qjit.get_kernel_function_ptr(callable_qjit.kernel_name()))
                print("Fn ptr:", fn_ptr)
                if fn_ptr == 0:
                    print('Failed to retrieve JIT-compiled function pointer for qjit kernel {}.'.format(callable_qjit.kernel_name()))
                    exit(1)
                # Replace the argument (in the dict) with the function pointer
                # qjit is a pure-Python object, hence cannot be used by native QCOR.
                args_dict[arg_name] = fn_ptr
                print(type(args_dict[arg_name]))
        
        # Invoke the JITed function
        self._qjit.invoke(self.function.__name__, args_dict)
+1 −1
Original line number Diff line number Diff line
@@ -22,7 +22,7 @@ class TestKernelJIT(unittest.TestCase):
                Z.ctrl(q[0: q.size() - 1], q[q.size() - 1])
            
        @qjit
        def run_grover(q: qreg, oracle_var: Callable[[qreg], None], iterations: int):
        def run_grover(q: qreg, oracle_var: KernelSignature(qreg), iterations: int):
            H(q)
            #Iteratively apply the oracle then reflect
            for i in range(iterations):
+4 −9
Original line number Diff line number Diff line
@@ -346,13 +346,6 @@ template <typename... Args>
using callable_function_ptr =
    void (*)(std::shared_ptr<xacc::CompositeInstruction>, Args...);

template <typename... Args>
callable_function_ptr<Args...> callable_function_ptr_from_raw_ptr(void *f_ptr) {
  void (*kernel_functor)(std::shared_ptr<xacc::CompositeInstruction>, Args...) =
      (callable_function_ptr<Args...>)f_ptr;
  return kernel_functor;
}

template <typename... Args>
class KernelSignature {
 protected:
@@ -361,8 +354,10 @@ class KernelSignature {
 public:
  KernelSignature(callable_function_ptr<Args...> &&f) : function_pointer(f) {}
  // Ctor from raw void* funtion pointer.
  KernelSignature(void *f_ptr)
      : KernelSignature(callable_function_ptr_from_raw_ptr<Args...>(f_ptr)) {}
  // IMPORTANT: since function_pointer is kept as a *reference*,
  // we must keep a reference to the original f_ptr void* as well.
  KernelSignature(void *&f_ptr)
      : function_pointer((callable_function_ptr<Args...> &)f_ptr) {}

  void operator()(std::shared_ptr<xacc::CompositeInstruction> ir,
                  Args... args) {