Commit dc467701 authored by Nguyen, Thien Minh's avatar Nguyen, Thien Minh
Browse files

[WIP] get the Affine loop break working



- Simplify the structure: the loop handler detect if it contains break/continue to set up the structure (conditional body). i.e., no need to move code after the body is constructed.

- Properly handle loop continuation directive, i.e., a true break point and ready for continue implementation.

Signed-off-by: default avatarThien Nguyen <nguyentm@ornl.gov>
parent 81764797
Loading
Loading
Loading
Loading
+11 −2
Original line number Diff line number Diff line
@@ -216,7 +216,16 @@ class qasm3_visitor : public qasm3::qasm3BaseVisitor {
  mlir::Type array_type;
  mlir::Type result_type;

  std::stack<mlir::Value> loop_break_vars;
  // Loop control vars for break/continue implementation with Region-based
  // Affine/SCF Ops.
  // Strategy:
  /// - A break-able for loop will have a bool (first in the pair) to control
  /// the loop body execution. i.e., bypass the whole loop if the break
  /// condition is triggered.
  /// - The second bool is the continue condition which will bypass all
  /// the remaining ops in the body.
  /// We use a stack to handle nested loops, which are all break-able.
  std::stack<std::pair<mlir::Value, mlir::Value>> for_loop_control_vars;
  // This method will add correct number of InstOps
  // based on quantum gate broadcasting
  void createInstOps_HandleBroadcast(std::string name,
+16 −1
Original line number Diff line number Diff line
@@ -240,6 +240,21 @@ antlrcpp::Any qasm3_visitor::visitBranchingStatement(
  // Restore builder
  builder = cached_builder;

  // Check if this if/else contains loop control directives:
  const bool containsLoopDirectives = scfIfOp->hasAttr("control-directive");
  if (containsLoopDirectives) {
    std::cout << "If op triggers control directive: \n";
    scfIfOp.dump();
    // At this point, wrap the following code in an If (check for loop
    // continuation condition.)
    auto [cond1, cond2] = for_loop_control_vars.top();
    // Wrap/Outline the loop body in an IfOp:
    auto continuationIfOp = builder.create<mlir::scf::IfOp>(
        location, mlir::TypeRange(),
        builder.create<mlir::LoadOp>(location, cond2), false);
    auto continuationThenBodyBuilder = continuationIfOp.getThenBodyBuilder();
    builder = continuationThenBodyBuilder;
  }
  return 0;
}
} // namespace qcor
 No newline at end of file
+14 −23
Original line number Diff line number Diff line
@@ -32,33 +32,24 @@ antlrcpp::Any qasm3_visitor::visitControlDirective(
      printErrorMessage("Illegal break directive: unconditional break.");
    }

    // Strategy: predicating every statement potentially executed after at least
    // one break on the absence of break.
    // Set an attribute so that we can detect this after handling this.
    parentIfOp->setAttr("control-directive",
                       mlir::IntegerAttr::get(builder.getIntegerType(1), 1));
    assert(!for_loop_control_vars.empty());
    auto [cond1, cond2] = for_loop_control_vars.top();

    // Create a 'mustBreak' bool at the outer scope:
    mlir::Value mustBreak;
    {
      mlir::OpBuilder::InsertionGuard g(builder);
      builder.setInsertionPointToStart(
          &(m_module.getRegion().getBlocks().front()));
      mustBreak = builder.create<mlir::AllocaOp>(
          location, mlir::MemRefType::get(llvm::ArrayRef<int64_t>{},
                                          builder.getI1Type()));
      // store false (at the outer scope)
    // Store false to both the break and continue:
    // i.e., bypass the whole for loop and the rest of the loop body:
    builder.create<mlir::StoreOp>(
        location,
        get_or_create_constant_integer_value(0, location, builder.getI1Type(),
                                             symbol_table, builder),
          mustBreak);
    }

    // Store true here:
        cond1);
    builder.create<mlir::StoreOp>(
        location,
        get_or_create_constant_integer_value(1, location, builder.getI1Type(),
        get_or_create_constant_integer_value(0, location, builder.getI1Type(),
                                             symbol_table, builder),
        mustBreak);
    loop_break_vars.push(mustBreak);
        cond2);
  } else if (stmt == "continue") {
    // TODO: Handle this case.
    if (current_loop_incrementor_block) {
+50 −38
Original line number Diff line number Diff line
@@ -313,6 +313,35 @@ antlrcpp::Any qasm3_visitor::visitLoopStatement(
      // HACK: Currently, we don't handle 'break', 'continue' or nested loop
      // in the Affine for loop yet.
      if (program_block_str.find("QCOR_EXPECT_TRUE") == std::string::npos) {
        const bool isLoopBreakable =
            hasChildNodeOfType<qasm3Parser::ControlDirectiveContext>(*context);
        auto cachedBuilder = builder;
        if (isLoopBreakable) {
          // Add the two loop control bool vars:
          mlir::OpBuilder::InsertionGuard g(builder);
          // Top-level if control (skipping the whole loop if false)
          mlir::Value executeWholeLoop = builder.create<mlir::AllocaOp>(
              location, mlir::MemRefType::get(llvm::ArrayRef<int64_t>{},
                                              builder.getI1Type()));
          // Loop body control: skipping portions of the the body if
          // false: e.g., handle 'continue'-like directive.
          mlir::Value executeThisBlock = builder.create<mlir::AllocaOp>(
              location, mlir::MemRefType::get(llvm::ArrayRef<int64_t>{},
                                              builder.getI1Type()));
          // store true
          builder.create<mlir::StoreOp>(
              location,
              get_or_create_constant_integer_value(
                  1, location, builder.getI1Type(), symbol_table, builder),
              executeWholeLoop);
          builder.create<mlir::StoreOp>(
              location,
              get_or_create_constant_integer_value(
                  1, location, builder.getI1Type(), symbol_table, builder),
              executeThisBlock);
          for_loop_control_vars.push(
              std::make_pair(executeWholeLoop, executeThisBlock));
        }
        // Can use Affine for loop....
        auto forLoop = affineLoopBuilder(
            a_value, b_value, c,
@@ -322,47 +351,30 @@ antlrcpp::Any qasm3_visitor::visitLoopStatement(
              auto loop_var_cast = builder.create<mlir::IndexCastOp>(
                  location, builder.getI64Type(), loop_var);
              symbol_table.add_symbol(idx_var_name, loop_var_cast, {}, true);
              visitChildren(program_block);
              symbol_table.exit_scope();
            },
            builder, location);
        if (!loop_break_vars.empty()) {
          mlir::OpBuilder::InsertionGuard g(builder);
          builder.setInsertionPointToStart(
              &(forLoop.getLoopBody().getBlocks().front()));
          mlir::Value mustBreak =
              builder.create<mlir::LoadOp>(location, loop_break_vars.top());
          assert(mustBreak.getType().isa<mlir::IntegerType>() &&
                 mustBreak.getType().getIntOrFloatBitWidth() == 1);
          loop_break_vars.pop();

              if (isLoopBreakable) {
                auto [cond1, cond2] = for_loop_control_vars.top();
                // Wrap/Outline the loop body in an IfOp:
                auto scfIfOp = builder.create<mlir::scf::IfOp>(
              location, mlir::TypeRange(), mustBreak, false);
          size_t count = 0;
          std::vector<mlir::Operation *> ops_to_clone;
          for (auto op_iter = forLoop.getLoopBody().op_begin();
               op_iter != forLoop.getLoopBody().op_end(); ++op_iter) {
            mlir::Operation &op = *op_iter;
            count++;
            // The first 2 are the Load and If that we just inserted
            if (count <= 2) {
              continue;
            }
            ops_to_clone.emplace_back(&op);
          }
          // Last one is affine yield
          // should be left outside:
          assert(
              mlir::dyn_cast_or_null<mlir::AffineYieldOp>(ops_to_clone.back()));
          ops_to_clone.pop_back();
          for (auto &op : ops_to_clone) {
            scfIfOp.getThenBodyBuilder().clone(*op);
                    location, mlir::TypeRange(),
                    builder.create<mlir::LoadOp>(location, cond1), false);
                auto thenBodyBuilder = scfIfOp.getThenBodyBuilder();
                auto cached_builder = builder;
                builder = thenBodyBuilder;
                visitChildren(program_block);
                builder = cached_builder;
              } else {
                visitChildren(program_block);
              }

          for (auto &op : ops_to_clone) {
            op->remove();
          }
              symbol_table.exit_scope();

              if (isLoopBreakable) {
                for_loop_control_vars.pop();
              }
            },
            builder, location);
        builder = cachedBuilder;
      } else {
        // TODO: Remove this code path once we convert control flow to Affine/SCF
        // Need to use the legacy for loop construction for now...