Commit 1cdad83f authored by Mccaskey, Alex's avatar Mccaskey, Alex
Browse files

Updating d-wave graph generation, adding analyzeResults to IRGenerator



Signed-off-by: Mccaskey, Alex's avatarAlex McCaskey <mccaskeyaj@ornl.gov>
parent 51733000
......@@ -11,7 +11,9 @@ if (NOT XACC_ROOT)
get_filename_component(XACC_ROOT "${CMAKE_CURRENT_LIST_FILE}" PATH)
endif()
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} ${XACC_ROOT}/share/xacc)
include(format)
if (NOT TARGET format)
include(format)
endif()
set (XACC_LIBRARY_DIR "${XACC_ROOT}/lib")
link_directories("${XACC_ROOT}/lib")
set(XACC_INCLUDE_ROOT "${XACC_ROOT}/include")
......
......@@ -184,7 +184,8 @@ PYBIND11_MODULE(_pyxacc, m) {
(std::shared_ptr<xacc::Function>(xacc::IRGenerator::*)(
std::map<std::string, xacc::InstructionParameter>)) &
xacc::IRGenerator::generate,
py::return_value_policy::reference, "");
py::return_value_policy::reference, "")
.def("analyzeResults", &xacc::IRGenerator::analyzeResults, "");
// Expose the Kernel
py::class_<xacc::Kernel<>, std::shared_ptr<xacc::Kernel<>>>(
......@@ -388,6 +389,7 @@ PYBIND11_MODULE(_pyxacc, m) {
xacc::getService<IRGenerator>,
py::return_value_policy::reference,
"Return the IRGenerator of given name.");
m.def("translate", &xacc::translate, "Translate the provided IR Function to the given language.");
m.def("setOption", [](const std::string s, InstructionParameter p) {
xacc::setOption(s, boost::lexical_cast<std::string>(p));
});
......
......@@ -18,6 +18,8 @@
#include "XACC.hpp"
#include "exprtk.hpp"
#include <boost/math/constants/constants.hpp>
#include "GraphProvider.hpp"
#include "DWGraph.hpp"
static constexpr double pi = boost::math::constants::pi<double>();
......@@ -28,11 +30,8 @@ using parser_t = exprtk::parser<double>;
namespace xacc {
namespace quantum {
/**
* The DWKernel is an XACC Function that contains
* DWQMI Instructions.
*/
class DWKernel : public virtual Function,
class DWKernel : public Function,
public GraphProvider<DWVertex>,
public std::enable_shared_from_this<DWKernel> {
protected:
......@@ -67,9 +66,9 @@ public:
return newF;
}
virtual const int nInstructions() { return instructions.size(); }
const int nInstructions() override { return instructions.size(); }
virtual InstPtr getInstruction(const int idx) {
InstPtr getInstruction(const int idx) override {
InstPtr i;
if (instructions.size() > idx) {
i = *std::next(instructions.begin(), idx);
......@@ -80,15 +79,17 @@ public:
return i;
}
virtual std::list<InstPtr> getInstructions() { return instructions; }
std::list<InstPtr> getInstructions() override { return instructions; }
virtual void removeInstruction(const int idx) {
void removeInstruction(const int idx) override {
instructions.remove(getInstruction(idx));
}
virtual const std::string getTag() { return ""; }
const std::string getTag() override { return ""; }
virtual void mapBits(std::vector<int> bitMap) {}
void mapBits(std::vector<int> bitMap) override {
xacc::error("DWKernel.mapBits not implemented");
}
/**
* Add an instruction to this quantum
......@@ -96,28 +97,52 @@ public:
*
* @param instruction
*/
virtual void addInstruction(InstPtr instruction) {
void addInstruction(InstPtr instruction) override {
instructions.push_back(instruction);
}
const int depth() override {
xacc::error("DWKernel.depth() not implemented.");
const int depth() override { xacc::error("DWKernel graph is undirected, cannot compute depth."); }
const std::string persistGraph() override {
std::stringstream s;
toGraph().write(s);
return s.str();
}
const std::string persistGraph() override { return ""; }
/**
* Replace the given current quantum instruction
* with the new replacingInst quantum Instruction.
*
* @param currentInst
* @param replacingInst
*/
virtual void replaceInstruction(const int idx, InstPtr replacingInst) {
Graph<DWVertex> toGraph() override {
int maxBit = 0;
for (int i = 0; i < nInstructions(); ++i) {
auto inst = getInstruction(i);
auto bits = inst->bits();
if (bits[0] > maxBit) {
maxBit = bits[0];
}
if (bits[1] > maxBit) {
maxBit = bits[1];
}
}
DWGraph graph(maxBit+1);
for (int i = 0; i < nInstructions(); ++i) {
auto inst = getInstruction(i);
auto bits = inst->bits();
if (bits[0] == bits[1]) {
std::get<0>(graph.getVertex(bits[0]).properties) = boost::get<double>(inst->getParameter(0));
} else {
graph.addEdge(bits[0], bits[1], boost::get<double>(inst->getParameter(0)));
}
}
return graph;
}
void replaceInstruction(const int idx, InstPtr replacingInst) override {
std::replace(instructions.begin(), instructions.end(), getInstruction(idx),
replacingInst);
}
virtual void insertInstruction(const int idx, InstPtr newInst) {
void insertInstruction(const int idx, InstPtr newInst) override {
auto iter = std::next(instructions.begin(), idx);
instructions.insert(iter, newInst);
}
......@@ -126,24 +151,16 @@ public:
* Return the name of this function
* @return
*/
virtual const std::string name() const { return _name; }
const std::string name() const override { return _name; }
const std::string description() const override { return ""; }
const std::vector<int> bits() override { return std::vector<int>{}; }
virtual const std::string description() const { return ""; }
/**
* Return the qubits this function acts on.
* @return
*/
virtual const std::vector<int> bits() { return std::vector<int>{}; }
/**
* Return an assembly-like string representation for this function .
* @param bufferVarName
* @return
*/
virtual const std::string toString(const std::string &bufferVarName) {
const std::string toString(const std::string &bufferVarName) override {
std::stringstream ss;
for (auto i : instructions) {
ss << i->toString("") << "\n";
ss << i->toString("") << ";\n";
}
return ss.str();
}
......@@ -170,11 +187,11 @@ public:
return weights;
}
virtual InstructionParameter getParameter(const int idx) const {
InstructionParameter getParameter(const int idx) const override {
return parameters[idx];
}
virtual void setParameter(const int idx, InstructionParameter &p) {
void setParameter(const int idx, InstructionParameter &p) override {
if (idx + 1 > parameters.size()) {
XACCLogger::instance()->error(
"DWKernel.setParameter: Invalid Parameter requested.");
......@@ -183,19 +200,19 @@ public:
parameters[idx] = p;
}
virtual std::vector<InstructionParameter> getParameters() {
std::vector<InstructionParameter> getParameters() override {
return parameters;
}
virtual void addParameter(InstructionParameter instParam) {
void addParameter(InstructionParameter instParam) override {
parameters.push_back(instParam);
}
virtual bool isParameterized() { return nParameters() > 0; }
bool isParameterized() override { return nParameters() > 0; }
virtual const int nParameters() { return parameters.size(); }
const int nParameters() override { return parameters.size(); }
virtual std::shared_ptr<Function> operator()(const Eigen::VectorXd &params) {
std::shared_ptr<Function> operator()(const Eigen::VectorXd &params) override {
if (params.size() != nParameters()) {
xacc::error("Invalid DWKernel evaluation: number "
"of parameters don't match. " +
......
......@@ -36,6 +36,23 @@ TEST(DWKernelTester, checkDWKernelConstruction) {
EXPECT_TRUE(kernel.toString("") == expected);
}
TEST(DWKernelTester, checkGraph) {
auto qmi = std::make_shared<DWQMI>(0, 1, 2.2);
auto qmi2 = std::make_shared<DWQMI>(0,1.2);
auto qmi3 = std::make_shared<DWQMI>(1, 3.3);
DWKernel kernel("foo");
kernel.addInstruction(qmi);
kernel.addInstruction(qmi2);
kernel.addInstruction(qmi3);
auto graph = kernel.toGraph();
graph.write(std::cout);
}
int main(int argc, char **argv) {
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
......
......@@ -290,20 +290,9 @@ getIRTransformations(const std::string &name) {
return t;
}
const std::string translate(const std::string &original,
const std::string &originalLanguageName,
const std::string &newLanguageName,
const std::string bufferName) {
auto originalCompiler = getCompiler(originalLanguageName);
auto newCompiler = getCompiler(newLanguageName);
auto ir = originalCompiler->compile(original);
std::string newSrc = "";
for (auto k : ir->getKernels()) {
newSrc += newCompiler->translate(bufferName, k) + "\n";
}
return newSrc;
const std::string translate(std::shared_ptr<Function> function, const std::string toLanguage);{
auto toLanguageCompiler = getCompiler(toLanguage);
return toLanguageCompiler->translate(nullptr, function);
}
const std::string translateWithVisitor(const std::string &originalSource,
......
......@@ -282,10 +282,7 @@ std::shared_ptr<Function> optimizeFunction(const std::string optimizer,
std::shared_ptr<IRTransformation> getIRTransformation(const std::string &name);
const std::string translate(const std::string &original,
const std::string &originalLanguageName,
const std::string &newLanguageName,
const std::string bufferName);
const std::string translate(std::shared_ptr<Function> function, const std::string toLanguage);
const std::string translateWithVisitor(const std::string &originalSource,
const std::string &originalLanguage,
......
......@@ -60,6 +60,10 @@ public:
return generate(nullptr, temp);
}
virtual std::vector<InstructionParameter> analyzeResults(std::shared_ptr<AcceleratorBuffer> buffer) {
return std::vector<InstructionParameter>{};
}
/**
* The destructor
*/
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment