Unverified Commit d110cb18 authored by Mccaskey, Alex's avatar Mccaskey, Alex Committed by GitHub
Browse files

Merge pull request #50 from tnguyen-ornl/pyxasm_ctrl_ajoint

More work on ctrl and adjoint
parents 6105e4ae dcdd6c93
Loading
Loading
Loading
Loading
Loading
+30 −3
Original line number Diff line number Diff line
@@ -31,7 +31,9 @@ class pyxasm_visitor : public pyxasmBaseVisitor {
      pyxasmParser::Atom_exprContext *context) override {
      
    // Handle kernel::ctrl(...), kernel::adjoint(...)
    if (!context->trailer().empty() && context->trailer()[0]->getText() == ".ctrl") {
    if (!context->trailer().empty() &&
        (context->trailer()[0]->getText() == ".ctrl" ||
         context->trailer()[0]->getText() == ".adjoint")) {
      std::cout << "HELLO: " << context->getText() << "\n";
      std::cout << context->trailer()[0]->getText() << "\n";
      std::cout << context->atom()->getText() << "\n";
@@ -41,7 +43,10 @@ class pyxasm_visitor : public pyxasmBaseVisitor {
      auto arg_list = context->trailer()[1]->arglist();

      std::stringstream ss;
      ss << context->atom()->getText() << "::ctrl(parent_kernel";
      // Remove the first '.' character
      const std::string methodName = context->trailer()[0]->getText().substr(1);
      ss << context->atom()->getText() << "::" << methodName
         << "(parent_kernel";
      for (int i = 0; i < arg_list->argument().size(); i++) {
        ss << ", " << arg_list->argument(i)->getText();
      }
@@ -50,7 +55,6 @@ class pyxasm_visitor : public pyxasmBaseVisitor {
      std::cout << "HELLO SS: " << ss.str() << "\n";
      result.first = ss.str();
      return 0;

    }
    if (context->atom()->NAME() != nullptr) {
      auto inst_name = context->atom()->NAME()->getText();
@@ -138,6 +142,29 @@ class pyxasm_visitor : public pyxasmBaseVisitor {
               << context->trailer()[0]->arglist()->argument(2)->getText()
               << ");\n";
            result.first = ss.str();
          }
          // Handle potential name collision: user-defined kernel having the
          // same name as an XACC circuit: e.g. common names such as qft, iqft
          // Note: these circuits (except exp_i_theta) don't have QRT
          // equivalents.
          // Condition: first argument is a qubit register
          else if (!context->trailer()[0]->arglist()->argument().empty() &&
                   xacc::container::contains(bufferNames, context->trailer()[0]
                                                              ->arglist()
                                                              ->argument(0)
                                                              ->getText())) {
            std::stringstream ss;
            // Use the kernel call with a parent kernel arg.
            ss << inst_name << "(parent_kernel, ";
            const auto &argList = context->trailer()[0]->arglist()->argument();
            for (size_t i = 0; i < argList.size(); ++i) {
              ss << argList[i]->getText();
              if (i != argList.size() - 1) {
                ss << ", ";
              }
            }
            ss << ");\n";
            result.first = ss.str();
          } else {
            xacc::error("Composite instruction '" + inst_name +
                        "' is not currently supported.");
+9 −3
Original line number Diff line number Diff line
@@ -200,10 +200,12 @@ class qjit(object):
        # Handle nested kernels:
        dependency = []
        for kernelName in self.__compiled__kernels:
            kernelCall = kernelName + '('
            # Check that this kernel *calls* a previously-compiled kernel:
            # pattern: "<white space> kernel("
            if re.search(r"\b" + re.escape(kernelCall), self.src):
            # pattern: "<white space> kernel(" OR "kernel.adjoint(" OR "kernel.ctrl("
            kernelCall = kernelName + '('
            kernelAdjCall = kernelName + '.adjoint('
            kernelCtrlCall = kernelName + '.ctrl('
            if re.search(r"\b" + re.escape(kernelCall) + '|' + re.escape(kernelAdjCall) + '|' + re.escape(kernelCtrlCall), self.src):
                dependency.append(kernelName)
                

@@ -312,6 +314,10 @@ class qjit(object):
        print('This is an internal API call and will be translated to C++ via the QJIT.\nIt can only be called from within another quantum kernel.')
        exit(1)

    def adjoint(self, *args):
        print('This is an internal API call and will be translated to C++ via the QJIT.\nIt can only be called from within another quantum kernel.')
        exit(1)

    def __call__(self, *args):
        """
        Execute the decorated quantum kernel. This will directly 
+32 −4
Original line number Diff line number Diff line
@@ -221,7 +221,7 @@ class TestSimpleKernelJIT(unittest.TestCase):

    def test_iqft_kernel(self):
        @qjit
        def iqft(q : qreg, startIdx : int, nbQubits : int):
        def inverse_qft(q : qreg, startIdx : int, nbQubits : int):
            for i in range(nbQubits/2):
                Swap(q[startIdx + i], q[startIdx + nbQubits - i - 1])
            
@@ -235,7 +235,7 @@ class TestSimpleKernelJIT(unittest.TestCase):
            H(q[startIdx+nbQubits-1])
        
        q = qalloc(5)
        comp = iqft.extract_composite(q, 0, 5)
        comp = inverse_qft.extract_composite(q, 0, 5)
        print(comp.toString())
        self.assertEqual(comp.nInstructions(), 17)   
        self.assertEqual(comp.getInstruction(0).name(), "Swap") 
@@ -279,24 +279,52 @@ class TestSimpleKernelJIT(unittest.TestCase):
        @qjit
        def qpe(q : qreg):
            nq = q.size()

            X(q[nq - 1])
            for i in range(q.size()-1):
                H(q[i])
            
            bitPrecision = nq-1
            for i in range(bitPrecision):
                nbCalls = 1 << i
                nbCalls = 2**i
                for j in range(nbCalls):
                    ctrl_bit = i
                    oracle.ctrl(ctrl_bit, q)
            
            # Inverse QFT on the counting qubits
            iqft(q, 0, bitPrecision)
            
            for i in range(bitPrecision):
                Measure(q[i])
        
        q = qalloc(4)
        qpe(q)
        print(q.counts())
        self.assertEqual(q.counts()['100'], 1024)

    def test_adjoint_kernel(self):
        @qjit
        def test_kernel(q : qreg):
            CX(q[0], q[1])
            Rx(q[0], 1.234)
            T(q[0])
            X(q[0])

        @qjit
        def check_adjoint(q : qreg):
            test_kernel.adjoint(q)
        
        q = qalloc(2)
        comp = check_adjoint.extract_composite(q)
        print(comp.toString())
        self.assertEqual(comp.nInstructions(), 4)   
        # Reverse
        self.assertEqual(comp.getInstruction(0).name(), "X") 
        # Check T -> Tdg
        self.assertEqual(comp.getInstruction(1).name(), "Tdg") 
        self.assertEqual(comp.getInstruction(2).name(), "Rx") 
        # Check angle -> -angle
        self.assertAlmostEqual((float)(comp.getInstruction(2).getParameter(0)), -1.234)
        self.assertEqual(comp.getInstruction(3).name(), "CNOT") 


if __name__ == '__main__':
+6 −2
Original line number Diff line number Diff line
@@ -108,7 +108,6 @@ public:
    }

    auto provider = qcor::__internal__::get_provider();
    std::reverse(instructions.begin(), instructions.end());
    for (int i = 0; i < instructions.size(); i++) {
      auto inst = derived.parent_kernel->getInstruction(i);
      // Parametric gates:
@@ -127,8 +126,13 @@ public:
      }
    }

    // We update/replace instructions in the derived.parent_kernel composite,
    // hence collecting these new instructions and reversing the sequence.
    auto new_instructions = derived.parent_kernel->getInstructions();
    std::reverse(new_instructions.begin(), new_instructions.end());

    // add the instructions to the current parent kernel
    parent_kernel->addInstructions(instructions);
    parent_kernel->addInstructions(new_instructions);

    // no measures, so no execute
  }
+5 −0
Original line number Diff line number Diff line
@@ -38,6 +38,11 @@ get_transformation(const std::string &transform_type) {
    xacc::internal_compiler::compiler_InitializeXACC();
  return xacc::getService<xacc::IRTransformation>(transform_type);
}

std::shared_ptr<qcor::IRProvider> get_provider() {
  return xacc::getIRProvider("quantum");
}

std::shared_ptr<qcor::CompositeInstruction>
decompose_unitary(const std::string algorithm, UnitaryMatrix &mat,
                  const std::string buffer_name) {