Commit 81764797 authored by Nguyen, Thien Minh's avatar Nguyen, Thien Minh
Browse files

[WIP] Work on supporting break-like statement in for loop



- Moving control directive handling to a separate cpp file for maintainability.

- Use outlining technique to wrap breakable loop body in a scf::if block

Signed-off-by: default avatarThien Nguyen <nguyentm@ornl.gov>
parent 7735b626
Loading
Loading
Loading
Loading
+1 −11
Original line number Diff line number Diff line
@@ -6,17 +6,7 @@ if (APPLE)
  set(ANTLR_LIB ${XACC_ROOT}/lib/libantlr4-runtime.dylib)
endif()

file(GLOB SRC *.cpp antlr/generated/*.cpp utils/*.cpp 
   visitor_handlers/quantum_types_handler.cpp
   visitor_handlers/quantum_instruction_handler.cpp
   visitor_handlers/classical_types_handler.cpp
   visitor_handlers/measurement_handler.cpp
   visitor_handlers/loop_stmt_handler.cpp
   visitor_handlers/conditional_handler.cpp
   visitor_handlers/subroutine_handler.cpp
   visitor_handlers/alias_handler.cpp
   visitor_handlers/compute_action_handler.cpp
   )
file(GLOB SRC *.cpp antlr/generated/*.cpp utils/*.cpp visitor_handlers/*.cpp)

add_library(${LIBRARY_NAME} SHARED ${SRC})
target_compile_features(${LIBRARY_NAME} 
+2 −1
Original line number Diff line number Diff line
@@ -216,6 +216,7 @@ class qasm3_visitor : public qasm3::qasm3BaseVisitor {
  mlir::Type array_type;
  mlir::Type result_type;
  
  std::stack<mlir::Value> loop_break_vars;
  // This method will add correct number of InstOps
  // based on quantum gate broadcasting
  void createInstOps_HandleBroadcast(std::string name,
+36 −119
Original line number Diff line number Diff line
@@ -201,90 +201,6 @@ antlrcpp::Any qasm3_visitor::visitBranchingStatement(
      return 0;
    }
  }

  // Manually write the conditional block:
  // If there is 'break', 'continue' (ControlDirective) in the body.
  // The reason being: these break/continue will be translated to BranchOp
  // which are overlapping with the BranchOp implicitly added at the end of SCF::IfOp.
  // e.g., 
  // br ^bb1 (e.g., out of the outer loop) <-- added by ControlDirectiveContext handler
  // br ^bb2 (e.g., to the end of the if statement) <-- added by the implicit yield op
  // The verify step (MLIR -> LLVM) will complain this....
  if (hasChildNodeOfType<qasm3Parser::ControlDirectiveContext>(*context)) {
    qasm3_expression_generator exp_generator(builder, symbol_table, file_name);
    exp_generator.visit(conditional_expr);
    auto expr_value = exp_generator.current_value;

    // build up the program block
    auto currRegion = builder.getBlock()->getParent();
    auto savept = builder.saveInsertionPoint();
    auto thenBlock = builder.createBlock(currRegion, currRegion->end());
    auto elseBlock = builder.createBlock(currRegion, currRegion->end());
    mlir::Block *exitBlock = nullptr;
    // If we have an else block from programBlock,
    // then create a stand alone exit block that both
    // then and else can fall to
    if (context->programBlock().size() == 2) {
      exitBlock = builder.createBlock(currRegion, currRegion->end());
    } else {
      exitBlock = elseBlock;
    }

    // Build up the THEN Block
    builder.setInsertionPointToStart(thenBlock);
    symbol_table.enter_new_scope();
    // Get the conditional code and visit the nodes
    auto conditional_code = context->programBlock(0);
    visitChildren(conditional_code);

    // Need to check if we have a branch out of
    // this thenBlock, if so do not add a branch
    // to the exitblock
    mlir::Operation &last_op = thenBlock->back();
    auto branchOps = thenBlock->getOps<mlir::BranchOp>();
    auto branch_to_exit = true;
    for (auto b : branchOps) {
      if (b.dest() == current_loop_exit_block ||
          b.dest() == current_loop_header_block ||
          b.dest() == current_loop_incrementor_block) {
        branch_to_exit = false;
        break;
      }
    }
    if (branch_to_exit) {
      builder.create<mlir::BranchOp>(location, exitBlock);
    }
    symbol_table.exit_scope();

    // If we have a second program block then we have an else stmt
    builder.setInsertionPointToStart(elseBlock);
    if (context->programBlock().size() == 2) {
      symbol_table.enter_new_scope();
      visitChildren(context->programBlock(1));
      branch_to_exit = true;
      for (auto b : branchOps) {
        if (b.dest() == current_loop_exit_block ||
            b.dest() == current_loop_header_block ||
            b.dest() == current_loop_incrementor_block) {
          branch_to_exit = false;
          break;
        }
      }
      if (branch_to_exit) {
        builder.create<mlir::BranchOp>(location, exitBlock);
      }

      symbol_table.exit_scope();
    }

    // Restore the insertion point and create the conditional statement
    builder.restoreInsertionPoint(savept);
    builder.create<mlir::CondBranchOp>(location, expr_value, thenBlock,
                                       elseBlock);
    builder.setInsertionPointToStart(exitBlock);

    symbol_table.set_last_created_block(exitBlock);
  } else {
  // Using SCF::IfOp
  // Map it to a Value
  qasm3_expression_generator exp_generator(builder, symbol_table, file_name);
@@ -320,9 +236,10 @@ antlrcpp::Any qasm3_visitor::visitBranchingStatement(
    visitChildren(context->programBlock(1));
    symbol_table.exit_scope();
  }

  // Restore builder
  builder = cached_builder;
  }
  
  return 0;
}
} // namespace qcor
 No newline at end of file
+80 −0
Original line number Diff line number Diff line
#include "expression_handler.hpp"
#include "qasm3_visitor.hpp"
#include "mlir/Dialect/SCF/SCF.h"

namespace qcor {
antlrcpp::Any qasm3_visitor::visitControlDirective(
    qasm3Parser::ControlDirectiveContext* context) {
  auto location = get_location(builder, file_name, context);

  auto stmt = context->getText();

  // Strategy:
  // Converting break/continue directives to region-based control-flow (in the
  // Affine/SCF dialects) Following the direction here:
  // https://llvm.discourse.group/t/dynamic-control-flow-break-like-operation/2495/16
  // e.g., predicating every statement potentially executed after at least one
  // break on the absence of break. This doesn’t break our SCF/Affine analyses
  // and transformations, that rely on there being single block and static
  // control flow.
  // For example, with a for loop: we need to wrap the whole body in a break check
  // **and** each subsequent block after the *break/continue* point
  if (stmt == "break") {
    // builder.create<mlir::BranchOp>(location, current_loop_exit_block);

    mlir::Region *region = builder.getInsertionBlock()->getParent();
    auto parent_op = region->getParentOp();
    mlir::scf::IfOp parentIfOp =
        mlir::dyn_cast_or_null<mlir::scf::IfOp>(parent_op);
    if (!parentIfOp) {
      // We can handle this as well, but it's really a programmers' bug.
      // Hence, just let them know.
      printErrorMessage("Illegal break directive: unconditional break.");
    }

    // Strategy: predicating every statement potentially executed after at least
    // one break on the absence of break.

    // Create a 'mustBreak' bool at the outer scope:
    mlir::Value mustBreak;
    {
      mlir::OpBuilder::InsertionGuard g(builder);
      builder.setInsertionPointToStart(
          &(m_module.getRegion().getBlocks().front()));
      mustBreak = builder.create<mlir::AllocaOp>(
          location, mlir::MemRefType::get(llvm::ArrayRef<int64_t>{},
                                          builder.getI1Type()));
      // store false (at the outer scope)
      builder.create<mlir::StoreOp>(
          location,
          get_or_create_constant_integer_value(0, location, builder.getI1Type(),
                                               symbol_table, builder),
          mustBreak);
    }

    // Store true here:
    builder.create<mlir::StoreOp>(
        location,
        get_or_create_constant_integer_value(1, location, builder.getI1Type(),
                                             symbol_table, builder),
        mustBreak);
    loop_break_vars.push(mustBreak);
  } else if (stmt == "continue") {
    // TODO: Handle this case.
    if (current_loop_incrementor_block) {
      builder.create<mlir::BranchOp>(location, current_loop_incrementor_block);
    } else if (current_loop_header_block) {
      // this is a while loop
      builder.create<mlir::BranchOp>(location, current_loop_header_block);
    } else {
      printErrorMessage(
          "Something went wrong with continue, no valid block to branch to.");
    }
  } else {
    printErrorMessage("we do not yet support the " + stmt +
                      " control directive.");
  }

  return 0;
}
}  // namespace qcor
 No newline at end of file
+45 −36
Original line number Diff line number Diff line
#include "expression_handler.hpp"
#include "exprtk.hpp"
#include "qasm3_visitor.hpp"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/IR/BlockAndValueMapping.h"

using symbol_table_t = exprtk::symbol_table<double>;
using expression_t = exprtk::expression<double>;
using parser_t = exprtk::parser<double>;
@@ -9,7 +12,7 @@ namespace {
/// Creates a single affine "for" loop, iterating from lbs to ubs with
/// the given step.
/// to construct the body of the loop and is passed the induction variable.
void affineLoopBuilder(mlir::Value lbs_val, mlir::Value ubs_val, int64_t step,
mlir::AffineForOp affineLoopBuilder(mlir::Value lbs_val, mlir::Value ubs_val, int64_t step,
                       std::function<void(mlir::Value)> bodyBuilderFn,
                       mlir::OpBuilder &builder, mlir::Location &loc) {
  if (!ubs_val.getType().isa<mlir::IndexType>()) {
@@ -28,7 +31,7 @@ void affineLoopBuilder(mlir::Value lbs_val, mlir::Value ubs_val, int64_t step,
    mlir::ValueRange lbs(lbs_val);
    mlir::ValueRange ubs(ubs_val);
    // Create the actual loop
    builder.create<mlir::AffineForOp>(
    return builder.create<mlir::AffineForOp>(
        loc, lbs, builder.getMultiDimIdentityMap(lbs.size()), ubs,
        builder.getMultiDimIdentityMap(ubs.size()), step, llvm::None,
        [&](mlir::OpBuilder &nestedBuilder, mlir::Location nestedLoc,
@@ -47,7 +50,7 @@ void affineLoopBuilder(mlir::Value lbs_val, mlir::Value ubs_val, int64_t step,
    ubs_val = builder.create<mlir::MulIOp>(loc, ubs_val, minus_one).result();
    mlir::ValueRange lbs(lbs_val);
    mlir::ValueRange ubs(ubs_val);
    builder.create<mlir::AffineForOp>(
    return builder.create<mlir::AffineForOp>(
        loc, lbs, builder.getMultiDimIdentityMap(lbs.size()), ubs,
        builder.getMultiDimIdentityMap(ubs.size()), -step, llvm::None,
        [&](mlir::OpBuilder &nestedBuilder, mlir::Location nestedLoc,
@@ -309,13 +312,9 @@ antlrcpp::Any qasm3_visitor::visitLoopStatement(

      // HACK: Currently, we don't handle 'break', 'continue' or nested loop
      // in the Affine for loop yet.
      if (!hasChildNodeOfType<qasm3Parser::ControlDirectiveContext>(
              *program_block) &&
          !hasChildNodeOfType<qasm3Parser::LoopStatementContext>(
              *program_block) &&
          (program_block_str.find("QCOR_EXPECT_TRUE") == std::string::npos)) {
      if (program_block_str.find("QCOR_EXPECT_TRUE") == std::string::npos) {
        // Can use Affine for loop....
        affineLoopBuilder(
        auto forLoop = affineLoopBuilder(
            a_value, b_value, c,
            [&](mlir::Value loop_var) {
              // Create a new scope for the for loop
@@ -327,6 +326,43 @@ antlrcpp::Any qasm3_visitor::visitLoopStatement(
              symbol_table.exit_scope();
            },
            builder, location);
        if (!loop_break_vars.empty()) {
          mlir::OpBuilder::InsertionGuard g(builder);
          builder.setInsertionPointToStart(
              &(forLoop.getLoopBody().getBlocks().front()));
          mlir::Value mustBreak =
              builder.create<mlir::LoadOp>(location, loop_break_vars.top());
          assert(mustBreak.getType().isa<mlir::IntegerType>() &&
                 mustBreak.getType().getIntOrFloatBitWidth() == 1);
          loop_break_vars.pop();
          // Wrap/Outline the loop body in an IfOp:
          auto scfIfOp = builder.create<mlir::scf::IfOp>(
              location, mlir::TypeRange(), mustBreak, false);
          size_t count = 0;
          std::vector<mlir::Operation *> ops_to_clone;
          for (auto op_iter = forLoop.getLoopBody().op_begin();
               op_iter != forLoop.getLoopBody().op_end(); ++op_iter) {
            mlir::Operation &op = *op_iter;
            count++;
            // The first 2 are the Load and If that we just inserted
            if (count <= 2) {
              continue;
            }
            ops_to_clone.emplace_back(&op);
          }
          // Last one is affine yield
          // should be left outside:
          assert(
              mlir::dyn_cast_or_null<mlir::AffineYieldOp>(ops_to_clone.back()));
          ops_to_clone.pop_back();
          for (auto &op : ops_to_clone) {
            scfIfOp.getThenBodyBuilder().clone(*op);
          }

          for (auto &op : ops_to_clone) {
            op->remove();
          }
        }
      } else {
        // TODO: Remove this code path once we convert control flow to Affine/SCF
        // Need to use the legacy for loop construction for now...
@@ -475,31 +511,4 @@ antlrcpp::Any qasm3_visitor::visitLoopStatement(

  return 0;
}

antlrcpp::Any qasm3_visitor::visitControlDirective(
    qasm3Parser::ControlDirectiveContext* context) {
  auto location = get_location(builder, file_name, context);

  auto stmt = context->getText();
  // FIXME: using Affine/SCF Ops:
  // affine.yield
  if (stmt == "break") {
    builder.create<mlir::BranchOp>(location, current_loop_exit_block);
  } else if (stmt == "continue") {
    if (current_loop_incrementor_block) {
      builder.create<mlir::BranchOp>(location, current_loop_incrementor_block);
    } else if (current_loop_header_block) {
      // this is a while loop
      builder.create<mlir::BranchOp>(location, current_loop_header_block);
    } else {
      printErrorMessage(
          "Something went wrong with continue, no valid block to branch to.");
    }
  } else {
    printErrorMessage("we do not yet support the " + stmt +
                      " control directive.");
  }

  return 0;
}
}  // namespace qcor
 No newline at end of file