Commit 3c86792b authored by Mccaskey, Alex's avatar Mccaskey, Alex
Browse files

started on supporting python qjit kernel ctrl, adjoint, etc. implement 2**x in...


started on supporting python qjit kernel ctrl, adjoint, etc. implement 2**x in pyxasm visitor, added qft example to tests

Signed-off-by: Mccaskey, Alex's avatarAlex McCaskey <mccaskeyaj@ornl.gov>
parent e2d51688
Loading
Loading
Loading
Loading
Loading
+39 −9
Original line number Diff line number Diff line
@@ -27,8 +27,8 @@ public:

  bool in_for_loop = false;

  antlrcpp::Any
  visitAtom_expr(pyxasmParser::Atom_exprContext *context) override {
  antlrcpp::Any visitAtom_expr(
      pyxasmParser::Atom_exprContext *context) override {
    if (context->atom()->NAME() != nullptr) {
      auto inst_name = context->atom()->NAME()->getText();

@@ -69,9 +69,9 @@ public:
              auto found_bracket = bit_expr_str.find_first_of("[");
              if (found_bracket != std::string::npos) {
                auto buffer_name = bit_expr_str.substr(0, found_bracket);
                auto bit_idx_expr = bit_expr_str.substr(found_bracket + 1,
                                                        bit_expr_str.length() -
                                                            found_bracket - 2);
                auto bit_idx_expr = bit_expr_str.substr(
                    found_bracket + 1,
                    bit_expr_str.length() - found_bracket - 2);
                buffer_names.push_back(buffer_name);
                inst->setBitExpression(i, bit_idx_expr);
              } else {
@@ -173,12 +173,42 @@ public:
      const std::string rhs = ctx->testlist_star_expr(1)->getText();
      ss << "auto " << lhs << " = " << rhs << "; \n";
      result.first = ss.str();
      if (rhs.find("**") != std::string::npos) {
        // keep processing
        return visitChildren(ctx);
      } else {
        return 0;
      }
    } else {
      return visitChildren(ctx);
    }
  }

  antlrcpp::Any visitPower(pyxasmParser::PowerContext *context) override {
    if (context->getText().find("**") != std::string::npos &&
        context->factor() != nullptr) {
      // Here we handle x**y from parent assignment expression
      auto replaceAll = [](std::string &s, const std::string &search,
                           const std::string &replace) {
        for (std::size_t pos = 0;; pos += replace.length()) {
          // Locate the substring to replace
          pos = s.find(search, pos);
          if (pos == std::string::npos) break;
          // Replace by erasing and inserting
          s.erase(pos, search.length());
          s.insert(pos, replace);
        }
      };
      auto factor = context->factor();
      auto atom_expr = context->atom_expr();
      std::string s =
          "std::pow(" + atom_expr->getText() + ", " + factor->getText() + ")";
      replaceAll(result.first, context->getText(), s);
      return 0;
    }
    return visitChildren(context);
  }

 private:
  // Replaces common Python constants, e.g. 'math.pi' or 'numpy.pi'.
  // Note: the library names have been resolved to their original names.
+16 −0
Original line number Diff line number Diff line
@@ -288,6 +288,22 @@ class qjit(object):
        staq = xacc.getCompiler('staq')
        return staq.translate(kernel)

    def print_kernel(self, *args):
        """
        Print the QJIT kernel as a QASM-like string
        """
        print(self.extract_composite(*args).toString())
    
    def n_instructions(self, *args):
        """
        Return the number of quantum instructions in this kernel. 
        """
        return self.extract_composite(*args).nInstructions()

    # def ctrl(self, *args):



    def __call__(self, *args):
        """
        Execute the decorated quantum kernel. This will directly 
+94 −1
Original line number Diff line number Diff line
@@ -219,5 +219,98 @@ class TestSimpleKernelJIT(unittest.TestCase):
        for i in range(q.size() * len(list1) * len(list2), comp.nInstructions()):
            self.assertEqual(comp.getInstruction(i).name(), "Measure") 

    def test_iqft_kernel(self):
        import numpy as np
        @qjit
        def iqft(q : qreg, startIdx : int, nbQubits : int):
            for i in range(nbQubits/2):
                Swap(q[startIdx + i], q[startIdx + nbQubits - i - 1])
            
            for i in range(nbQubits-1):
                H(q[startIdx+i])
                j = i +1
                for y in range(i, -1, -1):
                    theta = -MY_PI / 2**(j-y)
                    CPhase(q[startIdx+j], q[startIdx + y], theta)
            
            H(q[startIdx+nbQubits-1])
        
        q = qalloc(5)
        comp = iqft.extract_composite(q, 0, 5)
        print(comp.toString())
        self.assertEqual(comp.nInstructions(), 17)   
        self.assertEqual(comp.getInstruction(0).name(), "Swap") 
        self.assertEqual(comp.getInstruction(1).name(), "Swap") 
        self.assertEqual(comp.getInstruction(2).name(), "H") 
        self.assertEqual(comp.getInstruction(3).name(), "CPhase") 
        self.assertEqual(comp.getInstruction(4).name(), "H") 
        for i in range(5, 7):
            self.assertEqual(comp.getInstruction(i).name(), "CPhase") 
        self.assertEqual(comp.getInstruction(7).name(), "H") 
        for i in range(8, 11):
            self.assertEqual(comp.getInstruction(i).name(), "CPhase") 
        self.assertEqual(comp.getInstruction(11).name(), "H") 
        for i in range(12, 16):
            self.assertEqual(comp.getInstruction(i).name(), "CPhase")
        self.assertEqual(comp.getInstruction(16).name(), "H") 
        
    # def test_ctrl_kernel(self):
    #     @qjit
    #     def qft(q : qreg, startIdx : int, nbQubits : int): # with swap
    #         for i in range(nbQubits - 1, -1, -1):
    #             shiftedBitIdx = i + startIdx
    #             H(q[shiftedBitIdx])

    #             for j in range(i-1, -1, -1):
    #                 theta = np.pi / 2**(i-j)
    #                 tIdx = j + i
    #                 CPhase(q[shiftedBitIdx], q[tIdx], theta)

    #         swapCount = 0 if shouldSwap == 0 else 1
    #         for i in range(nbQubits/2):
    #             Swap(q[startIdx+i], q[startIdx+nbQubits-i-1])
        
    #     @qjit
    #     def iqft(q : qreg, startIdx : int, nbQubits : int):
    #         for i in range(nbQubits/2):
    #             Swap(q[startIdx + i], q[startIdx + nbQubits - i - 1])
            
    #         for i in range(nbQubits-1):
    #             H(q[startIdx+i])
    #             j = i +1
    #             for y in range(i, -1, -1):
    #                 theta = -np.pi / 2**(j-y)
    #                 CPhase(q[startIdx+j], q[startIdx + y], theta)
            
    #         H(q[startIdx+nbQubits-1])

    #     @qjit
    #     def oracle(q : qreg):
    #         bit = q.size()-1
    #         T(q[bit])

    #     def qpe(q : qreg):
    #         nq = q.size()

    #         for i in range(q.size()-1):
    #             H(q[i])
            
    #         bitPrecision = nq-1
    #         for i in range(bitPrecision):
    #             nbCalls = 1 << i
    #             for j in range(nbCalls):
    #                 ctrl_bit = i
    #                 oracle.ctrl(ctrl_bit, q)
            
    #         iqft(q, 0, bitPrecision)
    #         for i in range(bitPrecision):
    #             Measure(q[i])
        
    #     q = qalloc(4)
    #     qpe(q)
    #     print(q.counts())



if __name__ == '__main__':
  unittest.main()
 No newline at end of file
+23 −3
Original line number Diff line number Diff line
@@ -47,6 +47,10 @@ using Handle = std::future<ResultsBuffer>;
// Sync up a Handle
ResultsBuffer sync(Handle &handle);

// Indicate we have an error with the given message.
// This should abort execution
void error(const std::string &msg);

template <typename T> std::vector<T> linspace(T a, T b, size_t N) {
  T h = (b - a) / static_cast<T>(N - 1);
  std::vector<T> xs(N);
@@ -63,6 +67,25 @@ inline std::vector<int> range(int N) {
  return vec;
}

inline std::vector<int> range(int start, int stop, int step) {
  if (step == 0) {
    error("step for range must be non-zero.");
  }

  int i = start;
  std::vector<int> vec;
  while ((step > 0) ? (i < stop) : (i > stop)) {
    vec.push_back(i);
    i+=step;
  }
  return vec;
}


inline std::vector<int> range(int start, int stop) {
  return range(start, stop, 1);
}

// Get size() of any types that have size() implemented.
template <typename T> int len(const T &countable) { return countable.size(); }
template <typename T> int len(T &countable) { return countable.size(); }
@@ -272,8 +295,5 @@ bool get_verbose();
// Set the shots for a given quantum kernel execution
void set_shots(const int shots);

// Indicate we have an error with the given message.
// This should abort execution
void error(const std::string &msg);

} // namespace qcor
 No newline at end of file