quantum_kernel.hpp 44.2 KB
Newer Older
1
#pragma once
2
3
#include <optional>

4
#include "qcor_jit.hpp"
5
#include "qcor_observable.hpp"
6
#include "qcor_utils.hpp"
7
#include "qrt.hpp"
8
9

namespace qcor {
10
enum class QrtType { NISQ, FTQC };
11

12
// Forward declare
13
14
template <typename... Args>
class KernelSignature;
15
16
17
18

namespace internal {
// KernelSignature is the base of all kernel-like objects
// and we use it to implement kernel modifiers & utilities.
19
// i.e., anything that is KernelSignature-constructible can use these methods.
20
21
22
23
template <typename... Args>
void apply_control(std::shared_ptr<CompositeInstruction> parent_kernel,
                   const std::vector<qubit> &ctrl_qbits,
                   KernelSignature<Args...> &kernelCallable, Args... args);
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45

template <typename... Args>
void apply_adjoint(std::shared_ptr<CompositeInstruction> parent_kernel,
                   KernelSignature<Args...> &kernelCallable, Args... args);

template <typename... Args>
double observe(Observable &obs, KernelSignature<Args...> &kernelCallable,
               Args... args);

template <typename... Args>
Eigen::MatrixXcd as_unitary_matrix(KernelSignature<Args...> &kernelCallable,
                                   Args... args);
template <typename... Args>
std::string openqasm(KernelSignature<Args...> &kernelCallable, Args... args);

template <typename... Args>
void print_kernel(KernelSignature<Args...> &kernelCallable, std::ostream &os,
                  Args... args);

template <typename... Args>
std::size_t n_instructions(KernelSignature<Args...> &kernelCallable,
                           Args... args);
46
}  // namespace internal
47

48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
// The QuantumKernel represents the super-class of all qcor
// quantum kernel functors. Subclasses of this are auto-generated
// via the Clang Syntax Handler capability. Derived classes
// provide a destructor implementation that builds up and
// submits quantum instructions to the specified backend. This enables
// functor-like capability whereby programmers instantiate temporary
// instances of this via the constructor call, and the destructor is
// immediately called. More advanced usage is of course possible for
// qcor developers.
//
// This class works by taking the Derived type (CRTP) and the kernel function
// arguments as template parameters. The Derived type is therefore available for
// instantiation within provided static methods on QuantumKernel. The Args...
// are stored in a member tuple, and are available for use when evaluating the
// kernel. Importantly, QuantumKernel provides static adjoint and ctrl methods
// for auto-generating those circuits.
//
// The Syntax Handler will take kernels like this
// __qpu__ void foo(qreg q) { H(q[0]); }
// and create a derived type of QuantumKernel like this
// class foo : public qcor::QuantumKernel<class foo, qreg> {...};
// with an appropriate implementation of constructors and destructors.
// Users can then call for adjoint/ctrl methods like this
// foo::adjoint(q); foo::ctrl(1, q);
72
73
74
template <typename Derived, typename... Args>
class QuantumKernel {
 protected:
75
76
77
78
79
80
81
82
83
84
85
  // Tuple holder for variadic kernel arguments
  std::tuple<Args...> args_tuple;

  // Parent kernel - null if this is the top-level kernel
  // not null if this is a nested kernel call
  std::shared_ptr<qcor::CompositeInstruction> parent_kernel;

  // Default, submit this kernel, if parent is given
  // turn this to false
  bool is_callable = true;

86
  // Turn off destructor execution, useful for
87
  // qcor developers, not to be used by clients / programmers
88
  bool disable_destructor = false;
89

90
 public:
91
  // Flag to indicate we only want to
92
93
  // run the pass manager and not execute
  bool optimize_only = false;
Nguyen, Thien Minh's avatar
Nguyen, Thien Minh committed
94
  QrtType runtime_env = QrtType::NISQ;
95
  // Default constructor, takes quantum kernel function arguments
96
97
98
  QuantumKernel(Args... args) : args_tuple(std::forward_as_tuple(args...)) {
    runtime_env = (__qrt_env == "ftqc") ? QrtType::FTQC : QrtType::NISQ;
  }
99
100
101
102
103
104

  // Internal constructor, provide parent kernel, this
  // kernel now represents a nested kernel call and
  // appends to the parent kernel
  QuantumKernel(std::shared_ptr<qcor::CompositeInstruction> _parent_kernel,
                Args... args)
105
      : args_tuple(std::forward_as_tuple(args...)),
106
107
        parent_kernel(_parent_kernel),
        is_callable(false) {
108
109
    runtime_env = (__qrt_env == "ftqc") ? QrtType::FTQC : QrtType::NISQ;
  }
110

111
  // Static method for printing this kernel as a flat qasm string
112
113
  static void print_kernel(std::ostream &os, Args... args) {
    Derived derived(args...);
114
    KernelSignature<Args...> callable(derived);
115
    return internal::print_kernel<Args...>(callable, os, args...);
116
  }
117

118
  static void print_kernel(Args... args) { print_kernel(std::cout, args...); }
119

120
121
122
  // Static method to query how many instructions are in this kernel
  static std::size_t n_instructions(Args... args) {
    Derived derived(args...);
123
    KernelSignature<Args...> callable(derived);
124
    return internal::n_instructions<Args...>(callable, args...);
125
126
  }

127
  // Create the Adjoint of this quantum kernel
128
129
  static void adjoint(std::shared_ptr<CompositeInstruction> parent_kernel,
                      Args... args) {
130
    Derived derived(args...);
131
    KernelSignature<Args...> callable(derived);
132
    return internal::apply_adjoint<Args...>(parent_kernel, callable, args...);
133
134
  }

135
136
  // Create the controlled version of this quantum kernel
  static void ctrl(std::shared_ptr<CompositeInstruction> parent_kernel,
137
                   const std::vector<int> &ctrlIdx, Args... args) {
Nguyen, Thien Minh's avatar
Nguyen, Thien Minh committed
138
139
    std::vector<qubit> ctrl_qubit_vec;
    for (int i = 0; i < ctrlIdx.size(); i++) {
Nguyen, Thien Minh's avatar
Nguyen, Thien Minh committed
140
      ctrl_qubit_vec.push_back({"q", static_cast<size_t>(ctrlIdx[i]), nullptr});
141
    }
Nguyen, Thien Minh's avatar
Nguyen, Thien Minh committed
142
    ctrl(parent_kernel, ctrl_qubit_vec, args...);
143
144
145
146
147
  }

