Commit 1ccdb172 authored by Nguyen, Thien Minh's avatar Nguyen, Thien Minh
Browse files

Start working on support for List<KernelSignature> argument



Signed-off-by: default avatarThien Nguyen <nguyentm@ornl.gov>
parent 51478558
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -61,7 +61,7 @@ using AllowedKernelArgTypes =
                  xacc::internal_compiler::qubit, std::vector<double>,
                  std::vector<int>, qcor::PauliOperator, qcor::FermionOperator,
                  qcor::PairList<int>, std::vector<qcor::PauliOperator>,
                  std::vector<qcor::FermionOperator>>;
                  std::vector<qcor::FermionOperator>, std::vector<std::string>>;

// We will take as input a mapping of arg variable names to the argument itself.
using KernelArgDict = std::map<std::string, AllowedKernelArgTypes>;
+57 −10
Original line number Diff line number Diff line
@@ -408,22 +408,36 @@ class qjit(object):
                cpp_arg_str += ',' + \
                    'int& ' + arg
                continue
            if str(_type).startswith('typing.Callable'):
                cpp_type_str = 'KernelSignature<'
                for i in range(len(_type.__args__) - 1):
                
            # Helper to parse Python KernelSignature type annotation     
            def construct_callable_signature(clb_type):
                result_type_str = 'KernelSignature<'
                for i in range(len(clb_type.__args__) - 1):
                    # print("input type:", _type.__args__[i])
                    arg_type = _type.__args__[i]
                    arg_type = clb_type.__args__[i]
                    if str(arg_type) not in self.allowed_type_cpp_map:
                        print('Error, this quantum kernel arg type is not allowed: ', str(_type))
                        print('Error, this quantum kernel arg type is not allowed: ', str(clb_type))
                        exit(1)
                    cpp_type_str += self.allowed_type_cpp_map[str(arg_type)]
                    cpp_type_str += ','
                    result_type_str += self.allowed_type_cpp_map[str(arg_type)]
                    result_type_str += ','
                
                cpp_type_str = cpp_type_str[:-1]
                cpp_type_str += '>'
                result_type_str = result_type_str[:-1]
                result_type_str += '>'
                return result_type_str

            # Single Callable argument
            if str(_type).startswith('typing.Callable'):
                cpp_type_str = construct_callable_signature(_type)
                # print("cpp type", cpp_type_str)
                cpp_arg_str += ',' + cpp_type_str + ' ' + arg
                continue
            # List of KernelSignature
            if str(_type).startswith('typing.List[typing.Callable'):
                # Note: All the Callables in the list must have the same signature.
                # (and they should to be considered equivalent for grouping into a List)
                cpp_type_str = 'std::vector<' + construct_callable_signature(_type.__args__[0]) + '>'
                cpp_arg_str += ',' + cpp_type_str + ' ' + arg
                continue

            if str(_type) not in self.allowed_type_cpp_map:
                print('Error, this quantum kernel arg type is not allowed: ', str(_type))
@@ -718,6 +732,39 @@ class qjit(object):
                # Replace the argument (in the dict) with the function pointer
                # qjit is a pure-Python object, hence cannot be used by native QCOR.
                args_dict[arg_name] = hex(fn_ptr)
            # List of callables:
            if arg_type_str.startswith('typing.List[typing.Callable['):
                callable_qjit_list = args_dict[arg_name]
                need_recompile = False
                for clb in callable_qjit_list:
                    if clb.kernel_name() not in self.sorted_kernel_dep:
                        print('New kernel:', clb.kernel_name())
                        # Add a kernel dependency
                        self.__kernels__graph.addKernelDependency(self.function.__name__, clb.kernel_name())
                        self.sorted_kernel_dep = self.__kernels__graph.getSortedDependency(self.function.__name__)
                        need_recompile = True
                
                if need_recompile:
                    # Create a new QJIT
                    self._qjit = QJIT()
                    self._qjit.internal_python_jit_compile(self.src, self.sorted_kernel_dep, self.extra_cpp_code, extra_headers)
                
                clb_fn_ptrs = []
                for clb in callable_qjit_list:
                    if not isinstance(clb, qjit):
                        print('Invalid argument type for {}. A list of quantum kernels (qjit) is expected.'.format(arg_name))
                        exit(1)
                    print("Kernel name:", clb.kernel_name())
                    # This should always be successful.
                    fn_ptr = self._qjit.get_kernel_function_ptr(clb.kernel_name())
                    if fn_ptr == 0:
                        print('Failed to retrieve JIT-compiled function pointer for qjit kernel {}.'.format(clb.kernel_name()))
                        exit(1)
                    clb_fn_ptrs.append(hex(fn_ptr))
                    
                # Replace the argument (in the dict) with the list of function pointers
                # qjit is a pure-Python object, hence cannot be used by native QCOR.
                args_dict[arg_name] = clb_fn_ptrs
        
        return args_dict