Commit d1d4b092 authored by Nguyen, Thien Minh's avatar Nguyen, Thien Minh
Browse files

qpu_lambda needs to be registered to the token collector so that its...


qpu_lambda needs to be registered to the token collector so that its invocation (inside another lambda for example) has the parent_kernel attached

Signed-off-by: default avatarThien Nguyen <nguyentm@ornl.gov>
parent 000f4678
Loading
Loading
Loading
Loading
+38 −3
Original line number Diff line number Diff line
@@ -36,6 +36,8 @@ __qpu__ void run_grover(qreg q, GroverPhaseOracle oracle,
}

int main() {
  set_shots(1024);

  const int N = 3;

  // Write the oracle as a quantum lambda function
@@ -47,10 +49,43 @@ int main() {

  // Allocate some qubits
  auto q = qalloc(N);

  oracle.print_kernel(q);
  int iterations = 1;
  // Call grover given the oracle and n iterations
  run_grover(q, oracle, 1);

  run_grover(q, oracle, iterations);
  // print the histogram
  q.print();

  // Grover lambda:
  // amplification lambda
  auto amplification_lambda = qpu_lambda([](qreg q) {
    print("hey from amplification_lambda");
    compute {
      H(q);
      X(q);
    }
    action {
      auto ctrl_bits = q.head(q.size() - 1);
      auto last_qubit = q.tail();
      Z::ctrl(ctrl_bits, last_qubit);
    }
  });

  // Capture the grover lambda and iterations directly from the enclosing scope.
  auto grover_lambda = qpu_lambda([](qreg q) {
        H(q);
        for (int i = 0; i < iterations; i++) {
          oracle(q);
          amplification_lambda(q);
        }

        Measure(q);
      }, oracle, iterations, amplification_lambda);

  auto q_lambda = qalloc(N);

  std::cout << "Lamda result:\n";
  grover_lambda.print_kernel(q_lambda);
  grover_lambda(q_lambda);
  q_lambda.print();
}
 No newline at end of file
+10 −0
Original line number Diff line number Diff line
@@ -77,6 +77,16 @@ void QCORSyntaxHandler::GetReplacement(
    std::vector<std::string> program_parameters,
    std::vector<std::string> bufferNames, CachedTokens &Toks,
    llvm::raw_string_ostream &OS, bool add_het_map_ctor) {
  
  // Add any KernelSignature or qpu_lambda to the list of known kernels.
  // Note: this GetReplacement overload is called directly from QJIT
  // vs. the standard SyntaxHandler API.
  for (int i = 0; i < program_arg_types.size(); ++i) {
    if (program_arg_types[i].find("KernelSignature") != std::string::npos ||
        program_arg_types[i].find("qcor::_qpu_lambda") != std::string::npos) {
      qcor::append_kernel(program_parameters[i], {}, {});
    }
  }
  // Get the Diagnostics engine and create a few custom
  // error messgaes
  auto &diagnostics = PP.getDiagnostics();
+7 −3
Original line number Diff line number Diff line
@@ -19,10 +19,14 @@ namespace qcor {
void append_kernel(const std::string name,
                   const std::vector<std::string> &program_arg_types,
                   const std::vector<std::string> &program_parameters) {
  // Just ignored if we have tracked this kernel already.
  if (!xacc::container::contains(::quantum::kernels_in_translation_unit,
                                 name)) {
    ::quantum::kernels_in_translation_unit.push_back(name);
    ::quantum::kernel_signatures_in_translation_unit[name] =
        std::make_pair(program_arg_types, program_parameters);
  }
}

void set_verbose(bool verbose) { xacc::set_verbose(verbose); }
void info(const std::string &s) { xacc::info(s); }
+2 −2
Original line number Diff line number Diff line
@@ -67,12 +67,12 @@ class QJIT {
  template <typename... Args>
  void invoke_with_parent(const std::string &kernel_name,
                          std::shared_ptr<xacc::CompositeInstruction> parent,
                          Args... args) {
                          Args &&... args) {
    auto f_ptr = kernel_name_to_f_ptr_with_parent[kernel_name];
    void (*kernel_functor)(std::shared_ptr<xacc::CompositeInstruction>,
                           Args...) =
        (void (*)(std::shared_ptr<xacc::CompositeInstruction>, Args...))f_ptr;
    kernel_functor(parent, args...);
    kernel_functor(parent, std::forward<Args>(args)...);
  }

  int invoke_main(int argc, char **argv) {
+18 −2
Original line number Diff line number Diff line
@@ -244,8 +244,24 @@ const std::pair<std::string, std::string> QJIT::run_syntax_handler(
               arg_var[0] == "xacc::internal_compiler::qubit") {
      bufferNames.push_back(arg_var[1]);
    }

    // Handle type templated type:
    // e.g. qcor::_qpu_lambda<int const>& arg_0
    // we need to get the last one as arg name.
    if (arg_var.size() == 2) {
      arg_types.push_back(arg_var[0]);
      arg_vars.push_back(arg_var[1]);
    } else {
      // More than 2 sub-strings after split...
      // Reconstruct the full type name after the split.
      std::string arg_type = arg_var[0];
      for (int i = 1; i < arg_var.size() - 1; ++i) {
        arg_type = arg_type + " " + arg_var[i];
      }
      arg_types.push_back(arg_type);
      // Var name is the last
      arg_vars.push_back(arg_var[arg_var.size() - 1]);
    }
  }

  // second, lex the kernel_src
Loading