Commit 6105e4ae authored by Mccaskey, Alex's avatar Mccaskey, Alex
Browse files

[WIP] work on ctrl / adjoint in python qjit kernels

parent 3c86792b
Loading
Loading
Loading
Loading
+23 −0
Original line number Diff line number Diff line
@@ -29,6 +29,29 @@ class pyxasm_visitor : public pyxasmBaseVisitor {

  antlrcpp::Any visitAtom_expr(
      pyxasmParser::Atom_exprContext *context) override {
      
    // Handle kernel::ctrl(...), kernel::adjoint(...)
    if (!context->trailer().empty() && context->trailer()[0]->getText() == ".ctrl") {
      std::cout << "HELLO: " << context->getText() << "\n";
      std::cout << context->trailer()[0]->getText() << "\n";
      std::cout << context->atom()->getText() << "\n";

      std::cout << context->trailer()[1]->getText() << "\n";
      std::cout << context->trailer()[1]->arglist() << "\n";
      auto arg_list = context->trailer()[1]->arglist();

      std::stringstream ss;
      ss << context->atom()->getText() << "::ctrl(parent_kernel";
      for (int i = 0; i < arg_list->argument().size(); i++) {
        ss << ", " << arg_list->argument(i)->getText();
      }
      ss << ");\n";

      std::cout << "HELLO SS: " << ss.str() << "\n";
      result.first = ss.str();
      return 0;

    }
    if (context->atom()->NAME() != nullptr) {
      auto inst_name = context->atom()->NAME()->getText();

+60 −52
Original line number Diff line number Diff line
@@ -10,6 +10,7 @@ from collections import defaultdict
List = typing.List
PauliOperator = xacc.quantum.PauliOperator


def X(idx):
    return xacc.quantum.PauliOperator({idx: 'X'}, 1.0)

@@ -34,7 +35,8 @@ class KernelGraph(object):
        self.kernel_name_list.append(kernelName)
        self.kernel_idx_dep_map[self.V] = []
        for dep_ker_name in depList:
            self.kernel_idx_dep_map[self.V].append(self.kernel_name_list.index(dep_ker_name))
            self.kernel_idx_dep_map[self.V].append(
                self.kernel_name_list.index(dep_ker_name))
        self.V += 1

    def addEdge(self, u, v):
@@ -86,6 +88,7 @@ class KernelGraph(object):
            else:
                result_dep.append(dep_name)


class qjit(object):
    """
    The qjit class serves a python function decorator that enables 
@@ -186,11 +189,13 @@ class qjit(object):
            if moduleAlias != importedModules[moduleAlias]:
                aliasModuleStr = moduleAlias + '.'
                originalModuleStr = importedModules[moduleAlias] + '.'
                fbody_src = fbody_src.replace(aliasModuleStr, originalModuleStr)
                fbody_src = fbody_src.replace(
                    aliasModuleStr, originalModuleStr)

        # Create the qcor quantum kernel function src for QJIT and the Clang syntax handler
        self.src = '__qpu__ void '+self.function.__name__ + \
            '('+cpp_arg_str+') {\nusing qcor::pyxasm;\n' + globalDeclStr + '\n' + fbody_src +"}\n"
            '('+cpp_arg_str+') {\nusing qcor::pyxasm;\n' + \
            globalDeclStr + '\n' + fbody_src + "}\n"

        # Handle nested kernels:
        dependency = []
@@ -201,8 +206,11 @@ class qjit(object):
            if re.search(r"\b" + re.escape(kernelCall), self.src):
                dependency.append(kernelName)
                
        self.__kernels__graph.addKernelDependency(self.function.__name__, dependency)
        sorted_kernel_dep = self.__kernels__graph.getSortedDependency(self.function.__name__)

        self.__kernels__graph.addKernelDependency(
            self.function.__name__, dependency)
        sorted_kernel_dep = self.__kernels__graph.getSortedDependency(
            self.function.__name__)

        # Run the QJIT compile step to store function pointers internally
        self._qjit.internal_python_jit_compile(self.src, sorted_kernel_dep)
@@ -300,9 +308,9 @@ class qjit(object):
        """
        return self.extract_composite(*args).nInstructions()

    # def ctrl(self, *args):


    def ctrl(self, *args):
        print('This is an internal API call and will be translated to C++ via the QJIT.\nIt can only be called from within another quantum kernel.')
        exit(1)

    def __call__(self, *args):
        """
+34 −47
Original line number Diff line number Diff line
@@ -220,7 +220,6 @@ class TestSimpleKernelJIT(unittest.TestCase):
            self.assertEqual(comp.getInstruction(i).name(), "Measure") 

    def test_iqft_kernel(self):
        import numpy as np
        @qjit
        def iqft(q : qreg, startIdx : int, nbQubits : int):
            for i in range(nbQubits/2):
@@ -254,61 +253,49 @@ class TestSimpleKernelJIT(unittest.TestCase):
            self.assertEqual(comp.getInstruction(i).name(), "CPhase")
        self.assertEqual(comp.getInstruction(16).name(), "H") 
        
    # def test_ctrl_kernel(self):
    #     @qjit
    #     def qft(q : qreg, startIdx : int, nbQubits : int): # with swap
    #         for i in range(nbQubits - 1, -1, -1):
    #             shiftedBitIdx = i + startIdx
    #             H(q[shiftedBitIdx])

    #             for j in range(i-1, -1, -1):
    #                 theta = np.pi / 2**(i-j)
    #                 tIdx = j + i
    #                 CPhase(q[shiftedBitIdx], q[tIdx], theta)

    #         swapCount = 0 if shouldSwap == 0 else 1
    #         for i in range(nbQubits/2):
    #             Swap(q[startIdx+i], q[startIdx+nbQubits-i-1])
        
    #     @qjit
    #     def iqft(q : qreg, startIdx : int, nbQubits : int):
    #         for i in range(nbQubits/2):
    #             Swap(q[startIdx + i], q[startIdx + nbQubits - i - 1])
            
    #         for i in range(nbQubits-1):
    #             H(q[startIdx+i])
    #             j = i +1
    #             for y in range(i, -1, -1):
    #                 theta = -np.pi / 2**(j-y)
    #                 CPhase(q[startIdx+j], q[startIdx + y], theta)
            
    #         H(q[startIdx+nbQubits-1])

    #     @qjit
    #     def oracle(q : qreg):
    #         bit = q.size()-1
    #         T(q[bit])

    #     def qpe(q : qreg):
    #         nq = q.size()

    #         for i in range(q.size()-1):
    #             H(q[i])
            
    #         bitPrecision = nq-1
    #         for i in range(bitPrecision):
    #             nbCalls = 1 << i
    #             for j in range(nbCalls):
    #                 ctrl_bit = i
    #                 oracle.ctrl(ctrl_bit, q)
            
    #         iqft(q, 0, bitPrecision)
    #         for i in range(bitPrecision):
    #             Measure(q[i])
        
    #     q = qalloc(4)
    #     qpe(q)
    #     print(q.counts())
    def test_ctrl_kernel(self):
        
        set_qpu('qpp', {'shots':1024})

        @qjit
        def iqft(q : qreg, startIdx : int, nbQubits : int):
            for i in range(nbQubits/2):
                Swap(q[startIdx + i], q[startIdx + nbQubits - i - 1])
            
            for i in range(nbQubits-1):
                H(q[startIdx+i])
                j = i +1
                for y in range(i, -1, -1):
                    theta = -MY_PI / 2**(j-y)
                    CPhase(q[startIdx+j], q[startIdx + y], theta)
            
            H(q[startIdx+nbQubits-1])

        @qjit
        def oracle(q : qreg):
            bit = q.size()-1
            T(q[bit])

        @qjit
        def qpe(q : qreg):
            nq = q.size()

            for i in range(q.size()-1):
                H(q[i])
            
            bitPrecision = nq-1
            for i in range(bitPrecision):
                nbCalls = 1 << i
                for j in range(nbCalls):
                    ctrl_bit = i
                    oracle.ctrl(ctrl_bit, q)
            
            for i in range(bitPrecision):
                Measure(q[i])
        
        q = qalloc(4)
        qpe(q)
        print(q.counts())