Commit 6b2833b1 authored by Mccaskey, Alex's avatar Mccaskey, Alex
Browse files

add extra_code to cached kernel code for dependent kernels

parent 1345cc04
Loading
Loading
Loading
Loading
Loading
+3 −3
Original line number Diff line number Diff line
@@ -83,7 +83,7 @@ cpp_matrix_gen_code = '''#include <pybind11/embed.h>
#include <pybind11/complex.h>
namespace py = pybind11;
// returns 1d data as vector and matrix size (assume square)
auto __internal__qcor_pyjit_gen_{}_unitary_matrix({}) {{
auto __internal__qcor_pyjit_{}_gen_{}_unitary_matrix({}) {{
  auto py_src = R"#({})#";
  auto locals = py::dict();
  {}
@@ -309,13 +309,13 @@ class qjit(object):
                        [s for s in analyzer.depends_on])
                    locals_code = '\n'.join(
                        ['locals["{}"] = {};'.format(n, n) for n in arg_var_names])
                    self.extra_cpp_code = cpp_matrix_gen_code.format(
                    self.extra_cpp_code = cpp_matrix_gen_code.format(self.kernel_name(),
                        with_decomp_matrix_names[i], arg_struct, code_to_exec, locals_code)

                    col_skip = ' '*with_decomp_lines_col_starts[i]
                    new_src = col_skip + 'decompose {\n'
                    new_src += col_skip + ' '*4 + \
                        'auto [mat_data, mat_size] = __internal__qcor_pyjit_gen_{}_unitary_matrix({});\n'.format(
                        'auto [mat_data, mat_size] = __internal__qcor_pyjit_{}_gen_{}_unitary_matrix({});\n'.format(self.kernel_name(),
                            with_decomp_matrix_names[i], arg_var_names)
                    new_src += col_skip+' '*4 + \
                        'UnitaryMatrix {} = Eigen::Map<UnitaryMatrix>(mat_data.data(), mat_size, mat_size);\n'.format(
+3 −3
Original line number Diff line number Diff line
@@ -492,6 +492,9 @@ void QJIT::jit_compile(const std::string &code,
      run_syntax_handler(code, add_het_map_kernel_ctor);

  static std::unordered_map<std::string, std::string> JIT_KERNEL_RUNTIME_CACHE;
  
  // Add any extra functions to be compiled
  new_code = extra_functions_src + "\n" + new_code;
  JIT_KERNEL_RUNTIME_CACHE[kernel_name] = new_code;

  // Add dependency code if necessary:
@@ -507,9 +510,6 @@ void QJIT::jit_compile(const std::string &code,
  // Add dependency before JIT compile:
  new_code = dependencyCode + new_code;

  // Add any extra functions to be compiled
  new_code = extra_functions_src + "\n" + new_code;

  // std::cout << "New code:\n" << new_code << "\n";
  // Hash the new code
  std::hash<std::string> hasher;