Commit f66f8906 authored by Mccaskey, Alex's avatar Mccaskey, Alex
Browse files

updating visitor to handle runtime vars via lambda capture

parent 3c08179f
Loading
Loading
Loading
Loading
+18 −23
Original line number Diff line number Diff line
@@ -56,47 +56,33 @@ bool QCORASTVisitor::VisitLambdaExpr(LambdaExpr *LE) {
  // Double check... Is this a Quantum Kernel Lambda?
  IsQuantumKernelVisitor isqk(ci.getASTContext());
  isqk.TraverseStmt(LE->getBody());
  //   LE->dump();

  std::map<std::string, InstructionParameter> captures;
  std::vector<std::string> captureNames;
  // If it is, then map it to XACC IR
  if (isqk.isQuantumKernel()) {

    // std::cout << "LAMBDA IS Quantum Kernel\n";
    // LE->dumpColor();
    // exit(0);
    auto cb = LE->capture_begin(); // implicit_capture_begin();
    auto cb = LE->capture_begin();
    auto ce = LE->capture_end();
    VarDecl *v;
    for (auto it = cb; it != ce; ++it) {
      auto varName = it->getCapturedVar()->getNameAsString();

      //   it->getCapturedVar()->dumpColor();
      captureNames.push_back(varName);
      auto e = it->getCapturedVar()->getInit();
      auto int_value = dyn_cast<IntegerLiteral>(e);
      auto float_value = dyn_cast<FloatingLiteral>(e);
      if (int_value) {
        // std::cout << "THIS VALUE IS KNOWN AT COMPILE TIME: "
        //           << (int)int_value->getValue().signedRoundToDouble()
        //           << "\n"; // getAsString(ci.getASTContext(),
        //                    // it->getCapturedVar()->getType()) << "\n";
        captures.insert(
            {varName, (int)int_value->getValue().signedRoundToDouble()});
        continue;
      } else if (float_value) {
        // std::cout << varName << ", THIS DOUBLE VALUE IS KNOWN AT COMPILE
        // TIME: "
        //           << float_value->getValue().convertToDouble() << "\n";
        captures.insert({varName, float_value->getValue().convertToDouble()});
        continue;
      }

      auto varType =
          it->getCapturedVar()->getType().getCanonicalType().getAsString();
      //   std::cout << "TYPE: " << varType << "\n";
      //   it->getCapturedVar()->dumpColor();
      captures.insert({varName, varName});
      //   v = it->getCapturedVar();
    }

    SourceManager &SM = ci.getSourceManager();
@@ -125,6 +111,7 @@ bool QCORASTVisitor::VisitLambdaExpr(LambdaExpr *LE) {
    // std::cout << "LAMBDA STR:\n" << xaccKernelLambdaStr << "\n";
    auto compiler = xacc::getCompiler("xasm");
    auto ir = compiler->compile(xaccKernelLambdaStr, targetAccelerator);
    auto runtimeVariables = ir->getRuntimeVariables();

    auto function = ir->getComposites()[0];
    for (auto &inst : function->getInstructions()) {
@@ -175,14 +162,21 @@ bool QCORASTVisitor::VisitLambdaExpr(LambdaExpr *LE) {
    std::stringstream ss;
    function->persist(ss);
    std::string replacement =
        "{\nauto irstr = R\"irstr(" + ss.str() + ")irstr\";\n";

    replacement += "if (qcor::__internal::executeKernel) {\n";
        "\n{std::istringstream iss(R\"(" + ss.str() + ")\");\n";
    // "{\nauto irstr = R\"irstr(" + ss.str() + ")irstr\";\n";
    replacement +=
        "auto function = "
        "xacc::getIRProvider(\"quantum\")->createComposite(\"f\");\n";
    replacement += "std::istringstream iss(irstr);\n";
    replacement += "function->load(iss);\n";

    // Do any function expansion work here...
    std::stringstream sss;
    make_pair_visitor vis(sss, captureNames);
    runtimeVariables.visit(vis);
    auto makePairStr = sss.str();
    makePairStr = makePairStr.substr(0, makePairStr.length() - 1);
    replacement += "function->expand({" + makePairStr + "});\n";
    replacement += "if (qcor::__internal::executeKernel) {\n";
    replacement +=
        "auto acc = xacc::getAccelerator(\"" + acceleratorName + "\");\n";
    if (F->getNumParams() > 1) {
@@ -196,7 +190,8 @@ bool QCORASTVisitor::VisitLambdaExpr(LambdaExpr *LE) {
    }
    replacement += "acc->execute(" + bufferName + ",function);\n";
    replacement += "}\n";
    replacement += "return irstr;\n";
    replacement += "std::stringstream ss;\nfunction->persist(ss);\n";
    replacement += "return ss.str();\n";
    replacement += "}\n";
    rewriter.ReplaceText(sr, replacement);

+17 −0
Original line number Diff line number Diff line
@@ -20,6 +20,23 @@ class IRProvider;
}
namespace qcor {
namespace compiler {
class make_pair_visitor : public visitor_base<int, double, std::string> {
protected:
  std::stringstream &s;
  std::vector<std::string> captures;

public:
  make_pair_visitor(std::stringstream &ss, std::vector<std::string> &c)
      : s(ss), captures(c) {}
  template <typename T> void operator()(const std::string &key, const T &t) {
    if (std::find(captures.begin(), captures.end(), key) !=
        std::end(captures)) {
      s << "std::make_pair(\"" << key << "\"," << t << "),";
    } else {
      s << "std::make_pair(\"" << key << "\",\"" << t << "\"),";
    }
  }
};

class QCORASTVisitor : public RecursiveASTVisitor<QCORASTVisitor> {

+39 −11
Original line number Diff line number Diff line
@@ -4,16 +4,42 @@ int main(int argc, char **argv) {

  qcor::Initialize(argc, argv);

  auto optimizer = qcor::getOptimizer(
      "nlopt", {{"nlopt-optimizer", "cobyla"}, {"nlopt-maxeval", 1000}});
  auto optimizer =
      qcor::getOptimizer("nlopt", {std::make_pair("nlopt-optimizer", "cobyla"),
                                   std::make_pair("nlopt-maxeval", 2000)});

  const std::string src = R"src(0.7080240949826064
- 1.248846801817026 0^ 0
- 1.248846801817026 1^ 1
- 0.4796778151607899 2^ 2
- 0.4796778151607899 3^ 3
+ 0.33667197218932576 0^ 1^ 1 0
+ 0.0908126658307406 0^ 1^ 3 2
+ 0.09081266583074038 0^ 2^ 0 2
+ 0.331213646878486 0^ 2^ 2 0
+ 0.09081266583074038 0^ 3^ 1 2
+ 0.331213646878486 0^ 3^ 3 0
+ 0.33667197218932576 1^ 0^ 0 1
+ 0.0908126658307406 1^ 0^ 2 3
+ 0.09081266583074038 1^ 2^ 0 3
+ 0.331213646878486 1^ 2^ 2 1
+ 0.09081266583074038 1^ 3^ 1 3
+ 0.331213646878486 1^ 3^ 3 1
+ 0.331213646878486 2^ 0^ 0 2
+ 0.09081266583074052 2^ 0^ 2 0
+ 0.331213646878486 2^ 1^ 1 2
+ 0.09081266583074052 2^ 1^ 3 0
+ 0.09081266583074048 2^ 3^ 1 0
+ 0.34814578469185886 2^ 3^ 3 2
+ 0.331213646878486 3^ 0^ 0 3
+ 0.09081266583074052 3^ 0^ 2 1
+ 0.331213646878486 3^ 1^ 1 3
+ 0.09081266583074052 3^ 1^ 3 1
+ 0.09081266583074048 3^ 2^ 0 1
+ 0.34814578469185886 3^ 2^ 2 3)src";

  auto op = qcor::getObservable("fermion", src);

  auto geom = R"geom(2

H          0.00000        0.00000        0.00000
H          0.00000        0.00000        0.7474)geom";

  auto op = qcor::getObservable("chemistry",
                                {std::make_pair("basis", "sto-3g"), std::make_pair("geometry", geom)});
  int nq = op->nBits();

  std::vector<std::pair<int, int>> coupling{{0, 1}, {1, 2}, {2, 3}};
@@ -21,9 +47,11 @@ H 0.00000 0.00000 0.7474)geom";
  auto future = qcor::submit([&](qcor::qpu_handler &qh) {
    qh.vqe(
        [&](qbit q, std::vector<double> x) {
          hwe(q, x, {{"n-qubits", nq}, {"layers", 1}, {"coupling", coupling}});
          X(q[0]);
          X(q[1]);
          hwe(q, x, {{"nq", nq}, {"layers", 1}, {"coupling", coupling}});
        },
        op, optimizer);
        op, optimizer, std::vector<double>{});
  });

  auto results = future.get();
+36 −11
Original line number Diff line number Diff line
@@ -4,22 +4,47 @@ int main(int argc, char **argv) {

  qcor::Initialize(argc, argv);

  auto optimizer = qcor::getOptimizer(
      "nlopt", {{"nlopt-optimizer", "cobyla"}, {"nlopt-maxeval", 1000}});

  auto geom = R"geom(2

H          0.00000        0.00000        0.00000
H          0.00000        0.00000        0.7474)geom";

  auto op = qcor::getObservable("chemistry",
                                {std::make_pair("basis", "sto-3g"), std::make_pair("geometry", geom)});
  auto optimizer =
      qcor::getOptimizer("nlopt", {std::make_pair("nlopt-optimizer", "cobyla"),
                                   std::make_pair("nlopt-maxeval", 2000)});

  const std::string src = R"src(0.7080240949826064
- 1.248846801817026 0^ 0
- 1.248846801817026 1^ 1
- 0.4796778151607899 2^ 2
- 0.4796778151607899 3^ 3
+ 0.33667197218932576 0^ 1^ 1 0
+ 0.0908126658307406 0^ 1^ 3 2
+ 0.09081266583074038 0^ 2^ 0 2
+ 0.331213646878486 0^ 2^ 2 0
+ 0.09081266583074038 0^ 3^ 1 2
+ 0.331213646878486 0^ 3^ 3 0
+ 0.33667197218932576 1^ 0^ 0 1
+ 0.0908126658307406 1^ 0^ 2 3
+ 0.09081266583074038 1^ 2^ 0 3
+ 0.331213646878486 1^ 2^ 2 1
+ 0.09081266583074038 1^ 3^ 1 3
+ 0.331213646878486 1^ 3^ 3 1
+ 0.331213646878486 2^ 0^ 0 2
+ 0.09081266583074052 2^ 0^ 2 0
+ 0.331213646878486 2^ 1^ 1 2
+ 0.09081266583074052 2^ 1^ 3 0
+ 0.09081266583074048 2^ 3^ 1 0
+ 0.34814578469185886 2^ 3^ 3 2
+ 0.331213646878486 3^ 0^ 0 3
+ 0.09081266583074052 3^ 0^ 2 1
+ 0.331213646878486 3^ 1^ 1 3
+ 0.09081266583074052 3^ 1^ 3 1
+ 0.09081266583074048 3^ 2^ 0 1
+ 0.34814578469185886 3^ 2^ 2 3)src";

  auto op = qcor::getObservable("fermion", src);

  auto future = qcor::submit([&](qcor::qpu_handler &qh) {
    qh.vqe(
        [&](qbit q, double x) {
          X(q[0]);
          X(q[2]);
          X(q[1]);
          exp_i_theta(q, x, {{"pauli", "Y0 X1 X2 X3"}});
        },
        op, optimizer, 0.0);