Commit 461a5946 authored by Nguyen, Thien Minh's avatar Nguyen, Thien Minh
Browse files

Support control directive in While loop as well



Also, rename the loop control tracking var name to reflect that it supports all loop types.

Signed-off-by: default avatarThien Nguyen <nguyentm@ornl.gov>
parent f0997107
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -225,7 +225,7 @@ class qasm3_visitor : public qasm3::qasm3BaseVisitor {
  /// - The second bool is the continue condition which will bypass all
  /// the remaining ops in the body.
  /// We use a stack to handle nested loops, which are all break-able.
  std::stack<std::pair<mlir::Value, mlir::Value>> for_loop_control_vars;
  std::stack<std::pair<mlir::Value, mlir::Value>> loop_control_directive_bool_vars;
  // This method will add correct number of InstOps
  // based on quantum gate broadcasting
  void createInstOps_HandleBroadcast(std::string name,
+1 −1
Original line number Diff line number Diff line
@@ -245,7 +245,7 @@ antlrcpp::Any qasm3_visitor::visitBranchingStatement(
  if (containsLoopDirectives) {
    // At this point, wrap the following code in an If (check for loop
    // continuation condition.)
    auto [cond1, cond2] = for_loop_control_vars.top();
    auto [cond1, cond2] = loop_control_directive_bool_vars.top();
    // Wrap/Outline the loop body in an IfOp:
    auto continuationIfOp = builder.create<mlir::scf::IfOp>(
        location, mlir::TypeRange(),
+4 −4
Original line number Diff line number Diff line
@@ -33,8 +33,8 @@ antlrcpp::Any qasm3_visitor::visitControlDirective(
    // Set an attribute so that we can detect this after handling this.
    parentIfOp->setAttr("control-directive",
                       mlir::IntegerAttr::get(builder.getIntegerType(1), 1));
    assert(!for_loop_control_vars.empty());
    auto [cond1, cond2] = for_loop_control_vars.top();
    assert(!loop_control_directive_bool_vars.empty());
    auto [cond1, cond2] = loop_control_directive_bool_vars.top();

    // Store false to both the break and continue:
    // i.e., bypass the whole for loop and the rest of the loop body:
@@ -64,8 +64,8 @@ antlrcpp::Any qasm3_visitor::visitControlDirective(
    // Set an attribute so that we can detect this after handling this.
    parentIfOp->setAttr("control-directive",
                        mlir::IntegerAttr::get(builder.getIntegerType(1), 1));
    assert(!for_loop_control_vars.empty());
    auto [cond1, cond2] = for_loop_control_vars.top();
    assert(!loop_control_directive_bool_vars.empty());
    auto [cond1, cond2] = loop_control_directive_bool_vars.top();

    // Just bypass rest of the loop body (after this point)
    // i.e., not disable the whol loop.
+56 −10
Original line number Diff line number Diff line
@@ -247,7 +247,7 @@ void qasm3_visitor::createRangeBasedForLoop(
        get_or_create_constant_integer_value(1, location, builder.getI1Type(),
                                             symbol_table, builder),
        executeThisBlock);
    for_loop_control_vars.push(
    loop_control_directive_bool_vars.push(
        std::make_pair(executeWholeLoop, executeThisBlock));
  }
  // Can use Affine for loop....
@@ -261,7 +261,7 @@ void qasm3_visitor::createRangeBasedForLoop(
        symbol_table.add_symbol(idx_var_name, loop_var_cast, {}, true);

        if (isLoopBreakable) {
          auto [cond1, cond2] = for_loop_control_vars.top();
          auto [cond1, cond2] = loop_control_directive_bool_vars.top();
          // Wrap/Outline the loop body in an IfOp:
          auto scfIfOp = builder.create<mlir::scf::IfOp>(
              location, mlir::TypeRange(),
@@ -278,7 +278,7 @@ void qasm3_visitor::createRangeBasedForLoop(
        symbol_table.exit_scope();

        if (isLoopBreakable) {
          for_loop_control_vars.pop();
          loop_control_directive_bool_vars.pop();
        }
      },
      builder, location);
@@ -361,7 +361,7 @@ void qasm3_visitor::createSetBasedForLoop(
        get_or_create_constant_integer_value(1, location, builder.getI1Type(),
                                             symbol_table, builder),
        executeThisBlock);
    for_loop_control_vars.push(
    loop_control_directive_bool_vars.push(
        std::make_pair(executeWholeLoop, executeThisBlock));
  }

@@ -379,7 +379,7 @@ void qasm3_visitor::createSetBasedForLoop(
        symbol_table.add_symbol(idx_var_name, loop_var, {}, true);

        if (isLoopBreakable) {
          auto [cond1, cond2] = for_loop_control_vars.top();
          auto [cond1, cond2] = loop_control_directive_bool_vars.top();
          // Wrap/Outline the loop body in an IfOp:
          auto scfIfOp = builder.create<mlir::scf::IfOp>(
              location, mlir::TypeRange(),
@@ -395,7 +395,7 @@ void qasm3_visitor::createSetBasedForLoop(
        symbol_table.exit_scope();

        if (isLoopBreakable) {
          for_loop_control_vars.pop();
          loop_control_directive_bool_vars.pop();
        }
      },
      builder, location);
@@ -410,6 +410,36 @@ void qasm3_visitor::createWhileLoop(
  assert(loop_signature->booleanExpression());
  auto main_block = builder.saveInsertionPoint();
  auto cachedBuilder = builder;
  const bool isLoopBreakable =
      hasChildNodeOfType<qasm3Parser::ControlDirectiveContext>(*context);

  if (isLoopBreakable) {
    // Add the two loop control bool vars:
    mlir::OpBuilder::InsertionGuard g(builder);
    // Top-level if control (skipping the whole loop if false)
    mlir::Value executeWholeLoop = builder.create<mlir::AllocaOp>(
        location,
        mlir::MemRefType::get(llvm::ArrayRef<int64_t>{}, builder.getI1Type()));
    // Loop body control: skipping portions of the the body if
    // false: e.g., handle 'continue'-like directive.
    mlir::Value executeThisBlock = builder.create<mlir::AllocaOp>(
        location,
        mlir::MemRefType::get(llvm::ArrayRef<int64_t>{}, builder.getI1Type()));
    // store true
    builder.create<mlir::StoreOp>(
        location,
        get_or_create_constant_integer_value(1, location, builder.getI1Type(),
                                             symbol_table, builder),
        executeWholeLoop);
    builder.create<mlir::StoreOp>(
        location,
        get_or_create_constant_integer_value(1, location, builder.getI1Type(),
                                             symbol_table, builder),
        executeThisBlock);
    loop_control_directive_bool_vars.push(
        std::make_pair(executeWholeLoop, executeThisBlock));
  }

  mlir::scf::WhileOp whileOp = builder.create<mlir::scf::WhileOp>(
      location, mlir::TypeRange() /*resultTypes*/,
      mlir::ValueRange() /*operands*/);
@@ -423,8 +453,18 @@ void qasm3_visitor::createWhileLoop(
  qasm3_expression_generator exp_generator(builder, symbol_table, file_name);
  exp_generator.visit(loop_signature->booleanExpression());
  mlir::Value cond = exp_generator.current_value;
  builder.create<mlir::scf::ConditionOp>(location, cond, before->getArguments());
  builder.setInsertionPointToStart(&whileOp.after().front());

  if (isLoopBreakable) {
    auto [cond1, cond2] = loop_control_directive_bool_vars.top();
    // Do a logical AND (&&) with the while condition.
    mlir::Value extended_cond = builder.create<mlir::AndOp>(
        location, builder.create<mlir::LoadOp>(location, cond1), cond);
    builder.create<mlir::scf::ConditionOp>(location, extended_cond,
                                           before->getArguments());
  } else {
    builder.create<mlir::scf::ConditionOp>(location, cond,
                                           before->getArguments());
  }

  // Build the "after" region:
  // In a "while" loop, this region is the loop body.
@@ -433,11 +473,17 @@ void qasm3_visitor::createWhileLoop(
    symbol_table.enter_new_scope();
    visitChildren(program_block);
    symbol_table.exit_scope();
    // 'After' block must end with a yield op.
    builder.create<mlir::scf::YieldOp>(location);
  }

  if (isLoopBreakable) {
    loop_control_directive_bool_vars.pop();
  }

  builder = cachedBuilder;
  // 'After' block must end with a yield op.
  mlir::Operation &lastOp = whileOp.after().front().getOperations().back();
  builder.setInsertionPointAfter(&lastOp);
  builder.create<mlir::scf::YieldOp>(location);
  builder.restoreInsertionPoint(main_block);
}
} // namespace qcor
 No newline at end of file