Commit 631ecf09 authored by Mccaskey, Alex's avatar Mccaskey, Alex
Browse files

initial prototype of QuantumKernel template class

parent e70f27e2
Loading
Loading
Loading
Loading
+223 −0
Original line number Diff line number Diff line
#include "qcor.hpp"

// __qpu__ void measure_qbits(qreg q) {
//   for (int i = 0; i < 2; i++) {
//     Measure(q[i]);
//   }
// }

// __qpu__ void quantum_kernel(qreg q, double x) {
//     X(q[0]);
//     Ry(q[1], x);
//     CX(q[1],q[0]);
// }

// __qpu__ void z0z1(qreg q, double x) {
//     quantum_kernel(q, x);
//     measure_qbits(q);
// }

// __qpu__ void check_adjoint(qreg q, double x) {
//     quantum_kernel(q,x);
//     quantum_kernel::adjoint(q,x);
//     measure_qbits(q);
// }

// translated to the following

// the functions will remain, just empty
void measure_bits(qreg q) {
    return;
}

void quantum_kernel(qreg q, double x) {
    return;
}

void z0z1(qreg q, double x) {
    return;
}

void check_adjoint(qreg q, double x) {
    return;
}

class measure_qbits : public qcor::QuantumKernel<measure_qbits, qreg> {
  friend class qcor::QuantumKernel<measure_qbits, qreg>;

protected:
  void operator()(qreg q) {
    if (!parent_kernel) {
      // if has no parent, then create the parent
      // this means this is callable
      parent_kernel = qcor::__internal__::create_composite(kernel_name);
      q.setNameAndStore("q");
    }

    quantum::set_current_program(parent_kernel);

    for (int i = 0; i < 2; i++) {
      quantum::mz(q[i]);
    }
  }

public:
  inline static const std::string kernel_name = "measure_qbits";
  measure_qbits(qreg q) : QuantumKernel<measure_qbits, qreg>(q) {}
  measure_qbits(std::shared_ptr<qcor::CompositeInstruction> p, qreg q)
      : QuantumKernel<measure_qbits, qreg>(p, q) {}

  virtual ~measure_qbits() {
    if (disable_destructor) {
      return;
    }

    quantum::set_backend("qpp", 1024);
    auto [q] = args_tuple;
    operator()(q);
    if (is_callable) {
      quantum::submit(q.results());
    }
  }
};

class quantum_kernel
    : public qcor::QuantumKernel<quantum_kernel, qreg, double> {
  friend class qcor::QuantumKernel<quantum_kernel, qreg, double>;

protected:
  void operator()(qreg q, double t) {
    if (!parent_kernel) {
      // if has no parent, then create the parent
      // this means this is callable
      parent_kernel = qcor::__internal__::create_composite(kernel_name);
      q.setNameAndStore("q");
    }

    quantum::set_current_program(parent_kernel);

    quantum::x(q[0]);
    quantum::ry(q[1], t);
    quantum::cnot(q[1], q[0]);
  }

public:
  inline static const std::string kernel_name = "quantum_kernel";

  quantum_kernel(qreg q, double t)
      : QuantumKernel<quantum_kernel, qreg, double>(q, t) {}
  quantum_kernel(std::shared_ptr<qcor::CompositeInstruction> p, qreg q,
                 double t)
      : QuantumKernel<quantum_kernel, qreg, double>(p, q, t) {}
  quantum_kernel() : QuantumKernel<quantum_kernel, qreg, double>() {}

  virtual ~quantum_kernel() {
    if (disable_destructor) {
      return;
    }
    quantum::set_backend("qpp", 1024);
    auto [q, t] = args_tuple;
    operator()(q, t);
    if (is_callable) {
      quantum::submit(q.results());
    }
  }
};

class z0z1 : public qcor::QuantumKernel<z0z1, qreg, double> {

protected:
  void operator()(qreg q, double t) {
    if (!parent_kernel) {
      // if has no parent, then create the parent
      // this means this is callable
      parent_kernel = qcor::__internal__::create_composite(kernel_name);
      q.setNameAndStore("r");
    }

    quantum::set_current_program(parent_kernel);

    // FIXME Will require qrt_mapper to add parent_kernel to 
    // argument list

    quantum_kernel(parent_kernel, q, t);
    measure_qbits(parent_kernel, q);
  }

public:
  inline static const std::string kernel_name = "z0z1";

  z0z1(qreg q, double t) : QuantumKernel<z0z1, qreg, double>(q, t) {}

  virtual ~z0z1() {
    if (disable_destructor) {
      return;
    }
    quantum::set_backend("qpp", 1024);
    auto [q, t] = args_tuple;
    operator()(q, t);
    if (is_callable) {
      quantum::submit(q.results());
    }
  }
};

class check_adjoint : public qcor::QuantumKernel<check_adjoint, qreg, double> {

protected:
  void operator()(qreg q, double t) {
    if (!parent_kernel) {
      // if has no parent, then create the parent
      // this means this is callable
      parent_kernel = qcor::__internal__::create_composite(kernel_name);
      q.setNameAndStore("v");
    }
    quantum::set_current_program(parent_kernel);

    // FIXMEWill require qrt_mapper to add parent_kernel to 
    // argument list, for adjoint too

    quantum_kernel(parent_kernel, q, t);
    quantum_kernel::adjoint(parent_kernel, q, t);
    std::cout << "check here\n" << parent_kernel->toString() << "\n";
    measure_qbits(parent_kernel, q);
  }

public:
  inline static const std::string kernel_name = "check_adjoint";

  check_adjoint(qreg q, double t)
      : QuantumKernel<check_adjoint, qreg, double>(q, t) {}

  virtual ~check_adjoint() {
    if (disable_destructor) {
      return;
    }
    quantum::set_backend("qpp", 1024);
    auto [q, t] = args_tuple;
    operator()(q, t);
    if (is_callable) {
      //   quantum::program = parent_kernel;
      std::cout << quantum::program->toString() << "\n";
      quantum::submit(q.results());
    }
  }
};

int main() {
  auto q = qalloc(2);

  quantum_kernel(q, 2.2);

  q.print();

  auto r = qalloc(2);

  z0z1(r, 2.2);
  r.print();

  auto v = qalloc(2);

  check_adjoint(v, 2.2);
  v.print();
}
 No newline at end of file
+275 −17
Original line number Diff line number Diff line
@@ -43,12 +43,6 @@ public:
    SourceManager &sm = PP.getSourceManager();
    auto lo = PP.getLangOpts();

    // auto src_txt = Lexer::getSourceText(
    //     CharSourceRange::getTokenRange(SourceRange(
    //         Toks[0].getLocation(), Toks[Toks.size() - 1].getLocation())),
    //     sm, lo);


    // Get the Function Type Info from the Declarator,
    // If the function has no arguments, then we throw an error
    const DeclaratorChunk::FunctionTypeInfo &FTI = D.getFunctionTypeInfo();
@@ -71,20 +65,23 @@ public:
      PP.getRawToken(paramInfo.IdentLoc, IdentToken);
      PP.getRawToken(decl->getBeginLoc(), TypeToken);

      function_prototype +=
          PP.getSpelling(TypeToken) + " " + PP.getSpelling(IdentToken);
      auto type = PP.getSpelling(TypeToken);
      auto var = PP.getSpelling(IdentToken);

      function_prototype += type + " " + var;

      auto parm_var_decl = cast<ParmVarDecl>(decl);
      if (parm_var_decl) {
        auto type = QualType::getAsString(parm_var_decl->getType().split(),
                                          PrintingPolicy{{}});
      program_arg_types.push_back(type);
        program_parameters.push_back(ident->getName().str());
        if (type == "class xacc::internal_compiler::qreg") {
      program_parameters.push_back(var);

      auto parm_var_decl = cast<ParmVarDecl>(decl);

      if (parm_var_decl &&
          QualType::getAsString(parm_var_decl->getType().split(),
                                PrintingPolicy{{}}) ==
              "class xacc::internal_compiler::qreg") {
        bufferNames.push_back(ident->getName().str());
      }
    }
    }
    function_prototype += ")";

    // Get Tokens as a string, rewrite code
@@ -92,6 +89,8 @@ public:

    auto new_src = qcor::run_token_collector(PP, Toks, bufferNames);

    OS << function_prototype << "{\n";

    OS << "quantum::initialize(\"" << qpu_name << "\", \"" << kernel_name
       << "\");\n";
    for (auto &buf : bufferNames) {
@@ -112,6 +111,9 @@ public:
        OS << ", " << bufferNames[k] << ".results()";
      }
      OS << "};\n";
      OS << "std::cout << \"execing: \" << quantum::getProgram()->toString() "
            "<< \"\\n\";\n";

      OS << "quantum::submit(buffers," << bufferNames.size();
    } else {
      OS << "quantum::submit(" << bufferNames[0] << ".results()";
@@ -119,10 +121,62 @@ public:

    OS << ");\n";
    OS << "}";
    OS << "\n}\n";

    OS << "class " << kernel_name << "{\n";
    OS << "public:\n";
    OS << "static void adjoint(";
    for (int i = 0; i < program_arg_types.size(); i++) {
      if (i > 0) {
        OS << ",";
      }
      auto arg_type = program_arg_types[i];
      auto arg_var = program_parameters[i];

      OS << arg_type << " " << arg_var;
    }
    OS << ") {\n";

    OS << "quantum::initialize(\"" << qpu_name << "\", \"" << kernel_name
       << "\");\n";
    for (auto &buf : bufferNames) {
      OS << buf << ".setNameAndStore(\"" + buf + "\");\n";
    }

    if (shots > 0) {
      OS << "quantum::set_shots(" << shots << ");\n";
    }

    OS << new_src;

    OS << "quantum::adjoint();\n";

    OS << "if (__execute) {\n";

    if (bufferNames.size() > 1) {
      OS << "xacc::AcceleratorBuffer * buffers[" << bufferNames.size()
         << "] = {";
      OS << bufferNames[0] << ".results()";
      for (unsigned int k = 1; k < bufferNames.size(); k++) {
        OS << ", " << bufferNames[k] << ".results()";
      }
      OS << "};\n";
      OS << "quantum::submit(buffers," << bufferNames.size();
    } else {
      OS << "quantum::submit(" << bufferNames[0] << ".results()";
    }

    OS << ");\n";
    OS << "}\n";

    // close adjoint()
    OS << "}\n";
    // close class
    OS << "};";

    auto s = OS.str();
    qcor::info("[qcor syntax-handler] Rewriting " + kernel_name + " to\n\n" +
               function_prototype + "{\n" + s.substr(2, s.length()) + "\n}");
               s);
  }

  void AddToPredefines(llvm::raw_string_ostream &OS) override {
@@ -181,6 +235,210 @@ public:
  }
};

//// =================== COMMON CODE ==============/////
// Base class of all kernels:
// This will handle Adjoint() and Control() in an AUTOMATIC way,
// i.e. not taking into account if it is self-adjoint.
// Technically, we can define sub-classes for those special cases
// and then allow users to annotate kernels as self-adjoint for instance.
// class KernelBase {
// public:
//   KernelBase(xacc::internal_compiler::qreg q,
//              std::shared_ptr<xacc::CompositeInstruction> bodyComposite)
//       : m_qreg(q), m_usedAsCallable(true), m_body(bodyComposite) {}
//   // Adjoint
//   virtual KernelBase adjoint() {
//     // Copy this
//     KernelBase adjointKernel(this, "KernelName_ADJ");
//     // Reverse all instructions in m_body and replace instructions
//     // with their adjoint:
//     // T -> Tdag; Rx(theta) -> Rx(-theta), etc.
//     auto instructions = adjointKernel.m_body->getInstructions();
//     // Assert that we don't have measurement
//     if (!std::all_of(
//             instructions.cbegin(), instructions.cend(),
//             [](const auto &inst) { return inst->name() != "Measure"; })) {
//       xacc::error(
//           "Unable to create Adjoint for kernels that have Measure operations.");
//     }
//     std::reverse(instructions.begin(), instructions.end());
//     for (const auto &inst : instructions) {
//       // Parametric gates:
//       if (inst->name() == "Rx" || inst->name() == "Ry" ||
//           inst->name() == "Rz" || inst->name() == "CPHASE") {
//         inst->setParameter(0, -inst->getParameter(0).as<double>());
//       }
//       // TODO: Handles T and S gates, etc... => T -> Tdg
//     }
//     adjointKernel.m_body->clear();
//     adjointKernel.m_body->addInstructions(instructions);
//     return adjointKernel;
//   }
//   virtual KernelBase ctrl(size_t ctrlIdx) {
//     // Copy this
//     KernelBase controlledKernel(this, "KernelName_CTRL");
//     // Use the controlled gate module of XACC to transform
//     // controlledKernel.m_body
//     auto ctrlKernel = quantum::controlledKernel(m_body, ctrlIdx);
//     // Set the body of the returned kernel instance.
//     controlledKernel.m_body = ctrlKernel;
//     return controlledKernel;
//   }
//   // Destructor:
//   // called right after the object invocation:
//   // e.g.
//   // Case 1: free-standing invocation:
//   //  ... code ...
//   // kernelFuncClass(abc);
//   // -> DTor called here
//   // ... code ...
//   // Cade 2: chaining
//   //  ... code ...
//   // kernelFuncClass(abc).adjoint();
//   // -> DTor of the Adjoint instance called here (m_usedAsCallable = true)
//   // hence adding the adjoint body to the global composite.
//   // -> DTor of the kernelFuncClass(abc) instance called here (m_usedAsCallable
//   // = false) hence having no effect.
//   // ... code ...
//   virtual ~KernelBase() {
//     // This is used as a CALLABLE
//     if (m_usedAsCallable) {
//       // Add all instructions to the global program.
//       quantum::program->addInstructions(m_body->getInstructions());
//     }
//   }
//   // Default move CTor
//   KernelBase(KernelBase &&) = default;

// protected:
//   // Copy ctor:
//   // Deep copy of the CompositeInstruction to prevent dangling references.
//   KernelBase(KernelBase *other, const std::string &in_optional_newName = "") {
//     const auto kernelName = in_optional_newName.empty() ? other->m_body->name()
//                                                         : in_optional_newName;
//     auto provider = xacc::getIRProvider("quantum");
//     m_body = provider->createComposite(kernelName);
//     for (const auto &inst : other->m_body->getInstructions()) {
//       m_body->addInstruction(inst->clone());
//     }
//     m_qreg = other->m_qreg;
//     m_usedAsCallable = true;
//     // The copied kernel becomes *INACTIVE*
//     other->m_usedAsCallable = false;
//   }
//   // Denote if this instance was use as a *Callable*
//   // i.e.
//   // kernelFuncClass(qubitReg); => TRUE
//   // kernelFuncClass(qubitReg).adjoint(); => FALSE (on the original
//   // kernelFuncClass instance) but TRUE for the one returned by the adjoint()
//   // member function. This will allow arbitrary chaining: e.g.
//   // kernelFuncClass(qubitReg).adjoint().ctrl(k); only the last kernel returned
//   // by ctrl() will be the *Callable*;
//   bool m_usedAsCallable;
//   // The XACC composite instruction described by this kernel body:
//   std::shared_ptr<xacc::CompositeInstruction> m_body;
//   // From kernel params:
//   xacc::internal_compiler::qreg m_qreg;
// };
// //// =================== END COMMON CODE ==============/////
// // The above code can be placed in a header file which is then injected.
// /// ============= ORIGINAL CODE ===================
// // Assume we are rewriting this:
// // __qpu__ void kernelFunc(qreg q, double angle) {
// //   H(q[0]);
// //   CNOT(q[0], q[1]);
// //   Rx(q[0], angle);
// // }
// /// =============  CODE GEN ======================////
// // kernel function: returns the class object.
// KernelBase kernelFunc(xacc::internal_compiler::qreg q, double angle) {
//   quantum::initialize("qpp", "KERNEL_NAME");
//   q.setNameAndStore("q");
//   auto provider = xacc::getIRProvider("quantum");
//   // Kernel name (function name)
//   // BODY to denote it's the original body
//   auto kernelBody = provider->createComposite("KernelName_BODY");
//   // Rewrite from function body:
//   // TODO: QRT functions to take an composite instruction arg
//   // hence added instructions to that composite.
//   // HACK: for testing, swapping the *GLOBAL* program.
//   auto cachedGlobalProgram = quantum::program;
//   // Set the program to this body composite,
//   // hence we can listen to all the QRT instructions below.
//   quantum::program = kernelBody;
//   // ======== QRT Code ===========
//   // Rewrite from the __qpu__ body
//   quantum::h(q[0]); // Ideally, we'll do quantum::h(q[0], kernelBody);
//   quantum::cnot(q[0], q[1]);
//   quantum::rx(q[0], angle);
//   // ======== QRT Code ===========
//   // Restore the global program
//   quantum::program = cachedGlobalProgram;
//   KernelBase instance(q, kernelBody);
//   return instance;
// }
// /// =============  END CODE GEN ======================////
// // Note: The most difficult part is to *CHANGE* the function signature from
// // returning *void* to returing *KernelBase*,
// // i.e. we need to be able to rewrite:
// // "__qpu__ void" ==> "__qpu__ KernelBase" (__qpu__ is handled by the
// // pre-processor) Possibility: the *qcor* script to do that before calling clang
// // ??
// //////////////////////////////////////////////////
// // TEST kernel-in-kernel
// // __qpu__ void nestedFunc(qreg q, double angle1, double angle2) {
// //   kernelFunc(q, angle1).adjoint();
// //   Ry(q[1], angle2);
// //   Measure(q[0]);
// // }
// /// =============  CODE GEN ======================////
// KernelBase nestedFunc(xacc::internal_compiler::qreg q, double angle1,
//                       double angle2) {
//   quantum::initialize("qpp", "KERNEL_NAME");
//   q.setNameAndStore("q");
//   auto provider = xacc::getIRProvider("quantum");
//   // Kernel name (function name)
//   // BODY to denote it's the original body
//   auto kernelBody = provider->createComposite("KernelName_BODY");
//   // Rewrite from function body:
//   auto cachedGlobalProgram = quantum::program;
//   // Set the program to this body composite,
//   // hence we can listen to all the QRT instructions below.
//   quantum::program = kernelBody;
//   // ======== QRT Code ===========
//   // Call other kernels (i.e. left unchanged)
//   // Support arbitrary chaining here as well.
//   kernelFunc(q, angle1).adjoint();
//   // Some more gates:
//   quantum::ry(q[1], angle2);
//   quantum::mz(q[0]);
//   // ======== QRT Code ===========
//   // Restore the global program
//   quantum::program = cachedGlobalProgram;
//   KernelBase instance(q, kernelBody);
//   return instance;
// }
// /// ============= END CODE GEN ======================////
// // Classical code:
// int main(int argc, char **argv) {
//   // Allocate 3 qubits
//   auto q = qalloc(3);
//   // Can try any of the following things:
//   // kernelFunc(q, 1.234);
//   // kernelFunc(q, 1.234).adjoint();
//   // kernelFunc(q, 1.234).ctrl(2);
//   // I'm crazy :)
//   // kernelFunc(q, 1.234).adjoint().ctrl(2).adjoint();
//   // Nested case:
//   // Note: we cannot `adjoint` or `control` the nestedFunc
//   // since it contains Measure (throw).
//   nestedFunc(q, 1.23, 4.56);
//   // This should include instructions from the above kernel,
//   // which is added when the Dtor is called.
//   std::cout << "Program: \n" << quantum::program->toString() << "\n";
//   // dump the results
//   q.print();
// }
} // namespace

static SyntaxHandlerRegistry::Add<QCORSyntaxHandler>
+27 −3
Original line number Diff line number Diff line
@@ -2,6 +2,7 @@
#include "xacc.hpp"
#include "xasm_singleVisitor.h"
#include <IRProvider.hpp>
#include <regex>

using namespace xasm;

@@ -11,6 +12,9 @@ using xasm_single_result_type =
    std::pair<std::string, std::shared_ptr<xacc::Instruction>>;

class xasm_single_visitor : public xasm::xasm_singleVisitor {
protected:
  int n_cached_execs = 0;

public:
  xasm_single_result_type result;

@@ -26,6 +30,7 @@ public:
    if (!xacc::isInitialized()) {
      xacc::Initialize();
    }

    // if not in instruction registry, then forward to classical instructions
    auto inst_name = context->inst_name->getText();
    auto provider = xacc::getIRProvider("quantum");
@@ -107,9 +112,13 @@ public:
      for (auto c : context->children) {
        ss << c->getText() << " ";
      }
      ss << "\n";

      result.first = ss.str();
      // always wrap in execute false
      result.first =
          "const auto cached_exec_" + context->inst_name->getText() +
          " = __execute;\n__execute = false;\n" + ss.str() +
          "\n__execute = cached_exec_" + context->inst_name->getText() + ";\n";
      n_cached_execs++;
    }

    return 0;
@@ -123,12 +132,27 @@ public:

    std::stringstream ss;

    bool wrap_false_exec = false;
    std::string adjoint_call_name = "";
    for (auto c : context->children) {
      if (c->getText().find("::adjoint") != std::string::npos) {
        wrap_false_exec = true;
        adjoint_call_name = c->getText();
        adjoint_call_name = std::regex_replace(adjoint_call_name, std::regex("::"), "_");
      }
      ss << c->getText() << " ";
    }
    ss << "\n";

    if (wrap_false_exec) {
      result.first =
          "const auto cached_exec_" + adjoint_call_name +
          " = __execute;\n__execute = false;\n" + ss.str() +
          "__execute = cached_exec_" + adjoint_call_name + ";\n";
      n_cached_execs++;
    } else {
      result.first = ss.str();
    }
    return 0;
  }

+1 −2
Original line number Diff line number Diff line
@@ -136,8 +136,7 @@ void XasmTokenCollector::collect(clang::Preprocessor &PP,
  // or quantum IR from xacc.
  using namespace antlr4;
  for (const auto &line : lines) {
    // xasm_single_result_type result;
    xasm_single_visitor visitor; //(result);
    xasm_single_visitor visitor;

    ANTLRInputStream input(line);
    xasm_singleLexer lexer(&input);
+6 −0
Original line number Diff line number Diff line
@@ -13,8 +13,14 @@ namespace qcor {
void set_verbose(bool verbose) { xacc::set_verbose(verbose); }
bool get_verbose() {return xacc::verbose;}
void set_shots(const int shots) { ::quantum::set_shots(shots); }
void error(const std::string& msg) {
    xacc::error(msg);
}

namespace __internal__ {
std::shared_ptr<qcor::CompositeInstruction> create_composite(std::string name) {
    return xacc::getIRProvider("quantum")->createComposite(name);
}
std::shared_ptr<ObjectiveFunction> get_objective(const std::string &type) {
  if (!xacc::isInitialized())
    xacc::internal_compiler::compiler_InitializeXACC();
Loading