  // Single-qubit overload
  static void ctrl(std::shared_ptr<CompositeInstruction> parent_kernel,
                   int ctrlIdx, Args... args) {
148
    ctrl(parent_kernel, std::vector<int>{ctrlIdx}, args...);
149
150
  }

151
152
153
154
155
156
157
158
  static void ctrl(std::shared_ptr<CompositeInstruction> parent_kernel,
                   qreg ctrl_qbits, Args... args) {
    std::vector<qubit> ctrl_qubit_vec;
    for (int i = 0; i < ctrl_qbits.size(); i++)
      ctrl_qubit_vec.push_back(ctrl_qbits[i]);

    ctrl(parent_kernel, ctrl_qubit_vec, args...);
  }
159

160
161
162
163
  static void ctrl(std::shared_ptr<CompositeInstruction> parent_kernel,
                   const std::vector<qubit> &ctrl_qbits, Args... args) {
    // instantiate and don't let it call the destructor
    Derived derived(args...);
164
    KernelSignature<Args...> callable(derived);
165
166
    internal::apply_control<Args...>(parent_kernel, ctrl_qbits, callable,
                                     args...);
167
  }
168

Mccaskey, Alex's avatar
Mccaskey, Alex committed
169
170
171
  // Create the controlled version of this quantum kernel
  static void ctrl(std::shared_ptr<CompositeInstruction> parent_kernel,
                   qubit ctrl_qbit, Args... args) {
172
    ctrl(parent_kernel, std::vector<qubit>{ctrl_qbit}, args...);
Mccaskey, Alex's avatar
Mccaskey, Alex committed
173
  }
174
175
176

  static Eigen::MatrixXcd as_unitary_matrix(Args... args) {
    Derived derived(args...);
177
    KernelSignature<Args...> callable(derived);
178
    return internal::as_unitary_matrix<Args...>(callable, args...);
179
180
  }

181
  static double observe(Observable &obs, Args... args) {
Mccaskey, Alex's avatar
Mccaskey, Alex committed
182
    Derived derived(args...);
183
    KernelSignature<Args...> callable(derived);
184
    return internal::observe<Args...>(obs, callable, args...);
Mccaskey, Alex's avatar
Mccaskey, Alex committed
185
186
187
  }

