Commit 01d477e5 authored by Mccaskey, Alex's avatar Mccaskey, Alex
Browse files

Starting work on parameterized kernels. Cleaned up json persistence with...

Starting work on parameterized kernels. Cleaned up json persistence with JsonVisitor. Cleaned up Program.hpp
parent 63fc5b5e
......@@ -41,8 +41,29 @@ using namespace clang;
namespace scaffold {
class KernelParameter {
public:
std::string type;
std::string varName;
};
class ScaffoldASTConsumer: public clang::ASTConsumer,
public clang::RecursiveASTVisitor<ScaffoldASTConsumer> {
protected:
std::string cbitVarName;
std::string qbitVarName;
std::shared_ptr<xacc::quantum::GateFunction> function;
std::shared_ptr<xacc::quantum::ConditionalFunction> currentConditional;
int nCallExprToSkip = 0;
std::map<std::string, int> cbitRegToMeasuredQubit;
std::vector<KernelParameter> parameters;
public:
// Override the method that gets called for each parsed top-level
......@@ -63,7 +84,7 @@ public:
std::string ifStr;
llvm::raw_string_ostream ifS(ifStr);
ifStmt->printPretty(ifS, nullptr, policy);
// std::cout << "HELLO IF:\n" << ifS.str() << "\n";
// std::cout << "HELLO IF:\n" << ifS.str() << "\n";
if (const auto binOp = llvm::dyn_cast<BinaryOperator>(
ifStmt->getCond())) {
......@@ -73,21 +94,21 @@ public:
std::string str;
llvm::raw_string_ostream s(str);
LHS->printPretty(s, nullptr, policy);
// std::cout << "LHS IF: " << s.str() << "\n";
// std::cout << "LHS IF: " << s.str() << "\n";
if (boost::contains(s.str(), cbitVarName)) {
auto RHS = binOp->getRHS();
std::string rhsstr;
llvm::raw_string_ostream rhss(rhsstr);
RHS->printPretty(rhss, nullptr, policy);
// std::cout << "RHS IF: " << rhss.str() << "\n";
// std::cout << "RHS IF: " << rhss.str() << "\n";
auto thenCode = ifStmt->getThen();
std::string thenStr;
llvm::raw_string_ostream thenS(thenStr);
thenCode->printPretty(thenS, nullptr, policy);
auto then = thenS.str();
// std::cout << "ThenStmt:\n" << then << "\n";
// std::cout << "ThenStmt:\n" << then << "\n";
then.erase(std::remove(then.begin(), then.end(), '\t'),
then.end());
boost::replace_all(then, "{\n", "");
......@@ -103,7 +124,7 @@ public:
std::vector<std::string> vec;
boost::split(vec, then, boost::is_any_of("\n"));
nCallExprToSkip = vec.size();
// std::cout << "NCALLEXPRTOSKIP = " << nCallExprToSkip << "\n";
// std::cout << "NCALLEXPRTOSKIP = " << nCallExprToSkip << "\n";
}
}
}
......@@ -116,16 +137,29 @@ public:
auto varType = varDecl->getType().getAsString();
if (boost::contains(varType, "cbit")) {
cbitVarName = varDecl->getDeclName().getAsString();
// std::cout << "Found " << cbitVarName << "\n";
// std::cout << "Found " << cbitVarName << "\n";
} else if (boost::contains(varType, "qbit")) {
qbitVarName = varDecl->getDeclName().getAsString();
// std::cout << "Found " << qbitVarName << "\n";
// std::cout << "Found " << qbitVarName << "\n";
}
} else if (isa<FunctionDecl>(d)) {
auto c = cast<FunctionDecl>(d);
function =
std::make_shared<xacc::quantum::GateFunction>(
function = std::make_shared<xacc::quantum::GateFunction>(
c->getDeclName().getAsString());
clang::LangOptions lo;
clang::PrintingPolicy policy(lo);
std::string arg;
llvm::raw_string_ostream argstream(arg);
clang::FunctionDecl::param_iterator pBegin = c->param_begin();
clang::FunctionDecl::param_iterator pEnd = c->param_end();
std::cout << "TRYING FUNC " << c->getDeclName().getAsString() << " params " << c->param_size() << "\n";
for (auto i = pBegin; i != pEnd; ++i) {
KernelParameter p;
p.type = (*i)->getType().getAsString();
p.varName = (*i)->getNameAsString();
std::cout << "HELLO World: " << (*i)->getType().getAsString() << " : " << (*i)->getNameAsString() << "\n";
parameters.push_back(p);
}
}
return true;
}
......@@ -146,7 +180,7 @@ public:
llvm::raw_string_ostream argstream(arg);
i->printPretty(argstream, nullptr, policy);
auto argStr = argstream.str();
// std::cout << "Arg: " << argstream.str() << "\n";
std::cout << "Arg: " << argstream.str() << "\n";
if (boost::contains(argStr, qbitVarName)) {
boost::replace_all(argStr, qbitVarName, "");
......@@ -156,10 +190,22 @@ public:
} else {
// This is a gate parameter!!!
isParameterizedInst = true;
params.push_back(std::stod(argStr));
// This parameter could just be a hard-coded value
// or it could be a reference to a variable parameter...
try {
double d = boost::lexical_cast<double>(argStr);
params.push_back(d);
} catch (const boost::bad_lexical_cast &) {
std::cout << "SETTING 0\n";
params.push_back(0.0);
}
// params.push_back(std::stod(argStr));
}
}
std::cout << "WE HAVE " << qubits.size() << " qubit(s) and " << params.size() << " param(s).\n";
std::shared_ptr<xacc::quantum::GateInstruction> inst;
if (isParameterizedInst) {
if (params.size() == 1) {
......@@ -168,18 +214,18 @@ public:
params[0]);
} else if (params.size() == 2) {
inst = xacc::quantum::ParameterizedGateInstructionRegistry<
double, double>::instance()->create(gateName, qubits,
params[0], params[1]);
double, double>::instance()->create(gateName,
qubits, params[0], params[1]);
} else {
XACCError(
"Can only handle 1 and 2 parameter gates... and only doubles... for now.");
}
// std::cout << "CREATED A " << gateName << " parameterized gate\n";
} else if (gateName != "MeasZ") {
inst = xacc::quantum::GateInstructionRegistry::instance()->create(
gateName, qubits);
inst =
xacc::quantum::GateInstructionRegistry::instance()->create(
gateName, qubits);
}
if (gateName != "MeasZ") {
......@@ -187,7 +233,7 @@ public:
if (nCallExprToSkip == 0) {
function->addInstruction(inst);
} else {
// std::cout << "Adding Conditional Inst: " << gateName << "\n";
// std::cout << "Adding Conditional Inst: " << gateName << "\n";
currentConditional->addInstruction(inst);
nCallExprToSkip--;
......@@ -200,7 +246,7 @@ public:
return true;
}
bool VisitBinaryOperator(BinaryOperator *b) {
bool VisitBinaryOperator(BinaryOperator *b) {
clang::LangOptions lo;
clang::PrintingPolicy policy(lo);
......@@ -218,7 +264,7 @@ public:
llvm::raw_string_ostream lhss(lhsstr);
lhs->printPretty(lhss, nullptr, policy);
auto lhsString = lhss.str();
// std::cout << "HELLO BINOP LHS: " << lhsString << "\n";
// std::cout << "HELLO BINOP LHS: " << lhsString << "\n";
boost::replace_all(lhsString, cbitVarName, "");
boost::replace_all(lhsString, "[", "");
......@@ -232,16 +278,14 @@ public:
boost::replace_all(rhsString, "]", "");
// lhsString now just contains the classical index bit
auto inst =
xacc::quantum::ParameterizedGateInstructionRegistry<int>::instance()->create(
"Measure",
std::vector<int> { std::stoi(rhsString) },
std::stoi(lhsString));
auto inst = xacc::quantum::ParameterizedGateInstructionRegistry<
int>::instance()->create("Measure", std::vector<int> {
std::stoi(rhsString) }, std::stoi(lhsString));
cbitRegToMeasuredQubit.insert(
std::make_pair(lhss.str(), std::stoi(rhsString)));
// std::cout << "ADDING A MEASUREMENT GATE " << lhss.str() << "\n";
// std::cout << "ADDING A MEASUREMENT GATE " << lhss.str() << "\n";
function->addInstruction(inst);
}
......@@ -260,18 +304,6 @@ public:
virtual ~ScaffoldASTConsumer() {
}
protected:
std::string cbitVarName;
std::string qbitVarName;
std::shared_ptr<xacc::quantum::GateFunction> function;
std::shared_ptr<xacc::quantum::ConditionalFunction> currentConditional;
int nCallExprToSkip = 0;
std::map<std::string, int> cbitRegToMeasuredQubit;
};
}
......
......@@ -121,6 +121,25 @@ BOOST_AUTO_TEST_CASE(checkWithMeasurementIf) {
auto k = gateqir->getKernel("teleport");
}
BOOST_AUTO_TEST_CASE(checkWithParameter) {
const std::string src("module gateWithParam (qbit qreg[3], double phi) {\n"
" // Init qubit 0 to 1\n"
" X(qreg[0]);\n"
" // Now teleport...\n"
" H(qreg[1]);\n"
" Rz(qreg[2], phi);\n"
"}\n");
auto qir = compiler->compile(src);
auto gateqir = std::dynamic_pointer_cast<GateQIR>(qir);
auto f = gateqir->getKernel("gateWithParam");
BOOST_VERIFY(f->nInstructions() == 3);
gateqir->persist(std::cout);
}
/*
BOOST_AUTO_TEST_CASE(checkMultipleFunction) {
const std::string src(
......
......@@ -31,10 +31,7 @@
#include "GateQIR.hpp"
#include <boost/algorithm/string.hpp>
#include <regex>
#include "rapidjson/prettywriter.h"
using namespace rapidjson;
#include "JsonVisitor.hpp"
namespace xacc {
namespace quantum {
......@@ -268,12 +265,16 @@ std::string GateQIR::toAssemblyString(const std::string& kernelName, const std::
}
void GateQIR::persist(std::ostream& outStream) {
StringBuffer sb;
PrettyWriter<StringBuffer> writer(sb);
serializeJson(writer);
JsonVisitor visitor(kernels[0]);
outStream << visitor.write();
outStream << sb.GetString();
// StringBuffer sb;
// PrettyWriter<StringBuffer> writer(sb);
//
// serializeJson(writer);
//
// outStream << sb.GetString();
return;
}
......
......@@ -45,8 +45,6 @@
#include "Rz.hpp"
#include "Measure.hpp"
#define RAPIDJSON_HAS_STDSTRING 1
namespace xacc {
namespace quantum {
......@@ -76,111 +74,6 @@ public:
}
};
template<typename Writer>
class JsonSerializerGateVisitor:
public BaseInstructionVisitor,
public InstructionVisitor<GateFunction>,
public InstructionVisitor<Hadamard>,
public InstructionVisitor<CNOT>,
public InstructionVisitor<Rz>,
public InstructionVisitor<ConditionalFunction>,
public InstructionVisitor<X>,
public InstructionVisitor<Z>,
public InstructionVisitor<Measure> {
protected:
Writer& writer;
std::string currentFuncName;
std::string previousFuncName;
std::map<std::string, int> subInstMap;
public:
JsonSerializerGateVisitor(Writer& w) : writer(w) {}
void baseGateInst(GateInstruction& inst, bool endObject = true) {
writer.StartObject();
writer.String("gate");
writer.String(inst.getName().c_str());
writer.String("enabled");
writer.Bool(inst.isEnabled());
writer.String("qubits");
writer.StartArray();
for (auto qi : inst.bits()) {
writer.Int(qi);
}
writer.EndArray();
if (endObject) {
writer.EndObject();
}
subInstMap[currentFuncName]--;
if (subInstMap[currentFuncName] == 0) {
endFunction();
currentFuncName = previousFuncName;
}
}
void visit(Hadamard& h) {
baseGateInst(dynamic_cast<GateInstruction&>(h));
}
void visit(CNOT& cn) {
baseGateInst(dynamic_cast<GateInstruction&>(cn));
}
void visit(Rz& rz) {
baseGateInst(dynamic_cast<GateInstruction&>(rz), false);
writer.String("angle");
writer.Double(rz.getParameter(0));
writer.EndObject();
}
void visit(ConditionalFunction& cn) {
writer.StartObject();
writer.String("conditional_function");
writer.String(cn.getName());
writer.String("conditional_qubit");
writer.Int(cn.getConditionalQubit());
writer.String("instructions");
writer.StartArray();
subInstMap.insert(std::make_pair(cn.getName(), cn.nInstructions()));
previousFuncName = currentFuncName;
currentFuncName = cn.getName();
}
void visit(Measure& cn) {
baseGateInst(dynamic_cast<GateInstruction&>(cn), false);
writer.String("classicalBitIdx");
writer.Int(cn.getParameter(0));
writer.EndObject();
}
void visit(X& cn) {
baseGateInst(dynamic_cast<GateInstruction&>(cn));
}
void visit(Z& cn) {
baseGateInst(dynamic_cast<GateInstruction&>(cn));
}
void visit(GateFunction& function) {
writer.StartObject();
writer.String("function");
writer.String(function.getName());
writer.String("instructions");
writer.StartArray();
subInstMap.insert(std::make_pair(function.getName(), function.nInstructions()));
currentFuncName = function.getName();
}
private:
void endFunction() {
writer.EndArray();
writer.EndObject();
}
};
/**
* The GateQIR is an implementation of the QIR for gate model quantum
* computing. It provides a Graph node type that models a quantum
......@@ -273,27 +166,6 @@ protected:
private:
template<typename Writer>
void serializeJson(Writer& writer) {
std::string retStr = "";
auto visitor = std::make_shared<JsonSerializerGateVisitor<Writer>>(
writer);
writer.StartArray();
for (auto kernel : kernels) {
InstructionIterator it(kernel);
while (it.hasNext()) {
// Get the next node in the tree
auto nextInst = it.next();
nextInst->accept(visitor);
}
writer.EndArray();
writer.EndObject();
}
writer.EndArray();
}
/**
* This method determines if a new layer should be added to the circuit.
*
......
/***********************************************************************************
* Copyright (c) 2017, UT-Battelle
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the xacc nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL <COPYRIGHT HOLDER> BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
* Contributors:
* Initial API and implementation - Alex McCaskey
*
**********************************************************************************/
#ifndef QUANTUM_GATE_ALLGATEVISITOR_HPP_
#define QUANTUM_GATE_ALLGATEVISITOR_HPP_
#include "InstructionIterator.hpp"
#include "Hadamard.hpp"
#include "CNOT.hpp"
#include "X.hpp"
#include "Z.hpp"
#include "ConditionalFunction.hpp"
#include "Rz.hpp"
#include "Measure.hpp"
namespace xacc {
namespace quantum {
/**
* FIXME write this
*/
class AllGateVisitor:
public BaseInstructionVisitor,
public InstructionVisitor<GateFunction>,
public InstructionVisitor<Hadamard>,
public InstructionVisitor<CNOT>,
public InstructionVisitor<Rz>,
public InstructionVisitor<ConditionalFunction>,
public InstructionVisitor<X>,
public InstructionVisitor<Z>,
public InstructionVisitor<Measure> {
};
}
}
#endif
......@@ -35,7 +35,7 @@ file (GLOB HEADERS *.hpp)
# Gather tests
file (GLOB test_files tests/*.cpp)
add_tests("${test_files}" "${CMAKE_CURRENT_SOURCE_DIR}" "${Boost_LIBRARIES}")
add_tests("${test_files}" "${CMAKE_CURRENT_SOURCE_DIR}" "${Boost_LIBRARIES};xacc-gate-ir")
install(FILES ${HEADERS} DESTINATION include/quantum/gate)
/***********************************************************************************
* Copyright (c) 2017, UT-Battelle
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the xacc nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL <COPYRIGHT HOLDER> BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
* Contributors:
* Initial API and implementation - Alex McCaskey
*
**********************************************************************************/
#ifndef QUANTUM_GATE_JSONVISITOR_HPP_
#define QUANTUM_GATE_JSONVISITOR_HPP_
#include <memory>
#include "AllGateVisitor.hpp"
#define RAPIDJSON_HAS_STDSTRING 1
#include "rapidjson/prettywriter.h"
using namespace rapidjson;
namespace xacc {
namespace quantum {
using Writer = PrettyWriter<StringBuffer>;
/**
* FIXME write this
*/
class JsonVisitor: public AllGateVisitor {
protected:
std::shared_ptr<StringBuffer> buffer;
std::shared_ptr<Writer> writer;
std::shared_ptr<Function> function;
std::shared_ptr<InstructionIterator> topLevelInstructionIterator;
public:
JsonVisitor(std::shared_ptr<xacc::Function> f) : buffer(std::make_shared<StringBuffer>()),
writer(std::make_shared<PrettyWriter<StringBuffer>>(*buffer.get())), function(f) {
}
std::string write() {
// This is a Function, start it as an Object
writer->StartObject();
writer->String("function");
writer->String(function->getName());
// All functions have instructions, start
// that array here.
writer->String("instructions");
writer->StartArray();
topLevelInstructionIterator = std::make_shared<xacc::InstructionIterator>(function);
while (topLevelInstructionIterator->hasNext()) {
// Get the next node in the tree
auto nextInst = topLevelInstructionIterator->next();
nextInst->accept(this);
}
// End Instructions
writer->EndArray();
// End Function
writer->EndObject();
return buffer->GetString();
}
void visit(Hadamard& h) {
baseGateInst(dynamic_cast<GateInstruction&>(h));
}
void visit(CNOT& cn) {
baseGateInst(dynamic_cast<GateInstruction&>(cn));
}
void visit(Rz& rz) {
baseGateInst(dynamic_cast<GateInstruction&>(rz), false);
writer->String("angle");
writer->Double(rz.getParameter(0));
writer->EndObject();
}
void visit(ConditionalFunction& cn) {
writer->StartObject();
writer->String("conditional_function");
writer->Strin