Commit fb6c8fb8 authored by Nguyen, Thien Minh's avatar Nguyen, Thien Minh
Browse files

A minimal fix to make qpu_lambda type casting more robust



Since we ultimately need to use the template parameter pack to cast the raw function pointer from JIT, we need to regularize the QJIT kernel rewritten from the lambda body.

The strategy here is to convert by-value type of simple types (int, double) to reference types so that type-forwarding in JIT with Args&& will work consistently.

Tested by: unit test cases: both by-value argument passing (vqe ansatz) and capture by reference (lambda captures other lambda, etc.)

Signed-off-by: default avatarThien Nguyen <nguyentm@ornl.gov>
parent ba747ca5
Loading
Loading
Loading
Loading
+3 −3
Original line number Diff line number Diff line
@@ -43,9 +43,9 @@ add_test(NAME multi_ctrl_test COMMAND ${CMAKE_BINARY_DIR}/qcor ${CMAKE_CURRENT_S
add_qcor_compile_and_exe_test(qrt_bell_ctrl bell/bell_control.cpp)

# Lambda tests
#add_qcor_compile_and_exe_test(qrt_qpu_lambda_simple qpu_lambda/lambda_test.cpp)
#add_qcor_compile_and_exe_test(qrt_qpu_lambda_bell qpu_lambda/lambda_test_bell.cpp)
#add_qcor_compile_and_exe_test(qrt_qpu_lambda_grover qpu_lambda/grover_lambda_oracle.cpp)
add_qcor_compile_and_exe_test(qrt_qpu_lambda_simple qpu_lambda/lambda_test.cpp)
add_qcor_compile_and_exe_test(qrt_qpu_lambda_bell qpu_lambda/lambda_test_bell.cpp)
add_qcor_compile_and_exe_test(qrt_qpu_lambda_grover qpu_lambda/grover_lambda_oracle.cpp)
add_qcor_compile_and_exe_test(qrt_qpu_lambdas_in_loop qpu_lambda/deuteron_lambda.cpp)
add_qcor_compile_and_exe_test(qrt_qpu_lambda_deuteron qpu_lambda/deuteron_vqe.cpp)

+19 −0
Original line number Diff line number Diff line
@@ -6,17 +6,36 @@ int main() {
           6.125 * Z(1) + 5.907;

  auto ansatz = qpu_lambda([](qreg q, double x) {
    print("x = ", x);
    X(q[0]);
    Ry(q[1], x);
    CX(q[1], q[0]);
  });

  auto ansatz_take_vec = qpu_lambda([](qreg q, std::vector<double> x) {
    print("x = ", x[0]);
    X(q[0]);
    Ry(q[1], x[0]);
    CX(q[1], q[0]);
  });

  OptFunction opt_function(
      [&](std::vector<double> x) { return ansatz.observe(H, qalloc(2), x[0]); },
      1);

  OptFunction opt_function_vec(
      [&](std::vector<double> x) {
        return ansatz_take_vec.observe(H, qalloc(2), x);
      },
      1);

  auto optimizer = createOptimizer("nlopt");
  auto [ground_energy, opt_params] = optimizer->optimize(opt_function);
  print("Energy: ", ground_energy);
  qcor_expect(std::abs(ground_energy + 1.74886) < 0.1);

  auto [ground_energy_vec, opt_params_vec] =
      optimizer->optimize(opt_function_vec);
  print("Energy: ", ground_energy_vec);
  qcor_expect(std::abs(ground_energy_vec + 1.74886) < 0.1);
}
 No newline at end of file
+32 −0
Original line number Diff line number Diff line
@@ -56,4 +56,36 @@ int main(int argc, char** argv) {
  qcor_expect(rb.counts()["0"] > 400);
  qcor_expect(rb.counts()["1"] > 400);
  qcor_expect(rb.counts()["0"] + rb.counts()["1"] == 1024);

  // Test passing an r-val to lambda
  auto ansatz_X0X1 = qpu_lambda([](qreg q, double x) {
    print("ansatz: x = ", x);
    X(q[0]);
    Ry(q[1], x);
    CX(q[1], q[0]);
    H(q);
    Measure(q);
  });

  auto qtest = qalloc(2);
  // Pass an rval...
  ansatz_X0X1(qtest, 1.2334);
  auto exp = qtest.exp_val_z();
  print("<X0X1> = ", exp);

  // Test a loop:
  const std::vector<double> expectedResults{
      0.0,       -0.324699, -0.614213, -0.837166, -0.9694,
      -0.996584, -0.915773, -0.735724, -0.475947, -0.164595,
      0.164595,  0.475947,  0.735724,  0.915773,  0.996584,
      0.9694,    0.837166,  0.614213,  0.324699,  0.0};

  const auto angles = linspace(-M_PI, M_PI, 20);
  for (size_t i = 0; i < angles.size(); ++i) {
    auto buffer = qalloc(2);
    ansatz_X0X1(buffer, angles[i]);
    auto exp = buffer.exp_val_z();
    print("<X0X1>(", angles[i], ") = ", exp, "; expected:", expectedResults[i]);
    qcor_expect(std::abs(expectedResults[i] - exp) < 0.1);
  }
}
+26 −0
Original line number Diff line number Diff line
@@ -59,6 +59,8 @@ class QJIT {

  template <typename... Args>
  void invoke(const std::string &kernel_name, Args... args) {
    // Debug: print the Args... type
    // std::cout << "QJIT Invoke: " << __PRETTY_FUNCTION__ << "\n";
    auto f_ptr = kernel_name_to_f_ptr[kernel_name];
    void (*kernel_functor)(Args...) = (void (*)(Args...))f_ptr;
    kernel_functor(std::forward<Args>(args)...);
@@ -68,6 +70,30 @@ class QJIT {
  void invoke_with_parent(const std::string &kernel_name,
                          std::shared_ptr<xacc::CompositeInstruction> parent,
                          Args... args) {
    // Debug: print the Args... type
    // std::cout << "QJIT Invoke with Parent: " << __PRETTY_FUNCTION__ << "\n";
    auto f_ptr = kernel_name_to_f_ptr_with_parent[kernel_name];
    void (*kernel_functor)(std::shared_ptr<xacc::CompositeInstruction>,
                           Args...) =
        (void (*)(std::shared_ptr<xacc::CompositeInstruction>, Args...))f_ptr;
    kernel_functor(parent, std::forward<Args>(args)...);
  }

  // Invoke with type forwarding: Args &&
  template <typename... Args>
  void invoke_forwarding(const std::string &kernel_name, Args &&... args) {
    // std::cout << "QJIT Invoke: " << __PRETTY_FUNCTION__ << "\n";
    auto f_ptr = kernel_name_to_f_ptr[kernel_name];
    void (*kernel_functor)(Args...) = (void (*)(Args...))f_ptr;
    kernel_functor(std::forward<Args>(args)...);
  }

  // Invoke with type forwarding: Args &&
  template <typename... Args>
  void invoke_with_parent_forwarding(
      const std::string &kernel_name,
      std::shared_ptr<xacc::CompositeInstruction> parent, Args &&... args) {
    // std::cout << "QJIT Invoke with Parent: " << __PRETTY_FUNCTION__ << "\n";
    auto f_ptr = kernel_name_to_f_ptr_with_parent[kernel_name];
    void (*kernel_functor)(std::shared_ptr<xacc::CompositeInstruction>,
                           Args...) =
+105 −18
Original line number Diff line number Diff line
#pragma once
#include <optional>
#include "qcor_jit.hpp"
#include "qcor_observable.hpp"
#include "qcor_utils.hpp"
#include "qrt.hpp"
#include <optional>

namespace qcor {
enum class QrtType { NISQ, FTQC };
@@ -191,8 +191,7 @@ public:

  virtual ~QuantumKernel() {}

  template<typename... ArgTypes>
  friend class KernelSignature;
  template <typename... ArgTypes> friend class KernelSignature;
};

// We use the following to enable ctrl operations on our single
@@ -311,6 +310,89 @@ public:
    auto first = src_str.find_first_of("(");
    auto last = src_str.find_first_of(")");
    auto tt = src_str.substr(first, last - first + 1);
    // Parse the argument list
    const auto arg_type_and_names = [](const std::string &arg_string_decl)
        -> std::vector<std::pair<std::string, std::string>> {
      // std::cout << "HOWDY:" << arg_string_decl << "\n";
      std::vector<std::pair<std::string, std::string>> result;
      const auto args_string =
          arg_string_decl.substr(1, arg_string_decl.size() - 2);
      std::stack<char> grouping_chars;
      std::string type_name;
      std::string var_name;
      std::string temp;
      // std::cout << args_string << "\n";
      for (int i = 0; i < args_string.size(); ++i) {
        if (isspace(args_string[i]) && grouping_chars.empty()) {
          type_name = temp;
          temp.clear();
        } else if (args_string[i] == ',') {
          var_name = temp;
          if (var_name[0] == '&') {
            type_name += "&";
            var_name = var_name.substr(1);
          }
          result.emplace_back(std::make_pair(type_name, var_name));
          type_name.clear();
          var_name.clear();
          temp.clear();
        } else {
          temp.push_back(args_string[i]);
        }

        if (args_string[i] == '<') {
          grouping_chars.push(args_string[i]);
        }
        if (args_string[i] == '>') {
          assert(grouping_chars.top() == '<');
          grouping_chars.pop();
        }
      }

      // Last one:
      var_name = temp;
      if (var_name[0] == '&') {
        type_name += "&";
        var_name = var_name.substr(1);
      }
      result.emplace_back(std::make_pair(type_name, var_name));
      return result;
    }(tt);

    // Map simple type to its reference type so that the
    // we can use consistent type-forwarding
    // when casting the JIT raw function pointer.
    // Currently, looks like only these simple types are having problem
    // with perfect type forwarding.
    // i.e. by-value arguments of these types are incompatible with a by-ref
    // casted function.
    static const std::unordered_map<std::string, std::string>
        FORWARD_TYPE_CONVERSION_MAP{{"int", "const int&"},
                                    {"double", "const double&"}};
    std::vector<std::pair<std::string, std::string>> forward_types;
    for (const auto &[type, name] : arg_type_and_names) {
      // std::cout << type << " --> " << name << "\n";
      if (FORWARD_TYPE_CONVERSION_MAP.find(type) !=
          FORWARD_TYPE_CONVERSION_MAP.end()) {
        auto iter = FORWARD_TYPE_CONVERSION_MAP.find(type);
        forward_types.emplace_back(std::make_pair(iter->second, name));
      } else {
        forward_types.emplace_back(std::make_pair(type, name));
      }
    }

    // std::cout << "After\n";
    // Construct the new arg signature clause:
    std::string arg_clause_new;
    arg_clause_new.push_back('(');
    for (const auto &[type, name] : forward_types) {
      arg_clause_new.append(type);
      arg_clause_new.push_back(' ');
      arg_clause_new.append(name);
      arg_clause_new.push_back(',');
      // std::cout << type << " --> " << name << "\n";
    }
    arg_clause_new.pop_back();

    // Get the capture type:
    // By default "[]", pass by reference.
@@ -363,17 +445,19 @@ public:
      }

      tt.insert(last - capture_type.size(), args_string);
      arg_clause_new.append(args_string);
    }

    // Extract the function body
    first = src_str.find_first_of("{");
    last = src_str.find_last_of("}");
    auto rr = src_str.substr(first, last - first + 1);

    arg_clause_new.push_back(')');
    // std::cout << "New signature: " << arg_clause_new << "\n";
    // Reconstruct with new args signature and
    // existing function body
    std::stringstream ss;
    ss << "__qpu__ void foo" << tt << rr;
    ss << "__qpu__ void foo" << arg_clause_new << rr;

    // Get as a string, and insert capture
    // preamble if necessary
@@ -405,19 +489,19 @@ public:
      auto final_args_tuple = std::tuple_cat(kernel_args_tuple, capture_vars);
      std::apply(
          [&](auto &&... args) {
            qjit.invoke_with_parent("foo", parent, args...);
            qjit.invoke_with_parent_forwarding("foo", parent, args...);
          },
          final_args_tuple);
    } else if constexpr (std::conjunction_v<
                             std::is_copy_assignable<CaptureArgs>...>) {
      // constexpr compile-time check to prevent compiler from looking at this code path
      // if the capture variable is non-copyable, e.g. qpu_lambda.
      // constexpr compile-time check to prevent compiler from looking at this
      // code path if the capture variable is non-copyable, e.g. qpu_lambda.
      // By-value:
      auto final_args_tuple =
          std::tuple_cat(kernel_args_tuple, optional_copy_capture_vars.value());
      std::apply(
          [&](auto &&... args) {
            qjit.invoke_with_parent("foo", parent, args...);
            qjit.invoke_with_parent_forwarding("foo", parent, args...);
          },
          final_args_tuple);
    }
@@ -430,14 +514,14 @@ public:
      // By-ref
      // Merge the function args and the capture vars and execute
      auto final_args_tuple = std::tuple_cat(kernel_args_tuple, capture_vars);
      std::apply([&](auto &&... args) { qjit.invoke("foo", args...); },
      std::apply([&](auto &&... args) { qjit.invoke_forwarding("foo", args...); },
                 final_args_tuple);
    } else if constexpr (std::conjunction_v<
                             std::is_copy_assignable<CaptureArgs>...>) {
      // By-value
      auto final_args_tuple =
          std::tuple_cat(kernel_args_tuple, optional_copy_capture_vars.value());
      std::apply([&](auto &&... args) { qjit.invoke("foo", args...); },
      std::apply([&](auto &&... args) { qjit.invoke_forwarding("foo", args...); },
                 final_args_tuple);
    }
  }
@@ -465,7 +549,8 @@ public:
            const std::vector<int> &ctrl_idxs, FunctionArgs... args) {
    std::vector<qubit> ctrl_qubit_vec;
    for (int i = 0; i < ctrl_idxs.size(); i++) {
      ctrl_qubit_vec.push_back({"q", static_cast<size_t>(ctrl_idxs[i]), nullptr});
      ctrl_qubit_vec.push_back(
          {"q", static_cast<size_t>(ctrl_idxs[i]), nullptr});
    }
    ctrl(ir, ctrl_qubit_vec, args...);
  }
@@ -505,8 +590,9 @@ public:
    return internal::print_kernel(callable, os, args...);
  }

  template <typename... FunctionArgs>
  void print_kernel(FunctionArgs... args) { print_kernel(std::cout, args...); }
  template <typename... FunctionArgs> void print_kernel(FunctionArgs... args) {
    print_kernel(std::cout, args...);
  }

  template <typename... FunctionArgs>
  std::size_t n_instructions(FunctionArgs... args) {
@@ -594,7 +680,8 @@ public:
            const std::vector<int> ctrl_idxs, Args... args) {
    std::vector<qubit> ctrl_qubit_vec;
    for (int i = 0; i < ctrl_idxs.size(); i++) {
      ctrl_qubit_vec.push_back({"q", static_cast<size_t>(ctrl_idxs[i]), nullptr});
      ctrl_qubit_vec.push_back(
          {"q", static_cast<size_t>(ctrl_idxs[i]), nullptr});
    }
    ctrl(ir, ctrl_qubit_vec, args...);
  }