  static double observe(std::shared_ptr<Observable> obs, Args... args) {
188
    return observe(*obs, args...);
Mccaskey, Alex's avatar
Mccaskey, Alex committed
189
  }
190

191
192
193
194
195
196
197
  // Simple autograd support for kernel with simple type: double or
  // vector<double>. Other signatures must provide a translator...
  static double autograd(Observable &obs, std::vector<double> &dx, qreg q,
                         double x) {
    std::function<std::shared_ptr<xacc::CompositeInstruction>(
        std::vector<double>)>
        kernel_eval = [q](std::vector<double> x) {
198
199
          auto tempKernel =
              qcor::__internal__::create_composite("__temp__autograd__");
200
          Derived derived(q, x[0]);
201
202
203
204
          derived.disable_destructor = true;
          derived(q, x[0]);
          tempKernel->addInstructions(derived.parent_kernel->getInstructions());
          return tempKernel;
205
206
207
208
209
210
211
212
213
        };

    auto gradiend_method = qcor::__internal__::get_gradient_method(
        qcor::__internal__::DEFAULT_GRADIENT_METHOD, kernel_eval, obs);
    const double cost_val = observe(obs, q, x);
    dx = (*gradiend_method)({x}, cost_val);
    return cost_val;
  }

214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
  static double autograd(Observable &obs, std::vector<double> &dx, qreg q,
                         std::vector<double> x) {
    std::function<std::shared_ptr<xacc::CompositeInstruction>(
        std::vector<double>)>
        kernel_eval = [q](std::vector<double> x) {
          auto tempKernel =
              qcor::__internal__::create_composite("__temp__autograd__");
          Derived derived(q, x);
          derived.disable_destructor = true;
          derived(q, x);
          tempKernel->addInstructions(derived.parent_kernel->getInstructions());
          return tempKernel;
        };

    auto gradiend_method = qcor::__internal__::get_gradient_method(
        qcor::__internal__::DEFAULT_GRADIENT_METHOD, kernel_eval, obs);
    const double cost_val = observe(obs, q, x);
    dx = (*gradiend_method)(x, cost_val);
    return cost_val;
  }

235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
  static double autograd(Observable &obs, std::vector<double> &dx,
                         std::vector<double> x,
                         ArgsTranslator<Args...> args_translator) {
    std::function<std::shared_ptr<xacc::CompositeInstruction>(
        std::vector<double>)>
        kernel_eval = [&](std::vector<double> x_vec) {
          auto eval_lambda = [&](Args... args) {
            auto tempKernel =
                qcor::__internal__::create_composite("__temp__autograd__");
            Derived derived(args...);
            derived.disable_destructor = true;
            derived(args...);
            tempKernel->addInstructions(
                derived.parent_kernel->getInstructions());
            return tempKernel;
          };
          auto args_tuple = args_translator(x_vec);
          return std::apply(eval_lambda, args_tuple);
        };

    auto gradiend_method = qcor::__internal__::get_gradient_method(
        qcor::__internal__::DEFAULT_GRADIENT_METHOD, kernel_eval, obs);

    auto kernel_observe = [&](Args... args) { return observe(obs, args...); };

    auto args_tuple = args_translator(x);
    const double cost_val = std::apply(kernel_observe, args_tuple);
    dx = (*gradiend_method)(x, cost_val);
    return cost_val;
  }

266
267
  static std::string openqasm(Args... args) {
    Derived derived(args...);
268
    KernelSignature<Args...> callable(derived);
Mccaskey, Alex's avatar
Mccaskey, Alex committed
269
    return internal::openqasm<Args...>(callable, args...);
270
271
  }

272
  virtual ~QuantumKernel() {}
273

274
275
  template <typename... ArgTypes>
  friend class KernelSignature;
276
};
277

278
279
// We use the following to enable ctrl operations on our single
// qubit gates, X::ctrl(), Z::ctrl(), H::ctrl(), etc....
280
281
282
template <typename Derived>
using OneQubitKernel = QuantumKernel<Derived, qubit>;

283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
#define ONE_QUBIT_KERNEL_CTRL_ENABLER(CLASSNAME, QRTNAME)                 \
  class CLASSNAME : public OneQubitKernel<class CLASSNAME> {              \
   public:                                                                \
    CLASSNAME(qubit q) : OneQubitKernel<CLASSNAME>(q) {}                  \
    CLASSNAME(std::shared_ptr<qcor::CompositeInstruction> _parent_kernel, \
              qubit q)                                                    \
        : OneQubitKernel<CLASSNAME>(_parent_kernel, q) {                  \
      throw std::runtime_error("you cannot call this.");                  \
    }                                                                     \
    void operator()(qubit q) {                                            \
      parent_kernel = qcor::__internal__::create_composite(               \
          "__tmp_one_qubit_ctrl_enabler");                                \
      quantum::set_current_program(parent_kernel);                        \
      if (runtime_env == QrtType::FTQC) {                                 \
        quantum::set_current_buffer(q.results());                         \
      }                                                                   \
      ::quantum::QRTNAME(q);                                              \
      return;                                                             \
    }                                                                     \
    virtual ~CLASSNAME() {}                                               \
303
  };
304
305
306
307
308
309
310
311
312
313

ONE_QUBIT_KERNEL_CTRL_ENABLER(X, x)
ONE_QUBIT_KERNEL_CTRL_ENABLER(Y, y)
ONE_QUBIT_KERNEL_CTRL_ENABLER(Z, z)
ONE_QUBIT_KERNEL_CTRL_ENABLER(H, h)
ONE_QUBIT_KERNEL_CTRL_ENABLER(T, t)
ONE_QUBIT_KERNEL_CTRL_ENABLER(Tdg, tdg)
ONE_QUBIT_KERNEL_CTRL_ENABLER(S, s)
ONE_QUBIT_KERNEL_CTRL_ENABLER(Sdg, sdg)

314
315
316
317
318
319
320
321
322
323
324
325
326
// The following is a first pass at enabling qcor
// quantum lambdas. The goal is to mimic lambda functionality
// via our QJIT infrastructure. The lambda class takes
// as input a lambda of desired kernel signature calling
// a specific macro which expands to return the function body
// expression as a string, which we use with QJIT jit_compile.
// The lambda class is templated on the types of any capture variables
// the programmer would like to specify, and takes a second constructor
// argument indicating the variable names of all kernel arguments and
// capture variables. Finally, all capture variables must be passed to the
// trailing variadic argument for the lambda class constructor. Once
// instantiated lambda invocation looks just like kernel invocation.

327
328
329
template <typename... CaptureArgs>
class _qpu_lambda {
 private:
330
  // Private inner class for getting the type
331
  // of a capture variable as a string at runtime
332
  class TupleToTypeArgString {
333
   protected:
334
335
336
337
    std::string &tmp;
    std::vector<std::string> var_names;
    int counter = 0;

338
339
    template <class T>
    std::string type_name() {
340
341
342
343
344
345
346
347
      typedef typename std::remove_reference<T>::type TR;
      std::unique_ptr<char, void (*)(void *)> own(
          abi::__cxa_demangle(typeid(TR).name(), nullptr, nullptr, nullptr),
          std::free);
      std::string r = own != nullptr ? own.get() : typeid(TR).name();
      return r;
    }

348
   public:
349
350
351
    TupleToTypeArgString(std::string &t) : tmp(t) {}
    TupleToTypeArgString(std::string &t, std::vector<std::string> &_var_names)
        : tmp(t), var_names(_var_names) {}
352
353
    template <typename T>
    void operator()(T &t) {
354
      tmp += type_name<decltype(t)>() + "& " +
355
356
357
358
359
360
361
             (var_names.empty() ? "arg_" + std::to_string(counter)
                                : var_names[counter]) +
             ",";
      counter++;
    }
  };

362
363
  // Kernel lambda source string, has arg structure and body
  std::string &src_str;
364

365
366
  // Capture variable names
  std::string &capture_var_names;
367

368
  // By-ref capture variables, stored in tuple
369
  std::tuple<CaptureArgs &...> capture_vars;
370

371
372
373
374
375
376
377
378
379
380
381
382
  // Optional capture *by-value* variables:
  // We don't want to make unnecessary copies of capture variables
  // unless explicitly requested ("[=]").
  // Also, some types may not be copy-constructable...
  // Notes:
  // (1) we must copy at the lambda declaration point (i.e. _qpu_lambda
  // constructor)
  // (2) our JIT code chain is constructed using the by-reference convention,
  // just need to handle by-value (copy) at the top-level (i.e., in this tuple
  // storage)
  std::optional<std::tuple<CaptureArgs...>> optional_copy_capture_vars;

383
384
  // Quantum Just-in-Time Compiler :)
  QJIT qjit;
385

386
 public:
387
388
389
390
  // Variational information, i.e. is this lambda compatible with VQE
  // e.g. single double or single vector double input.
  enum class Variational_Arg_Type { Double, Vec_Double, None };
  Variational_Arg_Type var_type = Variational_Arg_Type::None;
391

392
393
394
  // Constructor, capture vars should be deduced without
  // specifying them since we're using C++17
  _qpu_lambda(std::string &&ff, std::string &&_capture_var_names,
395
396
397
              CaptureArgs &..._capture_vars)
      : src_str(ff),
        capture_var_names(_capture_var_names),
398
399
400
401
402
        capture_vars(std::forward_as_tuple(_capture_vars...)) {
    // Get the original args list
    auto first = src_str.find_first_of("(");
    auto last = src_str.find_first_of(")");
    auto tt = src_str.substr(first, last - first + 1);
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
    // Parse the argument list
    const auto arg_type_and_names = [](const std::string &arg_string_decl)
        -> std::vector<std::pair<std::string, std::string>> {
      // std::cout << "HOWDY:" << arg_string_decl << "\n";
      std::vector<std::pair<std::string, std::string>> result;
      const auto args_string =
          arg_string_decl.substr(1, arg_string_decl.size() - 2);
      std::stack<char> grouping_chars;
      std::string type_name;
      std::string var_name;
      std::string temp;
      // std::cout << args_string << "\n";
      for (int i = 0; i < args_string.size(); ++i) {
        if (isspace(args_string[i]) && grouping_chars.empty()) {
          type_name = temp;
          temp.clear();
        } else if (args_string[i] == ',') {
          var_name = temp;
          if (var_name[0] == '&') {
            type_name += "&";
            var_name = var_name.substr(1);
          }
          result.emplace_back(std::make_pair(type_name, var_name));
          type_name.clear();
          var_name.clear();
          temp.clear();
        } else {
          temp.push_back(args_string[i]);
        }

        if (args_string[i] == '<') {
          grouping_chars.push(args_string[i]);
        }
        if (args_string[i] == '>') {
          assert(grouping_chars.top() == '<');
          grouping_chars.pop();
        }
      }

      // Last one:
      var_name = temp;
      if (var_name[0] == '&') {
        type_name += "&";
        var_name = var_name.substr(1);
      }
      result.emplace_back(std::make_pair(type_name, var_name));
      return result;
    }(tt);

452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
    // Determine if this lambda has a VQE-compatible type:
    // QReg then variational params.
    if (arg_type_and_names.size() == 2) {
      const auto trim_space = [](std::string &stripString) {
        while (!stripString.empty() && std::isspace(*stripString.begin())) {
          stripString.erase(stripString.begin());
        }

        while (!stripString.empty() && std::isspace(*stripString.rbegin())) {
          stripString.erase(stripString.length() - 1);
        }
      };

      auto type_name = arg_type_and_names[1].first;
      trim_space(type_name);
      // Use a relax search to handle using namespace std...
      // FIXME: this is quite hacky.
      if (type_name.find("vector<double>") != std::string::npos) {
        var_type = Variational_Arg_Type::Vec_Double;
      } else if (type_name == "double") {
        var_type = Variational_Arg_Type::Double;
      }
    }

476
477
478
479
480
481
482
483
    // Map simple type to its reference type so that the
    // we can use consistent type-forwarding
    // when casting the JIT raw function pointer.
    // Currently, looks like only these simple types are having problem
    // with perfect type forwarding.
    // i.e. by-value arguments of these types are incompatible with a by-ref
    // casted function.
    static const std::unordered_map<std::string, std::string>
484
        FORWARD_TYPE_CONVERSION_MAP{{"int", "int&"}, {"double", "double&"}};
485
    std::vector<std::pair<std::string, std::string>> forward_types;
486
487
    // Replicate by-value by create copies and restore the variables.
    std::vector<std::string> byval_casted_arg_names;
488
489
490
491
492
493
    for (const auto &[type, name] : arg_type_and_names) {
      // std::cout << type << " --> " << name << "\n";
      if (FORWARD_TYPE_CONVERSION_MAP.find(type) !=
          FORWARD_TYPE_CONVERSION_MAP.end()) {
        auto iter = FORWARD_TYPE_CONVERSION_MAP.find(type);
        forward_types.emplace_back(std::make_pair(iter->second, name));
494
        byval_casted_arg_names.emplace_back(name);
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
      } else {
        forward_types.emplace_back(std::make_pair(type, name));
      }
    }

    // std::cout << "After\n";
    // Construct the new arg signature clause:
    std::string arg_clause_new;
    arg_clause_new.push_back('(');
    for (const auto &[type, name] : forward_types) {
      arg_clause_new.append(type);
      arg_clause_new.push_back(' ');
      arg_clause_new.append(name);
      arg_clause_new.push_back(',');
      // std::cout << type << " --> " << name << "\n";
    }
    arg_clause_new.pop_back();
512

513
514
515
516
517
518
519
520
521
    // Get the capture type:
    // By default "[]", pass by reference.
    // [=]: pass by value
    // [&]: pass by reference (same as default)
    const auto first_square_bracket = src_str.find_first_of("[");
    const auto last_square_bracket = src_str.find_first_of("]");
    const auto capture_type = src_str.substr(
        first_square_bracket, last_square_bracket - first_square_bracket + 1);
    if (!capture_type.empty() && capture_type == "[=]") {
522
523
524
525
526
527
528
529
530
531
      // We must check this at compile-time to prevent the compiler from
      // *attempting* to compile this code path even when by-val capture is not
      // in use. The common scenario is a qpu_lambda captures other qpu_lambda.
      // Copying of qpu_lambda by value is prohibitied.
      // We'll report a runtime error for this case.
      if constexpr (std::conjunction_v<
                        std::is_copy_assignable<CaptureArgs>...>) {
        // Store capture vars (by-value)
        optional_copy_capture_vars = std::forward_as_tuple(_capture_vars...);
      } else {
532
533
534
        error(
            "Capture variable type is non-copyable. Cannot use capture by "
            "value.");
535
      }
536
537
    }

538
    // Need to append capture vars to this arg signature
539
    std::string capture_preamble = "";
540
541
542
543
544
545
546
    const auto replaceVarName = [](std::string &str, const std::string &from,
                                   const std::string &to) {
      size_t start_pos = str.find(from);
      if (start_pos != std::string::npos) {
        str.replace(start_pos, from.length(), to);
      }
    };
547
548
549
550
551
    if (!capture_var_names.empty()) {
      std::string args_string = "";
      TupleToTypeArgString co(args_string);
      __internal__::tuple_for_each(capture_vars, co);
      args_string = "," + args_string.substr(0, args_string.length() - 1);
552
553
554
555
556
557

      // Replace the generic argument names (tuple foreach)
      // with the actual capture var name.
      // We need to do this so that the SyntaxHandler can properly detect if
      // a capture var is a Kernel-like ==> add the list of in-flight kernels
      // and add parent_kernel to the invocation.
558
559
      for (auto [i, capture_name] :
           qcor::enumerate(xacc::split(capture_var_names, ','))) {
560
561
        const auto old_name = "arg_" + std::to_string(i);
        replaceVarName(args_string, old_name, capture_name);
562
      }
563
564

      tt.insert(last - capture_type.size(), args_string);
565
      arg_clause_new.append(args_string);
566
567
    }

568
569
570
571
    // Extract the function body
    first = src_str.find_first_of("{");
    last = src_str.find_last_of("}");
    auto rr = src_str.substr(first, last - first + 1);
572
573
    arg_clause_new.push_back(')');
    // std::cout << "New signature: " << arg_clause_new << "\n";
574
    // Reconstruct with new args signature and
575
576
    // existing function body
    std::stringstream ss;
577
    ss << "__qpu__ void foo" << arg_clause_new << rr;
578

579
    // Get as a string, and insert capture
580
581
582
    // preamble if necessary
    auto jit_src = ss.str();
    first = jit_src.find_first_of("{");
583
584
    if (!capture_var_names.empty()) jit_src.insert(first + 1, capture_preamble);

585
586
    if (!byval_casted_arg_names.empty()) {
      std::stringstream cache_string, restore_string;
587
588
589
      for (const auto &var : byval_casted_arg_names) {
        cache_string << "auto __" << var << "__cached__ = " << var << ";\n";
        restore_string << var << " = __" << var << "__cached__;\n";
590
591
592
593
594
595
      }
      const auto begin = jit_src.find_first_of("{");
      jit_src.insert(begin + 1, cache_string.str());
      const auto end = jit_src.find_last_of("}");
      jit_src.insert(end, restore_string.str());
    }
596
597
598
599
600

    // std::cout << "JITSRC:\n" << jit_src << "\n";
    // JIT Compile, storing the function pointers
    qjit.jit_compile(jit_src);
  }
601

602
603
  template <typename... FunctionArgs>
  void eval_with_parent(std::shared_ptr<CompositeInstruction> parent,
604
                        FunctionArgs &&...args) {
605
    this->operator()(parent, std::forward<FunctionArgs>(args)...);
606
607
608
609
  }

