Unverified Commit 3a6b5f6c authored by Mccaskey, Alex's avatar Mccaskey, Alex Committed by GitHub
Browse files

Merge pull request #146 from tnguyen-ornl/tnguyen/lambda-type-forwarding

Consistent type forwarding for qpu_lambda
parents 5caf8879 f3d19512
Loading
Loading
Loading
Loading
Loading
+3 −3
Original line number Diff line number Diff line
@@ -43,9 +43,9 @@ add_test(NAME multi_ctrl_test COMMAND ${CMAKE_BINARY_DIR}/qcor ${CMAKE_CURRENT_S
add_qcor_compile_and_exe_test(qrt_bell_ctrl bell/bell_control.cpp)

# Lambda tests
#add_qcor_compile_and_exe_test(qrt_qpu_lambda_simple qpu_lambda/lambda_test.cpp)
#add_qcor_compile_and_exe_test(qrt_qpu_lambda_bell qpu_lambda/lambda_test_bell.cpp)
#add_qcor_compile_and_exe_test(qrt_qpu_lambda_grover qpu_lambda/grover_lambda_oracle.cpp)
add_qcor_compile_and_exe_test(qrt_qpu_lambda_simple qpu_lambda/lambda_test.cpp)
add_qcor_compile_and_exe_test(qrt_qpu_lambda_bell qpu_lambda/lambda_test_bell.cpp)
add_qcor_compile_and_exe_test(qrt_qpu_lambda_grover qpu_lambda/grover_lambda_oracle.cpp)
add_qcor_compile_and_exe_test(qrt_qpu_lambdas_in_loop qpu_lambda/deuteron_lambda.cpp)
add_qcor_compile_and_exe_test(qrt_qpu_lambda_deuteron qpu_lambda/deuteron_vqe.cpp)

+19 −0
Original line number Diff line number Diff line
@@ -6,17 +6,36 @@ int main() {
           6.125 * Z(1) + 5.907;

  auto ansatz = qpu_lambda([](qreg q, double x) {
    print("x = ", x);
    X(q[0]);
    Ry(q[1], x);
    CX(q[1], q[0]);
  });

  auto ansatz_take_vec = qpu_lambda([](qreg q, std::vector<double> x) {
    print("x = ", x[0]);
    X(q[0]);
    Ry(q[1], x[0]);
    CX(q[1], q[0]);
  });

  OptFunction opt_function(
      [&](std::vector<double> x) { return ansatz.observe(H, qalloc(2), x[0]); },
      1);

  OptFunction opt_function_vec(
      [&](std::vector<double> x) {
        return ansatz_take_vec.observe(H, qalloc(2), x);
      },
      1);

  auto optimizer = createOptimizer("nlopt");
  auto [ground_energy, opt_params] = optimizer->optimize(opt_function);
  print("Energy: ", ground_energy);
  qcor_expect(std::abs(ground_energy + 1.74886) < 0.1);

  auto [ground_energy_vec, opt_params_vec] =
      optimizer->optimize(opt_function_vec);
  print("Energy: ", ground_energy_vec);
  qcor_expect(std::abs(ground_energy_vec + 1.74886) < 0.1);
}
 No newline at end of file
+93 −0
Original line number Diff line number Diff line
@@ -56,4 +56,97 @@ int main(int argc, char** argv) {
  qcor_expect(rb.counts()["0"] > 400);
  qcor_expect(rb.counts()["1"] > 400);
  qcor_expect(rb.counts()["0"] + rb.counts()["1"] == 1024);

  // Test passing an r-val to lambda
  auto ansatz_X0X1 = qpu_lambda([](qreg q, double x) {
    print("ansatz: x = ", x);
    X(q[0]);
    Ry(q[1], x);
    CX(q[1], q[0]);
    H(q);
    Measure(q);
  });

  auto qtest = qalloc(2);
  // Pass an rval...
  ansatz_X0X1(qtest, 1.2334);
  auto exp = qtest.exp_val_z();
  print("<X0X1> = ", exp);

  // Test a loop:
  const std::vector<double> expectedResults{
      0.0,       -0.324699, -0.614213, -0.837166, -0.9694,
      -0.996584, -0.915773, -0.735724, -0.475947, -0.164595,
      0.164595,  0.475947,  0.735724,  0.915773,  0.996584,
      0.9694,    0.837166,  0.614213,  0.324699,  0.0};

  const auto angles = linspace(-M_PI, M_PI, 20);
  for (size_t i = 0; i < angles.size(); ++i) {
    auto buffer = qalloc(2);
    ansatz_X0X1(buffer, angles[i]);
    auto exp = buffer.exp_val_z();
    print("<X0X1>(", angles[i], ") = ", exp, "; expected:", expectedResults[i]);
    qcor_expect(std::abs(expectedResults[i] - exp) < 0.1);
  }

  // Test by-ref argument...
  auto add_one = qpu_lambda([](qreg q, int &result) {
    print("add_one: result =", result);
    result++;
  });
  
  // capture add_one lambda and use by-ref arguments.
  auto add_two = qpu_lambda(
      [](qreg q, int &result) {
        add_one(q, result);
        add_one(q, result);
      },
      add_one);
  auto buffer_test = qalloc(2);
  int test_val = 1;

  add_one(buffer_test, test_val);
  qcor_expect(test_val == 2);

  add_two(buffer_test, test_val);
  qcor_expect(test_val == 4);

  auto add_one_copy = qpu_lambda([](qreg q, int result) {
    print("add_one: entry result =", result);
    result++;
    print("add_one: exit result =", result);
  });

  auto test_val_const = 12;
  add_one_copy(buffer_test, test_val_const);
  // Should stay the same
  qcor_expect(test_val_const == 12);

  auto count_qubits = qpu_lambda([](qreg q, int &result) {
    result = q.size();
  });

  int nb_qubits = 0;
  count_qubits(qalloc(20), nb_qubits);
  std::cout << "Count = " << nb_qubits << "\n";
  qcor_expect(nb_qubits == 20);

  auto vector_sum =
      qpu_lambda([](qreg q, std::vector<double> input, double &result) {
        result = 0.0;
        for (auto &val : input) {
          result = result + val;
        }
      });

  double check = 0.0;
  std::vector<double> vec_to_check { 1.0, 2.0, 3.0 };
  vector_sum(qalloc(1), vec_to_check, check);
  std::cout << "Sum: " << check << "\n";
  qcor_expect(std::abs(check - 6.0) < 1e-12);
  check = 0.0;
  // Inline construction
  vector_sum(qalloc(1), std::vector<double>{2.0, 4.0, 6.0}, check);
  std::cout << "Sum: " << check << "\n";
  qcor_expect(std::abs(check - 12.0) < 1e-12);
}
+26 −0
Original line number Diff line number Diff line
@@ -59,6 +59,8 @@ class QJIT {

  template <typename... Args>
  void invoke(const std::string &kernel_name, Args... args) {
    // Debug: print the Args... type
    // std::cout << "QJIT Invoke: " << __PRETTY_FUNCTION__ << "\n";
    auto f_ptr = kernel_name_to_f_ptr[kernel_name];
    void (*kernel_functor)(Args...) = (void (*)(Args...))f_ptr;
    kernel_functor(std::forward<Args>(args)...);
@@ -68,6 +70,30 @@ class QJIT {
  void invoke_with_parent(const std::string &kernel_name,
                          std::shared_ptr<xacc::CompositeInstruction> parent,
                          Args... args) {
    // Debug: print the Args... type
    // std::cout << "QJIT Invoke with Parent: " << __PRETTY_FUNCTION__ << "\n";
    auto f_ptr = kernel_name_to_f_ptr_with_parent[kernel_name];
    void (*kernel_functor)(std::shared_ptr<xacc::CompositeInstruction>,
                           Args...) =
        (void (*)(std::shared_ptr<xacc::CompositeInstruction>, Args...))f_ptr;
    kernel_functor(parent, std::forward<Args>(args)...);
  }

  // Invoke with type forwarding: Args &&
  template <typename... Args>
  void invoke_forwarding(const std::string &kernel_name, Args &&... args) {
    // std::cout << "QJIT Invoke: " << __PRETTY_FUNCTION__ << "\n";
    auto f_ptr = kernel_name_to_f_ptr[kernel_name];
    void (*kernel_functor)(Args...) = (void (*)(Args...))f_ptr;
    kernel_functor(std::forward<Args>(args)...);
  }

  // Invoke with type forwarding: Args &&
  template <typename... Args>
  void invoke_with_parent_forwarding(
      const std::string &kernel_name,
      std::shared_ptr<xacc::CompositeInstruction> parent, Args &&... args) {
    // std::cout << "QJIT Invoke with Parent: " << __PRETTY_FUNCTION__ << "\n";
    auto f_ptr = kernel_name_to_f_ptr_with_parent[kernel_name];
    void (*kernel_functor)(std::shared_ptr<xacc::CompositeInstruction>,
                           Args...) =
+128 −24
Original line number Diff line number Diff line
#pragma once
#include <optional>
#include "qcor_jit.hpp"
#include "qcor_observable.hpp"
#include "qcor_utils.hpp"
#include "qrt.hpp"
#include <optional>

namespace qcor {
enum class QrtType { NISQ, FTQC };
@@ -191,8 +191,7 @@ public:

  virtual ~QuantumKernel() {}

  template<typename... ArgTypes>
  friend class KernelSignature;
  template <typename... ArgTypes> friend class KernelSignature;
};

// We use the following to enable ctrl operations on our single
@@ -311,6 +310,92 @@ public:
    auto first = src_str.find_first_of("(");
    auto last = src_str.find_first_of(")");
    auto tt = src_str.substr(first, last - first + 1);
    // Parse the argument list
    const auto arg_type_and_names = [](const std::string &arg_string_decl)
        -> std::vector<std::pair<std::string, std::string>> {
      // std::cout << "HOWDY:" << arg_string_decl << "\n";
      std::vector<std::pair<std::string, std::string>> result;
      const auto args_string =
          arg_string_decl.substr(1, arg_string_decl.size() - 2);
      std::stack<char> grouping_chars;
      std::string type_name;
      std::string var_name;
      std::string temp;
      // std::cout << args_string << "\n";
      for (int i = 0; i < args_string.size(); ++i) {
        if (isspace(args_string[i]) && grouping_chars.empty()) {
          type_name = temp;
          temp.clear();
        } else if (args_string[i] == ',') {
          var_name = temp;
          if (var_name[0] == '&') {
            type_name += "&";
            var_name = var_name.substr(1);
          }
          result.emplace_back(std::make_pair(type_name, var_name));
          type_name.clear();
          var_name.clear();
          temp.clear();
        } else {
          temp.push_back(args_string[i]);
        }

        if (args_string[i] == '<') {
          grouping_chars.push(args_string[i]);
        }
        if (args_string[i] == '>') {
          assert(grouping_chars.top() == '<');
          grouping_chars.pop();
        }
      }

      // Last one:
      var_name = temp;
      if (var_name[0] == '&') {
        type_name += "&";
        var_name = var_name.substr(1);
      }
      result.emplace_back(std::make_pair(type_name, var_name));
      return result;
    }(tt);

    // Map simple type to its reference type so that the
    // we can use consistent type-forwarding
    // when casting the JIT raw function pointer.
    // Currently, looks like only these simple types are having problem
    // with perfect type forwarding.
    // i.e. by-value arguments of these types are incompatible with a by-ref
    // casted function.
    static const std::unordered_map<std::string, std::string>
        FORWARD_TYPE_CONVERSION_MAP{{"int", "int&"},
                                    {"double", "double&"}};
    std::vector<std::pair<std::string, std::string>> forward_types;
    // Replicate by-value by create copies and restore the variables.
    std::vector<std::string> byval_casted_arg_names;
    for (const auto &[type, name] : arg_type_and_names) {
      // std::cout << type << " --> " << name << "\n";
      if (FORWARD_TYPE_CONVERSION_MAP.find(type) !=
          FORWARD_TYPE_CONVERSION_MAP.end()) {
        auto iter = FORWARD_TYPE_CONVERSION_MAP.find(type);
        forward_types.emplace_back(std::make_pair(iter->second, name));
        byval_casted_arg_names.emplace_back(name);
      } else {
        forward_types.emplace_back(std::make_pair(type, name));
      }
    }

    // std::cout << "After\n";
    // Construct the new arg signature clause:
    std::string arg_clause_new;
    arg_clause_new.push_back('(');
    for (const auto &[type, name] : forward_types) {
      arg_clause_new.append(type);
      arg_clause_new.push_back(' ');
      arg_clause_new.append(name);
      arg_clause_new.push_back(',');
      // std::cout << type << " --> " << name << "\n";
    }
    arg_clause_new.pop_back();

    // Get the capture type:
    // By default "[]", pass by reference.
@@ -363,17 +448,19 @@ public:
      }

      tt.insert(last - capture_type.size(), args_string);
      arg_clause_new.append(args_string);
    }

    // Extract the function body
    first = src_str.find_first_of("{");
    last = src_str.find_last_of("}");
    auto rr = src_str.substr(first, last - first + 1);

    arg_clause_new.push_back(')');
    // std::cout << "New signature: " << arg_clause_new << "\n";
    // Reconstruct with new args signature and
    // existing function body
    std::stringstream ss;
    ss << "__qpu__ void foo" << tt << rr;
    ss << "__qpu__ void foo" << arg_clause_new << rr;

    // Get as a string, and insert capture
    // preamble if necessary
@@ -382,6 +469,18 @@ public:
    if (!capture_var_names.empty())
      jit_src.insert(first + 1, capture_preamble);
    
    if (!byval_casted_arg_names.empty()) {
      std::stringstream cache_string, restore_string;
      for (const auto& var: byval_casted_arg_names) {
        cache_string << "auto __" <<  var << "__cached__ = " << var << ";\n";
        restore_string << var << " = __" <<  var << "__cached__;\n";
      }
      const auto begin = jit_src.find_first_of("{");
      jit_src.insert(begin + 1, cache_string.str());
      const auto end = jit_src.find_last_of("}");
      jit_src.insert(end, restore_string.str());
    }

    // std::cout << "JITSRC:\n" << jit_src << "\n";
    // JIT Compile, storing the function pointers
    qjit.jit_compile(jit_src);
@@ -389,13 +488,13 @@ public:

  template <typename... FunctionArgs>
  void eval_with_parent(std::shared_ptr<CompositeInstruction> parent,
                        FunctionArgs... args) {
    this->operator()(parent, args...);
                        FunctionArgs &&... args) {
    this->operator()(parent, std::forward<FunctionArgs>(args)...);
  }

  template <typename... FunctionArgs>
  void operator()(std::shared_ptr<CompositeInstruction> parent,
                  FunctionArgs... args) {
                  FunctionArgs &&... args) {
    // Map the function args to a tuple
    auto kernel_args_tuple = std::forward_as_tuple(args...);

@@ -405,39 +504,41 @@ public:
      auto final_args_tuple = std::tuple_cat(kernel_args_tuple, capture_vars);
      std::apply(
          [&](auto &&... args) {
            qjit.invoke_with_parent("foo", parent, args...);
            qjit.invoke_with_parent_forwarding("foo", parent, args...);
          },
          final_args_tuple);
    } else if constexpr (std::conjunction_v<
                             std::is_copy_assignable<CaptureArgs>...>) {
      // constexpr compile-time check to prevent compiler from looking at this code path
      // if the capture variable is non-copyable, e.g. qpu_lambda.
      // constexpr compile-time check to prevent compiler from looking at this
      // code path if the capture variable is non-copyable, e.g. qpu_lambda.
      // By-value:
      auto final_args_tuple =
          std::tuple_cat(kernel_args_tuple, optional_copy_capture_vars.value());
      std::apply(
          [&](auto &&... args) {
            qjit.invoke_with_parent("foo", parent, args...);
            qjit.invoke_with_parent_forwarding("foo", parent, args...);
          },
          final_args_tuple);
    }
  }

  template <typename... FunctionArgs> void operator()(FunctionArgs... args) {
  template <typename... FunctionArgs> void operator()(FunctionArgs &&... args) {
    // Map the function args to a tuple
    auto kernel_args_tuple = std::forward_as_tuple(args...);
    if (!optional_copy_capture_vars.has_value()) {
      // By-ref
      // Merge the function args and the capture vars and execute
      auto final_args_tuple = std::tuple_cat(kernel_args_tuple, capture_vars);
      std::apply([&](auto &&... args) { qjit.invoke("foo", args...); },
      std::apply(
          [&](auto &&... args) { qjit.invoke_forwarding("foo", args...); },
          final_args_tuple);
    } else if constexpr (std::conjunction_v<
                             std::is_copy_assignable<CaptureArgs>...>) {
      // By-value
      auto final_args_tuple =
          std::tuple_cat(kernel_args_tuple, optional_copy_capture_vars.value());
      std::apply([&](auto &&... args) { qjit.invoke("foo", args...); },
      std::apply(
          [&](auto &&... args) { qjit.invoke_forwarding("foo", args...); },
          final_args_tuple);
    }
  }
@@ -465,7 +566,8 @@ public:
            const std::vector<int> &ctrl_idxs, FunctionArgs... args) {
    std::vector<qubit> ctrl_qubit_vec;
    for (int i = 0; i < ctrl_idxs.size(); i++) {
      ctrl_qubit_vec.push_back({"q", static_cast<size_t>(ctrl_idxs[i]), nullptr});
      ctrl_qubit_vec.push_back(
          {"q", static_cast<size_t>(ctrl_idxs[i]), nullptr});
    }
    ctrl(ir, ctrl_qubit_vec, args...);
  }
@@ -505,8 +607,9 @@ public:
    return internal::print_kernel(callable, os, args...);
  }

  template <typename... FunctionArgs>
  void print_kernel(FunctionArgs... args) { print_kernel(std::cout, args...); }
  template <typename... FunctionArgs> void print_kernel(FunctionArgs... args) {
    print_kernel(std::cout, args...);
  }

  template <typename... FunctionArgs>
  std::size_t n_instructions(FunctionArgs... args) {
@@ -594,7 +697,8 @@ public:
            const std::vector<int> ctrl_idxs, Args... args) {
    std::vector<qubit> ctrl_qubit_vec;
    for (int i = 0; i < ctrl_idxs.size(); i++) {
      ctrl_qubit_vec.push_back({"q", static_cast<size_t>(ctrl_idxs[i]), nullptr});
      ctrl_qubit_vec.push_back(
          {"q", static_cast<size_t>(ctrl_idxs[i]), nullptr});
    }
    ctrl(ir, ctrl_qubit_vec, args...);
  }