Commit 03fbdf0d authored by Nguyen, Thien Minh's avatar Nguyen, Thien Minh
Browse files

More robust JIT symbol lookup



Turns out that we may look up the wrong function symbol for the base kernel.

The reason is that the base kernel function has the same name as the QuantumKernel class.
i.e. the constructor has the same name as the free-standing function.

Signed-off-by: default avatarThien Nguyen <nguyentm@ornl.gov>
parent 7ca241e8
Loading
Loading
Loading
Loading
+34 −4
Original line number Diff line number Diff line
@@ -500,12 +500,42 @@ void QJIT::jit_compile(const std::string &code,
  // Insert dependency kernels as well:
  std::unordered_map<std::string, std::string> mangled_kernel_dep_map;
  for (const auto &dep : kernel_dependency) {
    std::vector<std::string> matches;
    for (Function &f : *module) {
      auto name = f.getName().str();
      if (demangle(name.c_str()).find(dep) != std::string::npos) {
        mangled_kernel_dep_map[dep] = name;
        break;
      }
      const auto name = f.getName().str();
      const auto demangledName = demangle(name.c_str());
      // We look for the function with the signature:
      // KernelName(shared_ptr<CompositeInstruction, Args...)
      // The problem is that there is a class named KernelName as well
      // which has a constructor with the same signature.
      // The ctor one will have a demangled name of KernelName::KernelName(...)
      // hence, we tie break them by the length.
      // Looks for a call-like symbol
      const std::string pattern = dep + "(";
      // Looks for the one that has parent Composite in the arg
      const std::string subPattern = "CompositeInstruction";
      if (demangle(name.c_str()).find(pattern) != std::string::npos &&
          demangle(name.c_str()).find(subPattern) != std::string::npos) {
        // std::cout << dep << " --> " << name << "\n";
        // std::cout << "Demangle: " << demangle(name.c_str()) << "\n";
        matches.emplace_back(name);
      }
    }
    if (matches.size() > 0) {
      // std::cout << "Matches for " << dep << ":\n";
      // for (const auto &match : matches) {
      //   std::cout << match << "\n";
      // }

      const auto chosenMatch =
          *std::min_element(matches.begin(), matches.end(),
                            [&](const std::string &s1, const std::string &s2) {
                              return demangle(s1.c_str()).length() <
                                     demangle(s2.c_str()).length();
                            });
      // std::cout << "Select match: " << chosenMatch << ": "
      //           << demangle(chosenMatch.c_str()) << "\n";
      mangled_kernel_dep_map[dep] = chosenMatch;
    }
  }

+8 −4
Original line number Diff line number Diff line
@@ -346,6 +346,13 @@ template <typename... Args>
using callable_function_ptr =
    void (*)(std::shared_ptr<xacc::CompositeInstruction>, Args...);

template <typename... Args>
callable_function_ptr<Args...> callable_function_ptr_from_raw_ptr(void *f_ptr) {
  void (*kernel_functor)(std::shared_ptr<xacc::CompositeInstruction>, Args...) =
      (callable_function_ptr<Args...>)f_ptr;
  return kernel_functor;
}

template <typename... Args>
class KernelSignature {
 protected:
@@ -355,10 +362,7 @@ class KernelSignature {
  KernelSignature(callable_function_ptr<Args...> &&f) : function_pointer(f) {}
  // Ctor from raw void* funtion pointer.
  KernelSignature(void *f_ptr)
      : KernelSignature((callable_function_ptr<Args...>)f_ptr) {
    std::cout << "Contruct KernelSignature from function pointer: " << f_ptr
              << "\n";
  }
      : KernelSignature(callable_function_ptr_from_raw_ptr<Args...>(f_ptr)) {}

  void operator()(std::shared_ptr<xacc::CompositeInstruction> ir,
                  Args... args) {