Commit a6938257 authored by Nguyen, Thien Minh's avatar Nguyen, Thien Minh
Browse files

Refactor arg rewrite into a helper



Signed-off-by: default avatarThien Nguyen <nguyentm@ornl.gov>
parent 888041d5
Loading
Loading
Loading
Loading
+32 −24
Original line number Diff line number Diff line
@@ -135,7 +135,9 @@ class pyxasm_visitor : public pyxasmBaseVisitor {
      assert(subscriptTerms.size() == 2 || subscriptTerms.size() == 3);

      for (int i = 0; i < subscriptTerms.size(); ++i) {
        sub_node_translation << subscriptTerms[i];
        // Need to cast to prevent compiler errors,
        // e.g. when using q.size() which returns an int.
        sub_node_translation << "static_cast<size_t>(" << subscriptTerms[i] << ")";
        if (i != subscriptTerms.size() - 1) {
          sub_node_translation << ", ";
        }
@@ -167,28 +169,7 @@ class pyxasm_visitor : public pyxasmBaseVisitor {
      ss << context->atom()->getText() << "::" << methodName
         << "(parent_kernel";
      for (int i = 0; i < arg_list->argument().size(); i++) {
        // Strategy:
        // Traverse down the tree to see if the there is a potential translation:
        // i.e. it will populate sub_node_translation stream.
        // Otherwise, output the argument *as-is*

        // clear the sub_node_translation  
        sub_node_translation.str(std::string());

        // visit arg sub-node:
        visitChildren(arg_list->argument(i));

        // Check if there is a rewrite:
        if (!sub_node_translation.str().empty()) {
          const auto arg_new_str = sub_node_translation.str();
          std::cout << arg_list->argument(i)->getText() << " --> " << arg_new_str << "\n";
          sub_node_translation.str(std::string());
          ss << ", " << arg_new_str;
        }
        else {
          // Use the arg as is:
          ss << ", " << arg_list->argument(i)->getText();
        }
        ss << ", " << rewriteFunctionArgument(*(arg_list->argument(i)));
      }
      ss << ");\n";

@@ -231,7 +212,7 @@ class pyxasm_visitor : public pyxasmBaseVisitor {
            std::vector<std::string> buffer_names;
            for (int i = 0; i < required_bits; i++) {
              auto bit_expr = context->trailer()[0]->arglist()->argument()[i];
              auto bit_expr_str = bit_expr->getText();
              auto bit_expr_str = rewriteFunctionArgument(*bit_expr);

              auto found_bracket = bit_expr_str.find_first_of("[");
              if (found_bracket != std::string::npos) {
@@ -556,4 +537,31 @@ class pyxasm_visitor : public pyxasmBaseVisitor {
      return in_expr;
    }
  }

  // A helper to rewrite function argument by traversing the node to see
  // if there is a potential rewrite.
  // Use case: inline expressions
  // e.g. X(q[0:3])
  // slicing of the qreg 'q' then call the broadcast X op.
  // i.e., we need to rewrite the arg to q.extract_range({0, 3}).
  std::string
  rewriteFunctionArgument(pyxasmParser::ArgumentContext &in_argContext) {
    // Strategy: try to traverse the argument context to see if there is a
    // possible rewrite; i.e. it may be another atom_expression that we have a
    // handler for. Otherwise, use the text as is.

    // clear the sub_node_translation
    sub_node_translation.str(std::string());

    // visit arg sub-node:
    visitChildren(&in_argContext);

    // Check if there is a rewrite:
    if (!sub_node_translation.str().empty()) {
      // Update RHS
      return sub_node_translation.str();
    }
    // Returns the string as is
    return in_argContext.getText();
  }
};
 No newline at end of file
+40 −0
Original line number Diff line number Diff line
@@ -170,6 +170,46 @@ auto slice2 = q.extract_range({0, 5, 2});
  EXPECT_EQ(expectedCodeGen, ss.str());
}

TEST(PyXASMTokenCollectorTester, checkBroadCastWithSlice) {
  LexerHelper helper;
  auto [tokens, PP] = helper.Lex(R"(
    X(q.head(q.size()-1))
    X(q[0])
    X(q)
    X(q[0:2])
    X(q[0:5:2])
    Measure(q.head(q.size()-1))
    Measure(q[0])
    Measure(q)
    Measure(q[0:2])
    Measure(q[0:5:2])
)");

  clang::CachedTokens cached;
  for (auto &t : tokens) {
    cached.push_back(t);
  }

  std::stringstream ss;
  auto xasm_tc = xacc::getService<qcor::TokenCollector>("pyxasm");
  xasm_tc->collect(*PP.get(), cached, {"q"}, ss);
  std::cout << "heres the test\n";
  std::cout << ss.str() << "\n";
  const std::string expectedCodeGen =
      R"#(quantum::x(q.head(q.size()-1));
quantum::x(q[0]);
quantum::x(q);
quantum::x(q.extract_range({0, 2}));
quantum::x(q.extract_range({0, 5, 2}));
quantum::mz(q.head(q.size()-1));
quantum::mz(q[0]);
quantum::mz(q);
quantum::mz(q.extract_range({0, 2}));
quantum::mz(q.extract_range({0, 5, 2}));
)#";
  EXPECT_EQ(expectedCodeGen, ss.str());
}

int main(int argc, char **argv) {
  std::string xacc_config_install_dir = std::string(XACC_INSTALL_DIR);
  std::string qcor_root = std::string(QCOR_INSTALL_DIR);
+86 −0
Original line number Diff line number Diff line
@@ -33,5 +33,91 @@ class TestKernelJIT(unittest.TestCase):
        # q0: 1 --> 0
        self.assertTrue('0111' in counts)

    def test_qreg_head_tail(self):
        set_qpu('qpp', {'shots':1024})

        @qjit
        def test_cccx_qreg(q : qreg):
            # Broadcast
            X(q)
            # 3 control bits
            ctrl_qubits = q.tail(q.size() - 1)
            first_qubit = q.head()
            X.ctrl(ctrl_qubits, first_qubit)
            # # Broadcast
            Measure(q)
        
        q = qalloc(4)
        comp = test_cccx_qreg.extract_composite(q)
        print(comp)

        # Run experiment
        test_cccx_qreg(q)

        # Print the results
        q.print()
        counts = q.counts()
        print(counts)
        self.assertEqual(len(counts), 1)
        # q0: 1 --> 0
        self.assertTrue('0111' in counts)
    
    def test_qreg_slicing(self):
        set_qpu('qpp', {'shots':1024})

        @qjit
        def test_cccx_qreg_slice(q : qreg):
            # Broadcast
            X(q)
            # 3 control bits:
            # q[0], q[1], q[2]
            ctrl_qubits = q[0:3]
            last_qubit = q.tail()
            X.ctrl(ctrl_qubits, last_qubit)
            # Broadcast
            Measure(q)
        
        q = qalloc(4)
        comp = test_cccx_qreg_slice.extract_composite(q)
        print(comp)

        # Run experiment
        test_cccx_qreg_slice(q)

        # Print the results
        q.print()
        counts = q.counts()
        print(counts)
        self.assertEqual(len(counts), 1)
        # q3: 1 --> 0
        self.assertTrue('1110' in counts)
    
    def test_qreg_slicing_inline(self):
        set_qpu('qpp', {'shots':1024})

        @qjit
        def test_cccx_qreg_slice_inline(q : qreg):
            # Broadcast via a slice
            X(q)
            # Control with slicing inline
            X.ctrl(q[0:3], q.tail())
            # Broadcast
            Measure(q)
        
        q = qalloc(4)
        comp = test_cccx_qreg_slice_inline.extract_composite(q)
        print(comp)

        # Run experiment
        test_cccx_qreg_slice_inline(q)

        # Print the results
        q.print()
        counts = q.counts()
        print(counts)
        self.assertEqual(len(counts), 1)
        # q3: 1 --> 0
        self.assertTrue('1110' in counts)

if __name__ == '__main__':
  unittest.main()
 No newline at end of file