Commit 0230a552 authored by Nguyen, Thien Minh's avatar Nguyen, Thien Minh
Browse files

Handle the last loop-control directive: early return



When porting  QCOR_EXPECT_TRUE to region-based construct, I realized there is another directive which hasn't been handled: early return.

Adding a test for early returns.

Signed-off-by: default avatarThien Nguyen <nguyentm@ornl.gov>
parent f354f4f2
Loading
Loading
Loading
Loading
+19 −0
Original line number Diff line number Diff line
@@ -166,6 +166,14 @@ class qasm3_visitor : public qasm3::qasm3BaseVisitor {
  /// the remaining ops in the body.
  /// We use a stack to handle nested loops, which are all break-able.
  std::stack<std::pair<mlir::Value, mlir::Value>> loop_control_directive_bool_vars;

  // Early return loop control directive: return statement in the loop body.
  // This will escape all loops until the *FuncOp* body and return.
  // Note: MLIR validation will require ReturnOp in the **Region** of a FuncOp.
  // First value: the boolean to control the early return (if true)
  // Second value: the return value.
  std::optional<std::pair<mlir::Value, std::optional<mlir::Value>>>
      region_early_return_vars;
  // This method will add correct number of InstOps
  // based on quantum gate broadcasting
  void createInstOps_HandleBroadcast(std::string name,
@@ -183,6 +191,17 @@ class qasm3_visitor : public qasm3::qasm3BaseVisitor {
  void createSetBasedForLoop(qasm3Parser::LoopStatementContext *context);
  // While loop
  void createWhileLoop(qasm3Parser::LoopStatementContext *context);
  // Insert MLIR loop break
  void insertLoopBreak(mlir::Location &location,
                       mlir::OpBuilder *optional_builder = nullptr);
  void insertLoopContinue(mlir::Location &location,
                       mlir::OpBuilder *optional_builder = nullptr);
  // Insert a conditional return.
  // Assert that the insert location is *returnable*
  // i.e., in the FuncOp region.
  void conditionalReturn(mlir::Location &location, mlir::Value cond,
                         mlir::Value returnVal,
                         mlir::OpBuilder *optional_builder = nullptr);

  // This function serves as a utility for creating a MemRef and
  // corresponding AllocOp of a given 1d shape. It will also store
+31 −0
Original line number Diff line number Diff line
@@ -159,6 +159,37 @@ QCOR_EXPECT_TRUE(j == 7);
  EXPECT_FALSE(qcor::execute(uint_index, "uint_index"));
}

TEST(qasm3VisitorTester, checkEarlyReturn) {
  const std::string uint_index = R"#(OPENQASM 3;
include "qelib1.inc";

def generate_number(int[64]: count) -> int[64] {
  for i in [0:count] {
    if (i > 10) {
      print("Return at ", i);
      return 5;
    }
    print("i =", i);
  }

  print("make it to the end");
  return 1;  
}

int[64] val1 = generate_number(4);
print("Result 1 =", val1);
QCOR_EXPECT_TRUE(val1 == 1);
// Call it with 20 -> activate the early return
int[64] val2 = generate_number(20);
print("Result 2 =", val2);
QCOR_EXPECT_TRUE(val2 == 5);
)#";
  auto mlir = qcor::mlir_compile(uint_index, "uint_index",
                                 qcor::OutputType::MLIR, false);
  std::cout << mlir << "\n";
  EXPECT_FALSE(qcor::execute(uint_index, "uint_index"));
}

TEST(qasm3VisitorTester, checkIqpewithIf) {
  const std::string qasm_code = R"#(OPENQASM 3;
include "qelib1.inc";
+74 −31
Original line number Diff line number Diff line
@@ -3,24 +3,10 @@
#include "mlir/Dialect/SCF/SCF.h"

namespace qcor {
antlrcpp::Any qasm3_visitor::visitControlDirective(
    qasm3Parser::ControlDirectiveContext* context) {
  auto location = get_location(builder, file_name, context);

  auto stmt = context->getText();

  // Strategy:
  // Converting break/continue directives to region-based control-flow (in the
  // Affine/SCF dialects) Following the direction here:
  // https://llvm.discourse.group/t/dynamic-control-flow-break-like-operation/2495/16
  // e.g., predicating every statement potentially executed after at least one
  // break on the absence of break. This doesn’t break our SCF/Affine analyses
  // and transformations, that rely on there being single block and static
  // control flow.
  // For example, with a for loop: we need to wrap the whole body in a break check
  // **and** each subsequent block after the *break/continue* point
  if (stmt == "break") {
    mlir::Region *region = builder.getInsertionBlock()->getParent();
void qasm3_visitor::insertLoopBreak(mlir::Location &location,
                                    mlir::OpBuilder *optional_builder) {
  mlir::OpBuilder &opBuilder = optional_builder ? *optional_builder : builder;
  mlir::Region *region = opBuilder.getInsertionBlock()->getParent();
  auto parent_op = region->getParentOp();
  mlir::scf::IfOp parentIfOp =
      mlir::dyn_cast_or_null<mlir::scf::IfOp>(parent_op);
@@ -32,24 +18,81 @@ antlrcpp::Any qasm3_visitor::visitControlDirective(

  // Set an attribute so that we can detect this after handling this.
  parentIfOp->setAttr("control-directive",
                       mlir::IntegerAttr::get(builder.getIntegerType(1), 1));
                      mlir::IntegerAttr::get(opBuilder.getIntegerType(1), 1));
  assert(!loop_control_directive_bool_vars.empty());
  auto [cond1, cond2] = loop_control_directive_bool_vars.top();

  // Store false to both the break and continue:
  // i.e., bypass the whole for loop and the rest of the loop body:
  // (1) This bool will bypass the whole for loop body:
    builder.create<mlir::StoreOp>(
  opBuilder.create<mlir::StoreOp>(
      location,
        get_or_create_constant_integer_value(0, location, builder.getI1Type(),
                                             symbol_table, builder),
      get_or_create_constant_integer_value(0, location, opBuilder.getI1Type(),
                                           symbol_table, opBuilder),
      cond1);
  // (2) This bool will bypass rest of the loop body (after this point)
    builder.create<mlir::StoreOp>(
  opBuilder.create<mlir::StoreOp>(
      location,
        get_or_create_constant_integer_value(0, location, builder.getI1Type(),
                                             symbol_table, builder),
      get_or_create_constant_integer_value(0, location, opBuilder.getI1Type(),
                                           symbol_table, opBuilder),
      cond2);
}

void qasm3_visitor::insertLoopContinue(mlir::Location &location,
                                       mlir::OpBuilder *optional_builder) {
  mlir::OpBuilder &opBuilder = optional_builder ? *optional_builder : builder;
  assert(!loop_control_directive_bool_vars.empty());
  auto &[cond1, cond2] = loop_control_directive_bool_vars.top();
  // Wrap/Outline the loop body in an IfOp:
  auto continuationIfOp = opBuilder.create<mlir::scf::IfOp>(
      location, mlir::TypeRange(),
      opBuilder.create<mlir::LoadOp>(location, cond2), false);
  auto continuationThenBodyBuilder = continuationIfOp.getThenBodyBuilder();
  opBuilder = continuationThenBodyBuilder;
}

void qasm3_visitor::conditionalReturn(mlir::Location &location,
                                      mlir::Value cond, mlir::Value returnVal,
                                      mlir::OpBuilder *optional_builder) {
  mlir::OpBuilder &opBuilder = optional_builder ? *optional_builder : builder;
  assert(cond.getType() == opBuilder.getI1Type());

  auto savept = opBuilder.saveInsertionPoint();
  auto currRegion = opBuilder.getBlock()->getParent();
  assert(currRegion->getParentOp() &&
         mlir::dyn_cast_or_null<mlir::FuncOp>(currRegion->getParentOp()));

  // Create a CFG branch:
  auto thenBlock = opBuilder.createBlock(currRegion, currRegion->end());
  auto exitBlock = opBuilder.createBlock(currRegion, currRegion->end());
  opBuilder.setInsertionPointToStart(thenBlock);
  opBuilder.create<mlir::ReturnOp>(location,
                                   llvm::ArrayRef<mlir::Value>{returnVal});
  // builder.create<mlir::BranchOp>(location, exitBlock);
  opBuilder.restoreInsertionPoint(savept);
  opBuilder.create<mlir::CondBranchOp>(location, cond, thenBlock, exitBlock);
  opBuilder.setInsertionPointToStart(exitBlock);
  symbol_table.set_last_created_block(exitBlock);
}

antlrcpp::Any qasm3_visitor::visitControlDirective(
    qasm3Parser::ControlDirectiveContext *context) {
  auto location = get_location(builder, file_name, context);

  auto stmt = context->getText();

  // Strategy:
  // Converting break/continue directives to region-based control-flow (in the
  // Affine/SCF dialects) Following the direction here:
  // https://llvm.discourse.group/t/dynamic-control-flow-break-like-operation/2495/16
  // e.g., predicating every statement potentially executed after at least one
  // break on the absence of break. This doesn’t break our SCF/Affine analyses
  // and transformations, that rely on there being single block and static
  // control flow.
  // For example, with a for loop: we need to wrap the whole body in a break check
  // **and** each subsequent block after the *break/continue* point
  if (stmt == "break") {
    insertLoopBreak(location);
  } else if (stmt == "continue") {
    mlir::Region *region = builder.getInsertionBlock()->getParent();
    auto parent_op = region->getParentOp();
+55 −0
Original line number Diff line number Diff line
@@ -210,7 +210,44 @@ void qasm3_visitor::createRangeBasedForLoop(
  }

  // Check if the loop is break-able (contains control directive node)
  // The loop contains an early return.
  const bool loopEarlyReturn =
      hasChildNodeOfType<qasm3Parser::ReturnStatementContext>(*context) ||
      hasChildNodeOfType<qasm3Parser::Qcor_test_statementContext>(*context);
  // Top-level only
  if (loopEarlyReturn && !region_early_return_vars.has_value()) {
    mlir::OpBuilder::InsertionGuard g(builder);
    mlir::Value shouldReturn = builder.create<mlir::AllocaOp>(
        location,
        mlir::MemRefType::get(llvm::ArrayRef<int64_t>{}, builder.getI1Type()));
    // Store false:
    builder.create<mlir::StoreOp>(
        location,
        get_or_create_constant_integer_value(0, location, builder.getI1Type(),
                                             symbol_table, builder),
        shouldReturn);

    // Note: we don't know what the return value is yet
    if (current_function_return_type) {
      llvm::ArrayRef<int64_t> shaperef{};
      mlir::Value return_var_memref = builder.create<mlir::AllocaOp>(
          location,
          mlir::MemRefType::get(shaperef, current_function_return_type));
      region_early_return_vars =
          std::make_pair(shouldReturn, return_var_memref);
    } else {
      llvm::ArrayRef<int64_t> shaperef{};
      mlir::Value return_var_memref = builder.create<mlir::AllocaOp>(
          location, mlir::MemRefType::get(shaperef, builder.getI32Type()));
      region_early_return_vars =
          std::make_pair(shouldReturn, return_var_memref);
    }
  }

  // Loop has control directives (break/continue)
  // A loop has return statement must be breakable
  const bool isLoopBreakable =
      loopEarlyReturn ||
      hasChildNodeOfType<qasm3Parser::ControlDirectiveContext>(*context);
  auto cachedBuilder = builder;
  if (isLoopBreakable) {
@@ -272,6 +309,24 @@ void qasm3_visitor::createRangeBasedForLoop(
      },
      builder, location);
  builder = cachedBuilder;

  auto parentOp = builder.getBlock()->getParent()->getParentOp();
  if (parentOp && mlir::dyn_cast_or_null<mlir::FuncOp>(parentOp)) {
    if (region_early_return_vars.has_value()) {
      auto &[boolVar, returnVar] = region_early_return_vars.value();
      mlir::Value returnedValue;
      if (returnVar.has_value()) {
        assert(returnVar.value().getType().isa<mlir::MemRefType>());
        returnedValue =
            builder.create<mlir::LoadOp>(location, returnVar.value());
      }

      conditionalReturn(location,
                        builder.create<mlir::LoadOp>(location, boolVar),
                        returnedValue);
      region_early_return_vars.reset();
    }
  }
}

void qasm3_visitor::createSetBasedForLoop(
+23 −3
Original line number Diff line number Diff line
@@ -37,9 +37,29 @@ antlrcpp::Any qasm3_visitor::visitQcor_test_statement(
          location, str_type, str_attr, var_name_attr);
  thenBodyBuilder.create<mlir::quantum::PrintOp>(
      location, llvm::makeArrayRef(std::vector<mlir::Value>{string_literal}));
  auto integer_attr = mlir::IntegerAttr::get(thenBodyBuilder.getI32Type(), 1);
  auto ret = builder.create<mlir::ConstantOp>(location, integer_attr);
  thenBodyBuilder.create<mlir::ReturnOp>(location, llvm::ArrayRef<mlir::Value>(ret));

  if (region_early_return_vars.has_value()) {
    insertLoopBreak(location, &thenBodyBuilder);
    // This is in an affine region (loops)
    auto &[boolVar, returnVar] = region_early_return_vars.value();
    builder.create<mlir::StoreOp>(location, expr_value, boolVar);
    mlir::Value one_i32 = builder.create<mlir::ConstantOp>(
        location, mlir::IntegerAttr::get(thenBodyBuilder.getI32Type(), 1));
    builder.create<mlir::StoreOp>(location, one_i32, returnVar.value());
    auto &[cond1, cond2] = loop_control_directive_bool_vars.top();
    // Wrap/Outline the loop body in an IfOp:
    auto continuationIfOp = builder.create<mlir::scf::IfOp>(
        location, mlir::TypeRange(),
        builder.create<mlir::LoadOp>(location, cond2), false);
    auto continuationThenBodyBuilder = continuationIfOp.getThenBodyBuilder();
    builder = continuationThenBodyBuilder;
  } else {
    // Outside scope: just do early return
    conditionalReturn(
        location, expr_value,
        builder.create<mlir::ConstantOp>(
            location, mlir::IntegerAttr::get(thenBodyBuilder.getI32Type(), 1)));
  }

  return 0;
}
Loading