Loading python/qcor.in.py +9 −6 Original line number Diff line number Diff line Loading @@ -18,6 +18,13 @@ Tuple = typing.Tuple MethodType = types.MethodType Callable = typing.Callable # KernelSignature type annotation: # Usage: annotate an function argument as a KernelSignature by: # varName: KernelSignature(qreg, ...) # Kernel always returns void (None) def KernelSignature(*args): return Callable[list(args), None] # Static cache of all Python QJIT objects that have been created. # There seems to be a bug when a Python interpreter tried to create a new QJIT # *after* a previous QJIT is destroyed. Loading Loading @@ -649,12 +656,10 @@ class qjit(object): args_dict = {} for i, arg_name in enumerate(self.arg_names): args_dict[arg_name] = list(args)[i] print(arg_name) print(self.type_annotations[arg_name]) arg_type_str = str(self.type_annotations[arg_name]) if arg_type_str.startswith('typing.Callable'): print("callable:", arg_name) print("arg:", type(args_dict[arg_name])) # print("callable:", arg_name) # print("arg:", type(args_dict[arg_name])) # the arg must be a qjit if not isinstance(args_dict[arg_name], qjit): print('Invalid argument type for {}. A quantum kernel (qjit) is expected.'.format(arg_name)) Loading @@ -662,14 +667,12 @@ class qjit(object): callable_qjit = args_dict[arg_name] fn_ptr = hex(self._qjit.get_kernel_function_ptr(callable_qjit.kernel_name())) print("Fn ptr:", fn_ptr) if fn_ptr == 0: print('Failed to retrieve JIT-compiled function pointer for qjit kernel {}.'.format(callable_qjit.kernel_name())) exit(1) # Replace the argument (in the dict) with the function pointer # qjit is a pure-Python object, hence cannot be used by native QCOR. args_dict[arg_name] = fn_ptr print(type(args_dict[arg_name])) # Invoke the JITed function self._qjit.invoke(self.function.__name__, args_dict) Loading python/tests/test_jit_grover.py +1 −1 Original line number Diff line number Diff line Loading @@ -22,7 +22,7 @@ class TestKernelJIT(unittest.TestCase): Z.ctrl(q[0: q.size() - 1], q[q.size() - 1]) @qjit def run_grover(q: qreg, oracle_var: Callable[[qreg], None], iterations: int): def run_grover(q: qreg, oracle_var: KernelSignature(qreg), iterations: int): H(q) #Iteratively apply the oracle then reflect for i in range(iterations): Loading runtime/kernel/quantum_kernel.hpp +4 −9 Original line number Diff line number Diff line Loading @@ -346,13 +346,6 @@ template <typename... Args> using callable_function_ptr = void (*)(std::shared_ptr<xacc::CompositeInstruction>, Args...); template <typename... Args> callable_function_ptr<Args...> callable_function_ptr_from_raw_ptr(void *f_ptr) { void (*kernel_functor)(std::shared_ptr<xacc::CompositeInstruction>, Args...) = (callable_function_ptr<Args...>)f_ptr; return kernel_functor; } template <typename... Args> class KernelSignature { protected: Loading @@ -361,8 +354,10 @@ class KernelSignature { public: KernelSignature(callable_function_ptr<Args...> &&f) : function_pointer(f) {} // Ctor from raw void* funtion pointer. KernelSignature(void *f_ptr) : KernelSignature(callable_function_ptr_from_raw_ptr<Args...>(f_ptr)) {} // IMPORTANT: since function_pointer is kept as a *reference*, // we must keep a reference to the original f_ptr void* as well. KernelSignature(void *&f_ptr) : function_pointer((callable_function_ptr<Args...> &)f_ptr) {} void operator()(std::shared_ptr<xacc::CompositeInstruction> ir, Args... args) { Loading Loading
python/qcor.in.py +9 −6 Original line number Diff line number Diff line Loading @@ -18,6 +18,13 @@ Tuple = typing.Tuple MethodType = types.MethodType Callable = typing.Callable # KernelSignature type annotation: # Usage: annotate an function argument as a KernelSignature by: # varName: KernelSignature(qreg, ...) # Kernel always returns void (None) def KernelSignature(*args): return Callable[list(args), None] # Static cache of all Python QJIT objects that have been created. # There seems to be a bug when a Python interpreter tried to create a new QJIT # *after* a previous QJIT is destroyed. Loading Loading @@ -649,12 +656,10 @@ class qjit(object): args_dict = {} for i, arg_name in enumerate(self.arg_names): args_dict[arg_name] = list(args)[i] print(arg_name) print(self.type_annotations[arg_name]) arg_type_str = str(self.type_annotations[arg_name]) if arg_type_str.startswith('typing.Callable'): print("callable:", arg_name) print("arg:", type(args_dict[arg_name])) # print("callable:", arg_name) # print("arg:", type(args_dict[arg_name])) # the arg must be a qjit if not isinstance(args_dict[arg_name], qjit): print('Invalid argument type for {}. A quantum kernel (qjit) is expected.'.format(arg_name)) Loading @@ -662,14 +667,12 @@ class qjit(object): callable_qjit = args_dict[arg_name] fn_ptr = hex(self._qjit.get_kernel_function_ptr(callable_qjit.kernel_name())) print("Fn ptr:", fn_ptr) if fn_ptr == 0: print('Failed to retrieve JIT-compiled function pointer for qjit kernel {}.'.format(callable_qjit.kernel_name())) exit(1) # Replace the argument (in the dict) with the function pointer # qjit is a pure-Python object, hence cannot be used by native QCOR. args_dict[arg_name] = fn_ptr print(type(args_dict[arg_name])) # Invoke the JITed function self._qjit.invoke(self.function.__name__, args_dict) Loading
python/tests/test_jit_grover.py +1 −1 Original line number Diff line number Diff line Loading @@ -22,7 +22,7 @@ class TestKernelJIT(unittest.TestCase): Z.ctrl(q[0: q.size() - 1], q[q.size() - 1]) @qjit def run_grover(q: qreg, oracle_var: Callable[[qreg], None], iterations: int): def run_grover(q: qreg, oracle_var: KernelSignature(qreg), iterations: int): H(q) #Iteratively apply the oracle then reflect for i in range(iterations): Loading
runtime/kernel/quantum_kernel.hpp +4 −9 Original line number Diff line number Diff line Loading @@ -346,13 +346,6 @@ template <typename... Args> using callable_function_ptr = void (*)(std::shared_ptr<xacc::CompositeInstruction>, Args...); template <typename... Args> callable_function_ptr<Args...> callable_function_ptr_from_raw_ptr(void *f_ptr) { void (*kernel_functor)(std::shared_ptr<xacc::CompositeInstruction>, Args...) = (callable_function_ptr<Args...>)f_ptr; return kernel_functor; } template <typename... Args> class KernelSignature { protected: Loading @@ -361,8 +354,10 @@ class KernelSignature { public: KernelSignature(callable_function_ptr<Args...> &&f) : function_pointer(f) {} // Ctor from raw void* funtion pointer. KernelSignature(void *f_ptr) : KernelSignature(callable_function_ptr_from_raw_ptr<Args...>(f_ptr)) {} // IMPORTANT: since function_pointer is kept as a *reference*, // we must keep a reference to the original f_ptr void* as well. KernelSignature(void *&f_ptr) : function_pointer((callable_function_ptr<Args...> &)f_ptr) {} void operator()(std::shared_ptr<xacc::CompositeInstruction> ir, Args... args) { Loading