Commit f3d19512 authored by Nguyen, Thien Minh's avatar Nguyen, Thien Minh
Browse files

Enable type forwarding for qpu_lambda arguments as well



This will enable by-ref passing via lambda arguments in addition to capture vars.

Signed-off-by: default avatarThien Nguyen <nguyentm@ornl.gov>
parent 6af087a7
Loading
Loading
Loading
Loading
+61 −0
Original line number Diff line number Diff line
@@ -88,4 +88,65 @@ int main(int argc, char** argv) {
    print("<X0X1>(", angles[i], ") = ", exp, "; expected:", expectedResults[i]);
    qcor_expect(std::abs(expectedResults[i] - exp) < 0.1);
  }

  // Test by-ref argument...
  auto add_one = qpu_lambda([](qreg q, int &result) {
    print("add_one: result =", result);
    result++;
  });
  
  // capture add_one lambda and use by-ref arguments.
  auto add_two = qpu_lambda(
      [](qreg q, int &result) {
        add_one(q, result);
        add_one(q, result);
      },
      add_one);
  auto buffer_test = qalloc(2);
  int test_val = 1;

  add_one(buffer_test, test_val);
  qcor_expect(test_val == 2);

  add_two(buffer_test, test_val);
  qcor_expect(test_val == 4);

  auto add_one_copy = qpu_lambda([](qreg q, int result) {
    print("add_one: entry result =", result);
    result++;
    print("add_one: exit result =", result);
  });

  auto test_val_const = 12;
  add_one_copy(buffer_test, test_val_const);
  // Should stay the same
  qcor_expect(test_val_const == 12);

  auto count_qubits = qpu_lambda([](qreg q, int &result) {
    result = q.size();
  });

  int nb_qubits = 0;
  count_qubits(qalloc(20), nb_qubits);
  std::cout << "Count = " << nb_qubits << "\n";
  qcor_expect(nb_qubits == 20);

  auto vector_sum =
      qpu_lambda([](qreg q, std::vector<double> input, double &result) {
        result = 0.0;
        for (auto &val : input) {
          result = result + val;
        }
      });

  double check = 0.0;
  std::vector<double> vec_to_check { 1.0, 2.0, 3.0 };
  vector_sum(qalloc(1), vec_to_check, check);
  std::cout << "Sum: " << check << "\n";
  qcor_expect(std::abs(check - 6.0) < 1e-12);
  check = 0.0;
  // Inline construction
  vector_sum(qalloc(1), std::vector<double>{2.0, 4.0, 6.0}, check);
  std::cout << "Sum: " << check << "\n";
  qcor_expect(std::abs(check - 12.0) < 1e-12);
}
+27 −10
Original line number Diff line number Diff line
@@ -367,15 +367,18 @@ public:
    // i.e. by-value arguments of these types are incompatible with a by-ref
    // casted function.
    static const std::unordered_map<std::string, std::string>
        FORWARD_TYPE_CONVERSION_MAP{{"int", "const int&"},
                                    {"double", "const double&"}};
        FORWARD_TYPE_CONVERSION_MAP{{"int", "int&"},
                                    {"double", "double&"}};
    std::vector<std::pair<std::string, std::string>> forward_types;
    // Replicate by-value by create copies and restore the variables.
    std::vector<std::string> byval_casted_arg_names;
    for (const auto &[type, name] : arg_type_and_names) {
      // std::cout << type << " --> " << name << "\n";
      if (FORWARD_TYPE_CONVERSION_MAP.find(type) !=
          FORWARD_TYPE_CONVERSION_MAP.end()) {
        auto iter = FORWARD_TYPE_CONVERSION_MAP.find(type);
        forward_types.emplace_back(std::make_pair(iter->second, name));
        byval_casted_arg_names.emplace_back(name);
      } else {
        forward_types.emplace_back(std::make_pair(type, name));
      }
@@ -466,6 +469,18 @@ public:
    if (!capture_var_names.empty())
      jit_src.insert(first + 1, capture_preamble);
    
    if (!byval_casted_arg_names.empty()) {
      std::stringstream cache_string, restore_string;
      for (const auto& var: byval_casted_arg_names) {
        cache_string << "auto __" <<  var << "__cached__ = " << var << ";\n";
        restore_string << var << " = __" <<  var << "__cached__;\n";
      }
      const auto begin = jit_src.find_first_of("{");
      jit_src.insert(begin + 1, cache_string.str());
      const auto end = jit_src.find_last_of("}");
      jit_src.insert(end, restore_string.str());
    }

    // std::cout << "JITSRC:\n" << jit_src << "\n";
    // JIT Compile, storing the function pointers
    qjit.jit_compile(jit_src);
@@ -473,13 +488,13 @@ public:

  template <typename... FunctionArgs>
  void eval_with_parent(std::shared_ptr<CompositeInstruction> parent,
                        FunctionArgs... args) {
    this->operator()(parent, args...);
                        FunctionArgs &&... args) {
    this->operator()(parent, std::forward<FunctionArgs>(args)...);
  }

  template <typename... FunctionArgs>
  void operator()(std::shared_ptr<CompositeInstruction> parent,
                  FunctionArgs... args) {
                  FunctionArgs &&... args) {
    // Map the function args to a tuple
    auto kernel_args_tuple = std::forward_as_tuple(args...);

@@ -507,21 +522,23 @@ public:
    }
  }

  template <typename... FunctionArgs> void operator()(FunctionArgs... args) {
  template <typename... FunctionArgs> void operator()(FunctionArgs &&... args) {
    // Map the function args to a tuple
    auto kernel_args_tuple = std::forward_as_tuple(args...);
    if (!optional_copy_capture_vars.has_value()) {
      // By-ref
      // Merge the function args and the capture vars and execute
      auto final_args_tuple = std::tuple_cat(kernel_args_tuple, capture_vars);
      std::apply([&](auto &&... args) { qjit.invoke_forwarding("foo", args...); },
      std::apply(
          [&](auto &&... args) { qjit.invoke_forwarding("foo", args...); },
          final_args_tuple);
    } else if constexpr (std::conjunction_v<
                             std::is_copy_assignable<CaptureArgs>...>) {
      // By-value
      auto final_args_tuple =
          std::tuple_cat(kernel_args_tuple, optional_copy_capture_vars.value());
      std::apply([&](auto &&... args) { qjit.invoke_forwarding("foo", args...); },
      std::apply(
          [&](auto &&... args) { qjit.invoke_forwarding("foo", args...); },
          final_args_tuple);
    }
  }