Commit f511bd37 authored by Mccaskey, Alex's avatar Mccaskey, Alex
Browse files

Updates for qcor - added Compiler::canParse, updated xasm and staq to support it



Signed-off-by: Mccaskey, Alex's avatarAlex McCaskey <mccaskeyaj@ornl.gov>
parent f4a9e04a
...@@ -14,6 +14,8 @@ set(LIBRARY_NAME xacc-staq-compiler) ...@@ -14,6 +14,8 @@ set(LIBRARY_NAME xacc-staq-compiler)
file(GLOB SRC file(GLOB SRC
compiler/*.cpp *.cpp transformations/*.cpp utils/*.cpp) compiler/*.cpp *.cpp transformations/*.cpp utils/*.cpp)
message(STATUS "${BoldGreen}Building Staq Compiler.${ColorReset}")
usfunctiongetresourcesource(TARGET ${LIBRARY_NAME} OUT SRC) usfunctiongetresourcesource(TARGET ${LIBRARY_NAME} OUT SRC)
usfunctiongeneratebundleinit(TARGET ${LIBRARY_NAME} OUT SRC) usfunctiongeneratebundleinit(TARGET ${LIBRARY_NAME} OUT SRC)
......
...@@ -27,12 +27,106 @@ ...@@ -27,12 +27,106 @@
#include "output/cirq.hpp" #include "output/cirq.hpp"
#include <streambuf>
using namespace staq::ast; using namespace staq::ast;
namespace xacc { namespace xacc {
StaqCompiler::StaqCompiler() {} StaqCompiler::StaqCompiler() {}
bool StaqCompiler::canParse(const std::string &src) {
std::string _src = src;
bool isXaccKernel = false;
std::string prototype;
auto xasm = xacc::getCompiler("xasm");
/** staq prints to cerr, override this **/
class buffer_override {
std::ostream &os;
std::streambuf *buf;
public:
buffer_override(std::ostream &os) : os(os), buf(os.rdbuf()) {}
~buffer_override() { os.rdbuf(buf); }
};
buffer_override b(std::cerr);
std::stringstream xx;
std::cerr.rdbuf(xx.rdbuf());
/** --------------------------------- **/
if (src.find("__qpu__") != std::string::npos) {
prototype = _src.substr(0, _src.find_first_of("{")) + "{}";
auto bufferNames = xasm->getKernelBufferNames(prototype);
isXaccKernel = true;
std::string tmp = "";
auto first = _src.find_first_of("{");
auto last = _src.find_last_of("}");
auto sub = _src.substr(first + 1, last - first - 1);
auto lines = xacc::split(sub, '\n');
bool addedNames = false;
for (auto &l : lines) {
xacc::trim(l);
if (l.find("measure") != std::string::npos) {
// don't add measures
continue;
}
tmp += l + "\n";
if (l.find("include") != std::string::npos) {
for (auto &b : bufferNames) {
auto size = std::numeric_limits<int>::max();
tmp += "qreg " + b + "[" + std::to_string(size) + "];\n";
}
addedNames = true;
}
}
_src = tmp;
if (!addedNames) {
for (auto &b : bufferNames) {
auto size = std::numeric_limits<int>::max();
_src = "qreg " + b + "[" + std::to_string(size) + "];\n" + _src;
}
}
}
// we allow users to leave out OPENQASM 2.0 and include qelib.inc
// If they did we need to add it for them
std::string tmp = "";
auto lines = xacc::split(_src, '\n');
bool foundOpenQasm = false, foundInclude = false;
for (auto &l : lines) {
if (l.find("OPENQASM") != std::string::npos) {
foundOpenQasm = true;
}
if (l.find("include") != std::string::npos) {
foundInclude = true;
}
}
if (!foundInclude) {
_src = "include \"qelib1.inc\";\n" + _src;
}
if (!foundOpenQasm) {
_src = "OPENQASM 2.0;\n" + _src;
}
// std::cout << " HELLO:\n" << _src << "\n";
using namespace staq;
try {
auto prog = parser::parse_string(_src);
return true;
} catch (std::exception &e) {
return false;
}
}
std::shared_ptr<IR> StaqCompiler::compile(const std::string &src, std::shared_ptr<IR> StaqCompiler::compile(const std::string &src,
std::shared_ptr<Accelerator> acc) { std::shared_ptr<Accelerator> acc) {
// IF src contains typical xacc quantum kernel prototype, then take // IF src contains typical xacc quantum kernel prototype, then take
......
...@@ -26,6 +26,7 @@ public: ...@@ -26,6 +26,7 @@ public:
std::shared_ptr<Accelerator> acc) override; std::shared_ptr<Accelerator> acc) override;
std::shared_ptr<xacc::IR> compile(const std::string &src) override; std::shared_ptr<xacc::IR> compile(const std::string &src) override;
bool canParse(const std::string& src) override;
const std::string const std::string
translate(std::shared_ptr<CompositeInstruction> function) override; translate(std::shared_ptr<CompositeInstruction> function) override;
......
...@@ -26,7 +26,6 @@ TEST(StaqCompilerTester, checkSimple) { ...@@ -26,7 +26,6 @@ TEST(StaqCompilerTester, checkSimple) {
measure q -> c; measure q -> c;
)"); )");
auto hello = IR->getComposites()[0]; auto hello = IR->getComposites()[0];
std::cout << "HELLO:\n" << hello->toString() << "\n"; std::cout << "HELLO:\n" << hello->toString() << "\n";
...@@ -48,6 +47,40 @@ TEST(StaqCompilerTester, checkSimple) { ...@@ -48,6 +47,40 @@ TEST(StaqCompilerTester, checkSimple) {
std::cout << "HELLO:\n" << hello->toString() << "\n"; std::cout << "HELLO:\n" << hello->toString() << "\n";
} }
TEST(StaqCompilerTester, checkCanParse) {
auto a = xacc::qalloc(2);
a->setName("q");
xacc::storeBuffer(a);
auto compiler = xacc::getCompiler("staq");
EXPECT_FALSE(compiler->canParse(
R"(__qpu__ void bell_test_can_parse(qbit q, double t0) {
H(q[0]);
CX(q[0], q[1]);
Ry(q[0], t0);
Measure(q[0]);
Measure(q[1]);
})"));
EXPECT_FALSE(compiler->canParse(
R"(__qpu__ void bell_test_can_parse(qbit q, double t0) {
H(q[0]);
CX(q[0], [1]);
Ry(q[0], t0);
Measure(q[0]);
Measure(q[1]);
})"));
EXPECT_TRUE(compiler->canParse(R"(
qreg q[2];
creg c[2];
U(0,0,0) q[0];
CX q[0],q[1];
rx(3.3) q[0];
measure q -> c;
)"));
}
TEST(StaqCompilerTester, checkTranslate) { TEST(StaqCompilerTester, checkTranslate) {
auto compiler = xacc::getCompiler("staq"); auto compiler = xacc::getCompiler("staq");
...@@ -61,27 +94,26 @@ TEST(StaqCompilerTester, checkTranslate) { ...@@ -61,27 +94,26 @@ TEST(StaqCompilerTester, checkTranslate) {
measure q -> c; measure q -> c;
)"); )");
auto hello = IR->getComposites()[0]; auto hello = IR->getComposites()[0];
std::cout << "HELLO:\n" << hello->toString() << "\n"; std::cout << "HELLO:\n" << hello->toString() << "\n";
std::cout << "TRANSLATED: " << compiler->translate(hello) << "\n"; std::cout << "TRANSLATED: " << compiler->translate(hello) << "\n";
// auto q = xacc::qalloc(2); // auto q = xacc::qalloc(2);
// q->setName("q"); // q->setName("q");
// xacc::storeBuffer(q); // xacc::storeBuffer(q);
// IR = compiler->compile(R"(__qpu__ void f(qreg q) { // IR = compiler->compile(R"(__qpu__ void f(qreg q) {
// OPENQASM 2.0; // OPENQASM 2.0;
// include "qelib1.inc"; // include "qelib1.inc";
// creg c[2]; // creg c[2];
// U(0,0,0) q[0]; // U(0,0,0) q[0];
// CX q[0],q[1]; // CX q[0],q[1];
// rx(3.3) q[0]; // rx(3.3) q[0];
// measure q -> c; // measure q -> c;
// })"); // })");
// hello = IR->getComposites()[0]; // hello = IR->getComposites()[0];
// std::cout << "HELLO:\n" << hello->toString() << "\n"; // std::cout << "HELLO:\n" << hello->toString() << "\n";
} }
TEST(StaqCompilerTester, checkOracle) { TEST(StaqCompilerTester, checkOracle) {
...@@ -144,13 +176,12 @@ adder a[0],a[1],a[2],a[3],b[0],b[1],b[2],b[3],c[0],c[1],c[2],c[3]; ...@@ -144,13 +176,12 @@ adder a[0],a[1],a[2],a[3],b[0],b[1],b[2],b[3],c[0],c[1],c[2],c[3];
// measure // measure
measure c -> result; measure c -> result;
)"; )";
// IR = compiler->compile(src2); // IR = compiler->compile(src2);
// hello = IR->getComposites()[0]; // hello = IR->getComposites()[0];
// std::cout << hello->toString() << "\n"; // std::cout << hello->toString() << "\n";
} }
TEST(StaqCompilerTester, checkCirq) { TEST(StaqCompilerTester, checkCirq) {
auto compiler = xacc::getCompiler("staq"); auto compiler = xacc::getCompiler("staq");
auto IR = compiler->compile(R"( auto IR = compiler->compile(R"(
......
...@@ -32,6 +32,37 @@ TEST(XASMCompilerTester, checkTranslate) { ...@@ -32,6 +32,37 @@ TEST(XASMCompilerTester, checkTranslate) {
std::cout << "HELLO:\n" << translated << "\n"; std::cout << "HELLO:\n" << translated << "\n";
} }
TEST(XASMCompilerTester, checkCanParse) {
auto compiler = xacc::getCompiler("xasm");
EXPECT_TRUE(compiler->canParse(
R"(__qpu__ void bell_test_can_parse(qbit q, double t0) {
H(q[0]);
CX(q[0], q[1]);
Ry(q[0], t0);
Measure(q[0]);
Measure(q[1]);
})"));
EXPECT_FALSE(compiler->canParse(
R"(__qpu__ void bell_test_can_parse(qbit q, double t0) {
H(q[0]);
CX(q[0], [1]);
Ry(q[0], t0);
Measure(q[0]);
Measure(q[1]);
})"));
EXPECT_FALSE(compiler->canParse(R"(
qreg q[2];
creg c[2];
U(0,0,0) q[0];
CX q[0],q[1];
rx(3.3) q[0];
measure q -> c;
)"));
}
TEST(XASMCompilerTester, checkSimple) { TEST(XASMCompilerTester, checkSimple) {
auto compiler = xacc::getCompiler("xasm"); auto compiler = xacc::getCompiler("xasm");
...@@ -195,14 +226,12 @@ TEST(XASMCompilerTester, checkIfStmt) { ...@@ -195,14 +226,12 @@ TEST(XASMCompilerTester, checkIfStmt) {
q->reset_single_measurements(); q->reset_single_measurements();
IR->getComposites()[0]->getInstruction(2)->disable(); IR->getComposites()[0]->getInstruction(2)->disable();
std::cout << "KERNEL\n" << IR->getComposites()[0]->toString() << "\n"; std::cout << "KERNEL\n" << IR->getComposites()[0]->toString() << "\n";
q->measure(0, 1); q->measure(0, 1);
IR->getComposites()[0]->expand({}); IR->getComposites()[0]->expand({});
std::cout << "KERNEL\n" << IR->getComposites()[0]->toString() << "\n"; std::cout << "KERNEL\n" << IR->getComposites()[0]->toString() << "\n";
} }
TEST(XASMCompilerTester, checkApplyAll) { TEST(XASMCompilerTester, checkApplyAll) {
...@@ -245,20 +274,18 @@ TEST(XASMCompilerTester, checkGateOnAll) { ...@@ -245,20 +274,18 @@ TEST(XASMCompilerTester, checkGateOnAll) {
xacc::storeBuffer(q); xacc::storeBuffer(q);
auto compiler = xacc::getCompiler("xasm"); auto compiler = xacc::getCompiler("xasm");
auto IR = auto IR = compiler->compile(R"(__qpu__ void on_all(qbit qqq) {
compiler->compile(R"(__qpu__ void on_all(qbit qqq) {
H(qqq); H(qqq);
Measure(qqq); Measure(qqq);
})"); })");
auto f = IR->getComposite("on_all"); auto f = IR->getComposite("on_all");
std::cout << "F:\n" << f->toString() << "\n";
std::cout << "F:\n" << f->toString() << "\n";
} }
TEST(XASMCompilerTester, checkBugBug) { TEST(XASMCompilerTester, checkBugBug) {
auto a = xacc::qalloc(4); auto a = xacc::qalloc(4);
a->setName("a"); a->setName("a");
xacc::storeBuffer(a); xacc::storeBuffer(a);
...@@ -270,12 +297,12 @@ auto a = xacc::qalloc(4); ...@@ -270,12 +297,12 @@ auto a = xacc::qalloc(4);
c->setName("c"); c->setName("c");
xacc::storeBuffer(c); xacc::storeBuffer(c);
auto anc = xacc::qalloc(4); auto anc = xacc::qalloc(4);
anc->setName("anc"); anc->setName("anc");
xacc::storeBuffer(anc); xacc::storeBuffer(anc);
auto compiler = xacc::getCompiler("xasm"); auto compiler = xacc::getCompiler("xasm");
auto IR = auto IR = compiler->compile(
compiler->compile(R"(__qpu__ void bugbug(qbit a, qbit b, qbit c, qbit anc) { R"(__qpu__ void bugbug(qbit a, qbit b, qbit c, qbit anc) {
X(a[0]); X(a[0]);
X(a[1]); X(a[1]);
X(b[0]); X(b[0]);
...@@ -287,15 +314,15 @@ CX(b[0], anc[0]); ...@@ -287,15 +314,15 @@ CX(b[0], anc[0]);
Tdg(anc[0]); Tdg(anc[0]);
})"); })");
std::cout << IR->getComposites()[0]->toString() << "\n"; std::cout << IR->getComposites()[0]->toString() << "\n";
std::cout << IR->getComposites()[0]->getInstruction(4)->toString() << "\n"; std::cout << IR->getComposites()[0]->getInstruction(4)->toString() << "\n";
for (auto& b :IR->getComposites()[0]->getInstruction(4)->getBufferNames()) {std::cout << b << "\n";} for (auto &b : IR->getComposites()[0]->getInstruction(4)->getBufferNames()) {
std::cout << b << "\n";
}
} }
TEST(XASMCompilerTester, checkCallingPreviousKernel) { TEST(XASMCompilerTester, checkCallingPreviousKernel) {
auto compiler = xacc::getCompiler("xasm"); auto compiler = xacc::getCompiler("xasm");
auto IR = auto IR = compiler->compile(R"(__qpu__ void bell_call(qbit q) {
compiler->compile(R"(__qpu__ void bell_call(qbit q) {
H(q[0]); H(q[0]);
CX(q[0], q[1]); CX(q[0], q[1]);
Measure(q[0]); Measure(q[0]);
......
...@@ -29,6 +29,32 @@ namespace xacc { ...@@ -29,6 +29,32 @@ namespace xacc {
XASMCompiler::XASMCompiler() = default; XASMCompiler::XASMCompiler() = default;
bool XASMCompiler::canParse(const std::string &src) {
class XASMThrowExceptionErrorListener : public BaseErrorListener {
public:
void syntaxError(Recognizer *recognizer, Token *offendingSymbol,
size_t line, size_t charPositionInLine,
const std::string &msg, std::exception_ptr e) override {
throw std::runtime_error("Cannot parse this XASM source string.");
}
};
ANTLRInputStream input(src);
xasmLexer lexer(&input);
CommonTokenStream tokens(&lexer);
xasmParser parser(&tokens);
parser.removeErrorListeners();
parser.addErrorListener(new XASMThrowExceptionErrorListener());
try {
tree::ParseTree *tree = parser.xaccsrc();
return true;
} catch (std::exception &e) {
return false;
}
}
std::shared_ptr<IR> XASMCompiler::compile(const std::string &src, std::shared_ptr<IR> XASMCompiler::compile(const std::string &src,
std::shared_ptr<Accelerator> acc) { std::shared_ptr<Accelerator> acc) {
ANTLRInputStream input(src); ANTLRInputStream input(src);
...@@ -60,7 +86,8 @@ std::shared_ptr<IR> XASMCompiler::compile(const std::string &src) { ...@@ -60,7 +86,8 @@ std::shared_ptr<IR> XASMCompiler::compile(const std::string &src) {
return compile(src, nullptr); return compile(src, nullptr);
} }
std::vector<std::string> XASMCompiler::getKernelBufferNames(const std::string& src) { std::vector<std::string>
XASMCompiler::getKernelBufferNames(const std::string &src) {
ANTLRInputStream input(src); ANTLRInputStream input(src);
xasmLexer lexer(&input); xasmLexer lexer(&input);
CommonTokenStream tokens(&lexer); CommonTokenStream tokens(&lexer);
......
...@@ -26,7 +26,8 @@ public: ...@@ -26,7 +26,8 @@ public:
std::shared_ptr<Accelerator> acc) override; std::shared_ptr<Accelerator> acc) override;
std::shared_ptr<xacc::IR> compile(const std::string &src) override; std::shared_ptr<xacc::IR> compile(const std::string &src) override;
bool canParse(const std::string& src) override;
const std::string translate(std::shared_ptr<CompositeInstruction> function) override; const std::string translate(std::shared_ptr<CompositeInstruction> function) override;
std::vector<std::string> getKernelBufferNames(const std::string& src) override; std::vector<std::string> getKernelBufferNames(const std::string& src) override;
......
...@@ -27,6 +27,13 @@ public: ...@@ -27,6 +27,13 @@ public:
virtual std::shared_ptr<IR> compile(const std::string &src, virtual std::shared_ptr<IR> compile(const std::string &src,
std::shared_ptr<Accelerator> acc) = 0; std::shared_ptr<Accelerator> acc) = 0;
virtual std::shared_ptr<IR> compile(const std::string &src) = 0; virtual std::shared_ptr<IR> compile(const std::string &src) = 0;
// By default, we assume this compiler can not parse a given
// source string. Subtypes implement this to indicate if
// they can or not
virtual bool canParse(const std::string& src) {
return false;
}
virtual const std::string virtual const std::string
translate(std::shared_ptr<CompositeInstruction> program) = 0; translate(std::shared_ptr<CompositeInstruction> program) = 0;
......
...@@ -16,6 +16,7 @@ std::map<std::string, int> qreg::counts() { ...@@ -16,6 +16,7 @@ std::map<std::string, int> qreg::counts() {
double qreg::exp_val_z() { return buffer->getExpectationValueZ(); } double qreg::exp_val_z() { return buffer->getExpectationValueZ(); }
void qreg::reset() { buffer->resetBuffer(); } void qreg::reset() { buffer->resetBuffer(); }
void qreg::setName(const char *name) { buffer->setName(name); } void qreg::setName(const char *name) { buffer->setName(name); }
void qreg::setNameAndStore(const char *name) { setName(name); store(); }
void qreg::store() { void qreg::store() {
auto buffer_as_shared = std::shared_ptr<AcceleratorBuffer>( auto buffer_as_shared = std::shared_ptr<AcceleratorBuffer>(
buffer, empty_delete<AcceleratorBuffer>()); buffer, empty_delete<AcceleratorBuffer>());
......
...@@ -32,6 +32,7 @@ public: ...@@ -32,6 +32,7 @@ public:
double exp_val_z(); double exp_val_z();
void reset(); void reset();
void setName(const char *name); void setName(const char *name);
void setNameAndStore(const char *name);
void store(); void store();
}; };
...@@ -42,4 +43,6 @@ xacc::internal_compiler::qreg qalloc(const int n) { ...@@ -42,4 +43,6 @@ xacc::internal_compiler::qreg qalloc(const int n) {
return xacc::internal_compiler::qreg(n); return xacc::internal_compiler::qreg(n);
} }
#define __qpu__ [[clang::syntax(qcor)]]
#endif #endif
\ No newline at end of file
...@@ -6,19 +6,24 @@ ...@@ -6,19 +6,24 @@
namespace xacc { namespace xacc {
namespace internal_compiler { namespace internal_compiler {
Accelerator *qpu = nullptr; Accelerator *qpu = nullptr;
CompositeInstruction* lastCompiled = nullptr; CompositeInstruction *lastCompiled = nullptr;
bool __execute = true; bool __execute = true;
void __set_verbose(bool v) { void __set_verbose(bool v) { xacc::set_verbose(v); }
xacc::set_verbose(v);
}
void compiler_InitializeXACC(const char *qpu_backend) { void compiler_InitializeXACC(const char *qpu_backend) {
if (!xacc::isInitialized()) if (!xacc::isInitialized())
xacc::Initialize(); xacc::Initialize();
xacc::external::load_external_language_plugins(); xacc::external::load_external_language_plugins();
setAccelerator(qpu_backend); setAccelerator(qpu_backend);
}
void compiler_InitializeXACC(const char *qpu_backend, const int shots) {
if (!xacc::isInitialized())
xacc::Initialize();
xacc::external::load_external_language_plugins();
setAccelerator(qpu_backend, shots);
} }
void setAccelerator(const char *qpu_backend) { void setAccelerator(const char *qpu_backend) {
...@@ -31,18 +36,21 @@ void setAccelerator(const char *qpu_backend) { ...@@ -31,18 +36,21 @@ void setAccelerator(const char *qpu_backend) {
}