Unverified Commit 8844bf8b authored by Mccaskey, Alex's avatar Mccaskey, Alex Committed by GitHub
Browse files

Merge pull request #120 from tnguyen-ornl/tnguyen/ctrl-adj-callable

Add ctrl and adjoint for KernelSignature
parents 09e1c239 02ea30a7
Loading
Loading
Loading
Loading
Loading
+9 −1
Original line number Diff line number Diff line
@@ -178,7 +178,15 @@ class pyxasm_visitor : public pyxasmBaseVisitor {
      std::stringstream ss;
      // Remove the first '.' character
      const std::string methodName = context->trailer()[0]->getText().substr(1);
      ss << context->atom()->getText() << "::" << methodName
      // If this is a *variable*, then using '.' for control/adjoint.
      // Otherwise, use '::' (global scope kernel names)
      const std::string separator =
          (xacc::container::contains(declared_var_names,
                                     context->atom()->getText()))
              ? "."
              : "::";

      ss << context->atom()->getText() << separator << methodName
         << "(parent_kernel";
      for (int i = 0; i < arg_list->argument().size(); i++) {
        ss << ", " << rewriteFunctionArgument(*(arg_list->argument(i)));
+27 −0
Original line number Diff line number Diff line
@@ -259,6 +259,33 @@ auto index = std::pow(2, n);
  EXPECT_EQ(expectedCodeGen, ss.str());
}

TEST(PyXASMTokenCollectorTester, checkKernelSignature) {
  LexerHelper helper;

  auto [tokens, PP] = helper.Lex(R"(
    # fake local var creation 
    # (should be from the function args if compiling full source.)
    callable = createCallable(a,b,c)
    callable.ctrl([q[1], q[2]], q[0])
    callable.adjoint(q)
)");

  clang::CachedTokens cached;
  for (auto &t : tokens) {
    cached.push_back(t);
  }

  std::stringstream ss;
  auto xasm_tc = xacc::getService<qcor::TokenCollector>("pyxasm");
  xasm_tc->collect(*PP.get(), cached, {"q"}, ss);
  std::cout << "heres the test\n";
  std::cout << ss.str() << "\n";
  const std::string code_gen_str = ss.str();
  // Rewrite to '.'
  EXPECT_TRUE(code_gen_str.find("callable.ctrl(parent_kernel, {q[1], q[2]}, q[0]);") != std::string::npos);
  EXPECT_TRUE(code_gen_str.find("callable.adjoint(parent_kernel, q);") != std::string::npos);
}


int main(int argc, char **argv) {
  std::string xacc_config_install_dir = std::string(XACC_INSTALL_DIR);
+33 −0
Original line number Diff line number Diff line
@@ -38,5 +38,38 @@ class TestKernelJIT(unittest.TestCase):
            self.assertAlmostEqual((float)(comp.getInstruction(counter).getParameter(0)), i + 1.0)
            counter+=1

    def test_kernel_signature_ctrl_adj(self):
        set_qpu('qpp', {'shots':1024})
        
        @qjit
        def test_kernel1(q: qreg, call_var1: KernelSignature(qreg, int, float), call_var2: KernelSignature(qubit)):
            call_var1.adjoint(q, 0, 1.0)
            call_var1.adjoint(q, 1, 2.0)
            call_var2.ctrl(q[1], q[0])

        # These kernels are unknown to test_kernel 
        @qjit
        def rz_kernel(q: qreg, idx: int, theta: float):
            Rz(q[idx], theta)

        @qjit
        def x_kernel(q: qubit):
            X(q)

        q = qalloc(2)
        comp = test_kernel1.extract_composite(q, rz_kernel, x_kernel)
        print(comp)
        self.assertEqual(comp.nInstructions(), 3)   
        counter = 0
        for i in range(2):
            self.assertEqual(comp.getInstruction(counter).name(), "Rz") 
            # Minus due to adjoint
            self.assertAlmostEqual((float)(comp.getInstruction(counter).getParameter(0)), -(i + 1.0))
            counter+=1
        self.assertEqual(comp.getInstruction(2).name(), "CNOT")
        self.assertEqual(comp.getInstruction(2).bits()[0], 1)
        self.assertEqual(comp.getInstruction(2).bits()[1], 0)


if __name__ == '__main__':
  unittest.main()
 No newline at end of file