Commit 8b4841c8 authored by Nguyen, Thien Minh's avatar Nguyen, Thien Minh
Browse files

Use Affine for loop and add a test



Our Affine for loop cannot handle break inside its body yet...

Fallback to the manual construction

Signed-off-by: default avatarThien Nguyen <nguyentm@ornl.gov>
parent b079cf0a
Loading
Loading
Loading
Loading
+25 −0
Original line number Diff line number Diff line
@@ -152,6 +152,31 @@ cx q[0], q[1];
  EXPECT_EQ(countSubstring(llvm, "__quantum__rt__qubit_release_array"), 0);
}

TEST(qasm3PassManagerTester, checkLoopUnroll) {
  // Unroll the loop:
  // cancel all X gates; combine rx
  const std::string src = R"#(OPENQASM 3;
include "stdgates.inc";
qubit q[2];
for i in [0:10] {
    x q[0];
    rx(0.123) q[1];
}
)#";
  auto llvm =
      qcor::mlir_compile(src, "test_kernel", qcor::OutputType::LLVMIR, false);
  std::cout << "LLVM:\n" << llvm << "\n";
  
  // Get the main kernel section only (there is the oracle LLVM section as well)
  llvm = llvm.substr(llvm.find("@test_kernel"));
  const auto last = llvm.find_first_of("}");
  llvm = llvm.substr(0, last + 1);
  std::cout << "LLVM:\n" << llvm << "\n";
  // Only a single Rx remains (combine all angles)
  EXPECT_EQ(countSubstring(llvm, "__quantum__qis"), 1);
  EXPECT_EQ(countSubstring(llvm, "__quantum__qis__rx"), 1);
}

int main(int argc, char **argv) {
  ::testing::InitGoogleTest(&argc, argv);
  auto ret = RUN_ALL_TESTS();
+77 −53
Original line number Diff line number Diff line
@@ -234,6 +234,27 @@ antlrcpp::Any qasm3_visitor::visitLoopStatement(
        }
      }

      const std::string program_block_str = program_block->getText();
      // std::cout << "HOWDY:\n" << program_block_str << "\n";

      // HACK: Currently, we don't handle 'if', 'break', 'continue'
      // in the Affine for loop yet.
      if (program_block_str.find("if") == std::string::npos &&
          program_block_str.find("break") == std::string::npos &&
          program_block_str.find("continue") == std::string::npos) {
        // Can use Affine for loop....
        affineLoopBuilder(
            a_value, b_value, c,
            [&](mlir::Value loop_var) {
              // Create a new scope for the for loop
              symbol_table.enter_new_scope();
              symbol_table.add_symbol(idx_var_name, loop_var, {}, true);
              visitChildren(program_block);
              symbol_table.exit_scope();
            },
            builder, location);
      } else {
        // Need to use the legacy for loop construction for now...
        // Create a new scope for the for loop
        symbol_table.enter_new_scope();

@@ -245,22 +266,24 @@ antlrcpp::Any qasm3_visitor::visitLoopStatement(

        // Save the current builder point
        // auto savept = builder.saveInsertionPoint();
      auto loaded_var = builder.create<mlir::LoadOp>(location, loop_var_memref);
        auto loaded_var =
            builder.create<mlir::LoadOp>(location, loop_var_memref);

        symbol_table.add_symbol(idx_var_name, loaded_var, {}, true);

        // Strategy...

      // We need to create a header block to check that loop var is still valid
      // it will branch at the end to the body or the exit
        // We need to create a header block to check that loop var is still
        // valid it will branch at the end to the body or the exit

        // Then we create the body block, it should branch to the incrementor
        // block

        // Then we create the incrementor block, it should branch back to header

      // Any downstream children that will create blocks will need to know what
      // the fallback block for them is, and it should be the incrementor block
        // Any downstream children that will create blocks will need to know
        // what the fallback block for them is, and it should be the incrementor
        // block
        auto savept = builder.saveInsertionPoint();
        auto currRegion = builder.getBlock()->getParent();
        auto headerBlock = builder.createBlock(currRegion, currRegion->end());
@@ -275,7 +298,9 @@ antlrcpp::Any qasm3_visitor::visitLoopStatement(

        auto load = builder.create<mlir::LoadOp>(location, loop_var_memref);
        auto cmp = builder.create<mlir::CmpIOp>(
          location, c > 0 ? mlir::CmpIPredicate::slt : mlir::CmpIPredicate::sge, load, b_value);
            location,
            c > 0 ? mlir::CmpIPredicate::slt : mlir::CmpIPredicate::sge, load,
            b_value);
        builder.create<mlir::CondBranchOp>(location, cmp, bodyBlock, exitBlock);

        builder.setInsertionPointToStart(bodyBlock);
@@ -299,7 +324,6 @@ antlrcpp::Any qasm3_visitor::visitLoopStatement(

        auto add = builder.create<mlir::AddIOp>(location, load_inc, c_value);


        builder.create<mlir::StoreOp>(location, add, loop_var_memref);

        builder.create<mlir::BranchOp>(location, headerBlock);
@@ -309,7 +333,7 @@ antlrcpp::Any qasm3_visitor::visitLoopStatement(
        symbol_table.set_last_created_block(exitBlock);

        symbol_table.exit_scope();

      }
    } else {
      printErrorMessage(
          "For loops must be of form 'for i in {SET}' or 'for i in [RANGE]'.");