Commit 5e414798 authored by Mccaskey, Alex's avatar Mccaskey, Alex
Browse files

Setup compiler to read more complicated irgenerator options, added vqe test


Signed-off-by: Mccaskey, Alex's avatarAlex McCaskey <mccaskeyaj@ornl.gov>
parent 8113b3c9
Pipeline #46783 passed with stages
in 2 minutes and 19 seconds
......@@ -89,14 +89,24 @@ LambdaVisitor::CallExprToGateInstructionVisitor::getInstruction() {
bool LambdaVisitor::CallExprToIRGenerator::VisitInitListExpr(
InitListExpr *expr) {
if (haveSeenFirstInit) {
ScanInitListExpr visitor;
visitor.TraverseStmt(expr);
options.insert({visitor.key, visitor.value});
std::cout << "Inserting " << visitor.key << ", " << visitor.value.toString()
<< "\n";
if (haveSeenFirstInit && keepSearching) {
for (auto child : immediate_children) {
ScanInitListExpr visitor;
visitor.TraverseStmt(child);
options.insert({visitor.key, visitor.value});
std::cout << "Inserting " << visitor.key << ", "
<< visitor.value.toString() << "\n";
}
keepSearching = false;
} else {
haveSeenFirstInit = true;
auto children = expr->children();
for (auto it = children.begin(); it != children.end(); ++it) {
immediate_children.push_back(*it);
}
}
return true;
}
......@@ -113,47 +123,109 @@ bool LambdaVisitor::ScanInitListExpr::VisitDeclRefExpr(DeclRefExpr *expr) {
value = InstructionParameter(expr->getNameInfo().getAsString());
return true;
}
// bool LambdaVisitor::ScanInitListExpr::VisitInitListExpr(InitListExpr *expr) {
// ScanInitListExpr visitor(true);
// visitor.TraverseStmt(expr);
bool LambdaVisitor::ScanInitListExpr::VisitInitListExpr(InitListExpr *expr) {
// if (!visitor.intsFound.empty()) {
// // for (int i = 0; i < )
// } else if (!visitor.realsFound.empty()) {
// } else {
// xacc::error("Invalid pair type.");
// }
if (skipSubInits) {
return true;
}
if (hasSeenFirstIL) {
HasSubInitListExpr visitor;
visitor.TraverseStmt(*expr->children().begin());
if (visitor.hasSubInitLists) {
isVectorValue = true;
// this is a vector of pairs or doubles.
GetPairVisitor visitor;
visitor.TraverseStmt(expr);
if (!visitor.intsFound.empty()) {
std::vector<std::pair<int, int>> tmp;
for (int i = 0; i < visitor.intsFound.size(); i += 2) {
tmp.push_back({visitor.intsFound[i], visitor.intsFound[i + 1]});
}
value = InstructionParameter(tmp);
} else if (!visitor.realsFound.empty()) {
std::vector<std::pair<double, double>> tmp;
for (int i = 0; i < visitor.realsFound.size(); i += 2) {
tmp.push_back({visitor.realsFound[i], visitor.realsFound[i + 1]});
}
value = InstructionParameter(tmp);
} else {
xacc::error("invalid vector<pair> type for IRGenerator options.");
}
// return true;
skipSubInits = true;
} else {
// }
// this is a vector...
ScanInitListExpr visitor(true);
visitor.TraverseStmt(expr);
if (!visitor.intsFound.empty()) {
value = InstructionParameter(visitor.intsFound);
} else if (!visitor.realsFound.empty()) {
value = InstructionParameter(visitor.realsFound);
} else if (!visitor.stringsFound.empty()) {
value = InstructionParameter(visitor.stringsFound);
} else {
xacc::error("invalid vector type for IRGenerator options.");
}
}
} else {
hasSeenFirstIL = true;
}
return true;
}
bool LambdaVisitor::ScanInitListExpr::VisitStringLiteral(
StringLiteral *literal) {
if (isFirstStringLiteral) {
isFirstStringLiteral = false;
key = literal->getString().str();
if (isVectorValue) {
stringsFound.push_back(literal->getString().str());
} else {
value = InstructionParameter(literal->getString().str());
if (isFirstStringLiteral) {
isFirstStringLiteral = false;
key = literal->getString().str();
} else {
value = InstructionParameter(literal->getString().str());
}
}
return true;
}
bool LambdaVisitor::ScanInitListExpr::VisitFloatingLiteral(
FloatingLiteral *literal) {
value = InstructionParameter(literal->getValue().convertToDouble());
if (isVectorValue) {
realsFound.push_back(literal->getValue().convertToDouble());
} else {
value = InstructionParameter(literal->getValue().convertToDouble());
}
return true;
}
bool LambdaVisitor::ScanInitListExpr::VisitIntegerLiteral(
IntegerLiteral *literal) {
value = InstructionParameter((int)literal->getValue().getLimitedValue());
if (isVectorValue) {
intsFound.push_back((int)literal->getValue().getLimitedValue());
} else {
value = InstructionParameter((int)literal->getValue().getLimitedValue());
}
return true;
}
bool LambdaVisitor::GetPairVisitor::VisitFloatingLiteral(
FloatingLiteral *literal) {
realsFound.push_back(literal->getValue().convertToDouble());
return true;
}
bool LambdaVisitor::GetPairVisitor::VisitIntegerLiteral(
IntegerLiteral *literal) {
intsFound.push_back((int)literal->getValue().getLimitedValue());
return true;
}
bool LambdaVisitor::CppToXACCIRVisitor::VisitCallExpr(CallExpr *expr) {
auto gate_name = dyn_cast<DeclRefExpr>(*(expr->child_begin()))
->getNameInfo()
......@@ -216,11 +288,11 @@ bool LambdaVisitor::VisitLambdaExpr(LambdaExpr *LE) {
auto cb = LE->capture_begin(); // implicit_capture_begin();
auto ce = LE->capture_end();
VarDecl* v;
VarDecl *v;
for (auto it = cb; it != ce; ++it) {
auto varName = it->getCapturedVar()->getNameAsString();
// it->getCapturedVar()->dumpColor();
// it->getCapturedVar()->dumpColor();
auto e = it->getCapturedVar()->getInit();
auto int_value = dyn_cast<IntegerLiteral>(e);
auto float_value = dyn_cast<FloatingLiteral>(e);
......@@ -239,10 +311,10 @@ bool LambdaVisitor::VisitLambdaExpr(LambdaExpr *LE) {
auto varType =
it->getCapturedVar()->getType().getCanonicalType().getAsString();
// std::cout << "TYPE: " << varType << "\n";
// it->getCapturedVar()->dumpColor();
// std::cout << "TYPE: " << varType << "\n";
// it->getCapturedVar()->dumpColor();
captures.insert({varName, varName});
// v = it->getCapturedVar();
// v = it->getCapturedVar();
}
// q_kernel_body->dumpColor();
......@@ -254,8 +326,9 @@ bool LambdaVisitor::VisitLambdaExpr(LambdaExpr *LE) {
// Check if we have IRGenerators in the tree
if (function->hasIRGenerators()) {
// std::cout << "We have IRGenerators, checking to see if we know enough to "
// "generate it\n";
// std::cout << "We have IRGenerators, checking to see if we know enough
// to "
// "generate it\n";
int idx = 0;
std::shared_ptr<IRGenerator> irg;
for (auto &inst : function->getInstructions()) {
......@@ -307,21 +380,21 @@ bool LambdaVisitor::VisitLambdaExpr(LambdaExpr *LE) {
std::string replacement = "[&]() {\n";
std::shared_ptr<Instruction> irg;
for (auto i : function->getInstructions()) {
if (std::dynamic_pointer_cast<IRGenerator>(i)) {
irg = i;
break;
}
if (std::dynamic_pointer_cast<IRGenerator>(i)) {
irg = i;
break;
}
}
for (auto &kv : captures) {
std::string key = "";
auto opts = irg->getOptions();
for (auto &kv2 : opts) {
if (kv2.second.isVariable() &&
kv2.second.toString() == kv.first) {
key = kv2.first;
}
if (kv2.second.isVariable() && kv2.second.toString() == kv.first) {
key = kv2.first;
}
}
replacement += "qcor::storeRuntimeVariable(\"" + key +"\", " + kv.first + ");\n";
replacement +=
"qcor::storeRuntimeVariable(\"" + key + "\", " + kv.first + ");\n";
}
replacement += "return \"" + fileName + "\";\n}";
rewriter.ReplaceText(sr, replacement);
......@@ -355,10 +428,9 @@ bool LambdaVisitor::VisitLambdaExpr(LambdaExpr *LE) {
auto rtrn = ReturnStmt::Create(ci.getASTContext(), SourceLocation(),
fnameSL, nullptr);
auto cs = LE->getCallOperator()->getBody();
for (auto it = cs->child_begin(); it != cs->child_end(); ++it) {
svec.push_back(*it);
svec.push_back(*it);
}
// svec.push_back(LE->getCallOperator()->getBody());
......
......@@ -79,6 +79,8 @@ protected:
std::map<std::string, InstructionParameter> options;
bool haveSeenFirstDeclRef = false;
bool haveSeenFirstInit = false;
bool keepSearching = true;
std::vector<Stmt*> immediate_children;
public:
CallExprToIRGenerator(const std::string n, std::shared_ptr<IRProvider> p)
......@@ -90,20 +92,39 @@ protected:
class ScanInitListExpr : public RecursiveASTVisitor<ScanInitListExpr> {
protected:
bool isFirstStringLiteral = true;
// bool isSubInit;
bool isVectorValue;
bool hasSeenFirstIL = false;
bool skipSubInits = false;
public:
// std::vector<int> intsFound;
// std::vector<double> realsFound;
std::vector<int> intsFound;
std::vector<double> realsFound;
std::vector<std::string> stringsFound;
std::string key;
InstructionParameter value;
// ScanInitListExpr(bool isSubInitList = false) :isSubInit(isSubInitList) {}
ScanInitListExpr(bool isVecValued = false) :isVectorValue(isVecValued) {}
bool VisitDeclRefExpr(DeclRefExpr *expr);
bool VisitStringLiteral(StringLiteral *literal);
bool VisitFloatingLiteral(FloatingLiteral *literal);
bool VisitIntegerLiteral(IntegerLiteral *literal);
// bool VisitInitListExpr(InitListExpr *initList);
bool VisitInitListExpr(InitListExpr *initList);
};
class HasSubInitListExpr : public RecursiveASTVisitor<HasSubInitListExpr> {
public:
bool hasSubInitLists = false;
bool VisitInitListExpr(InitListExpr *initList) {
hasSubInitLists = true;
return true;
}
};
class GetPairVisitor : public RecursiveASTVisitor<GetPairVisitor> {
public:
std::vector<int> intsFound;
std::vector<double> realsFound;
bool VisitFloatingLiteral(FloatingLiteral *literal);
bool VisitIntegerLiteral(IntegerLiteral *literal);
};
public:
LambdaVisitor(CompilerInstance &c, Rewriter &rw);
......
......@@ -120,6 +120,20 @@ int main(int argc, char** argv){
};
return 0;
})hwe3";
const std::string hwe4 = R"hwe4(#include <vector>
int main(int argc, char** argv){
int nq = argc;
auto l = [&](std::vector<double> x) {
hwe(x, {
{"n-qubits", nq},
{"layers",1},
{"coupling", {{1,0}, {0,1}} },
{"testVector", {1,2,3,4,5,6} },
});
};
return 0;
})hwe4";
TEST(LambdaVisitorTester, checkSimple) {
Rewriter rewriter1, rewriter2;
auto action1 = new TestQCORFrontendAction(rewriter1);
......@@ -256,6 +270,39 @@ TEST(LambdaVisitorTester, checkRuntimeGeneratorWithVectorPair) {
std::istreambuf_iterator<char>());
std::remove(".output.cpp");
std::cout << "HELLO:\n" << src2 <<"\n";
const std::string exp1 = R"exp1(#include <vector>
int main(int argc, char** argv){
int nq = argc;
std::vector<std::pair<int,int>> c{{1,0}};
auto l = [&]() {
qcor::storeRuntimeVariable("coupling", c);
qcor::storeRuntimeVariable("n-qubits", nq);
return "lambda_visitor_tester";
};
return 0;
})exp1";
EXPECT_EQ(exp1,src2);
}
TEST(LambdaVisitorTester, checkRuntimeGeneratorWithVectorPairAndVector) {
Rewriter rewriter1, rewriter2;
auto action1 = new TestQCORFrontendAction(rewriter1);
xacc::setOption("qcor-compiled-filename", "lambda_visitor_tester");
std::vector<std::string> args{"-std=c++11"};
std::cout << "Source Code:\n" << hwe4 << "\n";
// first case, I know compile time values, so ahead-of-time compilation
EXPECT_TRUE(tooling::runToolOnCodeWithArgs(action1, hwe4, args));
std::ifstream t1(".output.cpp");
std::string src2((std::istreambuf_iterator<char>(t1)),
std::istreambuf_iterator<char>());
std::remove(".output.cpp");
std::cout << "HELLO:\n" << src2 <<"\n";
const std::string exp1 = R"exp1(#include <vector>
int main(int argc, char** argv){
......@@ -267,7 +314,7 @@ return "lambda_visitor_tester";
return 0;
})exp1";
// EXPECT_EQ(exp1,src2);
EXPECT_EQ(exp1,src2);
}
int main(int argc, char **argv) {
qcor::Initialize(argc, argv);
......
......@@ -31,7 +31,7 @@ usfunctionembedresources(TARGET
qcor_enable_rpath(${LIBRARY_NAME})
if(QCOR_BUILD_TESTS)
#add_subdirectory(tests)
add_subdirectory(tests)
endif()
install(TARGETS ${LIBRARY_NAME} DESTINATION ${CMAKE_INSTALL_PREFIX}/plugins)
include_directories(${CMAKE_SOURCE_DIR}/runtime/algorithms/vqe)
include_directories(${CMAKE_SOURCE_DIR}/runtime)
add_xacc_test(VQE)
target_link_libraries(VQETester qcor-algorithm-vqe qcor)
\ No newline at end of file
#include <gtest/gtest.h>
#include "XACC.hpp"
#include "qcor.hpp"
#include "vqe.hpp"
#include "xacc_service.hpp"
using namespace qcor;
using namespace qcor::algorithm;
const std::string rucc = R"rucc(def f(buffer, theta):
X(0)
X(1)
Rx(1.5707,0)
H(1)
H(2)
H(3)
CNOT(0,1)
CNOT(1,2)
CNOT(2,3)
Rz(theta,3)
CNOT(2,3)
CNOT(1,2)
CNOT(0,1)
Rx(-1.5707,0)
H(1)
H(2)
H(3)
)rucc";
TEST(VQETester, checkSimple) {
if (xacc::hasAccelerator("tnqvm") && xacc::hasCompiler("xacc-py")) {
auto acc = xacc::getAccelerator("tnqvm");
auto buffer = acc->createBuffer("q", 4);
auto compiler = xacc::getService<xacc::Compiler>("xacc-py");
auto ir = compiler->compile(rucc, nullptr);
auto ruccsd = ir->getKernel("f");
auto optimizer = qcor::getOptimizer("nlopt");
auto observable = qcor::getObservable(
"(0.174073,0) Z2 Z3 + (0.1202,0) Z1 Z3 + (0.165607,0) Z1 Z2 + "
"(0.165607,0) Z0 Z3 + (0.1202,0) Z0 Z2 + (-0.0454063,0) Y0 Y1 X2 X3 + "
"(-0.220041,0) Z3 + (-0.106477,0) + (0.17028,0) Z0 + (-0.220041,0) Z2 "
"+ (0.17028,0) Z1 + (-0.0454063,0) X0 X1 Y2 Y3 + (0.0454063,0) X0 Y1 "
"Y2 X3 + (0.168336,0) Z0 Z1 + (0.0454063,0) Y0 X1 X2 Y3");
VQE vqe;
vqe.initialize(ruccsd, acc, buffer);
vqe.execute(*observable.get(), *optimizer.get());
EXPECT_NEAR(-1.13717, mpark::get<double>(buffer->getInformation("opt-val")), 1e-4);
}
}
int main(int argc, char **argv) {
qcor::Initialize(argc, argv);
::testing::InitGoogleTest(&argc, argv);
auto ret = RUN_ALL_TESTS();
return ret;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment