Unverified Commit 75d9e8ad authored by Mccaskey, Alex's avatar Mccaskey, Alex Committed by GitHub
Browse files

Merge pull request #208 from tnguyen-ornl/tnguyen/update-qasm3

Support MLIR codegen with scf.if
parents 207fd011 4de3a61b
Loading
Loading
Loading
Loading
Loading
+31 −0
Original line number Diff line number Diff line
OPENQASM 3;
include "stdgates.inc";
const n_iters = 100;
int count1 = 0;
qubit q, a;
 
for i in [0:n_iters] {
    // Generate |+> eigenstate
    x q;

    // apply hadamard on ancilla
    h a;

    // Ctrl-U, U == H
    cx a, q;

    // apply hadamard again
    h a;

    // measure and reset
    bit c;
    c = measure a;
    reset a;

    // Store up the observed bits
    if (c == 1) {
      count1 = count1 + 1;
    } 
}

print("one count = ", count1);
 No newline at end of file
+11 −0
Original line number Diff line number Diff line
@@ -267,6 +267,17 @@ class qasm3_visitor : public qasm3::qasm3BaseVisitor {

    return allocation;
  }

  template <class NodeType>
  bool hasChildNodeOfType(antlr4::tree::ParseTree &in_node) {
    for (auto &child_node : in_node.children) {
      if (dynamic_cast<NodeType *>(child_node) ||
          hasChildNodeOfType<NodeType>(*child_node)) {
        return true;
      }
    }
    return false;
  }
};

}  // namespace qcor
+137 −1
Original line number Diff line number Diff line
#include "gtest/gtest.h"
#include "qcor_mlir_api.hpp"
#include "gtest/gtest.h"

namespace {
// returns count of non-overlapping occurrences of 'sub' in 'str'
int countSubstring(const std::string &str, const std::string &sub) {
  if (sub.length() == 0)
    return 0;
  int count = 0;
  for (size_t offset = str.find(sub); offset != std::string::npos;
       offset = str.find(sub, offset + sub.length())) {
    ++count;
  }
  return count;
}
} // namespace

// Check Affine-SCF constructs
TEST(qasm3VisitorTester, checkCFG_AffineScf) {
  const std::string qasm_code = R"#(OPENQASM 3;
include "qelib1.inc";

int[64] iterate_value = 0;
int[64] value_5 = 0;
int[64] value_2 = 0;
for i in [0:10] {
    iterate_value = i;
    if (i == 5) {
        print("Iterate over 5");
        value_5 = 5;
    }
    if (i == 2) {
        print("Iterate over 2");
        value_2 = 2;
       
    }
    print("i = ", i);
}

QCOR_EXPECT_TRUE(iterate_value == 9);
QCOR_EXPECT_TRUE(value_5 == 5);
QCOR_EXPECT_TRUE(value_2 == 2);
print("made it out of the loop");
)#";
  auto mlir = qcor::mlir_compile(qasm_code, "affine_scf",
                                 qcor::OutputType::MLIR, false);
  std::cout << mlir << "\n";
  // 1 for loop, 2 if blocks
  EXPECT_EQ(countSubstring(mlir, "affine.for"), 1);
  EXPECT_EQ(countSubstring(mlir, "scf.if"), 2);
  EXPECT_FALSE(qcor::execute(qasm_code, "affine_scf"));
}

TEST(qasm3VisitorTester, checkCtrlDirectives) {
  const std::string uint_index = R"#(OPENQASM 3;
@@ -33,6 +83,92 @@ print("made it out of the loop");
  EXPECT_FALSE(qcor::execute(uint_index, "uint_index"));
}

TEST(qasm3VisitorTester, checkIqpewithIf) {
  const std::string qasm_code = R"#(OPENQASM 3;
include "qelib1.inc";

// Expected to get 4 bits (iteratively) of 1011 (or 1101 LSB) = 11(decimal):
// phi_est = 11/16 (denom = 16 since we have 4 bits)
// => phi = 2pi * 11/16 = 11pi/8 = 2pi - 5pi/8
// i.e. we estimate the -5*pi/8 angle...
qubit q[2];
const bits_precision = 4;
bit c[bits_precision];

// Prepare the eigen-state: |1>
x q[1];

// First bit
h q[0];
// Controlled rotation: CU^k
for i in [0:8] {
  cphase(-5*pi/8) q[0], q[1];
}
h q[0];
// Measure and reset
measure q[0] -> c[0];
reset q[0];

// Second bit
h q[0];
for i in [0:4] {
  cphase(-5*pi/8) q[0], q[1];
}
// Conditional rotation
if (c[0] == 1) {
  rz(-pi/2) q[0];
}
h q[0];
// Measure and reset
measure q[0] -> c[1];
reset q[0];

// Third bit
h q[0];
for i in [0:2] {
  cphase(-5*pi/8) q[0], q[1];
}
// Conditional rotation
if (c[0] == 1) {
  rz(-pi/4) q[0];
}
if (c[1] == 1) {
  rz(-pi/2) q[0];
}
h q[0];
// Measure and reset
measure q[0] -> c[2];
reset q[0];

// Fourth bit
h q[0];
cphase(-5*pi/8) q[0], q[1];
// Conditional rotation
if (c[0] == 1) {
  rz(-pi/8) q[0];
}
if (c[1] == 1) {
  rz(-pi/4) q[0];
}
if (c[2] == 1) {
  rz(-pi/2) q[0];
}
h q[0];
measure q[0] -> c[3];

print(c[0], c[1], c[2], c[3]);
QCOR_EXPECT_TRUE(c[0] == 1);
QCOR_EXPECT_TRUE(c[1] == 1);
QCOR_EXPECT_TRUE(c[2] == 0);
QCOR_EXPECT_TRUE(c[3] == 1);
)#";
  // Make sure we can compile this in FTQC.
  // i.e., usual if ...
  auto mlir = qcor::mlir_compile(qasm_code, "iqpe",
                                 qcor::OutputType::LLVMIR, false);
  std::cout << mlir << "\n";
}


int main(int argc, char **argv) {
  ::testing::InitGoogleTest(&argc, argv);
+8 −2
Original line number Diff line number Diff line
@@ -232,8 +232,10 @@ antlrcpp::Any qasm3_expression_generator::visitComparsionExpression(
      auto lhs_bw = lhs.getType().getIntOrFloatBitWidth();
      auto rhs_bw = rhs.getType().getIntOrFloatBitWidth();

      if (lhs.getType().isa<mlir::IntegerType>()) {
        if (!rhs.getType().isa<mlir::IntegerType>()) {
      if (lhs.getType().isa<mlir::IntegerType>() ||
          lhs.getType().isa<mlir::IndexType>()) {
        if (!rhs.getType().isa<mlir::IntegerType>() &&
            !rhs.getType().isa<mlir::IndexType>()) {
          printErrorMessage("for comparison " + op +
                                " lhs was an integer type, but rhs was not.",
                            compare, {lhs, rhs});
@@ -259,6 +261,10 @@ antlrcpp::Any qasm3_expression_generator::visitComparsionExpression(
        }
        update_current_value(builder.create<mlir::CmpFOp>(
            location, antlr_to_mlir_fpredicate[op], lhs, rhs));
      } else {
        // Sth wrong, we cannot handle this atm.
        printErrorMessage("Unhandled comparison " + op + ": ", compare,
                          {lhs, rhs});
      }
      return 0;
    } else {
+108 −62
Original line number Diff line number Diff line

#include "expression_handler.hpp"
#include "mlir/Dialect/SCF/SCF.h"
#include "qasm3_visitor.hpp"

namespace {
// ATM, we don't try to convert everything to the
// special Quantum If-Then-Else Op.
@@ -10,7 +10,8 @@ namespace {
// i.e., this serve mainly as a stop-gap before fully-FTQC runtimes become
// available.

// FIXME: Define a Target Capability setting and make the compiler aware of that.
// FIXME: Define a Target Capability setting and make the compiler aware of
// that.

// Capture binary comparison conditional.
// Note: currently, only bit-level is modeled.
@@ -201,8 +202,15 @@ antlrcpp::Any qasm3_visitor::visitBranchingStatement(
    }
  }

  // TODO: The below code could be rewritten to an AffineIfOp/SCF::IfOp:
  // Map it to a Value
  // Manually write the conditional block:
  // If there is 'break', 'continue' (ControlDirective) in the body.
  // The reason being: these break/continue will be translated to BranchOp
  // which are overlapping with the BranchOp implicitly added at the end of SCF::IfOp.
  // e.g., 
  // br ^bb1 (e.g., out of the outer loop) <-- added by ControlDirectiveContext handler
  // br ^bb2 (e.g., to the end of the if statement) <-- added by the implicit yield op
  // The verify step (MLIR -> LLVM) will complain this....
  if (hasChildNodeOfType<qasm3Parser::ControlDirectiveContext>(*context)) {
    qasm3_expression_generator exp_generator(builder, symbol_table, file_name);
    exp_generator.visit(conditional_expr);
    auto expr_value = exp_generator.current_value;
@@ -276,7 +284,45 @@ antlrcpp::Any qasm3_visitor::visitBranchingStatement(
    builder.setInsertionPointToStart(exitBlock);

    symbol_table.set_last_created_block(exitBlock);
  } else {
    // Using SCF::IfOp
    // Map it to a Value
    qasm3_expression_generator exp_generator(builder, symbol_table, file_name);
    exp_generator.visit(conditional_expr);
    // Boolean check value:
    auto expr_value = exp_generator.current_value;
    // Must be an i1 (bool)
    assert(expr_value.getType().isa<mlir::IntegerType>() &&
           expr_value.getType().getIntOrFloatBitWidth() == 1);
    // Create SCF If Op:
    // SCF IfOp (switching on a boolean value) matches what we need here,
    // an AffineIfOp requires an integer set and will be lowered to SCF's IfOp
    // later, hence is not a good solution.
    const bool hasElseBlock = context->programBlock().size() == 2;
    auto scfIfOp = builder.create<mlir::scf::IfOp>(location, mlir::TypeRange(),
                                                   expr_value, hasElseBlock);

    // Build up the 'then' region:
    auto thenBodyBuilder = scfIfOp.getThenBodyBuilder();
    auto cached_builder = builder;
    builder = thenBodyBuilder;
    symbol_table.enter_new_scope();
    // Get the conditional code and visit the nodes
    auto conditional_code = context->programBlock(0);
    visitChildren(conditional_code);
    symbol_table.exit_scope();

    if (hasElseBlock) {
      auto elseBodyBuilder = scfIfOp.getElseBodyBuilder();
      builder = elseBodyBuilder;
      symbol_table.enter_new_scope();
      // Visit the second programBlock
      visitChildren(context->programBlock(1));
      symbol_table.exit_scope();
    }
    // Restore builder
    builder = cached_builder;
  }
  return 0;
}
} // namespace qcor
 No newline at end of file
Loading