Commit b8db73c6 authored by Nguyen, Thien Minh's avatar Nguyen, Thien Minh
Browse files

Speed-up OpenQASM compilation for long circuits



- Add a visitor to translate Staq's AST to XACC's IR directly.

- Add ability to add a vector of instructions to Composite without validation. These instructions are translated from a valid AST.

Signed-off-by: Nguyen, Thien Minh's avatarThien Nguyen <nguyentm@ornl.gov>
parent 6c250e44
......@@ -224,6 +224,21 @@ public:
for (auto &i : insts)
addInstruction(i);
}
void addInstructions(const std::vector<InstPtr> &&insts,
bool shouldValidate = true) override {
if (shouldValidate) {
for (auto &i : insts) {
addInstruction(i);
}
} else {
// Bypass instruction validation, append all the instructions directly.
instructions.insert(instructions.end(),
std::make_move_iterator(insts.begin()),
std::make_move_iterator(insts.end()));
}
}
void clear() override { instructions.clear(); }
bool hasChildren() const override { return !instructions.empty(); }
......
......@@ -226,9 +226,37 @@ std::shared_ptr<IR> StaqCompiler::compile(const std::string &src,
*prog, {false, transformations::default_overrides, "anc"});
// std::cout <<"PROG: " << *prog << "\n";
// Visit Program to find out how many qreg there are and
// use that to build up openqasm xacc function prototype
// Determine the number of qreqs
internal_staq::CountQregs countQreq;
dynamic_cast<ast::Traverse &>(countQreq).visit(*prog);
const auto nbQreqs = countQreq.qregs.size() + ancillas.ancillas.size();
// Direct Staq's AST -> XACC's IR translation:
// This can only be used (reliably) for simple QASM source,
// that uses a single qreg (we can use simple qubit indexing to construct IR)
// Note: we don't handle *embedded* QASM source in this direct translate mode.
if (!isXaccKernel && nbQreqs == 1) {
// Create a temporary kernel name:
std::string name = "tmp";
if (xacc::hasCompiled(name)) {
int counter = 0;
while (true) {
name = "tmp" + std::to_string(counter);
if (!xacc::hasCompiled(name)) {
break;
}
counter++;
}
}
// Direct translation
internal_staq::StaqToIr translate(name);
translate.visit(*prog);
return translate.getIr();
}
// Otherwise, translate Staq's AST to XASM source string
// then recompile.
internal_staq::StaqToXasm translate;
translate.visit(*prog);
......
......@@ -15,6 +15,11 @@
#include "xacc.hpp"
#include "xacc_service.hpp"
namespace {
// CU1 decompose: 5 instructions
constexpr int CU1_DECOMP_N_INSTS = 5;
}
TEST(StaqCompilerTester, checkCphase) {
auto src = R"#(
......@@ -34,9 +39,10 @@ measure q[2] -> c[2];)#";
auto hello = IR->getComposites()[0];
std::cout << "HELLO:\n" << hello->toString() << "\n";
const int expectedNinsts = 7 + CU1_DECOMP_N_INSTS;
EXPECT_EQ(hello->nInstructions(), expectedNinsts);
}
TEST(StaqCompilerTester, checkSimple) {
auto compiler = xacc::getCompiler("staq");
auto IR = compiler->compile(R"(
......@@ -50,7 +56,8 @@ TEST(StaqCompilerTester, checkSimple) {
auto hello = IR->getComposites()[0];
std::cout << "HELLO:\n" << hello->toString() << "\n";
EXPECT_EQ(hello->nInstructions(), 5);
auto q = xacc::qalloc(2);
q->setName("q");
xacc::storeBuffer(q);
......@@ -67,6 +74,7 @@ TEST(StaqCompilerTester, checkSimple) {
hello = IR->getComposites()[0];
std::cout << "HELLO:\n" << hello->toString() << "\n";
EXPECT_EQ(hello->nInstructions(), 5);
}
TEST(StaqCompilerTester, checkCanParse) {
......
......@@ -18,7 +18,8 @@
#include "ast/traversal.hpp"
#include <map>
#include <iomanip>
#include "xacc.hpp"
#include "xacc_service.hpp"
#include "AllGateVisitor.hpp"
using namespace staq::ast;
......@@ -110,6 +111,94 @@ public:
}
};
// Staq AST to XACC IR
class StaqToIr : public staq::ast::Visitor {
public:
StaqToIr(const std::string &in_kernelName)
: m_kernelName(in_kernelName), m_provider(xacc::getIRProvider("quantum")) {}
void visit(VarAccess &) override {}
// Expressions
void visit(BExpr &) override {}
void visit(UExpr &) override {}
void visit(PiExpr &) override {}
void visit(IntExpr &) override {}
void visit(RealExpr &r) override {}
void visit(VarExpr &v) override {}
void visit(ResetStmt &) override {}
void visit(IfStmt &) override {}
void visit(BarrierGate &) override {}
void visit(GateDecl &) override {}
void visit(OracleDecl &) override {}
void visit(RegisterDecl &) override {}
void visit(AncillaDecl &) override {}
void visit(Program &prog) override {
// Program body
m_runtimeInsts.clear();
m_runtimeInsts.reserve(prog.body().size());
prog.foreach_stmt([this](auto &stmt) { stmt.accept(*this); });
}
void visit(MeasureStmt &m) override {
m_runtimeInsts.emplace_back(
std::make_shared<xacc::quantum::Measure>(m.q_arg().offset().value()));
}
void visit(UGate &u) override {
m_runtimeInsts.emplace_back(std::make_shared<xacc::quantum::U>(
u.arg().offset().value(), u.theta().constant_eval().value(),
u.phi().constant_eval().value(), u.lambda().constant_eval().value()));
}
void visit(CNOTGate &cx) override {
m_runtimeInsts.emplace_back(std::make_shared<xacc::quantum::CNOT>(
cx.ctrl().offset().value(), cx.tgt().offset().value()));
}
void visit(DeclaredGate &g) override {
auto xacc_name = staq_to_xacc.at(g.name());
// Handle common gates:
if (xacc_name == "Rx") {
m_runtimeInsts.emplace_back(std::make_shared<xacc::quantum::Rx>(
g.qarg(0).offset().value(), g.carg(0).constant_eval().value()));
} else if (xacc_name == "Ry") {
m_runtimeInsts.emplace_back(std::make_shared<xacc::quantum::Ry>(
g.qarg(0).offset().value(), g.carg(0).constant_eval().value()));
} else if (xacc_name == "Rz") {
m_runtimeInsts.emplace_back(std::make_shared<xacc::quantum::Rz>(
g.qarg(0).offset().value(), g.carg(0).constant_eval().value()));
} else {
// Otherwise, just do generic construction
std::vector<std::size_t> gate_bits;
std::vector<InstructionParameter> gate_params;
for (int i = 0; i < g.num_qargs(); i++) {
gate_bits.emplace_back(g.qarg(i).offset().value());
}
for (int i = 0; i < g.num_cargs(); i++) {
gate_params.emplace_back(g.carg(i).constant_eval().value());
}
m_runtimeInsts.emplace_back(
m_provider->createInstruction(xacc_name, gate_bits, gate_params));
}
}
std::shared_ptr<IR> getIr() {
auto composite =
xacc::getService<IRProvider>("quantum")->createComposite(m_kernelName);
// Since the instructions were *compiled* by staq (valid AST),
// hence, we skip all validation.
composite->addInstructions(std::move(m_runtimeInsts), false);
auto ir = xacc::getService<IRProvider>("quantum")->createIR();
ir->addComposite(composite);
return ir;
}
private:
std::vector<InstPtr> m_runtimeInsts;
std::shared_ptr<IRProvider> m_provider;
std::string m_kernelName;
};
using namespace xacc::quantum;
class XACCToStaqOpenQasm : public AllGateVisitor {
......
......@@ -176,7 +176,7 @@ public:
virtual void addInstruction(InstPtr instruction) = 0;
virtual void addInstructions(std::vector<InstPtr> &instruction) = 0;
virtual void addInstructions(const std::vector<InstPtr> &instruction) = 0;
virtual void addInstructions(const std::vector<InstPtr> &&insts) {
virtual void addInstructions(const std::vector<InstPtr> &&insts, bool shouldValidate = true) {
addInstructions(insts);
}
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment