Commit 7ae612a7 authored by Nguyen, Thien Minh's avatar Nguyen, Thien Minh
Browse files

Convert set-based for loop to affine



Signed-off-by: default avatarThien Nguyen <nguyentm@ornl.gov>
parent 2dabec02
Loading
Loading
Loading
Loading
+35 −0
Original line number Diff line number Diff line
@@ -94,6 +94,41 @@ print("made it out of the loop");)#";
  EXPECT_FALSE(qcor::execute(uint_index, "uint_index"));
}

TEST(qasm3VisitorTester, checkCtrlDirectivesSetBasedForLoop) {
  const std::string uint_index = R"#(OPENQASM 3;
include "qelib1.inc";

int[64] sum_value = 0;
int[64] break_value = 0;
int[64] loop_count = 0;

for val in {1,3,5,7} {
  print("iter: ", val);
  if (val < 4) {
    sum_value += val;
  } else {
    break_value = val;
    break;
  }

  loop_count += 1;
}

print(sum_value);
print(loop_count);
print(break_value);
QCOR_EXPECT_TRUE(sum_value == 4);
QCOR_EXPECT_TRUE(loop_count == 2);
QCOR_EXPECT_TRUE(break_value == 5);)#";
  auto mlir = qcor::mlir_compile(uint_index, "uint_index",
                                 qcor::OutputType::MLIR, false);
  std::cout << mlir << "\n";
  // Make sure we're using Affine and SCF
  EXPECT_EQ(countSubstring(mlir, "affine.for"), 1);
  EXPECT_GT(countSubstring(mlir, "scf.if"), 1);
  EXPECT_FALSE(qcor::execute(uint_index, "uint_index"));
}

