Commit 9ad2d539 authored by Nguyen, Thien Minh's avatar Nguyen, Thien Minh
Browse files

qrt broadcast API to take qreg by value (copies)



This allows passing a qreg slice directly to broadcast instructions.

Refactor the PyXASM token collector impl.

Signed-off-by: default avatarThien Nguyen <nguyentm@ornl.gov>
parent a6938257
Loading
Loading
Loading
Loading
+97 −100
Original line number Diff line number Diff line
@@ -36,6 +36,7 @@ class pyxasm_visitor : public pyxasmBaseVisitor {
  // Var to keep track of sub-node rewrite:
  // e.g., traverse down the AST recursively.
  std::stringstream sub_node_translation;
  bool is_processing_sub_expr = false;

  antlrcpp::Any visitAtom_expr(
      pyxasmParser::Atom_exprContext *context) override {
@@ -53,7 +54,9 @@ class pyxasm_visitor : public pyxasmBaseVisitor {
       '{' (dictorsetmaker)? '}' |
       NAME | NUMBER | STRING+ | '...' | 'None' | 'True' | 'False');
    */

    // Only processes these for sub-expressesions, 
    // e.g. re-entries to this function
    if (is_processing_sub_expr) {
      if (context->atom() && context->atom()->testlist_comp()) {
        // Array type expression:
        std::cout << "Array atom expression: "
@@ -92,7 +95,8 @@ class pyxasm_visitor : public pyxasmBaseVisitor {
            cppStrLiteral.back() = '"';
          }
          sub_node_translation << cppStrLiteral;
        std::cout << "String expression: " << strNode->getText() << " --> " << cppStrLiteral << "\n";
          std::cout << "String expression: " << strNode->getText() << " --> "
                    << cppStrLiteral << "\n";
        }
        return 0;
      }
@@ -112,13 +116,14 @@ class pyxasm_visitor : public pyxasmBaseVisitor {
        return false;
      };

    // Handle slicing operations (multiple array subscriptions separated by ':')
    // on a qreg.
      // Handle slicing operations (multiple array subscriptions separated by
      // ':') on a qreg.
      if (context->atom() &&
          xacc::container::contains(bufferNames, context->atom()->getText()) &&
          isSliceOp(context)) {
        std::cout << "Slice op: " << context->getText() << "\n";
      sub_node_translation << context->atom()->getText() << ".extract_range({";
        sub_node_translation << context->atom()->getText()
                             << ".extract_range({";
        auto subscripts =
            context->trailer(0)->subscriptlist()->subscript(0)->test();
        assert(subscripts.size() > 1);
@@ -137,7 +142,8 @@ class pyxasm_visitor : public pyxasmBaseVisitor {
        for (int i = 0; i < subscriptTerms.size(); ++i) {
          // Need to cast to prevent compiler errors,
          // e.g. when using q.size() which returns an int.
        sub_node_translation << "static_cast<size_t>(" << subscriptTerms[i] << ")";
          sub_node_translation << "static_cast<size_t>(" << subscriptTerms[i]
                               << ")";
          if (i != subscriptTerms.size() - 1) {
            sub_node_translation << ", ";
          }
@@ -151,6 +157,9 @@ class pyxasm_visitor : public pyxasmBaseVisitor {
        return 0;
      }

      return 0;
    }

    // Handle kernel::ctrl(...), kernel::adjoint(...)
    if (!context->trailer().empty() &&
        (context->trailer()[0]->getText() == ".ctrl" ||
@@ -334,22 +343,7 @@ class pyxasm_visitor : public pyxasmBaseVisitor {
                  context->trailer()[0]->arglist()->argument();
              ss << inst_name << "(";
              for (size_t i = 0; i < argList.size(); ++i) {                
                // Find rewrite for arguments
                sub_node_translation.str(std::string());
                // visit arg sub-node:
                visitChildren(argList[i]);
                // Check if there is a rewrite:
                if (!sub_node_translation.str().empty()) {
                  const auto arg_new_str = sub_node_translation.str();
                  std::cout << argList[i]->getText() << " --> " << arg_new_str << "\n";
                  sub_node_translation.str(std::string());
                  ss << arg_new_str;
                }
                else {
                  // Use the arg as is:
                  ss << argList[i]->getText();
                }
                                
                ss << rewriteFunctionArgument(*(argList[i]));                
                if (i != argList.size() - 1) {
                  ss << ", ";
                }
@@ -417,7 +411,7 @@ class pyxasm_visitor : public pyxasmBaseVisitor {
      } else {
        // Strategy: try to traverse the rhs to see if there is a possible rewrite;
        // Otherwise, use the text as is.
        
        is_processing_sub_expr = true;
        // clear the sub_node_translation  
        sub_node_translation.str(std::string());

@@ -549,7 +543,10 @@ class pyxasm_visitor : public pyxasmBaseVisitor {
    // Strategy: try to traverse the argument context to see if there is a
    // possible rewrite; i.e. it may be another atom_expression that we have a
    // handler for. Otherwise, use the text as is.

    // We need this flag to prevent parsing quantum instructions as sub-expressions.
    // e.g. QCOR operators (X, Y, Z) in an observable definition shouldn't be 
    // processed as instructions.
    is_processing_sub_expr = true;
    // clear the sub_node_translation
    sub_node_translation.str(std::string());

+30 −6
Original line number Diff line number Diff line
@@ -164,8 +164,8 @@ auto last_qubit = q.tail();
Z::ctrl(parent_kernel, ctrl_qubits, last_qubit);
X::ctrl(parent_kernel, q.head(q.size()-1), q.tail());
auto r = q.extract_range(0,bitPrecision); 
auto slice1 = q.extract_range({0, 3}); 
auto slice2 = q.extract_range({0, 5, 2}); 
auto slice1 = q.extract_range({static_cast<size_t>(0), static_cast<size_t>(3)}); 
auto slice2 = q.extract_range({static_cast<size_t>(0), static_cast<size_t>(5), static_cast<size_t>(2)}); 
)#";
  EXPECT_EQ(expectedCodeGen, ss.str());
}
@@ -199,13 +199,37 @@ TEST(PyXASMTokenCollectorTester, checkBroadCastWithSlice) {
      R"#(quantum::x(q.head(q.size()-1));
quantum::x(q[0]);
quantum::x(q);
quantum::x(q.extract_range({0, 2}));
quantum::x(q.extract_range({0, 5, 2}));
quantum::x(q.extract_range({static_cast<size_t>(0), static_cast<size_t>(2)}));
quantum::x(q.extract_range({static_cast<size_t>(0), static_cast<size_t>(5), static_cast<size_t>(2)}));
quantum::mz(q.head(q.size()-1));
quantum::mz(q[0]);
quantum::mz(q);
quantum::mz(q.extract_range({0, 2}));
quantum::mz(q.extract_range({0, 5, 2}));
quantum::mz(q.extract_range({static_cast<size_t>(0), static_cast<size_t>(2)}));
quantum::mz(q.extract_range({static_cast<size_t>(0), static_cast<size_t>(5), static_cast<size_t>(2)}));
)#";
  EXPECT_EQ(expectedCodeGen, ss.str());
}

TEST(PyXASMTokenCollectorTester, checkQcorOperators) {
  LexerHelper helper;
  auto [tokens, PP] = helper.Lex(R"(
    exponent_op = X(0) * Y(1) - Y(0) * X(1)
    exp_i_theta(q, theta, exponent_op)
)");

  clang::CachedTokens cached;
  for (auto &t : tokens) {
    cached.push_back(t);
  }

  std::stringstream ss;
  auto xasm_tc = xacc::getService<qcor::TokenCollector>("pyxasm");
  xasm_tc->collect(*PP.get(), cached, {"q"}, ss);
  std::cout << "heres the test\n";
  std::cout << ss.str() << "\n";
  const std::string expectedCodeGen =
      R"#(auto exponent_op = X(0)*Y(1)-Y(0)*X(1); 
quantum::exp(q, theta, exponent_op); 
)#";
  EXPECT_EQ(expectedCodeGen, ss.str());
}
+33 −0
Original line number Diff line number Diff line
@@ -179,5 +179,38 @@ class TestKernelJIT(unittest.TestCase):
        self.assertEqual(comp2.getInstruction(1).name(), "Z") 
        self.assertEqual(comp3.getInstruction(1).name(), "T") 

    def test_instBroadCast(self):
        set_qpu('qpp', {'shots':1024})
        
        @qjit
        def broadCastTest(q : qreg):
            # Simple broadcast
            X(q)
            # broadcast by slice
            Z(q[0:q.size()])
            # Even qubits
            Y(q[0:q.size():2])

        q = qalloc(6)
        comp = broadCastTest.extract_composite(q)
        counter = 0
        for i in range(q.size()):
            self.assertEqual(comp.getInstruction(counter).name(), "X") 
            self.assertEqual(comp.getInstruction(counter).bits()[0], i) 
            counter += 1
        
        for i in range(q.size()):
            self.assertEqual(comp.getInstruction(counter).name(), "Z") 
            self.assertEqual(comp.getInstruction(counter).bits()[0], i) 
            counter += 1
        
        for i in range(0, q.size(), 2):
            self.assertEqual(comp.getInstruction(counter).name(), "Y") 
            self.assertEqual(comp.getInstruction(counter).bits()[0], i) 
            counter += 1    
        
        self.assertEqual(comp.nInstructions(), counter)
        

if __name__ == '__main__':
  unittest.main()
 No newline at end of file
+15 −15
Original line number Diff line number Diff line
@@ -225,80 +225,80 @@ void persistBitstring(xacc::AcceleratorBuffer *buffer) {
  }
}

void h(qreg &q) {
void h(qreg q) {
  for (int i = 0; i < q.size(); i++) {
    h(q[i]);
  }
}

void x(qreg &q) {
void x(qreg q) {
  for (int i = 0; i < q.size(); i++) {
    x(q[i]);
  }
}
void y(qreg &q) {
void y(qreg q) {
  for (int i = 0; i < q.size(); i++) {
    y(q[i]);
  }
}
void z(qreg &q) {
void z(qreg q) {
  for (int i = 0; i < q.size(); i++) {
    z(q[i]);
  }
}
void t(qreg &q) {
void t(qreg q) {
  for (int i = 0; i < q.size(); i++) {
    t(q[i]);
  }
}
void tdg(qreg &q) {
void tdg(qreg q) {
  for (int i = 0; i < q.size(); i++) {
    tdg(q[i]);
  }
}
void s(qreg &q) {
void s(qreg q) {
  for (int i = 0; i < q.size(); i++) {
    s(q[i]);
  }
}
void sdg(qreg &q) {
void sdg(qreg q) {
  for (int i = 0; i < q.size(); i++) {
    sdg(q[i]);
  }
}
void mz(qreg &q) {
void mz(qreg q) {
  for (int i = 0; i < q.size(); i++) {
    mz(q[i]);
  }
}

void rx(qreg &q, const double theta) {
void rx(qreg q, const double theta) {
  for (int i = 0; i < q.size(); i++) {
    rx(q[i], theta);
  }
}
void ry(qreg &q, const double theta) {
void ry(qreg q, const double theta) {
  for (int i = 0; i < q.size(); i++) {
    ry(q[i], theta);
  }
}
void rz(qreg &q, const double theta) {
void rz(qreg q, const double theta) {
  for (int i = 0; i < q.size(); i++) {
    rz(q[i], theta);
  }
}
// U1(theta) gate
void u1(qreg &q, const double theta) {
void u1(qreg q, const double theta) {
  for (int i = 0; i < q.size(); i++) {
    u1(q[i], theta);
  }
}
void u3(qreg &q, const double theta, const double phi, const double lambda) {
void u3(qreg q, const double theta, const double phi, const double lambda) {
  for (int i = 0; i < q.size(); i++) {
    u3(q[i], theta, phi, lambda);
  }
}
void reset(qreg &q) {
void reset(qreg q) {
  for (int i = 0; i < q.size(); i++) {
    reset(q[i]);
  }
+15 −15
Original line number Diff line number Diff line
@@ -125,15 +125,15 @@ void sdg(const qubit &qidx);
void reset(const qubit &qidx);

// broadcast across qreg
void h(qreg &q);
void x(qreg &q);
void y(qreg &q);
void z(qreg &q);
void t(qreg &q);
void tdg(qreg &q);
void s(qreg &q);
void sdg(qreg &q);
void reset(qreg &qidx);
void h(qreg q);
void x(qreg q);
void y(qreg q);
void z(qreg q);
void t(qreg q);
void tdg(qreg q);
void s(qreg q);
void sdg(qreg q);
void reset(qreg qidx);

// Common single-qubit, parameterized instructions
void rx(const qubit &qidx, const double theta);
@@ -145,17 +145,17 @@ void u3(const qubit &qidx, const double theta, const double phi,
        const double lambda);

// broadcast rotations across qubits
void rx(qreg &qidx, const double theta);
void ry(qreg &qidx, const double theta);
void rz(qreg &qidx, const double theta);
void rx(qreg qidx, const double theta);
void ry(qreg qidx, const double theta);
void rz(qreg qidx, const double theta);
// U1(theta) gate
void u1(qreg &qidx, const double theta);
void u3(qreg &qidx, const double theta, const double phi,
void u1(qreg qidx, const double theta);
void u3(qreg qidx, const double theta, const double phi,
        const double lambda);

// Measure-Z and broadcast mz
bool mz(const qubit &qidx);
void mz(qreg &q);
void mz(qreg q);

// Common two-qubit gates.
void cnot(const qubit &src_idx, const qubit &tgt_idx);