Unverified Commit 1347c898 authored by Mccaskey, Alex's avatar Mccaskey, Alex Committed by GitHub
Browse files

Merge pull request #166 from tnguyen-ornl/tnguyen/pyxasm-list-callables

Support List<KernelSignature>
parents c604bfa5 4f44c954
Loading
Loading
Loading
Loading
Loading
+35 −0
Original line number Diff line number Diff line
@@ -193,6 +193,16 @@ void QCORSyntaxHandler::GetReplacement(
  // We only support one buffer in FTQC mode atm.
  OS << "quantum::set_current_buffer(" << bufferNames[0] << ".results());\n";
  OS << "}\n";
  // Set the parent_kernel of this kernel to any KernelSignature
  // (even nested in container) so that even w/o parent_kernel added by the 
  // token collector, the KernelSignature will still be referring to the correct parent_kernel.
  // i.e. we can track KernelSignature coming from complex data container such as std::vector
  // rather than relying on the list of kernels in translation unit.
  OS << "init_kernel_signature_args(parent_kernel, " << program_parameters[0];
  for (int i = 1; i < program_arg_types.size(); i++) {
    OS << ", " << program_parameters[i];
  }
  OS << ");\n";
  OS << new_src << "\n";
  OS << "}\n";

@@ -437,6 +447,31 @@ void QCORSyntaxHandler::GetReplacement(
                              << kernel_signature_var_name << "("
                              << new_var_name << ");\n";
        arg_ctor_list.emplace_back(kernel_signature_var_name);
      } else if (program_arg_types[i].rfind("std::vector<KernelSignature<",
                                            0) == 0) {
        // This is a list of KernelSignatures argument.
        // The one in HetMap is the vector of function pointers represented as a
        // hex string.
        const std::string new_var_name =
            "__temp_kernel_ptr_vector_var__" + std::to_string(var_counter++);
        // Retrieve the list of function pointer from the HetMap
        ref_type_copy_decl_ss << "auto " << new_var_name
                              << " = args.get<std::vector<std::string>>(\""
                              << program_parameters[i] << "\");\n";

        const std::string list_kernel_signature_var_name =
            "__temp_kernel_signature_var__" + std::to_string(var_counter++);
        // Declare the vector of kernel signatures
        ref_type_copy_decl_ss << program_arg_types[i] << " "
                              << list_kernel_signature_var_name << ";\n";
        // Construct the list from function pointers
        ref_type_copy_decl_ss << "std::vector<void*> temp_fn_ptrs(" << new_var_name << ".size());\n";
        ref_type_copy_decl_ss << "int fn_idx = 0;\n";
        ref_type_copy_decl_ss << "for (const auto& ptr_str: " << new_var_name << ") {\n";
        ref_type_copy_decl_ss << "temp_fn_ptrs[fn_idx] = (void *) strtoull(ptr_str.c_str(), nullptr, 16);\n";
        ref_type_copy_decl_ss <<  list_kernel_signature_var_name << ".emplace_back(temp_fn_ptrs[fn_idx++]);\n";
        ref_type_copy_decl_ss << "}\n";
        arg_ctor_list.emplace_back(list_kernel_signature_var_name);
      } else {
        // Otherwise, just unpack the arg inline in the ctor call.
        std::stringstream ss;
+2 −0
Original line number Diff line number Diff line
@@ -383,6 +383,8 @@ class pyxasm_visitor : public pyxasmBaseVisitor {
    // C++: for (auto [idx, var] : enumerate(listvar))
    auto iter_container = context->testlist()->test()[0]->getText();
    std::string counter_expr = context->exprlist()->expr()[0]->getText();
    // Add the for loop variable to the tracking list as well.
    new_var = counter_expr;
    if (context->exprlist()->expr().size() > 1) {
      counter_expr = "[" + counter_expr;
      for (int i = 1; i < context->exprlist()->expr().size(); i++) {
+1 −1
Original line number Diff line number Diff line
@@ -61,7 +61,7 @@ using AllowedKernelArgTypes =
                  xacc::internal_compiler::qubit, std::vector<double>,
                  std::vector<int>, qcor::PauliOperator, qcor::FermionOperator,
                  qcor::PairList<int>, std::vector<qcor::PauliOperator>,
                  std::vector<qcor::FermionOperator>>;
                  std::vector<qcor::FermionOperator>, std::vector<std::string>>;

// We will take as input a mapping of arg variable names to the argument itself.
using KernelArgDict = std::map<std::string, AllowedKernelArgTypes>;
+57 −10
Original line number Diff line number Diff line
@@ -408,22 +408,36 @@ class qjit(object):
                cpp_arg_str += ',' + \
                    'int& ' + arg
                continue
            if str(_type).startswith('typing.Callable'):
                cpp_type_str = 'KernelSignature<'
                for i in range(len(_type.__args__) - 1):
                
            # Helper to parse Python KernelSignature type annotation     
            def construct_callable_signature(clb_type):
                result_type_str = 'KernelSignature<'
                for i in range(len(clb_type.__args__) - 1):
                    # print("input type:", _type.__args__[i])
                    arg_type = _type.__args__[i]
                    arg_type = clb_type.__args__[i]
                    if str(arg_type) not in self.allowed_type_cpp_map:
                        print('Error, this quantum kernel arg type is not allowed: ', str(_type))
                        print('Error, this quantum kernel arg type is not allowed: ', str(clb_type))
                        exit(1)
                    cpp_type_str += self.allowed_type_cpp_map[str(arg_type)]
                    cpp_type_str += ','
                    result_type_str += self.allowed_type_cpp_map[str(arg_type)]
                    result_type_str += ','
                
                cpp_type_str = cpp_type_str[:-1]
                cpp_type_str += '>'
                result_type_str = result_type_str[:-1]
                result_type_str += '>'
                return result_type_str

            # Single Callable argument
            if str(_type).startswith('typing.Callable'):
                cpp_type_str = construct_callable_signature(_type)
                # print("cpp type", cpp_type_str)
                cpp_arg_str += ',' + cpp_type_str + ' ' + arg
                continue
            # List of KernelSignature
            if str(_type).startswith('typing.List[typing.Callable'):
                # Note: All the Callables in the list must have the same signature.
                # (and they should to be considered equivalent for grouping into a List)
                cpp_type_str = 'std::vector<' + construct_callable_signature(_type.__args__[0]) + '>'
                cpp_arg_str += ',' + cpp_type_str + ' ' + arg
                continue

            if str(_type) not in self.allowed_type_cpp_map:
                print('Error, this quantum kernel arg type is not allowed: ', str(_type))
@@ -718,6 +732,39 @@ class qjit(object):
                # Replace the argument (in the dict) with the function pointer
                # qjit is a pure-Python object, hence cannot be used by native QCOR.
                args_dict[arg_name] = hex(fn_ptr)
            # List of callables:
            if arg_type_str.startswith('typing.List[typing.Callable['):
                callable_qjit_list = args_dict[arg_name]
                need_recompile = False
                for clb in callable_qjit_list:
                    if clb.kernel_name() not in self.sorted_kernel_dep:
                        # print('New kernel:', clb.kernel_name())
                        # Add a kernel dependency
                        self.__kernels__graph.addKernelDependency(self.function.__name__, clb.kernel_name())
                        self.sorted_kernel_dep = self.__kernels__graph.getSortedDependency(self.function.__name__)
                        need_recompile = True
                
                if need_recompile:
                    # Create a new QJIT
                    self._qjit = QJIT()
                    self._qjit.internal_python_jit_compile(self.src, self.sorted_kernel_dep, self.extra_cpp_code, extra_headers)
                
                clb_fn_ptrs = []
                for clb in callable_qjit_list:
                    if not isinstance(clb, qjit):
                        print('Invalid argument type for {}. A list of quantum kernels (qjit) is expected.'.format(arg_name))
                        exit(1)
                    # print("Kernel name:", clb.kernel_name())
                    # This should always be successful.
                    fn_ptr = self._qjit.get_kernel_function_ptr(clb.kernel_name())
                    if fn_ptr == 0:
                        print('Failed to retrieve JIT-compiled function pointer for qjit kernel {}.'.format(clb.kernel_name()))
                        exit(1)
                    clb_fn_ptrs.append(hex(fn_ptr))
                    
                # Replace the argument (in the dict) with the list of function pointers
                # qjit is a pure-Python object, hence cannot be used by native QCOR.
                args_dict[arg_name] = clb_fn_ptrs
        
        return args_dict

+40 −0
Original line number Diff line number Diff line
@@ -92,5 +92,45 @@ class TestKernelJIT(unittest.TestCase):
        self.assertEqual(comp.getInstruction(2).name(), "X")
        self.assertEqual(comp.getInstruction(2).bits()[0], 2)

    def test_list_kernel_signature(self):
        set_qpu('qpp', {'shots':1024})
        @qjit
        def kernel_take_list(q: qreg, kernels_to_calls: List[KernelSignature(qubit)]):
            for f in kernels_to_calls:
                f(q[0])
        
        @qjit
        def kernel_take_list_ctrl(q: qreg, kernels_to_calls: List[KernelSignature(qubit)]):
            for f in kernels_to_calls:
                f.ctrl(q[1], q[0])

        @qjit
        def x_gate_kernel(q: qubit):
            X(q)

        @qjit
        def y_gate_kernel(q: qubit):
            Y(q)

        @qjit
        def z_gate_kernel(q: qubit):
            Z(q)

        q = qalloc(1)
        comp = kernel_take_list.extract_composite(q, [x_gate_kernel, y_gate_kernel, z_gate_kernel])
        print(comp)
        self.assertEqual(comp.nInstructions(), 3)
        self.assertEqual(comp.getInstruction(0).name(), "X")
        self.assertEqual(comp.getInstruction(1).name(), "Y")
        self.assertEqual(comp.getInstruction(2).name(), "Z")

        q2 = qalloc(2)
        comp1 = kernel_take_list_ctrl.extract_composite(q2, [x_gate_kernel, y_gate_kernel, z_gate_kernel])
        print(comp1)
        self.assertEqual(comp1.nInstructions(), 3)
        self.assertEqual(comp1.getInstruction(0).name(), "CNOT")
        self.assertEqual(comp1.getInstruction(1).name(), "CY")
        self.assertEqual(comp1.getInstruction(2).name(), "CZ")

if __name__ == '__main__':
  unittest.main()
 No newline at end of file
Loading