Commit e9615ea1 authored by Nguyen, Thien Minh's avatar Nguyen, Thien Minh
Browse files

Added a handler for bubbling return statements in nested loops



Signed-off-by: default avatarThien Nguyen <nguyentm@ornl.gov>
parent 1dae38ee
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -196,6 +196,7 @@ class qasm3_visitor : public qasm3::qasm3BaseVisitor {
                       mlir::OpBuilder *optional_builder = nullptr);
  void insertLoopContinue(mlir::Location &location,
                       mlir::OpBuilder *optional_builder = nullptr);
  void handleReturnInLoop(mlir::Location &location);
  // Insert a conditional return.
  // Assert that the insert location is *returnable*
  // i.e., in the FuncOp region.
+54 −0
Original line number Diff line number Diff line
@@ -224,6 +224,60 @@ QCOR_EXPECT_TRUE(val2 == 3);
}


TEST(qasm3VisitorTester, checkEarlyReturnNestedLoop) {
  const std::string uint_index = R"#(OPENQASM 3;
include "qelib1.inc";

def generate_number(int[64]: break_value, int[64]: max_run) -> int[64] {
  int[64] run_count = 0;
  for i in [0:10] {
    for j in [0:10] {
      run_count += 1;
      if (run_count > max_run) {
        print("Exceeding max run count of", max_run);
        return 3;
      }
      
      if (i == j && i > break_value) {
        print("Return at i = ", i);
        print("Return at j = ", j);
        return run_count;
      }

      print("i =", i);
      print("j =", j);
    }
    print("Out of inner loop");
  }

  print("make it to the end");
  return 0;  
}

// Case 1: run to the end.
int[64] val = generate_number(10, 100);
print("Result =", val);
QCOR_EXPECT_TRUE(val == 0);

// Case 2: Return @ (i == j && i > break_value) 
// i = 0: 10; i = 1: 10; i = 2: j = 0, 1, 2 
// => 23 runs (return run_count in this path)
val = generate_number(1, 100);
print("Result =", val);
QCOR_EXPECT_TRUE(val == 23);

// Case 3: return due to max_run limit
// limit to 20 (less than 23) => return value 3
val = generate_number(1, 20);
print("Result =", val);
QCOR_EXPECT_TRUE(val == 3);
)#";
  auto mlir = qcor::mlir_compile(uint_index, "uint_index",
                                 qcor::OutputType::MLIR, false);
  std::cout << mlir << "\n";
  EXPECT_FALSE(qcor::execute(uint_index, "uint_index"));
}

TEST(qasm3VisitorTester, checkIqpewithIf) {
  const std::string qasm_code = R"#(OPENQASM 3;
include "qelib1.inc";
+117 −18
Original line number Diff line number Diff line
@@ -93,6 +93,41 @@ qasm3_visitor::visitLoopStatement(qasm3Parser::LoopStatementContext *context) {
  return 0;
}

void qasm3_visitor::handleReturnInLoop(mlir::Location &location) {
  if (region_early_return_vars.has_value()) {
    auto parentOp = builder.getBlock()->getParent()->getParentOp();
    // Make it out to the Function scope:
    if (parentOp && mlir::dyn_cast_or_null<mlir::FuncOp>(parentOp)) {
      auto &[boolVar, returnVar] = region_early_return_vars.value();
      mlir::Value returnedValue;
      if (returnVar.has_value()) {
        assert(returnVar.value().getType().isa<mlir::MemRefType>());
        returnedValue =
            builder.create<mlir::LoadOp>(location, returnVar.value());
      }
      conditionalReturn(location,
                        builder.create<mlir::LoadOp>(location, boolVar),
                        returnedValue);
      region_early_return_vars.reset();
      assert(symbol_table.get_last_created_block());
    } else if (!loop_control_directive_bool_vars.empty()) {
      // The outer loop needs to set-up as a breakable loop as well.
      auto &[boolVar, returnVar] = region_early_return_vars.value();
      auto returnIfOp = builder.create<mlir::scf::IfOp>(
          location, mlir::TypeRange(),
          builder.create<mlir::LoadOp>(location, boolVar), false);
      // Break the outer loop if the return flag has been set
      auto opBuilder = returnIfOp.getThenBodyBuilder();
      insertLoopBreak(location, &opBuilder);
      // Treating the remaining code in the outer loop after the nested loop
      // as conditional (i.e., could be bypassed if the continue condition set).
      insertLoopContinue(location);
    } else {
      printErrorMessage("Internal error: Unable to handle return statement in a loop.");
    }
  }
}

void qasm3_visitor::createRangeBasedForLoop(
    qasm3Parser::LoopStatementContext *context) {
  auto location = get_location(builder, file_name, context);
@@ -309,24 +344,7 @@ void qasm3_visitor::createRangeBasedForLoop(
      },
      builder, location);
  builder = cachedBuilder;

  auto parentOp = builder.getBlock()->getParent()->getParentOp();
  if (parentOp && mlir::dyn_cast_or_null<mlir::FuncOp>(parentOp)) {
    if (region_early_return_vars.has_value()) {
      auto &[boolVar, returnVar] = region_early_return_vars.value();
      mlir::Value returnedValue;
      if (returnVar.has_value()) {
        assert(returnVar.value().getType().isa<mlir::MemRefType>());
        returnedValue =
            builder.create<mlir::LoadOp>(location, returnVar.value());
      }

      conditionalReturn(location,
                        builder.create<mlir::LoadOp>(location, boolVar),
                        returnedValue);
      region_early_return_vars.reset();
    }
  }
  handleReturnInLoop(location);
}

void qasm3_visitor::createSetBasedForLoop(
@@ -379,7 +397,44 @@ void qasm3_visitor::createSetBasedForLoop(
                                                  symbol_table, builder);

  // Check if the loop is break-able (contains control directive node)
  // The loop contains an early return.
  const bool loopEarlyReturn =
      hasChildNodeOfType<qasm3Parser::ReturnStatementContext>(*context) ||
      hasChildNodeOfType<qasm3Parser::Qcor_test_statementContext>(*context);
  // Top-level only
  if (loopEarlyReturn && !region_early_return_vars.has_value()) {
    mlir::OpBuilder::InsertionGuard g(builder);
    mlir::Value shouldReturn = builder.create<mlir::AllocaOp>(
        location,
        mlir::MemRefType::get(llvm::ArrayRef<int64_t>{}, builder.getI1Type()));
    // Store false:
    builder.create<mlir::StoreOp>(
        location,
        get_or_create_constant_integer_value(0, location, builder.getI1Type(),
                                             symbol_table, builder),
        shouldReturn);

    // Note: we don't know what the return value is yet
    if (current_function_return_type) {
      llvm::ArrayRef<int64_t> shaperef{};
      mlir::Value return_var_memref = builder.create<mlir::AllocaOp>(
          location,
          mlir::MemRefType::get(shaperef, current_function_return_type));
      region_early_return_vars =
          std::make_pair(shouldReturn, return_var_memref);
    } else {
      llvm::ArrayRef<int64_t> shaperef{};
      mlir::Value return_var_memref = builder.create<mlir::AllocaOp>(
          location, mlir::MemRefType::get(shaperef, builder.getI32Type()));
      region_early_return_vars =
          std::make_pair(shouldReturn, return_var_memref);
    }
  }

  // Loop has control directives (break/continue)
  // A loop has return statement must be breakable
  const bool isLoopBreakable =
      loopEarlyReturn ||
      hasChildNodeOfType<qasm3Parser::ControlDirectiveContext>(*context);
  auto cachedBuilder = builder;
  if (isLoopBreakable) {
@@ -444,6 +499,7 @@ void qasm3_visitor::createSetBasedForLoop(
      },
      builder, location);
  builder = cachedBuilder;
  handleReturnInLoop(location);
}

void qasm3_visitor::createWhileLoop(
@@ -454,7 +510,46 @@ void qasm3_visitor::createWhileLoop(
  assert(loop_signature->booleanExpression());
  auto main_block = builder.saveInsertionPoint();
  auto cachedBuilder = builder;
  
  // Check if the loop is break-able (contains control directive node)
  // The loop contains an early return.
  const bool loopEarlyReturn =
      hasChildNodeOfType<qasm3Parser::ReturnStatementContext>(*context) ||
      hasChildNodeOfType<qasm3Parser::Qcor_test_statementContext>(*context);
  // Top-level only
  if (loopEarlyReturn && !region_early_return_vars.has_value()) {
    mlir::OpBuilder::InsertionGuard g(builder);
    mlir::Value shouldReturn = builder.create<mlir::AllocaOp>(
        location,
        mlir::MemRefType::get(llvm::ArrayRef<int64_t>{}, builder.getI1Type()));
    // Store false:
    builder.create<mlir::StoreOp>(
        location,
        get_or_create_constant_integer_value(0, location, builder.getI1Type(),
                                             symbol_table, builder),
        shouldReturn);

    // Note: we don't know what the return value is yet
    if (current_function_return_type) {
      llvm::ArrayRef<int64_t> shaperef{};
      mlir::Value return_var_memref = builder.create<mlir::AllocaOp>(
          location,
          mlir::MemRefType::get(shaperef, current_function_return_type));
      region_early_return_vars =
          std::make_pair(shouldReturn, return_var_memref);
    } else {
      llvm::ArrayRef<int64_t> shaperef{};
      mlir::Value return_var_memref = builder.create<mlir::AllocaOp>(
          location, mlir::MemRefType::get(shaperef, builder.getI32Type()));
      region_early_return_vars =
          std::make_pair(shouldReturn, return_var_memref);
    }
  }

  // Loop has control directives (break/continue)
  // A loop has return statement must be breakable
  const bool isLoopBreakable =
      loopEarlyReturn ||
      hasChildNodeOfType<qasm3Parser::ControlDirectiveContext>(*context);

  if (isLoopBreakable) {
@@ -524,6 +619,10 @@ void qasm3_visitor::createWhileLoop(
  }

  builder = cachedBuilder;

  // Handle potential return statement in the loop.
  handleReturnInLoop(location);

  // 'After' block must end with a yield op.
  mlir::Operation &lastOp = whileOp.after().front().getOperations().back();
  builder.setInsertionPointAfter(&lastOp);
+3 −2
Original line number Diff line number Diff line
@@ -43,8 +43,9 @@ antlrcpp::Any qasm3_visitor::visitQcor_test_statement(
    // This is in an affine region (loops)
    auto &[boolVar, returnVar] = region_early_return_vars.value();
    builder.create<mlir::StoreOp>(location, expr_value, boolVar);
    mlir::Value one_i32 = builder.create<mlir::ConstantOp>(
        location, mlir::IntegerAttr::get(thenBodyBuilder.getI32Type(), 1));
    mlir::Value one_i32 = get_or_create_constant_integer_value(
        1, location, thenBodyBuilder.getI32Type(), symbol_table,
        thenBodyBuilder);
    builder.create<mlir::StoreOp>(location, one_i32, returnVar.value());
    auto &[cond1, cond2] = loop_control_directive_bool_vars.top();
    // Wrap/Outline the loop body in an IfOp: