Commit ff802f26 authored by Mccaskey, Alex's avatar Mccaskey, Alex
Browse files

started enabling d-wave code in c++ lambdas

parent 61802c7b
Loading
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -13,7 +13,7 @@ namespace compiler {
FuzzyParsingExternalSemaSource::FuzzyParsingExternalSemaSource(
    ASTContext &context)
    : m_Context(context) {
  auto irProvider = xacc::getService<xacc::IRProvider>("gate");
  auto irProvider = xacc::getService<xacc::IRProvider>("quantum");
  validInstructions = irProvider->getInstructions();
  validInstructions.push_back("CX");
  auto irgens = xacc::getRegisteredIds<xacc::IRGenerator>();
+67 −53
Original line number Diff line number Diff line
@@ -21,7 +21,7 @@ namespace compiler {

LambdaVisitor::IsQuantumKernelVisitor::IsQuantumKernelVisitor(ASTContext &c)
    : context(c) {
  auto irProvider = xacc::getService<xacc::IRProvider>("gate");
  auto irProvider = xacc::getService<xacc::IRProvider>("quantum");
  validInstructions = irProvider->getInstructions();
  validInstructions.push_back("CX");
  auto irgens = xacc::getRegisteredIds<xacc::IRGenerator>();
@@ -37,6 +37,9 @@ bool LambdaVisitor::IsQuantumKernelVisitor::VisitDeclRefExpr(
    if (std::find(validInstructions.begin(), validInstructions.end(),
                  gateName) != validInstructions.end()) {
      _isQuantumKernel = true;
      if (irType != "anneal" && (gateName == "qmi" || gateName == "anneal")) {
          irType = "anneal";
      }
    }
  }
  return true;
@@ -47,9 +50,13 @@ bool LambdaVisitor::IsQuantumKernelVisitor::VisitLambdaExpr(LambdaExpr *expr) {
  return true;
}

LambdaVisitor::CppToXACCIRVisitor::CppToXACCIRVisitor() {
  provider = xacc::getService<IRProvider>("gate");
  function = provider->createFunction("tmp", {});
LambdaVisitor::CppToXACCIRVisitor::CppToXACCIRVisitor(IsQuantumKernelVisitor& v) {
  provider = xacc::getService<IRProvider>("quantum");
  if (v.irType == "gate") {
    function = provider->createFunction("tmp", {}, {InstructionParameter("gate")});
  } else {
    function = provider->createFunction("tmp", {}, {InstructionParameter("anneal")});
  }

  auto irgens = xacc::getRegisteredIds<xacc::IRGenerator>();
  for (auto &irg : irgens) {
@@ -57,17 +64,60 @@ LambdaVisitor::CppToXACCIRVisitor::CppToXACCIRVisitor() {
  }
}

bool LambdaVisitor::CallExprToGateInstructionVisitor::VisitIntegerLiteral(
bool LambdaVisitor::CppToXACCIRVisitor::VisitCallExpr(CallExpr *expr) {
  auto gate_name = dyn_cast<DeclRefExpr>(*(expr->child_begin()))
                       ->getNameInfo()
                       .getAsString();

  if (std::find(irGeneratorNames.begin(), irGeneratorNames.end(), gate_name) !=
      irGeneratorNames.end()) {

    // This is an IRGenerator
    // Map this CallExpr to an IRGenerator
    CallExprToIRGenerator visitor(gate_name, provider);
    visitor.TraverseStmt(expr);
    auto irg = visitor.getIRGenerator();
    if (irg->validateOptions()) {
      auto generated =
          irg->generate(std::map<std::string, InstructionParameter>{});
      for (auto inst : generated->getInstructions()) {
        function->addInstruction(inst);
      }
    } else {
      function->addInstruction(irg);
    }
  } else {
    // This is a regular gate
    // Map this Call Expr to a Instruction
    if (gate_name == "CX") {
      gate_name = "CNOT";
    }
    CallExprToXACCInstructionVisitor visitor(gate_name, provider);
    visitor.TraverseStmt(expr);
    auto inst = visitor.getInstruction();
    function->addInstruction(inst);
  }

  return true;
}

bool LambdaVisitor::CallExprToXACCInstructionVisitor::VisitIntegerLiteral(
    IntegerLiteral *il) {
  if (name == "anneal") {
      int i = il->getValue().getLimitedValue();
    InstructionParameter p(i);
    parameters.push_back(p);
  } else {
  bits.push_back(il->getValue().getLimitedValue());
  if (name == "Measure") {
    InstructionParameter p(bits[0]);
    parameters.push_back(p);
  }
  }
  return true;
}

bool LambdaVisitor::CallExprToGateInstructionVisitor::VisitUnaryOperator(
bool LambdaVisitor::CallExprToXACCInstructionVisitor::VisitUnaryOperator(
    UnaryOperator *op) {
  if (op->getOpcode() == UnaryOperator::Opcode::UO_Minus) {
    addMinus = true;
@@ -75,7 +125,7 @@ bool LambdaVisitor::CallExprToGateInstructionVisitor::VisitUnaryOperator(
  return true;
}

bool LambdaVisitor::CallExprToGateInstructionVisitor::VisitFloatingLiteral(
bool LambdaVisitor::CallExprToXACCInstructionVisitor::VisitFloatingLiteral(
    FloatingLiteral *literal) {
  double value = literal->getValue().convertToDouble();
  InstructionParameter p(addMinus ? -1.0 * value : value);
@@ -84,7 +134,7 @@ bool LambdaVisitor::CallExprToGateInstructionVisitor::VisitFloatingLiteral(
  return true;
}

bool LambdaVisitor::CallExprToGateInstructionVisitor::VisitDeclRefExpr(
bool LambdaVisitor::CallExprToXACCInstructionVisitor::VisitDeclRefExpr(
    DeclRefExpr *decl) {
  auto declName = decl->getNameInfo().getAsString();
  if (addMinus) {
@@ -93,14 +143,14 @@ bool LambdaVisitor::CallExprToGateInstructionVisitor::VisitDeclRefExpr(
  if (dyn_cast<ParmVarDecl>(decl->getDecl())) {
    parameters.push_back(InstructionParameter(declName));
  } else if (dyn_cast<VarDecl>(decl->getDecl())) {
    std::cout << "THIS IS A VARDECL: " << declName << "\n";
    // std::cout << "THIS IS A VARDECL: " << declName << "\n";
    parameters.push_back(InstructionParameter(declName));
  }
  return true;
}

std::shared_ptr<Instruction>
LambdaVisitor::CallExprToGateInstructionVisitor::getInstruction() {
LambdaVisitor::CallExprToXACCInstructionVisitor::getInstruction() {
  return provider->createInstruction(name, bits, parameters);
}

@@ -253,42 +303,6 @@ bool LambdaVisitor::GetPairVisitor::VisitIntegerLiteral(
  intsFound.push_back((int)literal->getValue().getLimitedValue());
  return true;
}
bool LambdaVisitor::CppToXACCIRVisitor::VisitCallExpr(CallExpr *expr) {
  auto gate_name = dyn_cast<DeclRefExpr>(*(expr->child_begin()))
                       ->getNameInfo()
                       .getAsString();

  if (std::find(irGeneratorNames.begin(), irGeneratorNames.end(), gate_name) !=
      irGeneratorNames.end()) {

    // This is an IRGenerator
    // Map this CallExpr to an IRGenerator
    CallExprToIRGenerator visitor(gate_name, provider);
    visitor.TraverseStmt(expr);
    auto irg = visitor.getIRGenerator();
    if (irg->validateOptions()) {
      auto generated =
          irg->generate(std::map<std::string, InstructionParameter>{});
      for (auto inst : generated->getInstructions()) {
        function->addInstruction(inst);
      }
    } else {
      function->addInstruction(irg);
    }
  } else {
    // This is a regular gate
    // Map this Call Expr to a GateInstruction
    if (gate_name == "CX") {
      gate_name = "CNOT";
    }
    CallExprToGateInstructionVisitor visitor(gate_name, provider);
    visitor.TraverseStmt(expr);
    auto inst = visitor.getInstruction();
    function->addInstruction(inst);
  }

  return true;
}

std::shared_ptr<Function> LambdaVisitor::CppToXACCIRVisitor::getFunction() {
  return function;
@@ -324,16 +338,16 @@ bool LambdaVisitor::VisitLambdaExpr(LambdaExpr *LE) {
      auto int_value = dyn_cast<IntegerLiteral>(e);
      auto float_value = dyn_cast<FloatingLiteral>(e);
      if (int_value) {
        std::cout << "THIS VALUE IS KNOWN AT COMPILE TIME: "
                  << (int)int_value->getValue().signedRoundToDouble()
                  << "\n"; // getAsString(ci.getASTContext(),
                           // it->getCapturedVar()->getType()) << "\n";
        // std::cout << "THIS VALUE IS KNOWN AT COMPILE TIME: "
        //           << (int)int_value->getValue().signedRoundToDouble()
        //           << "\n"; // getAsString(ci.getASTContext(),
        //                    // it->getCapturedVar()->getType()) << "\n";
        captures.insert(
            {varName, (int)int_value->getValue().signedRoundToDouble()});
        continue;
      } else if (float_value) {
        std::cout << varName << ", THIS DOUBLE VALUE IS KNOWN AT COMPILE TIME: "
                  << float_value->getValue().convertToDouble() << "\n";
        // std::cout << varName << ", THIS DOUBLE VALUE IS KNOWN AT COMPILE TIME: "
        //           << float_value->getValue().convertToDouble() << "\n";
        captures.insert({varName, float_value->getValue().convertToDouble()});
        continue;
      }
@@ -347,7 +361,7 @@ bool LambdaVisitor::VisitLambdaExpr(LambdaExpr *LE) {
    }

    // q_kernel_body->dumpColor();
    CppToXACCIRVisitor visitor;
    CppToXACCIRVisitor visitor(isqk);
    visitor.TraverseStmt(LE);

    auto function = visitor.getFunction();
+5 −4
Original line number Diff line number Diff line
@@ -39,6 +39,7 @@ protected:
    bool VisitDeclRefExpr(DeclRefExpr *expr);
    bool VisitLambdaExpr(LambdaExpr* expr);
    bool isQuantumKernel() { return _isQuantumKernel; }
    std::string irType = "gate";
  };

  class CppToXACCIRVisitor : public RecursiveASTVisitor<CppToXACCIRVisitor> {
@@ -48,13 +49,13 @@ protected:
    std::vector<std::string> irGeneratorNames;

  public:
    CppToXACCIRVisitor();
    CppToXACCIRVisitor(IsQuantumKernelVisitor& v);
    bool VisitCallExpr(CallExpr *expr);
    std::shared_ptr<Function> getFunction();
  };

  class CallExprToGateInstructionVisitor
      : public RecursiveASTVisitor<CallExprToGateInstructionVisitor> {
  class CallExprToXACCInstructionVisitor
      : public RecursiveASTVisitor<CallExprToXACCInstructionVisitor> {
  protected:
    std::vector<int> bits;
    std::vector<InstructionParameter> parameters;
@@ -63,7 +64,7 @@ protected:
    bool addMinus = false;

  public:
    CallExprToGateInstructionVisitor(const std::string n,
    CallExprToXACCInstructionVisitor(const std::string n,
                                     std::shared_ptr<IRProvider> p)
        : name(n), provider(p) {}
    std::shared_ptr<Instruction> getInstruction();
+48 −0
Original line number Diff line number Diff line
@@ -157,6 +157,17 @@ int main(int argc, char** argv){
    };
    return 0;
})hwe4";

const std::string dw = R"dw(
int main() {
    auto l = [&]() {
        qmi(0,0,2.2);
        qmi(1,1,3.3);
        qmi(0,1,3.4);
        anneal(20,0,0,0);
    };
    return 0;
})dw";
TEST(LambdaVisitorTester, checkSimple) {
  Rewriter rewriter1, rewriter2;
  auto action1 = new TestQCORFrontendAction(rewriter1);
@@ -394,6 +405,43 @@ int main() {

  EXPECT_EQ(expectedSrc, src2);
}

TEST(LambdaVisitorTester, checkSimpleDW) {
  Rewriter rewriter1, rewriter2;
  auto action1 = new TestQCORFrontendAction(rewriter1);
  auto action2 = new TestQCORFrontendAction(rewriter2);

  xacc::setOption("qcor-compiled-filename", "lambda_visitor_tester");

  std::vector<std::string> args{"-std=c++11"};

  EXPECT_TRUE(tooling::runToolOnCodeWithArgs(action1, dw, args));

  const std::string expectedSrc = R"expectedSrc(
int main() {
    auto l = [&](){return "lambda_visitor_tester";};
    return 0;
})expectedSrc";
  std::ifstream t(".output.cpp");
  std::string src((std::istreambuf_iterator<char>(t)),
                  std::istreambuf_iterator<char>());
  std::remove(".output.cpp");

  std::cout << "SOURCE\n" << src << "\n";
  EXPECT_EQ(expectedSrc, src);

  auto function = qcor::loadCompiledCircuit("lambda_visitor_tester");
  EXPECT_EQ(4, function->nInstructions());
  EXPECT_EQ(0, function->getInstruction(0)->bits()[0]);
  EXPECT_EQ(0, function->getInstruction(0)->bits()[1]);
  EXPECT_EQ(1, function->getInstruction(1)->bits()[0]);
  EXPECT_EQ(1, function->getInstruction(1)->bits()[1]);
  EXPECT_EQ(0, function->getInstruction(2)->bits()[0]);
  EXPECT_EQ(1, function->getInstruction(2)->bits()[1]);

  std::cout << "HOWDY:\n" << function->getInstruction(3)->toString() <<"\n";
}

int main(int argc, char **argv) {
  qcor::Initialize(argc, argv);
  ::testing::InitGoogleTest(&argc, argv);
+7 −2
Original line number Diff line number Diff line
@@ -2,7 +2,12 @@ FROM theiaide/theia-full:next
USER root
RUN apt-get -y update \
    && apt-get -y update && apt-get install -y libcurl4-openssl-dev libssl-dev \
              python3 libpython3-dev python3-pip cmake gdb gfortran libblas-dev \
              liblapack-dev pkg-config
              python3 libpython3-dev python3-pip gdb gfortran libblas-dev \
              liblapack-dev pkg-config software-properties-common 
RUN apt-get update \ 
    && wget -O - https://apt.llvm.org/llvm-snapshot.gpg.key | apt-key add - \
    && add-apt-repository "deb http://apt.llvm.org/jessie/ llvm-toolchain-jessie main" \
    && apt-get -y update && apt-get -y install libclang-9-dev llvm-9-dev clang-9 \
    && ln -s /usr/bin/llvm-config-9 /usr/bin/llvm-config && python3 -m pip install cmake
ADD settings.json /home/.theia/
Loading