Commit ae1a6c2b authored by Nguyen, Thien Minh's avatar Nguyen, Thien Minh
Browse files

Handle multiple kernels more robustly



- Make sure that the callable is a quantum kernel call (first arg is a qreg) in the pyxasm visitor.

- Prevent redefinition errors.

- Add tests for nested JIT kernels.

Signed-off-by: Nguyen, Thien Minh's avatarThien Nguyen <nguyentm@ornl.gov>
parent 5c4651f2
......@@ -96,7 +96,7 @@ void PyXasmTokenCollector::collect(clang::Preprocessor &PP,
// << ": " << line.first << ", " << line.second << std::boolalpha
// << ", " << is_in_for_loop << "\n";
pyxasm_visitor visitor;
pyxasm_visitor visitor(bufferNames);
if (line.first.find("for ") != std::string::npos) {
is_in_for_loop = true;
......
......@@ -17,9 +17,12 @@ using pyxasm_result_type =
class pyxasm_visitor : public pyxasmBaseVisitor {
protected:
std::shared_ptr<xacc::IRProvider> provider;
// List of buffers in the *context* of this XASM visitor
std::vector<std::string> bufferNames;
public:
pyxasm_visitor() : provider(xacc::getIRProvider("quantum")) {}
pyxasm_visitor(const std::vector<std::string> &buffers = {})
: provider(xacc::getIRProvider("quantum")), bufferNames(buffers) {}
pyxasm_result_type result;
bool in_for_loop = false;
......@@ -118,12 +121,17 @@ public:
}
}
} else {
// This *callable* is not an intrinsic instruction, just reassemble the
// call:
// TODO: validate that this is a *previously-defined* kernel (i.e. not a
// classical function call)
if (!context->trailer().empty()) {
// This kernel *callable* is not an intrinsic instruction, just
// reassemble the call:
// Check that the *first* argument is a *qreg* in the current context of
// *this* kernel.
if (!context->trailer().empty() &&
!context->trailer()[0]->arglist()->argument().empty() &&
xacc::container::contains(
bufferNames,
context->trailer()[0]->arglist()->argument(0)->getText())) {
std::stringstream ss;
// Use the kernel call with a parent kernel arg.
ss << inst_name << "(parent_kernel, ";
// TODO: We potentially need to handle *inline* expressions in the
// function call.
......
......@@ -130,6 +130,40 @@ class TestSimpleKernelJIT(unittest.TestCase):
q = qalloc(5)
comp = testFor.extract_composite(q)
self.assertEqual(comp.nInstructions(), 5)
def test_multiple_kernels(self):
@qjit
def apply_H(q : qreg):
for i in range(q.size()):
H(q[i])
@qjit
def apply_Rx(q : qreg, theta: float):
for i in range(q.size()):
Rx(q[i], theta)
@qjit
def measure_all(q : qreg):
for i in range(q.size()):
Measure(q[i])
@qjit
def entry_kernel(q : qreg, theta: float):
apply_H(q)
apply_Rx(q, theta)
measure_all(q)
q = qalloc(5)
angle = 1.234
comp = entry_kernel.extract_composite(q, angle)
self.assertEqual(comp.nInstructions(), 15)
for i in range(5):
self.assertEqual(comp.getInstruction(i).name(), "H")
for i in range(5, 10):
self.assertEqual(comp.getInstruction(i).name(), "Rx")
self.assertAlmostEqual((float)(comp.getInstruction(i).getParameter(0)), angle)
for i in range(10, 15):
self.assertEqual(comp.getInstruction(i).name(), "Measure")
if __name__ == '__main__':
unittest.main()
\ No newline at end of file
......@@ -224,9 +224,15 @@ const std::pair<std::string, std::string> QJIT::run_syntax_handler(
preamble += ", " + arg_types[j] + " " + arg_vars[j];
}
return std::make_pair(kernel_name, Replacement +
"\n// Fix for __dso_handle symbol not "
"found\nint __dso_handle = 1;\n");
const std::string fix_dso_str = R"(
// Fix for __dso_handle symbol not found
#ifndef __FIX__DSO__HANDLE__
#define __FIX__DSO__HANDLE__
int __dso_handle = 1;
#endif
)";
return std::make_pair(kernel_name, Replacement + "\n" + fix_dso_str + "\n");
}
class LLVMJIT {
......@@ -359,15 +365,12 @@ void QJIT::jit_compile(const std::string &code,
// this kernel before compilation.
std::string dependencyCode;
if (!kernel_dependency.empty()) {
// Put the code in an anonymous namespace
dependencyCode += "namespace { \n";
for (const auto &dep : kernel_dependency) {
const auto depIter = JIT_KERNEL_RUNTIME_CACHE.find(dep);
if (depIter != JIT_KERNEL_RUNTIME_CACHE.end()) {
dependencyCode += JIT_KERNEL_RUNTIME_CACHE[dep];
}
}
dependencyCode += "}\n";
}
// Add dependency before JIT compile:
new_code = dependencyCode + new_code;
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment