Commit b759c28e authored by Nguyen, Thien Minh's avatar Nguyen, Thien Minh
Browse files

Added test for Affine For loop with If (SCF)



Signed-off-by: default avatarThien Nguyen <nguyentm@ornl.gov>
parent fac11e47
Loading
Loading
Loading
Loading
+51 −1
Original line number Diff line number Diff line
#include "gtest/gtest.h"
#include "qcor_mlir_api.hpp"
#include "gtest/gtest.h"

namespace {
// returns count of non-overlapping occurrences of 'sub' in 'str'
int countSubstring(const std::string &str, const std::string &sub) {
  if (sub.length() == 0)
    return 0;
  int count = 0;
  for (size_t offset = str.find(sub); offset != std::string::npos;
       offset = str.find(sub, offset + sub.length())) {
    ++count;
  }
  return count;
}
} // namespace

// Check Affine-SCF constructs
TEST(qasm3VisitorTester, checkCFG_AffineScf) {
  const std::string qasm_code = R"#(OPENQASM 3;
include "qelib1.inc";

int[64] iterate_value = 0;
int[64] value_5 = 0;
int[64] value_2 = 0;
for i in [0:10] {
    iterate_value = i;
    if (i == 5) {
        print("Iterate over 5");
        value_5 = 5;
    }
    if (i == 2) {
        print("Iterate over 2");
        value_2 = 2;
       
    }
    print("i = ", i);
}

QCOR_EXPECT_TRUE(iterate_value == 9);
QCOR_EXPECT_TRUE(value_5 == 5);
QCOR_EXPECT_TRUE(value_2 == 2);
print("made it out of the loop");
)#";
  auto mlir = qcor::mlir_compile(qasm_code, "affine_scf",
                                 qcor::OutputType::MLIR, false);
  std::cout << mlir << "\n";
  // 1 for loop, 2 if blocks
  EXPECT_EQ(countSubstring(mlir, "affine.for"), 1);
  EXPECT_EQ(countSubstring(mlir, "scf.if"), 2);
  EXPECT_FALSE(qcor::execute(qasm_code, "affine_scf"));
}

TEST(qasm3VisitorTester, checkCtrlDirectives) {
  const std::string uint_index = R"#(OPENQASM 3;
+3 −1
Original line number Diff line number Diff line
@@ -324,7 +324,9 @@ antlrcpp::Any qasm3_visitor::visitLoopStatement(
            [&](mlir::Value loop_var) {
              // Create a new scope for the for loop
              symbol_table.enter_new_scope();
              symbol_table.add_symbol(idx_var_name, loop_var, {}, true);
              auto loop_var_cast = builder.create<mlir::IndexCastOp>(
                  location, builder.getI64Type(), loop_var);
              symbol_table.add_symbol(idx_var_name, loop_var_cast, {}, true);
              visitChildren(program_block);
              symbol_table.exit_scope();
            },