Commit c2a70e57 authored by Nguyen, Thien Minh's avatar Nguyen, Thien Minh
Browse files

Support KernelSignature in vector



With KernelSignature in a vector, we need to be able to reliably add the parent_kernel to the operator() call.
It's error-prone to try doing that at the token collector level, hence using a type-safe helper to iterate over the kernel arguments.
We'll pick up any KernelSignature vars and attach the proper parent_kernel => i.e. its operator() invocation w/o the CompositeInstruction arg will be using the correct parent composite.

Signed-off-by: default avatarThien Nguyen <nguyentm@ornl.gov>
parent 7ba86767
Loading
Loading
Loading
Loading
+10 −0
Original line number Diff line number Diff line
@@ -193,6 +193,16 @@ void QCORSyntaxHandler::GetReplacement(
  // We only support one buffer in FTQC mode atm.
  OS << "quantum::set_current_buffer(" << bufferNames[0] << ".results());\n";
  OS << "}\n";
  // Set the parent_kernel of this kernel to any KernelSignature
  // (even nested in container) so that even w/o parent_kernel added by the 
  // token collector, the KernelSignature will still be referring to the correct parent_kernel.
  // i.e. we can track KernelSignature coming from complex data container such as std::vector
  // rather than relying on the list of kernels in translation unit.
  OS << "init_kernel_signature_args(parent_kernel, " << program_parameters[0];
  for (int i = 1; i < program_arg_types.size(); i++) {
    OS << ", " << program_parameters[i];
  }
  OS << ");\n";
  OS << new_src << "\n";
  OS << "}\n";

+2 −2
Original line number Diff line number Diff line
@@ -738,7 +738,7 @@ class qjit(object):
                need_recompile = False
                for clb in callable_qjit_list:
                    if clb.kernel_name() not in self.sorted_kernel_dep:
                        print('New kernel:', clb.kernel_name())
                        # print('New kernel:', clb.kernel_name())
                        # Add a kernel dependency
                        self.__kernels__graph.addKernelDependency(self.function.__name__, clb.kernel_name())
                        self.sorted_kernel_dep = self.__kernels__graph.getSortedDependency(self.function.__name__)
@@ -754,7 +754,7 @@ class qjit(object):
                    if not isinstance(clb, qjit):
                        print('Invalid argument type for {}. A list of quantum kernels (qjit) is expected.'.format(arg_name))
                        exit(1)
                    print("Kernel name:", clb.kernel_name())
                    # print("Kernel name:", clb.kernel_name())
                    # This should always be successful.
                    fn_ptr = self._qjit.get_kernel_function_ptr(clb.kernel_name())
                    if fn_ptr == 0:
+30 −0
Original line number Diff line number Diff line
@@ -92,5 +92,35 @@ class TestKernelJIT(unittest.TestCase):
        self.assertEqual(comp.getInstruction(2).name(), "X")
        self.assertEqual(comp.getInstruction(2).bits()[0], 2)

    def test_list_kernel_signature(self):
        set_qpu('qpp', {'shots':1024})
        @qjit
        def kernel_take_list(q: qreg, kernels_to_calls: List[KernelSignature(qubit)]):
            for f in kernels_to_calls:
                f(q[0])
        
        @qjit
        def x_gate_kernel(q: qubit):
            print("Call X")
            X(q)

        @qjit
        def y_gate_kernel(q: qubit):
            print("Call Y")
            Y(q)

        @qjit
        def z_gate_kernel(q: qubit):
            print("Call Z")
            Z(q)

        q = qalloc(1)
        comp = kernel_take_list.extract_composite(q, [x_gate_kernel, y_gate_kernel, z_gate_kernel])
        print(comp)
        self.assertEqual(comp.nInstructions(), 3)
        self.assertEqual(comp.getInstruction(0).name(), "X")
        self.assertEqual(comp.getInstruction(1).name(), "Y")
        self.assertEqual(comp.getInstruction(2).name(), "Z")

if __name__ == '__main__':
  unittest.main()
 No newline at end of file
+64 −0
Original line number Diff line number Diff line
@@ -757,6 +757,7 @@ class KernelSignature {
  callable_function_ptr<Args...> &function_pointer;
  std::function<void(std::shared_ptr<xacc::CompositeInstruction>, Args...)>
      lambda_func;
  std::shared_ptr<xacc::CompositeInstruction> parent_kernel;

 public:
  // Here we set function_pointer to null and instead
@@ -803,6 +804,14 @@ class KernelSignature {
    function_pointer(ir, args...);
  }

  void operator()(Args... args) {
    operator()(parent_kernel, args...);
  }

  void set_parent_kernel(std::shared_ptr<xacc::CompositeInstruction> ir) {
    parent_kernel = ir;
  }

  void ctrl(std::shared_ptr<xacc::CompositeInstruction> ir,
            const std::vector<qubit> &ctrl_qbits, Args... args) {
    internal::apply_control(ir, ctrl_qbits, *this, args...);
@@ -867,6 +876,61 @@ class KernelSignature {
  }
};

// Templated helper to attach parent_kernel to any
// KernelSignature arguments even nested in a std::vector<KernelSignature>
// The reason is that the Token Collector relies on a list of kernel names
// in the translation unit to attach parent_kernel to the operator() call.
// For KernelSignature provided in a container, tracking these at the
// TokenCollector level is error-prone (e.g. need to track any array accesses).
// Hence, we iterate over all kernel arguments and attach the parent_kernel
// to any KernelSignature argument at the top of each kernel's operator() call
// in a type-safe manner.

// Last arg
inline void init_kernel_signature_args_impl(
    std::shared_ptr<xacc::CompositeInstruction> ir) {}
template <typename T, typename... ArgsType>
void init_kernel_signature_args_impl(
    std::shared_ptr<xacc::CompositeInstruction> ir, T &t, ArgsType &... Args);

// Main function: to be added by the token collector at the beginning
// of each kernel operator().
template <typename... T>
void init_kernel_signature_args(std::shared_ptr<xacc::CompositeInstruction> ir,
                                T &... multi_inputs) {
  init_kernel_signature_args_impl(ir, multi_inputs...);
}

// Base case: generic type T,
// just ignore, proceed to the next arg.
template <typename T, typename... ArgsType>
void init_kernel_signature_args_impl(
    std::shared_ptr<xacc::CompositeInstruction> ir, T &t, ArgsType &... Args) {
  init_kernel_signature_args(ir, Args...);
}

// Special case: this is a vector:
// iterate over all elements.
template <typename T, typename... ArgsType>
void init_kernel_signature_args_impl(
    std::shared_ptr<xacc::CompositeInstruction> ir, std::vector<T> &vec_arg,
    ArgsType... Args) {
  for (auto &el : vec_arg) {
    // Iterate the vector elements.
    init_kernel_signature_args_impl(ir, el);
  }
  // Proceed with the rest.
  init_kernel_signature_args(ir, Args...);
}

// Handle KernelSignature arg => set the parent kernel.
template <typename... ArgsType>
void init_kernel_signature_args_impl(
    std::shared_ptr<xacc::CompositeInstruction> ir,
    KernelSignature<ArgsType...> &kernel_signature) {
  kernel_signature.set_parent_kernel(ir);
}

namespace internal {
// KernelSignature is the base of all kernel-like objects
// and we use it to implement kernel modifiers && utilities.