Commit d3dee29c authored by Mccaskey, Alex's avatar Mccaskey, Alex
Browse files

major improvement on the qpu_lambda, now takes lambda as string, processes it,...


major improvement on the qpu_lambda, now takes lambda as string, processes it, automatically handles variable names and capture vars passed in. performance improvement too, we only call qjit compile once

Signed-off-by: Mccaskey, Alex's avatarAlex McCaskey <mccaskeyaj@ornl.gov>
parent 871299ac
Loading
Loading
Loading
Loading
Loading
+9 −14
Original line number Diff line number Diff line
#include "qcor.hpp"
using namespace qcor;

int main() {
  qpu_lambda<> ansatz_X0X1(
      [](qreg q, double x) {
        qpu_lambda_body({

  auto ansatz_X0X1 = qpu_lambda([](qreg q, double x) {
    X(q[0]);
    Ry(q[1], x);
    CX(q[1], q[0]);
    H(q);
    Measure(q);
        })
      },
      qpu_lambda_variables({"q", "x"}, {}));
  });

  OptFunction obj(
      [&](const std::vector<double> &x, std::vector<double> &) {
        print("running ", x[0]);
        auto q = qalloc(2);
        ansatz_X0X1(q, x[0]);
        auto exp = q.exp_val_z();
        print(x[0], exp);
        print("<X0X1(",x[0],") = ", exp);
        return exp;
      },
      1);
+19 −20
Original line number Diff line number Diff line
@@ -4,24 +4,23 @@ int main(int argc, char** argv) {
  int n = argc;
  double m = 22;

  using namespace qcor;
 
  qpu_lambda<int, double> superposition(
      [](qreg q) {          // Provide the kernel lambda
        qpu_lambda_body({  // wrap function body in this macro
          print("n = ", n);
          print("m = ", m);
  auto a = qpu_lambda([](qreg q) {
      print("n was captured, and is ", n);
      print("m was captured, and is ", m);
      for (int i = 0; i < n; i++) {
        H(q[0]);
      }
      Measure(q[0]);
        })
      },
      qpu_lambda_variables({"q"},
                           {"n", "m"}),  // Must provide variable names in order
      n, m);                             // Must provide the captured variables
  }, n, m);

  auto q = qalloc(1);
  superposition(q);
  a(q);
  q.print();

  n = 2;
  m = 33.0;
  auto r = qalloc(1);
  print("running again to show capture variables are captured by reference");
  a(r);
  r.print();
}
+1 −0
Original line number Diff line number Diff line
@@ -234,6 +234,7 @@ const std::pair<std::string, std::string> QJIT::run_syntax_handler(
  std::vector<std::string> arg_types, arg_vars, bufferNames;
  auto args_split = split_args_signature(args_signature);
  for (auto &arg : args_split) {
    trim(arg);
    auto arg_var = split(arg, ' ');
    if (arg_var[0] == "qreg" || arg_var[0] == "xacc::internal_compiler::qreg") {
      bufferNames.push_back(arg_var[1]);
+63 −67
Original line number Diff line number Diff line
@@ -484,30 +484,12 @@ ONE_QUBIT_KERNEL_CTRL_ENABLER(Sdg, sdg)
// trailing variadic argument for the lambda class constructor. Once
// instantiated lambda invocation looks just like kernel invocation.

template <typename... Args>
using GenerateKernelBodyPtr = std::string (*)(Args...);

// This class is used to simplify the syntax of 
// passing kernel and capture variable names to the qpu_lambda
class qpu_lambda_variables {
 protected:
  std::vector<std::string> kernel_args;
  std::vector<std::string> capture_args;

 public:
  qpu_lambda_variables(std::initializer_list<std::string> k)
      : kernel_args(std::vector<std::string>(k)) {}
  qpu_lambda_variables(std::initializer_list<std::string> k,
                       std::initializer_list<std::string> c)
      : kernel_args(std::vector<std::string>(k)),
        capture_args(std::vector<std::string>(c)) {}
  std::vector<std::string> get_kernel_args() { return kernel_args; }
  std::vector<std::string> get_capture_args() { return capture_args; }
};

template <typename... Args>
class qpu_lambda {
template <typename... CaptureArgs>
class _qpu_lambda {
 private:

  // Private inner class for getting the type 
  // of a capture variable as a string at runtime
  class TupleToTypeArgString {
   protected:
    std::string &tmp;
@@ -538,66 +520,80 @@ class qpu_lambda {
    }
  };

 protected:
  void *f;
  std::tuple<Args...> capture_vars;
  qpu_lambda_variables variable_names_map;
  // Kernel lambda source string, has arg structure and body
  std::string &src_str;

 public:
  template <typename LambdaType>
  qpu_lambda(LambdaType &&ff, qpu_lambda_variables &&variable_names,
             Args... _capture_vars)
      : variable_names_map(std::move(variable_names)),
        capture_vars(std::forward_as_tuple(_capture_vars...)) {
    f = reinterpret_cast<void *>(+ff);
  }
  // Capture variable names
  std::string &capture_var_names;

  template <typename... FunctionArgs>
  void operator()(FunctionArgs... args) {
    auto casted = reinterpret_cast<GenerateKernelBodyPtr<FunctionArgs...>>(f);
    std::stringstream ss;
    QJIT qjit;
    // Map the args to a tuple
    auto kernel_args_tuple = std::make_tuple(args...);
  // Capture variables, stored in tuple
  std::tuple<CaptureArgs &...> capture_vars;

    // Get the kernel body as a string
    auto s = casted(args...);
  // Quantum Just-in-Time Compiler :)
  QJIT qjit;

    // Extract the kernel and capture variable names
    auto kernel_var_names = variable_names_map.get_kernel_args();
    auto capture_var_names = variable_names_map.get_capture_args();
 public:
  // Constructor, capture vars should be deduced without
  // specifying them since we're using C++17
  _qpu_lambda(std::string &&ff, std::string &&_capture_var_names,
              CaptureArgs &..._capture_vars)
      : src_str(ff),
        capture_var_names(_capture_var_names),
        capture_vars(std::forward_as_tuple(_capture_vars...)) {
    // Get the original args list
    auto first = src_str.find_first_of("(");
    auto last = src_str.find_first_of(")");
    auto tt = src_str.substr(first, last - first + 1);

    // Build up the function argument signature string
    // Need to append capture vars to this arg signature
    std::string capture_preamble = "";
    if (!capture_var_names.empty()) {
      std::string args_string = "";
    TupleToTypeArgString to(args_string, kernel_var_names);
    __internal__::tuple_for_each(kernel_args_tuple, to);
      TupleToTypeArgString co(args_string);
      __internal__::tuple_for_each(capture_vars, co);
    args_string = args_string.substr(0, args_string.length() - 1);

    std::string capture_preamble = "";
    for (auto [i, capture_name] : qcor::enumerate(capture_var_names)) {
      args_string = "," + args_string.substr(0, args_string.length() - 1);
      tt.insert(last - 2, args_string);
      capture_preamble += "\n";
      for (auto [i, capture_name] :
           qcor::enumerate(xacc::split(capture_var_names, ','))) {
        capture_preamble +=
            "auto " + capture_name + " = arg_" + std::to_string(i) + ";\n";
      }
    }

    // Extract the function body
    first = src_str.find_first_of("{");
    last = src_str.find_last_of("}");
    auto rr = src_str.substr(first, last - first + 1);

    // Reconstruct with new args signature and 
    // existing function body
    std::stringstream ss;
    ss << "__qpu__ void foo" << tt << rr;

    // Insert the capture preamble code
    s.insert(1, "\n" + capture_preamble);
    // Get as a string, and insert capture 
    // preamble if necessary
    auto jit_src = ss.str();
    first = jit_src.find_first_of("{");
    if (!capture_var_names.empty()) jit_src.insert(first + 1, capture_preamble);

    // Create the kernel string for QJIT
    ss << "__qpu__ void foo(" << args_string << ")\n" << s;
    // std::cout << "JITSRC:\n" << jit_src << "\n";
    // JIT Compile, storing the function pointers
    qjit.jit_compile(jit_src);
  }

    // Compile
    qjit.jit_compile(ss.str());
  template <typename... FunctionArgs>
  void operator()(FunctionArgs... args) {
    // Map the function args to a tuple
    auto kernel_args_tuple = std::make_tuple(args...);

    // Merge the kernel args and the capture vars and execute
    // Merge the function args and the capture vars and execute
    auto final_args_tuple = std::tuple_cat(kernel_args_tuple, capture_vars);
    std::apply([&](auto &&...args) { qjit.invoke("foo", args...); },
               final_args_tuple);
  }
};

#define STRINGIZE(A) #A
#define qpu_lambda_body(EXPR) return std::string(STRINGIZE(EXPR));
#define qpu_lambda(EXPR, ...) _qpu_lambda(#EXPR, #__VA_ARGS__, ##__VA_ARGS__)

}  // namespace qcor