Commit 02ea30a7 authored by Nguyen, Thien Minh's avatar Nguyen, Thien Minh
Browse files

Add ctrl and adjoint for KernelSignature



This turns out to be quite simple since we know all the local variables (including those from kernel arguments).
Hence, using that to choose b/w '.' and '::'.

Signed-off-by: default avatarThien Nguyen <nguyentm@ornl.gov>
parent 6bae7628
Loading
Loading
Loading
Loading
+9 −1
Original line number Diff line number Diff line
@@ -178,7 +178,15 @@ class pyxasm_visitor : public pyxasmBaseVisitor {
      std::stringstream ss;
      // Remove the first '.' character
      const std::string methodName = context->trailer()[0]->getText().substr(1);
      ss << context->atom()->getText() << "::" << methodName
      // If this is a *variable*, then using '.' for control/adjoint.
      // Otherwise, use '::' (global scope kernel names)
      const std::string separator =
          (xacc::container::contains(declared_var_names,
                                     context->atom()->getText()))
              ? "."
              : "::";

      ss << context->atom()->getText() << separator << methodName
         << "(parent_kernel";
      for (int i = 0; i < arg_list->argument().size(); i++) {
        ss << ", " << rewriteFunctionArgument(*(arg_list->argument(i)));
+27 −0
Original line number Diff line number Diff line
@@ -259,6 +259,33 @@ auto index = std::pow(2, n);
  EXPECT_EQ(expectedCodeGen, ss.str());
}

TEST(PyXASMTokenCollectorTester, checkKernelSignature) {
  LexerHelper helper;

  auto [tokens, PP] = helper.Lex(R"(
    # fake local var creation 
    # (should be from the function args if compiling full source.)
    callable = createCallable(a,b,c)
    callable.ctrl([q[1], q[2]], q[0])
    callable.adjoint(q)
)");

  clang::CachedTokens cached;
  for (auto &t : tokens) {
    cached.push_back(t);
  }

  std::stringstream ss;
  auto xasm_tc = xacc::getService<qcor::TokenCollector>("pyxasm");
  xasm_tc->collect(*PP.get(), cached, {"q"}, ss);
  std::cout << "heres the test\n";
  std::cout << ss.str() << "\n";
  const std::string code_gen_str = ss.str();
  // Rewrite to '.'
  EXPECT_TRUE(code_gen_str.find("callable.ctrl(parent_kernel, {q[1], q[2]}, q[0]);") != std::string::npos);
  EXPECT_TRUE(code_gen_str.find("callable.adjoint(parent_kernel, q);") != std::string::npos);
}


int main(int argc, char **argv) {
  std::string xacc_config_install_dir = std::string(XACC_INSTALL_DIR);
+33 −0
Original line number Diff line number Diff line
@@ -38,5 +38,38 @@ class TestKernelJIT(unittest.TestCase):
            self.assertAlmostEqual((float)(comp.getInstruction(counter).getParameter(0)), i + 1.0)
            counter+=1

    def test_kernel_signature_ctrl_adj(self):
        set_qpu('qpp', {'shots':1024})
        
        @qjit
        def test_kernel1(q: qreg, call_var1: KernelSignature(qreg, int, float), call_var2: KernelSignature(qubit)):
            call_var1.adjoint(q, 0, 1.0)
            call_var1.adjoint(q, 1, 2.0)
            call_var2.ctrl(q[1], q[0])

        # These kernels are unknown to test_kernel 
        @qjit
        def rz_kernel(q: qreg, idx: int, theta: float):
            Rz(q[idx], theta)

        @qjit
        def x_kernel(q: qubit):
            X(q)

        q = qalloc(2)
        comp = test_kernel1.extract_composite(q, rz_kernel, x_kernel)
        print(comp)
        self.assertEqual(comp.nInstructions(), 3)   
        counter = 0
        for i in range(2):
            self.assertEqual(comp.getInstruction(counter).name(), "Rz") 
            # Minus due to adjoint
            self.assertAlmostEqual((float)(comp.getInstruction(counter).getParameter(0)), -(i + 1.0))
            counter+=1
        self.assertEqual(comp.getInstruction(2).name(), "CNOT")
        self.assertEqual(comp.getInstruction(2).bits()[0], 1)
        self.assertEqual(comp.getInstruction(2).bits()[1], 0)


if __name__ == '__main__':
  unittest.main()
 No newline at end of file