Commit 9c150ffe authored by Mccaskey, Alex's avatar Mccaskey, Alex
Browse files

adding CallableKernel type, enabling one to create kernels that take other...


adding CallableKernel type, enabling one to create kernels that take other kernels as a function argument. added qpe example for demo

Signed-off-by: Mccaskey, Alex's avatarAlex McCaskey <mccaskeyaj@ornl.gov>
parent 96986db7
Loading
Loading
Loading
Loading
Loading
+56 −0
Original line number Diff line number Diff line
#include <qcor_qft>

using QPEOracle = CallableKernel<qreg, int>;

__qpu__ void QuantumPhaseEstimation(qreg q, QPEOracle oracle) {
  // We have nQubits, the last one we use
  // as the state qubit, the others we use as the counting qubits
  const auto nQubits = q.size();
  const auto nCounting = nQubits - 1;
  const auto state_qubit_idx = nQubits - 1;

  // Put it in |1> eigenstate
  X(q[state_qubit_idx]);

  // Create uniform superposition
  for (auto i : range(nCounting)) {
    H(q[i]);
  }

  for (auto i : range(nCounting)) {
    const int nbCalls = 1 << i;
    for (auto j : range(nbCalls)) {
      int ctlBit = i;
      // Will be fixing this parent_kernel thing...
      oracle.ctrl(ctlBit, q, state_qubit_idx);
    }
  }

  // Run Inverse QFT, on 0:nCounting qubits
  int startIdx = 0;
  int shouldSwap = 1;
  iqft(q, startIdx, nCounting, shouldSwap);

  for (int i : range(nCounting)) {
    Measure(q[i]);
  }
}

// QPE Problem
// In this example, we demonstrate a simple QPE algorithm, i.e.
// i.e. Oracle(|State>) = exp(i*Phase)*|State>
// and we need to estimate that Phase value.
// The Oracle in this case is a T gate and the eigenstate is |1>
// i.e. T|1> = exp(i*pi/4)|1>
// We use 3 counting bits => totally 4 qubits.

// Oracle I want to consider
__qpu__ void compositeOp(qreg q, int idx) { T(q[idx]); }

int main(int argc, char **argv) {
  auto q = qalloc(4);
  // very cool, implicit conversion works here
  // so you can just pass the kernel function
  QuantumPhaseEstimation(q, compositeOp);
  q.print();
}
+9 −13
Original line number Diff line number Diff line
@@ -85,19 +85,15 @@ void QCORSyntaxHandler::GetReplacement(
  // with XACC api calls
  qcor::append_kernel(kernel_name, program_arg_types, program_parameters);

  auto new_src = qcor::run_token_collector(PP, Toks, bufferNames);
  for (int i = 0; i < program_arg_types.size(); i++) {
    if (program_arg_types[i].find("CallableKernel") != std::string::npos) {
      // we have a kernel we can call, need to add it to 
      // append_kernel call. 
      qcor::append_kernel(program_parameters[i], {}, {});
    }
  }

  //   auto random_string = [](size_t length) {
  //     auto randchar = []() -> char {
  //       const char charset[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
  //                              "abcdefghijklmnopqrstuvwxyz";
  //       const size_t max_index = (sizeof(charset) - 1);
  //       return charset[rand() % max_index];
  //     };
  //     std::string str(length, 0);
  //     std::generate_n(str.begin(), length, randchar);
  //     return str;
  //   };
  auto new_src = qcor::run_token_collector(PP, Toks, bufferNames);

  // Rewrite the original function
  OS << "void " << kernel_name << "(" << program_arg_types[0] << " "
+20 −22
Original line number Diff line number Diff line
#pragma once
#include <regex>

#include "IRProvider.hpp"
#include "qrt.hpp"
#include "xacc.hpp"
#include "xasm_singleVisitor.h"
#include <regex>

using namespace xasm;

@@ -19,8 +20,8 @@ protected:
 public:
  xasm_single_result_type result;

  antlrcpp::Any
  visitStatement(xasm_singleParser::StatementContext *context) override {
  antlrcpp::Any visitStatement(
      xasm_singleParser::StatementContext *context) override {
    // should only have 1 child, if it is qinst
    // we expect a xacc Instruction return type
    // if cinst we expect a Cinst
@@ -41,7 +42,6 @@ public:
    }

    if (xacc::container::contains(provider->getInstructions(), inst_name)) {

      // We don't really care about Instruction::bits(), qrt_mapper
      // will look for bit expressions and use those, so just set
      // everything as a string...
@@ -93,7 +93,6 @@ public:
          counter++;
        }
      } else {

        // I don't want to use xasm circuit gen any more...
        // So use it as a fallback, but first look for previous
        if (xacc::container::contains(quantum::kernels_in_translation_unit,
@@ -125,7 +124,6 @@ public:

      result.second = inst;
    } else {

      std::stringstream ss;

      if (xacc::container::contains(quantum::kernels_in_translation_unit,
@@ -155,7 +153,6 @@ public:
  }

  antlrcpp::Any visitCinst(xasm_singleParser::CinstContext *context) override {

    // Strategy here is simple, we just want to
    // preserve all classical code statements in
    // the original quantum kernel
@@ -170,7 +167,8 @@ public:
          ss << c->getText() << " ";
        }
      }
    } else if (context->getText().find("::ctrl") != std::string::npos) {
    } else if (context->getText().find("::ctrl") != std::string::npos ||
               context->getText().find(".ctrl") != std::string::npos) {
      for (auto c : context->children) {
        if (c->getText() == "(") {
          ss << c->getText() << "parent_kernel, ";
@@ -222,22 +220,22 @@ public:
    return 0;
  }

  antlrcpp::Any
  visitComment(xasm_singleParser::CommentContext *context) override {
  antlrcpp::Any visitComment(
      xasm_singleParser::CommentContext *context) override {
    return 0;
  }
  antlrcpp::Any
  visitCompare(xasm_singleParser::CompareContext *context) override {
  antlrcpp::Any visitCompare(
      xasm_singleParser::CompareContext *context) override {
    return 0;
  }

  antlrcpp::Any
  visitCpp_type(xasm_singleParser::Cpp_typeContext *context) override {
  antlrcpp::Any visitCpp_type(
      xasm_singleParser::Cpp_typeContext *context) override {
    return 0;
  }

  antlrcpp::Any
  visitExplist(xasm_singleParser::ExplistContext *context) override {
  antlrcpp::Any visitExplist(
      xasm_singleParser::ExplistContext *context) override {
    return 0;
  }

@@ -245,8 +243,8 @@ public:
    return 0;
  }

  antlrcpp::Any
  visitUnaryop(xasm_singleParser::UnaryopContext *context) override {
  antlrcpp::Any visitUnaryop(
      xasm_singleParser::UnaryopContext *context) override {
    return 0;
  }

@@ -258,8 +256,8 @@ public:
    return 0;
  }

  antlrcpp::Any
  visitString(xasm_singleParser::StringContext *context) override {
  antlrcpp::Any visitString(
      xasm_singleParser::StringContext *context) override {
    return 0;
  }
};
 No newline at end of file
+86 −6
Original line number Diff line number Diff line
#pragma once

#include "qcor_observable.hpp"
#include "qcor_utils.hpp"
#include "qrt.hpp"
#include "qcor_observable.hpp"

namespace qcor {
enum class QrtType { NISQ, FTQC };
@@ -236,8 +236,7 @@ class QuantumKernel {
    if (!std::all_of(
            instructions.cbegin(), instructions.cend(),
            [](const auto &inst) { return inst->name() != "Measure"; })) {
      error(
          "Unable to observe kernels that already have Measure operations.");
      error("Unable to observe kernels that already have Measure operations.");
    }

    xacc::internal_compiler::execute_pass_manager();
@@ -262,8 +261,7 @@ class QuantumKernel {
    if (!std::all_of(
            instructions.cbegin(), instructions.cend(),
            [](const auto &inst) { return inst->name() != "Measure"; })) {
      error(
          "Unable to observe kernels that already have Measure operations.");
      error("Unable to observe kernels that already have Measure operations.");
    }

    xacc::internal_compiler::execute_pass_manager();
@@ -285,4 +283,86 @@ class QuantumKernel {
  virtual ~QuantumKernel() {}
};

template <typename... Args>
using callable_function_ptr =
    void (*)(std::shared_ptr<xacc::CompositeInstruction>, Args...);

template <typename... Args>
class CallableKernel {
 protected:
  callable_function_ptr<Args...> &function_pointer;

 public:
  CallableKernel(callable_function_ptr<Args...> &&f) : function_pointer(f) {}
  void operator()(std::shared_ptr<xacc::CompositeInstruction> ir,
                  Args... args) {
    function_pointer(ir, args...);
  }
  void ctrl(std::shared_ptr<xacc::CompositeInstruction> ir, int ctrl_qbit,
            Args... args) {
    auto tempKernel = qcor::__internal__::create_composite("temp_control");
    function_pointer(tempKernel, args...);

    auto ctrlKernel = qcor::__internal__::create_ctrl_u();
    ctrlKernel->expand({
        std::make_pair("U", tempKernel),
        std::make_pair("control-idx", ctrl_qbit),
    });

    for (int instId = 0; instId < ctrlKernel->nInstructions(); ++instId) {
      ir->addInstruction(ctrlKernel->getInstruction(instId)->clone());
    }
  }

  void ctrl(std::shared_ptr<xacc::CompositeInstruction> ir, qubit ctrl_qbit,
            Args... args) {
    int ctrl_bit = (int)ctrl_qbit.second;
    ctrl(ir, ctrl_bit, args...);
  }

  void adjoint(std::shared_ptr<CompositeInstruction> ir, Args... args) {
    auto tempKernel = qcor::__internal__::create_composite("temp_adjoint");
    function_pointer(tempKernel, args...);
 
    // get the instructions
    auto instructions = tempKernel->getInstructions();
    std::shared_ptr<CompositeInstruction> program = tempKernel;

    // Assert that we don't have measurement
    if (!std::all_of(
            instructions.cbegin(), instructions.cend(),
            [](const auto &inst) { return inst->name() != "Measure"; })) {
      error(
          "Unable to create Adjoint for kernels that have Measure operations.");
    }

    auto provider = qcor::__internal__::get_provider();
    for (int i = 0; i < instructions.size(); i++) {
      auto inst = tempKernel->getInstruction(i);
      // Parametric gates:
      if (inst->name() == "Rx" || inst->name() == "Ry" ||
          inst->name() == "Rz" || inst->name() == "CPHASE" ||
          inst->name() == "U1" || inst->name() == "CRZ") {
        inst->setParameter(0, -inst->getParameter(0).template as<double>());
      }
      // Handles T and S gates, etc... => T -> Tdg
      else if (inst->name() == "T") {
        auto tdg = provider->createInstruction("Tdg", inst->bits());
        program->replaceInstruction(i, tdg);
      } else if (inst->name() == "S") {
        auto sdg = provider->createInstruction("Sdg", inst->bits());
        program->replaceInstruction(i, sdg);
      }
    }

    // We update/replace instructions in the derived.parent_kernel composite,
    // hence collecting these new instructions and reversing the sequence.
    auto new_instructions = tempKernel->getInstructions();
    std::reverse(new_instructions.begin(), new_instructions.end());

    // add the instructions to the current parent kernel
    ir->addInstructions(new_instructions);
  }
};

}  // namespace qcor
 No newline at end of file