Commit e21fb98f authored by Nguyen, Thien Minh's avatar Nguyen, Thien Minh
Browse files

Complete QPE JIT unit test



iqft is a valid XACC circuit hence caused problems in the pyxasm_visitor.

Hence, fixed accordingly.

Signed-off-by: default avatarThien Nguyen <nguyentm@ornl.gov>
parent 67094264
Loading
Loading
Loading
Loading
+23 −0
Original line number Diff line number Diff line
@@ -142,6 +142,29 @@ class pyxasm_visitor : public pyxasmBaseVisitor {
               << context->trailer()[0]->arglist()->argument(2)->getText()
               << ");\n";
            result.first = ss.str();
          }
          // Handle potential name collision: user-defined kernel having the
          // same name as an XACC circuit: e.g. common names such as qft, iqft
          // Note: these circuits (except exp_i_theta) don't have QRT
          // equivalents.
          // Condition: first argument is a qubit register
          else if (!context->trailer()[0]->arglist()->argument().empty() &&
                   xacc::container::contains(bufferNames, context->trailer()[0]
                                                              ->arglist()
                                                              ->argument(0)
                                                              ->getText())) {
            std::stringstream ss;
            // Use the kernel call with a parent kernel arg.
            ss << inst_name << "(parent_kernel, ";
            const auto &argList = context->trailer()[0]->arglist()->argument();
            for (size_t i = 0; i < argList.size(); ++i) {
              ss << argList[i]->getText();
              if (i != argList.size() - 1) {
                ss << ", ";
              }
            }
            ss << ");\n";
            result.first = ss.str();
          } else {
            xacc::error("Composite instruction '" + inst_name +
                        "' is not currently supported.");
+4 −0
Original line number Diff line number Diff line
@@ -314,6 +314,10 @@ class qjit(object):
        print('This is an internal API call and will be translated to C++ via the QJIT.\nIt can only be called from within another quantum kernel.')
        exit(1)

    def adjoint(self, *args):
        print('This is an internal API call and will be translated to C++ via the QJIT.\nIt can only be called from within another quantum kernel.')
        exit(1)

    def __call__(self, *args):
        """
        Execute the decorated quantum kernel. This will directly 
+8 −5
Original line number Diff line number Diff line
@@ -221,7 +221,7 @@ class TestSimpleKernelJIT(unittest.TestCase):

    def test_iqft_kernel(self):
        @qjit
        def iqft(q : qreg, startIdx : int, nbQubits : int):
        def inverse_qft(q : qreg, startIdx : int, nbQubits : int):
            for i in range(nbQubits/2):
                Swap(q[startIdx + i], q[startIdx + nbQubits - i - 1])
            
@@ -235,7 +235,7 @@ class TestSimpleKernelJIT(unittest.TestCase):
            H(q[startIdx+nbQubits-1])
        
        q = qalloc(5)
        comp = iqft.extract_composite(q, 0, 5)
        comp = inverse_qft.extract_composite(q, 0, 5)
        print(comp.toString())
        self.assertEqual(comp.nInstructions(), 17)   
        self.assertEqual(comp.getInstruction(0).name(), "Swap") 
@@ -279,24 +279,27 @@ class TestSimpleKernelJIT(unittest.TestCase):
        @qjit
        def qpe(q : qreg):
            nq = q.size()

            X(q[nq - 1])
            for i in range(q.size()-1):
                H(q[i])
            
            bitPrecision = nq-1
            for i in range(bitPrecision):
                nbCalls = 1 << i
                nbCalls = 2**i
                for j in range(nbCalls):
                    ctrl_bit = i
                    oracle.ctrl(ctrl_bit, q)
            
            # Inverse QFT on the counting qubits
            iqft(q, 0, bitPrecision)
            
            for i in range(bitPrecision):
                Measure(q[i])
        
        q = qalloc(4)
        qpe(q)
        print(q.counts())

        self.assertEqual(q.counts()['100'], 1024)


if __name__ == '__main__':