  template <typename... FunctionArgs>
  void operator()(std::shared_ptr<CompositeInstruction> parent,
610
                  FunctionArgs &&...args) {
611
    // Map the function args to a tuple
612
613
614
615
616
617
    auto kernel_args_tuple = std::forward_as_tuple(args...);
    if (!optional_copy_capture_vars.has_value()) {
      // By-ref:
      // Merge the function args and the capture vars and execute
      auto final_args_tuple = std::tuple_cat(kernel_args_tuple, capture_vars);
      std::apply(
618
          [&](auto &&...args) {
619
            qjit.invoke_with_parent_forwarding("foo", parent, args...);
620
621
          },
          final_args_tuple);
622

623
624
    } else if constexpr (std::conjunction_v<
                             std::is_copy_assignable<CaptureArgs>...>) {
625
626
      // constexpr compile-time check to prevent compiler from looking at this
      // code path if the capture variable is non-copyable, e.g. qpu_lambda.
627
628
629
630
      // By-value:
      auto final_args_tuple =
          std::tuple_cat(kernel_args_tuple, optional_copy_capture_vars.value());
      std::apply(
631
          [&](auto &&...args) {
632
            qjit.invoke_with_parent_forwarding("foo", parent, args...);
633
634
635
          },
          final_args_tuple);
    }
636
637
  }

638
639
  template <typename... FunctionArgs>
  void operator()(FunctionArgs &&...args) {
640
    // Map the function args to a tuple
641
642
643
644
645
    auto kernel_args_tuple = std::forward_as_tuple(args...);
    if (!optional_copy_capture_vars.has_value()) {
      // By-ref
      // Merge the function args and the capture vars and execute
      auto final_args_tuple = std::tuple_cat(kernel_args_tuple, capture_vars);
646
      std::apply(
647
          [&](auto &&...args) { qjit.invoke_forwarding("foo", args...); },
648
          final_args_tuple);
649
650
    } else if constexpr (std::conjunction_v<
                             std::is_copy_assignable<CaptureArgs>...>) {
651
652
653
      // By-value
      auto final_args_tuple =
          std::tuple_cat(kernel_args_tuple, optional_copy_capture_vars.value());
654
      std::apply(
655
          [&](auto &&...args) { qjit.invoke_forwarding("foo", args...); },
656
          final_args_tuple);
657
    }
658
  }
659

660
  template <typename... FunctionArgs>
661
662
663
664
  double observe(std::shared_ptr<Observable> obs, FunctionArgs... args) {
    return observe(*obs.get(), args...);
  }

665
666
  template <typename... FunctionArgs>
  double observe(Observable &obs, FunctionArgs... args) {
667
    KernelSignature<FunctionArgs...> callable(*this);
Mccaskey, Alex's avatar
Mccaskey, Alex committed
668
    return internal::observe<FunctionArgs...>(obs, callable, args...);
669
  }
Nguyen, Thien Minh's avatar
Nguyen, Thien Minh committed
670
671
672
673
674

  template <typename... FunctionArgs>
  void ctrl(std::shared_ptr<CompositeInstruction> ir,
            const std::vector<qubit> &ctrl_qbits, FunctionArgs... args) {
    KernelSignature<FunctionArgs...> callable(*this);
Mccaskey, Alex's avatar
Mccaskey, Alex committed
675
    internal::apply_control<FunctionArgs...>(ir, ctrl_qbits, callable, args...);
Nguyen, Thien Minh's avatar
Nguyen, Thien Minh committed
676
677
678
679
680
681
682
  }

  template <typename... FunctionArgs>
  void ctrl(std::shared_ptr<CompositeInstruction> ir,
            const std::vector<int> &ctrl_idxs, FunctionArgs... args) {
    std::vector<qubit> ctrl_qubit_vec;
    for (int i = 0; i < ctrl_idxs.size(); i++) {
683
684
      ctrl_qubit_vec.push_back(
          {"q", static_cast<size_t>(ctrl_idxs[i]), nullptr});
Nguyen, Thien Minh's avatar
Nguyen, Thien Minh committed
685
686
687
688
689
690
691
    }
    ctrl(ir, ctrl_qubit_vec, args...);
  }

  template <typename... FunctionArgs>
  void ctrl(std::shared_ptr<CompositeInstruction> ir, int ctrl_qbit,
            FunctionArgs... args) {
692
    ctrl(ir, std::vector<int>{ctrl_qbit}, args...);
Nguyen, Thien Minh's avatar
Nguyen, Thien Minh committed
693
694
695
696
697
  }

  template <typename... FunctionArgs>
  void ctrl(std::shared_ptr<CompositeInstruction> ir, qubit ctrl_qbit,
            FunctionArgs... args) {
698
    ctrl(ir, std::vector<qubit>{ctrl_qbit}, args...);
Nguyen, Thien Minh's avatar
Nguyen, Thien Minh committed
699
700
701
702
703
704
705
706
707
708
709
  }

  template <typename... FunctionArgs>
  void ctrl(std::shared_ptr<CompositeInstruction> ir, qreg ctrl_qbits,
            FunctionArgs... args) {
    std::vector<qubit> ctrl_qubit_vec;
    for (int i = 0; i < ctrl_qbits.size(); i++) {
      ctrl_qubit_vec.push_back(ctrl_qbits[i]);
    }
    ctrl(ir, ctrl_qubit_vec, args...);
  }
710
711
712
713
714

  template <typename... FunctionArgs>
  void adjoint(std::shared_ptr<CompositeInstruction> parent_kernel,
               FunctionArgs... args) {
    KernelSignature<FunctionArgs...> callable(*this);
Mccaskey, Alex's avatar
Mccaskey, Alex committed
715
    return internal::apply_adjoint<FunctionArgs...>(parent_kernel, callable, args...);
716
717
718
719
720
  }

  template <typename... FunctionArgs>
  void print_kernel(std::ostream &os, FunctionArgs... args) {
    KernelSignature<FunctionArgs...> callable(*this);
721
    return internal::print_kernel<FunctionArgs...>(callable, os, args...);
722
723
  }

724
725
  template <typename... FunctionArgs>
  void print_kernel(FunctionArgs... args) {
726
727
    print_kernel(std::cout, args...);
  }
728
729
730
731

  template <typename... FunctionArgs>
  std::size_t n_instructions(FunctionArgs... args) {
    KernelSignature<FunctionArgs...> callable(*this);
Mccaskey, Alex's avatar
Mccaskey, Alex committed
732
    return internal::n_instructions<FunctionArgs...>(callable, args...);
733
734
735
736
737
  }

  template <typename... FunctionArgs>
  Eigen::MatrixXcd as_unitary_matrix(FunctionArgs... args) {
    KernelSignature<FunctionArgs...> callable(*this);
Mccaskey, Alex's avatar
Mccaskey, Alex committed
738
    return internal::as_unitary_matrix<FunctionArgs...>(callable, args...);
739
740
741
742
743
  }

