Unverified Commit d0176843 authored by Mccaskey, Alex's avatar Mccaskey, Alex Committed by GitHub
Browse files

Merge pull request #167 from tnguyen-ornl/tnguyen/pyxasm-while

While loop and op-assign syntax for PyXasm
parents 19ee8d86 ee57082e
Loading
Loading
Loading
Loading
Loading
+9 −2
Original line number Diff line number Diff line
@@ -79,10 +79,12 @@ void PyXasmTokenCollector::collect(clang::Preprocessor &PP,
      line += for_stmt;
    }

    // If statement:
    // If statement or while statement:
    // Add a space b/w tokens.
    // Note: Python has an "elif" token, which doesn't have a C++ equiv.
    if (Toks[i].is(clang::tok::TokenKind::kw_if) ||
        PP.getSpelling(Toks[i]) == "elif") {
        PP.getSpelling(Toks[i]) == "elif" ||
        Toks[i].is(clang::tok::TokenKind::kw_while)) {
      line += " ";
      i += 1;
      line += PP.getSpelling(Toks[i]);
@@ -171,6 +173,11 @@ void PyXasmTokenCollector::collect(clang::Preprocessor &PP,
      // Remove the first two characters ("el")
      // hence this line will be parsed as an idependent C++ if block:
      lineText.erase(0, 2);
    } else if (line.first.rfind("while ", 0) == 0) {
      // rewrite to 
      // while (condition) {}
      // Just capture the indent level to close the scope properly
      scope_block_indent.push(line.second);
    }
    // is_in_for_loop = line.first.find("for ") != std::string::npos &&
    // line.second >= previous_col;
+44 −1
Original line number Diff line number Diff line
@@ -457,7 +457,17 @@ class pyxasm_visitor : public pyxasmBaseVisitor {
        return 0;
      }
    } else {
      return visitChildren(ctx);
      // Visit child node:
      auto child_result = visitChildren(ctx);
      const auto translated_src = sub_node_translation.str();
      sub_node_translation.str(std::string());
      // If no child nodes, perform the codegen (result.first is not set)
      // but just appending the incremental translation collector;
      // return the collected C++ statement.
      if (result.first.empty() && !translated_src.empty()) {
        result.first = translated_src + ";\n";
      }
      return child_result;
    }
  }

@@ -501,6 +511,39 @@ class pyxasm_visitor : public pyxasmBaseVisitor {
    return visitChildren(ctx);
  }

  virtual antlrcpp::Any
  visitWhile_stmt(pyxasmParser::While_stmtContext *ctx) override {
    std::stringstream ss;
    ss << "while (" << ctx->test()->getText() << ") {\n";
    result.first = ss.str();
    return 0;
  }

  virtual antlrcpp::Any visitTestlist_star_expr(
      pyxasmParser::Testlist_star_exprContext *context) override {
    // std::cout << "Testlist_star_exprContext:" << context->getText() << "\n";
    const auto var_name = context->getText();
    if (xacc::container::contains(declared_var_names, var_name)) {
      sub_node_translation << var_name << " ";
      return 0;
    }
    return visitChildren(context);
  }

  virtual antlrcpp::Any
  visitAugassign(pyxasmParser::AugassignContext *context) override {
    // std::cout << "Augassign:" << context->getText() << "\n";
    sub_node_translation << context->getText() << " ";
    return 0;
  }

  virtual antlrcpp::Any
  visitTestlist(pyxasmParser::TestlistContext *context) override {
    // std::cout << "visitTestlist:" << context->getText() << "\n";
    sub_node_translation << context->getText() << " ";
    return 0;
  }

 private:
  // Replaces common Python constants, e.g. 'math.pi' or 'numpy.pi'.
  // Note: the library names have been resolved to their original names.
+29 −0
Original line number Diff line number Diff line
@@ -232,5 +232,34 @@ class TestKernelJIT(unittest.TestCase):
        self.assertEqual(q.counts()["0"], r.counts()["0"])
        self.assertEqual(q.counts()["1"], r.counts()["1"])
    
    def test_while_loop(self):
        @qjit 
        def while_loop_kernel(q: qubit, x: int, exp_inv: int):
            rev = 0
            while x:
                rev <<= 1
                rev += x & 1
                x >>= 1

            if rev == exp_inv:
                print("Success")
                X(q)

        # Reference: pure python to check
        def reverse_bit(num):
            result = 0
            while num:
                result = (result << 1) + (num & 1)
                num >>= 1
            return result
        
        q = qalloc(1)
        test_val = 123
        expected = reverse_bit(test_val)
        comp0 = while_loop_kernel.extract_composite(q[0], test_val, expected)    
        print("Comp:", comp0)
        # Has X applied.
        self.assertEqual(comp0.nInstructions(), 1) 

if __name__ == '__main__':
  unittest.main()
 No newline at end of file