Commit 5c4651f2 authored by Nguyen, Thien Minh's avatar Nguyen, Thien Minh
Browse files

First pass to support nested JIT kernels



Kernel dependency is injected to the generated source before JIT compilation.

Signed-off-by: Nguyen, Thien Minh's avatarThien Nguyen <nguyentm@ornl.gov>
parent 957d90b7
......@@ -117,6 +117,26 @@ public:
"' is not currently supported.");
}
}
} else {
// This *callable* is not an intrinsic instruction, just reassemble the
// call:
// TODO: validate that this is a *previously-defined* kernel (i.e. not a
// classical function call)
if (!context->trailer().empty()) {
std::stringstream ss;
ss << inst_name << "(parent_kernel, ";
// TODO: We potentially need to handle *inline* expressions in the
// function call.
const auto &argList = context->trailer()[0]->arglist()->argument();
for (size_t i = 0; i < argList.size(); ++i) {
ss << argList[i]->getText();
if (i != argList.size() - 1) {
ss << ", ";
}
}
ss << ");\n";
result.first = ss.str();
}
}
}
return 0;
......
# Run this from the command line like this
#
# python3 multiple_kernels.py -shots 100
from qcor import qjit, qalloc, qreg
# To create QCOR quantum kernels in Python one
# simply creates a Python function, writes Pythonic,
# XASM-like quantum code, and annotates the kernel
# to indicate it is meant for QCOR just in time compilation
# NOTE Programmers must type annotate their function arguments
@qjit
def measure_all_qubits(q : qreg):
for i in range(q.size()):
Measure(q[i])
# Define a Bell kernel
@qjit
def bell_test(q : qreg):
H(q[0])
CX(q[0], q[1])
# Call other kernels
measure_all_qubits(q)
# Allocate 2 qubits
q = qalloc(2)
# Inspect the IR
comp = bell_test.extract_composite(q)
print(comp.toString())
# Run the bell experiment
bell_test(q)
# Print the results
q.print()
\ No newline at end of file
......@@ -287,9 +287,10 @@ PYBIND11_MODULE(_pyqcor, m) {
.def("jit_compile", &qcor::QJIT::jit_compile, "")
.def(
"internal_python_jit_compile",
[](qcor::QJIT &qjit, const std::string src) {
[](qcor::QJIT &qjit, const std::string src,
const std::vector<std::string> &dependency = {}) {
bool turn_on_hetmap_kernel_ctor = true;
qjit.jit_compile(src, turn_on_hetmap_kernel_ctor);
qjit.jit_compile(src, turn_on_hetmap_kernel_ctor, dependency);
},
"")
.def("run_syntax_handler", &qcor::QJIT::run_syntax_handler, "")
......
......@@ -4,6 +4,7 @@ import sys
import inspect
from typing import List
import typing
import re
List = typing.List
......@@ -124,12 +125,24 @@ class qjit(object):
self.src = '__qpu__ void '+self.function.__name__ + \
'('+cpp_arg_str+') {\nusing qcor::pyxasm;\n' + globalDeclStr + '\n' + fbody_src +"}\n"
# Handle nested kernels:
dependency = []
for kernelName in self.__compiled__kernels:
kernelCall = kernelName + '('
# Check that this kernel *calls* a previously-compiled kernel:
# pattern: "<white space> kernel("
if re.search(r"\b" + re.escape(kernelCall), self.src):
dependency.append(kernelName)
# Run the QJIT compile step to store function pointers internally
self._qjit.internal_python_jit_compile(self.src)
self._qjit.internal_python_jit_compile(self.src, dependency)
self._qjit.write_cache()
self.__compiled__kernels.append(self.function.__name__)
return
# Static list of all kernels compiled
__compiled__kernels = []
def get_internal_src(self):
"""Return the C++ / embedded python DSL function code that will be passed to QJIT
and the clang syntax handler. This function is primarily to be used for developer purposes. """
......
......@@ -45,7 +45,8 @@ class QJIT {
const std::string &quantum_kernel_src,
const bool add_het_map_kernel_ctor = false);
void jit_compile(const std::string &quantum_kernel_src,
const bool add_het_map_kernel_ctor = false);
const bool add_het_map_kernel_ctor = false,
const std::vector<std::string> &kernel_dependency = {});
void write_cache();
......
......@@ -344,12 +344,34 @@ void QJIT::write_cache() {
QJIT::~QJIT() { write_cache(); }
void QJIT::jit_compile(const std::string &code,
const bool add_het_map_kernel_ctor) {
const bool add_het_map_kernel_ctor,
const std::vector<std::string> &kernel_dependency) {
// Run the Syntax Handler to get the kernel name and
// the kernel code (the QuantumKernel subtype def + utility functions)
auto [kernel_name, new_code] =
run_syntax_handler(code, add_het_map_kernel_ctor);
static std::unordered_map<std::string, std::string> JIT_KERNEL_RUNTIME_CACHE;
JIT_KERNEL_RUNTIME_CACHE[kernel_name] = new_code;
// Add dependency code if necessary:
// Look up the previously-generated for dependency kernels and add them to
// this kernel before compilation.
std::string dependencyCode;
if (!kernel_dependency.empty()) {
// Put the code in an anonymous namespace
dependencyCode += "namespace { \n";
for (const auto &dep : kernel_dependency) {
const auto depIter = JIT_KERNEL_RUNTIME_CACHE.find(dep);
if (depIter != JIT_KERNEL_RUNTIME_CACHE.end()) {
dependencyCode += JIT_KERNEL_RUNTIME_CACHE[dep];
}
}
dependencyCode += "}\n";
}
// Add dependency before JIT compile:
new_code = dependencyCode + new_code;
// std::cout << "New code:\n" << new_code << "\n";
// Hash the new code
std::hash<std::string> hasher;
auto hash = hasher(new_code);
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment