Commit 14549108 authored by Mccaskey, Alex's avatar Mccaskey, Alex
Browse files

Adding a new more flexible visitor pattern for instructions, setup firetensoracc to use it

parent 21ef2a44
......@@ -40,6 +40,8 @@ include_directories(${CMAKE_SOURCE_DIR}/quantum/gate/gateqir/functions)
link_directories(${CLANG_LIBRARY_DIRS})
#set_source_files_properties(teleport_scaffold.cpp PROPERTIES COMPILE_FLAGS -fno-rtti)
add_executable(teleport_scaffold teleport_scaffold.cpp)
target_link_libraries(teleport_scaffold xacc-scaffold ${CLANG_LIBS} ${LLVM_LIBS} ${Boost_LIBRARIES} xacc-firetensor xacc-gateqir dl pthread)
......@@ -29,7 +29,6 @@
*
**********************************************************************************/
#include "FireTensorAccelerator.hpp"
#include <boost/variant.hpp>
namespace xacc {
namespace quantum {
......@@ -223,9 +222,11 @@ void FireTensorAccelerator::execute(std::shared_ptr<AcceleratorBuffer> buffer,
// Create a Visitor that will execute our lambdas when
// we encounter one
auto visitor = std::make_shared<GateInstructionVisitor>(hadamard, cnot, x,
measure, z, cond);
// auto visitor = std::make_shared<GateInstructionVisitor>(hadamard, cnot, x,
// measure, z, cond);
auto visitor = std::make_shared<FunctionalGateInstructionVisitor>(hadamard,
cnot, x, measure, z, cond);
// Our QIR is really a tree structure
// so create a pre-order tree traversal
// InstructionIterator to walk it
......@@ -244,7 +245,7 @@ void FireTensorAccelerator::execute(std::shared_ptr<AcceleratorBuffer> buffer,
}
// Register the ScaffoldCompiler with the CompilerRegistry.
static xacc::RegisterAccelerator<xacc::quantum::FireTensorAccelerator> X(
static xacc::RegisterAccelerator<xacc::quantum::FireTensorAccelerator> FIRETEMP(
"firetensor");
}
......
......@@ -31,12 +31,21 @@
#ifndef QUANTUM_GATE_ACCELERATORS_EIGENACCELERATOR_HPP_
#define QUANTUM_GATE_ACCELERATORS_EIGENACCELERATOR_HPP_
#include "Hadamard.hpp"
#include "Measure.hpp"
#include "CNOT.hpp"
#include "Rz.hpp"
#include "Z.hpp"
#include "X.hpp"
#include "ConditionalFunction.hpp"
#include "Z.hpp"
#include "QPUGate.hpp"
#include "QuantumCircuit.hpp"
#include "SimulatedQubits.hpp"
#include <random>
#include "InstructionIterator.hpp"
#include "GateInstructionVisitor.hpp"
using namespace xacc;
namespace xacc {
namespace quantum {
......@@ -45,6 +54,48 @@ double sqrt2 = std::sqrt(2.0);
using ProductList = std::vector<fire::Tensor<2, fire::EigenProvider, std::complex<double>>>;
using ComplexTensor = fire::Tensor<2, fire::EigenProvider, std::complex<double>>;
class FunctionalGateInstructionVisitor: public BaseInstructionVisitor,
public InstructionVisitor<CNOT>,
public InstructionVisitor<Hadamard>,
public InstructionVisitor<X>,
public InstructionVisitor<Z>,
public InstructionVisitor<Measure>,
public InstructionVisitor<ConditionalFunction> {
protected:
std::function<void(Hadamard&)> hAction;
std::function<void(CNOT&)> cnotAction;
std::function<void(X&)> xAction;
std::function<void(Z&)> zAction;
std::function<void(Measure&)> measureAction;
std::function<void(ConditionalFunction&)> condAction;
public:
template<typename HF, typename CNF, typename XF, typename MF, typename ZF, typename CF>
FunctionalGateInstructionVisitor(HF h, CNF cn, XF x, MF m, ZF z, CF c) :
hAction(h), cnotAction(cn), xAction(x), zAction(z), measureAction(m), condAction(c) {
}
void visit(Hadamard& h) {
hAction(h);
}
void visit(CNOT& cn) {
cnotAction(cn);
}
void visit(X& x) {
xAction(x);
}
void visit(Z& z) {
zAction(z);
}
void visit(Measure& m) {
measureAction(m);
}
void visit(ConditionalFunction& c) {
condAction(c);
}
virtual ~FunctionalGateInstructionVisitor() {}
};
/**
* The FireTensorAccelerator is an XACC Accelerator that simulates
* gate based quantum computing circuits. It models the QPUGate Accelerator
......
......@@ -37,7 +37,7 @@ file (GLOB SRC *.cpp)
set(CMAKE_CXX_FLAGS "-D_DEBUG -D_GNU_SOURCE -D__STDC_CONSTANT_MACROS -D__STDC_FORMAT_MACROS -D__STDC_LIMIT_MACROS -O3 -w -fomit-frame-pointer -fvisibility-inlines-hidden -fPIC -Woverloaded-virtual -Wcast-qual")
set_source_files_properties(ScaffoldASTConsumer.cpp PROPERTIES COMPILE_FLAGS -fno-rtti)
#set_source_files_properties(ScaffoldASTConsumer.cpp PROPERTIES COMPILE_FLAGS -fno-rtti)
include_directories(${CLANG_INCLUDE_DIRS})
include_directories(${CLANG_INCLUDE_DIRS}/extra-tools)
......@@ -53,4 +53,5 @@ link_directories(${CLANG_LIBRARY_DIRS})
# Gather tests
file (GLOB test_files tests/*.cpp)
add_tests("${test_files}" "${CMAKE_CURRENT_SOURCE_DIR}" "${LIBRARY_NAME};xacc-gateqir;${CLANG_LIBS};${LLVM_LIBS};dl;pthread")
#add_tests_with_flags("${test_files}" "${CMAKE_CURRENT_SOURCE_DIR}" "${LIBRARY_NAME};xacc-gateqir;${CLANG_LIBS};${LLVM_LIBS};dl;pthread" "-fno-rtti")
......@@ -35,7 +35,7 @@
#include "Utils.hpp"
#include <boost/algorithm/string.hpp>
#include "../scaffold/ScaffoldASTConsumer.hpp"
#include "ScaffoldASTConsumer.hpp"
#include "Accelerator.hpp"
......
#include "../scaffold/ScaffoldASTConsumer.hpp"
#include <iostream>
#include <boost/algorithm/string.hpp>
#include "ParameterizedGateInstruction.hpp"
#include "Utils.hpp"
using namespace xacc;
namespace scaffold {
bool ScaffoldASTConsumer::HandleTopLevelDecl(DeclGroupRef DR) {
for (DeclGroupRef::iterator b = DR.begin(), e = DR.end(); b != e; ++b) {
// Traverse the declaration using our AST visitor.
TraverseDecl(*b);
}
return true;
}
bool ScaffoldASTConsumer::VisitDecl(Decl *d) {
if (isa<VarDecl>(d)) {
auto varDecl = cast<VarDecl>(d);
auto varType = varDecl->getType().getAsString();
if (boost::contains(varType, "cbit")) {
cbitVarName = varDecl->getDeclName().getAsString();
// std::cout << "Found " << cbitVarName << "\n";
} else if (boost::contains(varType, "qbit")) {
qbitVarName = varDecl->getDeclName().getAsString();
// std::cout << "Found " << qbitVarName << "\n";
}
} else if (isa<FunctionDecl>(d)) {
auto c = cast<FunctionDecl>(d);
function = std::make_shared<xacc::quantum::GateFunction>(
c->getDeclName().getAsString());
}
return true;
}
bool ScaffoldASTConsumer::VisitStmt(Stmt *s) {
if (isa<IfStmt>(s)) {
auto ifStmt = cast<IfStmt>(s);
clang::LangOptions lo;
clang::PrintingPolicy policy(lo);
std::string ifStr;
llvm::raw_string_ostream ifS(ifStr);
ifStmt->printPretty(ifS, nullptr, policy);
// std::cout << "HELLO IF:\n" << ifS.str() << "\n";
if (const auto binOp = llvm::dyn_cast<BinaryOperator>(
ifStmt->getCond())) {
if (binOp->getOpcode() == BO_EQ) {
// We have an equality check...
auto LHS = binOp->getLHS();
std::string str;
llvm::raw_string_ostream s(str);
LHS->printPretty(s, nullptr, policy);
// std::cout << "LHS IF: " << s.str() << "\n";
if (boost::contains(s.str(), cbitVarName)) {
auto RHS = binOp->getRHS();
std::string rhsstr;
llvm::raw_string_ostream rhss(rhsstr);
RHS->printPretty(rhss, nullptr, policy);
// std::cout << "RHS IF: " << rhss.str() << "\n";
auto thenCode = ifStmt->getThen();
std::string thenStr;
llvm::raw_string_ostream thenS(thenStr);
thenCode->printPretty(thenS, nullptr, policy);
auto then = thenS.str();
// std::cout << "ThenStmt:\n" << then << "\n";
then.erase(std::remove(then.begin(), then.end(), '\t'),
then.end());
boost::replace_all(then, "{\n", "");
boost::replace_all(then, "}\n", "");
boost::replace_all(then, " ", "");
boost::trim(then);
int conditionalQubit = cbitRegToMeasuredQubit[s.str()];
currentConditional = std::make_shared<
xacc::quantum::ConditionalFunction>(
conditionalQubit);
std::vector<std::string> vec;
boost::split(vec, then, boost::is_any_of("\n"));
nCallExprToSkip = vec.size();
// std::cout << "NCALLEXPRTOSKIP = " << nCallExprToSkip << "\n";
}
}
}
}
return true;
}
bool ScaffoldASTConsumer::VisitCallExpr(CallExpr* c) {
clang::LangOptions lo;
clang::PrintingPolicy policy(lo);
auto q = c->getType();
auto t = q.getTypePtrOrNull();
if (t != NULL) {
bool isParameterizedInst = false;
auto fd = c->getDirectCallee();
auto gateName = fd->getNameInfo().getAsString();
std::vector<int> qubits;
std::vector<double> params;
for (auto i = c->arg_begin(); i != c->arg_end(); ++i) {
std::string arg;
llvm::raw_string_ostream argstream(arg);
i->printPretty(argstream, nullptr, policy);
auto argStr = argstream.str();
// std::cout << "Arg: " << argstream.str() << "\n";
if (boost::contains(argStr, qbitVarName)) {
boost::replace_all(argStr, qbitVarName, "");
boost::replace_all(argStr, "[", "");
boost::replace_all(argStr, "]", "");
qubits.push_back(std::stoi(argStr));
} else {
// This is a gate parameter!!!
isParameterizedInst = true;
params.push_back(std::stod(argStr));
}
}
std::shared_ptr<xacc::quantum::GateInstruction> inst;
if (isParameterizedInst) {
if (params.size() == 1) {
inst = xacc::quantum::ParameterizedGateInstructionRegistry<
double>::instance()->create(gateName, qubits,
params[0]);
} else if (params.size() == 2) {
inst = xacc::quantum::ParameterizedGateInstructionRegistry<
double, double>::instance()->create(gateName, qubits,
params[0], params[1]);
} else {
XACCError(
"Can only handle 1 and 2 parameter gates... and only doubles... for now.");
}
// std::cout << "CREATED A " << gateName << " parameterized gate\n";
} else if (gateName != "MeasZ") {
inst = xacc::quantum::GateInstructionRegistry::instance()->create(
gateName, qubits);
}
if (gateName != "MeasZ") {
if (nCallExprToSkip == 0) {
function->addInstruction(inst);
} else {
// std::cout << "Adding Conditional Inst: " << gateName << "\n";
currentConditional->addInstruction(inst);
nCallExprToSkip--;
if (nCallExprToSkip == 0) {
function->addInstruction(currentConditional);
}
}
}
}
return true;
}
bool ScaffoldASTConsumer::VisitBinaryOperator(BinaryOperator * b) {
clang::LangOptions lo;
clang::PrintingPolicy policy(lo);
if (b->isAssignmentOp()) {
auto rhs = b->getRHS();
std::string rhsstr;
llvm::raw_string_ostream rhss(rhsstr);
rhs->printPretty(rhss, nullptr, policy);
auto rhsString = rhss.str();
if (boost::contains(rhsString, "MeasZ")) {
auto lhs = b->getLHS();
std::string lhsstr;
llvm::raw_string_ostream lhss(lhsstr);
lhs->printPretty(lhss, nullptr, policy);
auto lhsString = lhss.str();
// std::cout << "HELLO BINOP LHS: " << lhsString << "\n";
boost::replace_all(lhsString, cbitVarName, "");
boost::replace_all(lhsString, "[", "");
boost::replace_all(lhsString, "]", "");
boost::replace_all(rhsString, "MeasZ", "");
boost::replace_all(rhsString, "(", "");
boost::replace_all(rhsString, ")", "");
boost::replace_all(rhsString, qbitVarName, "");
boost::replace_all(rhsString, "[", "");
boost::replace_all(rhsString, "]", "");
// lhsString now just contains the classical index bit
auto inst =
xacc::quantum::ParameterizedGateInstructionRegistry<int>::instance()->create(
"Measure",
std::vector<int> { std::stoi(rhsString) },
std::stoi(lhsString));
cbitRegToMeasuredQubit.insert(
std::make_pair(lhss.str(), std::stoi(rhsString)));
// std::cout << "ADDING A MEASUREMENT GATE " << lhss.str() << "\n";
function->addInstruction(inst);
}
}
return true;
}
}
......@@ -33,8 +33,8 @@
#include "clang/Frontend/DiagnosticOptions.h"
#include "clang/Frontend/TextDiagnosticPrinter.h"
//#include "GateQIR.hpp"
//
#include "GateInstruction.hpp"
#include "ParameterizedGateInstruction.hpp"
#include "GateFunction.hpp"
#include "ConditionalFunction.hpp"
......@@ -48,12 +48,208 @@ public:
// Override the method that gets called for each parsed top-level
// declaration.
virtual bool HandleTopLevelDecl(DeclGroupRef DR);
virtual bool HandleTopLevelDecl(DeclGroupRef DR) {
for (DeclGroupRef::iterator b = DR.begin(), e = DR.end(); b != e; ++b) {
// Traverse the declaration using our AST visitor.
TraverseDecl(*b);
}
return true;
}
bool VisitStmt(clang::Stmt *s) {
if (isa<IfStmt>(s)) {
auto ifStmt = cast<IfStmt>(s);
clang::LangOptions lo;
clang::PrintingPolicy policy(lo);
std::string ifStr;
llvm::raw_string_ostream ifS(ifStr);
ifStmt->printPretty(ifS, nullptr, policy);
// std::cout << "HELLO IF:\n" << ifS.str() << "\n";
if (const auto binOp = llvm::dyn_cast<BinaryOperator>(
ifStmt->getCond())) {
if (binOp->getOpcode() == BO_EQ) {
// We have an equality check...
auto LHS = binOp->getLHS();
std::string str;
llvm::raw_string_ostream s(str);
LHS->printPretty(s, nullptr, policy);
// std::cout << "LHS IF: " << s.str() << "\n";
if (boost::contains(s.str(), cbitVarName)) {
auto RHS = binOp->getRHS();
std::string rhsstr;
llvm::raw_string_ostream rhss(rhsstr);
RHS->printPretty(rhss, nullptr, policy);
// std::cout << "RHS IF: " << rhss.str() << "\n";
auto thenCode = ifStmt->getThen();
std::string thenStr;
llvm::raw_string_ostream thenS(thenStr);
thenCode->printPretty(thenS, nullptr, policy);
auto then = thenS.str();
// std::cout << "ThenStmt:\n" << then << "\n";
then.erase(std::remove(then.begin(), then.end(), '\t'),
then.end());
boost::replace_all(then, "{\n", "");
boost::replace_all(then, "}\n", "");
boost::replace_all(then, " ", "");
boost::trim(then);
int conditionalQubit = cbitRegToMeasuredQubit[s.str()];
currentConditional = std::make_shared<
xacc::quantum::ConditionalFunction>(
conditionalQubit);
std::vector<std::string> vec;
boost::split(vec, then, boost::is_any_of("\n"));
nCallExprToSkip = vec.size();
// std::cout << "NCALLEXPRTOSKIP = " << nCallExprToSkip << "\n";
}
}
}
}
return true;
}
bool VisitDecl(clang::Decl *d) {
if (isa<VarDecl>(d)) {
auto varDecl = cast<VarDecl>(d);
auto varType = varDecl->getType().getAsString();
if (boost::contains(varType, "cbit")) {
cbitVarName = varDecl->getDeclName().getAsString();
// std::cout << "Found " << cbitVarName << "\n";
} else if (boost::contains(varType, "qbit")) {
qbitVarName = varDecl->getDeclName().getAsString();
// std::cout << "Found " << qbitVarName << "\n";
}
} else if (isa<FunctionDecl>(d)) {
auto c = cast<FunctionDecl>(d);
function =
std::make_shared<xacc::quantum::GateFunction>(
c->getDeclName().getAsString());
}
return true;
}
bool VisitCallExpr(CallExpr * c) {
clang::LangOptions lo;
clang::PrintingPolicy policy(lo);
auto q = c->getType();
auto t = q.getTypePtrOrNull();
if (t != NULL) {
bool isParameterizedInst = false;
auto fd = c->getDirectCallee();
auto gateName = fd->getNameInfo().getAsString();
std::vector<int> qubits;
std::vector<double> params;
for (auto i = c->arg_begin(); i != c->arg_end(); ++i) {
std::string arg;
llvm::raw_string_ostream argstream(arg);
i->printPretty(argstream, nullptr, policy);
auto argStr = argstream.str();
// std::cout << "Arg: " << argstream.str() << "\n";
if (boost::contains(argStr, qbitVarName)) {
boost::replace_all(argStr, qbitVarName, "");
boost::replace_all(argStr, "[", "");
boost::replace_all(argStr, "]", "");
qubits.push_back(std::stoi(argStr));
} else {
// This is a gate parameter!!!
isParameterizedInst = true;
params.push_back(std::stod(argStr));
}
}
std::shared_ptr<xacc::quantum::GateInstruction> inst;
if (isParameterizedInst) {
if (params.size() == 1) {
inst = xacc::quantum::ParameterizedGateInstructionRegistry<
double>::instance()->create(gateName, qubits,
params[0]);
} else if (params.size() == 2) {
inst = xacc::quantum::ParameterizedGateInstructionRegistry<
double, double>::instance()->create(gateName, qubits,
params[0], params[1]);
} else {
XACCError(
"Can only handle 1 and 2 parameter gates... and only doubles... for now.");
}
// std::cout << "CREATED A " << gateName << " parameterized gate\n";
} else if (gateName != "MeasZ") {
inst = xacc::quantum::GateInstructionRegistry::instance()->create(
gateName, qubits);
}
if (gateName != "MeasZ") {
bool VisitStmt(clang::Stmt *s);
bool VisitDecl(clang::Decl *d);
bool VisitCallExpr(CallExpr * c);
bool VisitBinaryOperator(BinaryOperator *b);
if (nCallExprToSkip == 0) {
function->addInstruction(inst);
} else {
// std::cout << "Adding Conditional Inst: " << gateName << "\n";
currentConditional->addInstruction(inst);
nCallExprToSkip--;
if (nCallExprToSkip == 0) {
function->addInstruction(currentConditional);
}
}
}
}
return true;
}
bool VisitBinaryOperator(BinaryOperator *b) {
clang::LangOptions lo;
clang::PrintingPolicy policy(lo);
if (b->isAssignmentOp()) {
auto rhs = b->getRHS();
std::string rhsstr;
llvm::raw_string_ostream rhss(rhsstr);
rhs->printPretty(rhss, nullptr, policy);
auto rhsString = rhss.str();
if (boost::contains(rhsString, "MeasZ")) {
auto lhs = b->getLHS();
std::string lhsstr;
llvm::raw_string_ostream lhss(lhsstr);
lhs->printPretty(lhss, nullptr, policy);
auto lhsString = lhss.str();
// std::cout << "HELLO BINOP LHS: " << lhsString << "\n";
boost::replace_all(lhsString, cbitVarName, "");
boost::replace_all(lhsString, "[", "");
boost::replace_all(lhsString, "]", "");
boost::replace_all(rhsString, "MeasZ", "");
boost::replace_all(rhsString, "(", "");
boost::replace_all(rhsString, ")", "");
boost::replace_all(rhsString, qbitVarName, "");
boost::replace_all(rhsString, "[", "");
boost::replace_all(rhsString, "]", "");
// lhsString now just contains the classical index bit
auto inst =
xacc::quantum::ParameterizedGateInstructionRegistry<int>::instance()->create(
"Measure",
std::vector<int> { std::stoi(rhsString) },
std::stoi(lhsString));
cbitRegToMeasuredQubit.insert(
std::make_pair(lhss.str(), std::stoi(rhsString)));
// std::cout << "ADDING A MEASUREMENT GATE " << lhss.str() << "\n";
function->addInstruction(inst);
}
}
return true;
}
std::shared_ptr<xacc::Function> getFunction() {
return function;
}
......
#include "GateFunction.hpp"
#include "GateInstructionVisitor.hpp"