TEST(qasm3VisitorTester, checkIqpewithIf) {
  const std::string qasm_code = R"#(OPENQASM 3;
include "qelib1.inc";
+65 −71
Original line number Diff line number Diff line
@@ -311,19 +311,13 @@ void qasm3_visitor::createSetBasedForLoop(
    qasm3_expression_generator exp_generator(builder, symbol_table, file_name);
    exp_generator.visit(exp);
    auto value = exp_generator.current_value;

    mlir::Value pos = get_or_create_constant_index_value(counter, location, 64,
                                                         symbol_table, builder);

    builder.create<mlir::StoreOp>(
        location, value, allocation,
        llvm::makeArrayRef(std::vector<mlir::Value>{pos}));

    counter++;
  }

  symbol_table.enter_new_scope();

  auto tmp = get_or_create_constant_index_value(0, location, 64, symbol_table,
                                                builder);
  auto tmp2 = get_or_create_constant_index_value(0, location, 64, symbol_table,
@@ -335,77 +329,77 @@ void qasm3_visitor::createSetBasedForLoop(
      location, 1, builder.getIndexType(), std::vector<mlir::Value>{tmp},
      llvm::makeArrayRef(std::vector<mlir::Value>{tmp}));

  auto a_val = get_or_create_constant_index_value(0, location, 64, symbol_table,
                                                  builder);
  auto b_val = get_or_create_constant_index_value(n_expr, location, 64,
                                                  symbol_table, builder);
  auto c_val = get_or_create_constant_index_value(1, location, 64, symbol_table,
                                                  builder);

  // Strategy...

  // We need to create a header block to check that loop var is still valid
  // it will branch at the end to the body or the exit

  // Then we create the body block, it should branch to the incrementor
  // block

  // Then we create the incrementor block, it should branch back to header

  // Any downstream children that will create blocks will need to know what
  // the fallback block for them is, and it should be the incrementor block
  auto savept = builder.saveInsertionPoint();
  auto currRegion = builder.getBlock()->getParent();
  auto headerBlock = builder.createBlock(currRegion, currRegion->end());
  auto bodyBlock = builder.createBlock(currRegion, currRegion->end());
  auto incBlock = builder.createBlock(currRegion, currRegion->end());
  mlir::Block *exitBlock = builder.createBlock(currRegion, currRegion->end());
  builder.restoreInsertionPoint(savept);

  builder.create<mlir::BranchOp>(location, headerBlock);
  builder.setInsertionPointToStart(headerBlock);

  auto load =
      builder.create<mlir::LoadOp>(location, loop_var_memref, zero_index);
  auto cmp = builder.create<mlir::CmpIOp>(location, mlir::CmpIPredicate::slt,
                                          load, b_val);
  builder.create<mlir::CondBranchOp>(location, cmp, bodyBlock, exitBlock);

  builder.setInsertionPointToStart(bodyBlock);
  // Load the loop variable from the memref allocation
  auto load2 =
      builder.create<mlir::LoadOp>(location, allocation, load.result());
  // Check if the loop is break-able (contains control directive node)
  const bool isLoopBreakable =
      hasChildNodeOfType<qasm3Parser::ControlDirectiveContext>(*context);
  auto cachedBuilder = builder;
  if (isLoopBreakable) {
    // Add the two loop control bool vars:
    mlir::OpBuilder::InsertionGuard g(builder);
    // Top-level if control (skipping the whole loop if false)
    mlir::Value executeWholeLoop = builder.create<mlir::AllocaOp>(
        location,
        mlir::MemRefType::get(llvm::ArrayRef<int64_t>{}, builder.getI1Type()));
    // Loop body control: skipping portions of the the body if
    // false: e.g., handle 'continue'-like directive.
    mlir::Value executeThisBlock = builder.create<mlir::AllocaOp>(
        location,
        mlir::MemRefType::get(llvm::ArrayRef<int64_t>{}, builder.getI1Type()));
    // store true
    builder.create<mlir::StoreOp>(
        location,
        get_or_create_constant_integer_value(1, location, builder.getI1Type(),
                                             symbol_table, builder),
        executeWholeLoop);
    builder.create<mlir::StoreOp>(
        location,
        get_or_create_constant_integer_value(1, location, builder.getI1Type(),
                                             symbol_table, builder),
        executeThisBlock);
    for_loop_control_vars.push(
        std::make_pair(executeWholeLoop, executeThisBlock));
  }

  // Can use Affine for loop....
  auto forLoop = affineLoopBuilder(
      a_val, b_val, 1,
      [&](mlir::Value loop_index_var) {
        // Create a new scope for the for loop
        symbol_table.enter_new_scope();
        // Load the value at the index from the set
        auto loop_var =
            builder.create<mlir::LoadOp>(location, allocation, loop_index_var)
                .result();
        // Save the loaded value as the loop variable name
  symbol_table.add_symbol(idx_var_name, load2.result(), {}, true);

  current_loop_exit_block = exitBlock;

  current_loop_incrementor_block = incBlock;
        symbol_table.add_symbol(idx_var_name, loop_var, {}, true);

        if (isLoopBreakable) {
          auto [cond1, cond2] = for_loop_control_vars.top();
          // Wrap/Outline the loop body in an IfOp:
          auto scfIfOp = builder.create<mlir::scf::IfOp>(
              location, mlir::TypeRange(),
              builder.create<mlir::LoadOp>(location, cond1), false);
          auto thenBodyBuilder = scfIfOp.getThenBodyBuilder();
          auto cached_builder = builder;
          builder = thenBodyBuilder;
          visitChildren(program_block);

  current_loop_header_block = nullptr;
  current_loop_exit_block = nullptr;

  builder.create<mlir::BranchOp>(location, incBlock);

  builder.setInsertionPointToStart(incBlock);
  auto load_inc =
      builder.create<mlir::LoadOp>(location, loop_var_memref, zero_index);
  auto add = builder.create<mlir::AddIOp>(location, load_inc, c_val);

  assert(tmp2.getType().isa<mlir::IndexType>());
  builder.create<mlir::StoreOp>(
      location, add, loop_var_memref,
      llvm::makeArrayRef(std::vector<mlir::Value>{tmp2}));

  builder.create<mlir::BranchOp>(location, headerBlock);

  builder.setInsertionPointToStart(exitBlock);

  symbol_table.set_last_created_block(exitBlock);

  // Exit scope and restore insertion
          builder = cached_builder;
        } else {
          visitChildren(program_block);
        }
        symbol_table.exit_scope();

        if (isLoopBreakable) {
          for_loop_control_vars.pop();
        }
      },
      builder, location);
  builder = cachedBuilder;
}

void qasm3_visitor::createWhileLoop(