  template <typename... FunctionArgs>
  std::string openqasm(FunctionArgs... args) {
    KernelSignature<FunctionArgs...> callable(*this);
Mccaskey, Alex's avatar
Mccaskey, Alex committed
744
    return internal::openqasm<FunctionArgs...>(callable, args...);
745
  }
746
747
};

748
#define qpu_lambda(EXPR, ...) _qpu_lambda(#EXPR, #__VA_ARGS__, ##__VA_ARGS__)
749

750
751
752
753
template <typename... Args>
using callable_function_ptr =
    void (*)(std::shared_ptr<xacc::CompositeInstruction>, Args...);

754
755
756
template <typename... Args>
class KernelSignature {
 private:
757
  callable_function_ptr<Args...> *readOnly = 0;
758
  callable_function_ptr<Args...> &function_pointer;
759
760
  std::function<void(std::shared_ptr<xacc::CompositeInstruction>, Args...)>
      lambda_func;
761
  std::shared_ptr<xacc::CompositeInstruction> parent_kernel;
762

763
 public:
764
765
  // Here we set function_pointer to null and instead
  // only use lambda_func. If we set lambda_func, function_pointer
766
767
  // will never be used, so we should be good.
  template <typename... CaptureArgs>
768
  KernelSignature(_qpu_lambda<CaptureArgs...> &lambda)
769
770
771
772
      : function_pointer(*readOnly),
        lambda_func([&](std::shared_ptr<xacc::CompositeInstruction> pp,
                        Args... a) { lambda(pp, a...); }) {}

773
  KernelSignature(callable_function_ptr<Args...> &&f) : function_pointer(f) {}
774

775
776
777
778
779
780
  // CTor from a QCOR QuantumKernel instance:
  template <
      typename KernelType,
      std::enable_if_t<
          std::is_base_of_v<QuantumKernel<KernelType, Args...>, KernelType>,
          bool> = true>
781
782
783
784
  KernelSignature(KernelType &kernel)
      : function_pointer(*readOnly),
        lambda_func(
            [&](std::shared_ptr<xacc::CompositeInstruction> pp, Args... a) {
785
786
              // Expand the kernel and append to the *externally-provided*
              // parent kernel as a KernelSignature one.
787
788
              kernel.disable_destructor = true;
              kernel(a...);
789
              pp->addInstructions(kernel.parent_kernel->getInstructions());
790
791
792
            }) {}

  // Ctor from raw void* function pointer.
793
794
795
796
797
798
799
  // IMPORTANT: since function_pointer is kept as a *reference*,
  // we must keep a reference to the original f_ptr void* as well.
  KernelSignature(void *&f_ptr)
      : function_pointer((callable_function_ptr<Args...> &)f_ptr) {}

  void operator()(std::shared_ptr<xacc::CompositeInstruction> ir,
                  Args... args) {
800
801
802
803
804
    if (lambda_func) {
      lambda_func(ir, args...);
      return;
    }

805
806
    function_pointer(ir, args...);
  }
807

808
  void operator()(Args... args) { operator()(parent_kernel, args...); }
809
810
811
812
813

  void set_parent_kernel(std::shared_ptr<xacc::CompositeInstruction> ir) {
    parent_kernel = ir;
  }

Nguyen, Thien Minh's avatar
Nguyen, Thien Minh committed
814
815
  void ctrl(std::shared_ptr<xacc::CompositeInstruction> ir,
            const std::vector<qubit> &ctrl_qbits, Args... args) {
Mccaskey, Alex's avatar
Mccaskey, Alex committed
816
    internal::apply_control<Args...>(ir, ctrl_qbits, *this, args...);
Nguyen, Thien Minh's avatar
Nguyen, Thien Minh committed
817
  }
818

Nguyen, Thien Minh's avatar
Nguyen, Thien Minh committed
819
820
821
822
  void ctrl(std::shared_ptr<xacc::CompositeInstruction> ir,
            const std::vector<int> ctrl_idxs, Args... args) {
    std::vector<qubit> ctrl_qubit_vec;
    for (int i = 0; i < ctrl_idxs.size(); i++) {
823
824
      ctrl_qubit_vec.push_back(
          {"q", static_cast<size_t>(ctrl_idxs[i]), nullptr});
825
    }
Nguyen, Thien Minh's avatar
Nguyen, Thien Minh committed
826
827
828
829
    ctrl(ir, ctrl_qubit_vec, args...);
  }
  void ctrl(std::shared_ptr<xacc::CompositeInstruction> ir, int ctrl_qbit,
            Args... args) {
830
    ctrl(ir, std::vector<int>{ctrl_qbit}, args...);
831
832
833
834
  }

  void ctrl(std::shared_ptr<xacc::CompositeInstruction> ir, qubit ctrl_qbit,
            Args... args) {
835
    ctrl(ir, std::vector<qubit>{ctrl_qbit}, args...);
Nguyen, Thien Minh's avatar
Nguyen, Thien Minh committed
836
837
838
839
840
841
842
843
844
  }

  void ctrl(std::shared_ptr<xacc::CompositeInstruction> ir, qreg ctrl_qbits,
            Args... args) {
    std::vector<qubit> ctrl_qubit_vec;
    for (int i = 0; i < ctrl_qbits.size(); i++) {
      ctrl_qubit_vec.push_back(ctrl_qbits[i]);
    }
    ctrl(ir, ctrl_qubit_vec, args...);
845
846
847
  }

  void adjoint(std::shared_ptr<CompositeInstruction> ir, Args... args) {
848
    internal::apply_adjoint<Args...>(ir, *this, args...);
849
  }
850

851
  void print_kernel(std::ostream &os, Args... args) {
852
    return internal::print_kernel<Args...>(*this, os, args...);
853
854
855
  }

  void print_kernel(Args... args) { print_kernel(std::cout, args...); }
856

857
  std::size_t n_instructions(Args... args) {
858
    return internal::n_instructions<Args...>(*this, args...);
859
  }
860

861
  Eigen::MatrixXcd as_unitary_matrix(Args... args) {
862
    return internal::as_unitary_matrix<Args...>(*this, args...);
863
  }
864

865
  std::string openqasm(Args... args) {
866
    return internal::openqasm<Args...>(*this, args...);
867
  }
868
869
870
871
872
873

  double observe(std::shared_ptr<Observable> obs, Args... args) {
    return observe(*obs.get(), args...);
  }

  double observe(Observable &obs, Args... args) {
874
    return internal::observe<Args...>(obs, *this, args...);
875
  }
876
877
};

878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
// Templated helper to attach parent_kernel to any
// KernelSignature arguments even nested in a std::vector<KernelSignature>
// The reason is that the Token Collector relies on a list of kernel names
// in the translation unit to attach parent_kernel to the operator() call.
// For KernelSignature provided in a container, tracking these at the
// TokenCollector level is error-prone (e.g. need to track any array accesses).
// Hence, we iterate over all kernel arguments and attach the parent_kernel
// to any KernelSignature argument at the top of each kernel's operator() call
// in a type-safe manner.

// Last arg
inline void init_kernel_signature_args_impl(
    std::shared_ptr<xacc::CompositeInstruction> ir) {}
template <typename T, typename... ArgsType>
void init_kernel_signature_args_impl(
893
    std::shared_ptr<xacc::CompositeInstruction> ir, T &t, ArgsType &...Args);
894
895
896
897
898

// Main function: to be added by the token collector at the beginning
// of each kernel operator().
template <typename... T>
void init_kernel_signature_args(std::shared_ptr<xacc::CompositeInstruction> ir,
899
                                T &...multi_inputs) {
900
901
902
903
904
905
906
  init_kernel_signature_args_impl(ir, multi_inputs...);
}

// Base case: generic type T,
// just ignore, proceed to the next arg.
template <typename T, typename... ArgsType>
void init_kernel_signature_args_impl(
907
    std::shared_ptr<xacc::CompositeInstruction> ir, T &t, ArgsType &...Args) {
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
  init_kernel_signature_args(ir, Args...);
}

// Special case: this is a vector:
// iterate over all elements.
template <typename T, typename... ArgsType>
void init_kernel_signature_args_impl(
    std::shared_ptr<xacc::CompositeInstruction> ir, std::vector<T> &vec_arg,
    ArgsType... Args) {
  for (auto &el : vec_arg) {
    // Iterate the vector elements.
    init_kernel_signature_args_impl(ir, el);
  }
  // Proceed with the rest.
  init_kernel_signature_args(ir, Args...);
}

// Handle KernelSignature arg => set the parent kernel.
template <typename... ArgsType>
void init_kernel_signature_args_impl(
    std::shared_ptr<xacc::CompositeInstruction> ir,
    KernelSignature<ArgsType...> &kernel_signature) {
  kernel_signature.set_parent_kernel(ir);
}

933
934
935
936
937
938
939
940
941
namespace internal {
// KernelSignature is the base of all kernel-like objects
// and we use it to implement kernel modifiers && utilities.
// Make this a utility function so that implicit conversion to KernelSignature
// occurs automatically.
template <typename... Args>
void apply_control(std::shared_ptr<CompositeInstruction> parent_kernel,
                   const std::vector<qubit> &ctrl_qbits,
                   KernelSignature<Args...> &kernelCallable, Args... args) {
942
  std::vector<std::pair<std::string, size_t>> ctrl_qubits;
943
  for (const auto &qb : ctrl_qbits) {
944
    ctrl_qubits.emplace_back(std::make_pair(qb.first, qb.second));
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
  }

  // Is is in a **compute** segment?
  // i.e. doing control within the compute block itself.
  // need to by-pass the compute marking in order for the control gate to
  // work.
  const bool cached_is_compute_section =
      ::quantum::qrt_impl->isComputeSection();
  if (cached_is_compute_section) {
    ::quantum::qrt_impl->__end_mark_segment_as_compute();
  }

  // Use the controlled gate module of XACC to transform
  auto tempKernel = qcor::__internal__::create_composite("temp_control");
  kernelCallable(tempKernel, args...);

  if (cached_is_compute_section) {
    ::quantum::qrt_impl->__begin_mark_segment_as_compute();
  }

  auto ctrlKernel = qcor::__internal__::create_ctrl_u();
966
  ctrlKernel->expand({{"U", tempKernel}, {"control-idx", ctrl_qubits}});
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984

  // Mark all the *Controlled* instructions as compute segment
  // if it was in the compute_section.
  // i.e. we have bypassed the marker previously to make C-U to work,
  // now we mark all the generated instructions.
  if (cached_is_compute_section) {
    for (int instId = 0; instId < ctrlKernel->nInstructions(); ++instId) {
      ctrlKernel->getInstruction(instId)->attachMetadata(
          {{"__qcor__compute__segment__", true}});
    }
  }

  for (int instId = 0; instId < ctrlKernel->nInstructions(); ++instId) {
    parent_kernel->addInstruction(ctrlKernel->getInstruction(instId));
  }
  // Need to reset and point current program to the parent
  quantum::set_current_program(parent_kernel);
}
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000

template <typename... Args>
void apply_adjoint(std::shared_ptr<CompositeInstruction> parent_kernel,
                   KernelSignature<Args...> &kernelCallable, Args... args) {
  auto tempKernel = qcor::__internal__::create_composite("temp_adjoint");
  kernelCallable(tempKernel, args...);

  // get the instructions
  auto instructions = tempKernel->getInstructions();
  std::shared_ptr<CompositeInstruction> program = tempKernel;

  // Assert that we don't have measurement
  if (!std::all_of(
          instructions.cbegin(), instructions.cend(),
          [](const auto &inst) { return inst->name() != "Measure"; })) {
    error("Unable to create Adjoint for kernels that have Measure operations.");