Commit ae1a6c2b authored by Nguyen, Thien Minh's avatar Nguyen, Thien Minh
Browse files

Handle multiple kernels more robustly



- Make sure that the callable is a quantum kernel call (first arg is a qreg) in the pyxasm visitor.

- Prevent redefinition errors.

- Add tests for nested JIT kernels.

Signed-off-by: default avatarThien Nguyen <nguyentm@ornl.gov>
parent 5c4651f2
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -96,7 +96,7 @@ void PyXasmTokenCollector::collect(clang::Preprocessor &PP,
    //           << ": " << line.first << ", " << line.second << std::boolalpha
    //           << ", " << is_in_for_loop << "\n";

    pyxasm_visitor visitor;
    pyxasm_visitor visitor(bufferNames);

    if (line.first.find("for ") != std::string::npos) {
      is_in_for_loop = true;
+14 −6
Original line number Diff line number Diff line
@@ -17,9 +17,12 @@ using pyxasm_result_type =
class pyxasm_visitor : public pyxasmBaseVisitor {
protected:
  std::shared_ptr<xacc::IRProvider> provider;
  // List of buffers in the *context* of this XASM visitor
  std::vector<std::string> bufferNames;

public:
  pyxasm_visitor() : provider(xacc::getIRProvider("quantum")) {}
  pyxasm_visitor(const std::vector<std::string> &buffers = {})
      : provider(xacc::getIRProvider("quantum")), bufferNames(buffers) {}
  pyxasm_result_type result;

  bool in_for_loop = false;
@@ -118,12 +121,17 @@ public:
          }
        }
      } else {
        // This *callable* is not an intrinsic instruction, just reassemble the
        // call:
        // TODO: validate that this is a *previously-defined* kernel (i.e. not a
        // classical function call)
        if (!context->trailer().empty()) {
        // This kernel *callable* is not an intrinsic instruction, just
        // reassemble the call:
        // Check that the *first* argument is a *qreg* in the current context of
        // *this* kernel.
        if (!context->trailer().empty() &&
            !context->trailer()[0]->arglist()->argument().empty() &&
            xacc::container::contains(
                bufferNames,
                context->trailer()[0]->arglist()->argument(0)->getText())) {
          std::stringstream ss;
          // Use the kernel call with a parent kernel arg.
          ss << inst_name << "(parent_kernel, ";
          // TODO: We potentially need to handle *inline* expressions in the
          // function call.
+34 −0
Original line number Diff line number Diff line
@@ -131,5 +131,39 @@ class TestSimpleKernelJIT(unittest.TestCase):
        comp = testFor.extract_composite(q)
        self.assertEqual(comp.nInstructions(), 5)   
    
    def test_multiple_kernels(self):
        @qjit
        def apply_H(q : qreg):
            for i in range(q.size()):
                H(q[i])
        
        @qjit
        def apply_Rx(q : qreg, theta: float):
            for i in range(q.size()):
                Rx(q[i], theta)

        @qjit
        def measure_all(q : qreg):
            for i in range(q.size()):
                Measure(q[i])

        @qjit
        def entry_kernel(q : qreg, theta: float):
           apply_H(q)
           apply_Rx(q, theta)
           measure_all(q) 
        
        q = qalloc(5)
        angle = 1.234
        comp = entry_kernel.extract_composite(q, angle)
        self.assertEqual(comp.nInstructions(), 15)   
        for i in range(5):
            self.assertEqual(comp.getInstruction(i).name(), "H") 
        for i in range(5, 10):
            self.assertEqual(comp.getInstruction(i).name(), "Rx") 
            self.assertAlmostEqual((float)(comp.getInstruction(i).getParameter(0)), angle)
        for i in range(10, 15):
            self.assertEqual(comp.getInstruction(i).name(), "Measure") 

if __name__ == '__main__':
  unittest.main()
 No newline at end of file
+9 −6
Original line number Diff line number Diff line
@@ -224,9 +224,15 @@ const std::pair<std::string, std::string> QJIT::run_syntax_handler(
    preamble += ", " + arg_types[j] + " " + arg_vars[j];
  }

  return std::make_pair(kernel_name, Replacement +
                                         "\n// Fix for __dso_handle symbol not "
                                         "found\nint __dso_handle = 1;\n");
  const std::string fix_dso_str = R"(
// Fix for __dso_handle symbol not found
#ifndef __FIX__DSO__HANDLE__
#define __FIX__DSO__HANDLE__ 
int __dso_handle = 1;
#endif
)";

  return std::make_pair(kernel_name, Replacement + "\n" + fix_dso_str + "\n");
}

class LLVMJIT {
@@ -359,15 +365,12 @@ void QJIT::jit_compile(const std::string &code,
  // this kernel before compilation.
  std::string dependencyCode;
  if (!kernel_dependency.empty()) {
    // Put the code in an anonymous namespace
    dependencyCode += "namespace { \n";
    for (const auto &dep : kernel_dependency) {
      const auto depIter = JIT_KERNEL_RUNTIME_CACHE.find(dep);
      if (depIter != JIT_KERNEL_RUNTIME_CACHE.end()) {
        dependencyCode += JIT_KERNEL_RUNTIME_CACHE[dep];
      }
    }
    dependencyCode += "}\n";
  }
  // Add dependency before JIT compile:
  new_code = dependencyCode + new_code;