Commit 4d8dee7f authored by Nguyen, Thien Minh's avatar Nguyen, Thien Minh
Browse files

Fixed for unit tests



Affine for loops cannot be mixed with other manual branching structures yet. Need to debug further; disable for now.

The affine for validation failure seems to be related to IndexType, not memref.

Signed-off-by: default avatarThien Nguyen <nguyentm@ornl.gov>
parent 60634256
Loading
Loading
Loading
Loading
+18 −12
Original line number Diff line number Diff line
@@ -12,6 +12,14 @@ namespace {
void affineLoopBuilder(mlir::Value lbs_val, mlir::Value ubs_val, int64_t step,
                       std::function<void(mlir::Value)> bodyBuilderFn,
                       mlir::OpBuilder &builder, mlir::Location &loc) {
  if (!ubs_val.getType().isa<mlir::IndexType>()) {
    ubs_val =
        builder.create<mlir::IndexCastOp>(loc, builder.getIndexType(), ubs_val);
  }
  if (!lbs_val.getType().isa<mlir::IndexType>()) {
    lbs_val =
        builder.create<mlir::IndexCastOp>(loc, builder.getIndexType(), lbs_val);
  }
  // Note: Affine for loop only accepts **positive** step:
  // The stride, represented by step, is a positive constant integer which
  // defaults to “1” if not present.
@@ -222,18 +230,12 @@ antlrcpp::Any qasm3_visitor::visitLoopStatement(
          c_value = get_or_create_constant_integer_value(c, location, int_type,
                                                         symbol_table, builder);

      // Either a_value or b_value (loop bounds) is a memref
      // (For some reason, affine loop inliner doesn't work in this case, 
      // causing some validation errors)
      bool loop_bounds_are_memref = false;
      
      qasm3_expression_generator exp_generator(builder, symbol_table,
                                               file_name);
      exp_generator.visit(range->expression(0));
      a_value = exp_generator.current_value;
      if (a_value.getType().isa<mlir::MemRefType>()) {
        a_value = builder.create<mlir::LoadOp>(location, a_value);
        loop_bounds_are_memref = true;
      }

      if (n_expr == 3) {
@@ -243,7 +245,6 @@ antlrcpp::Any qasm3_visitor::visitLoopStatement(
        b_value = exp_generator.current_value;
        if (b_value.getType().isa<mlir::MemRefType>()) {
          b_value = builder.create<mlir::LoadOp>(location, b_value);
          loop_bounds_are_memref = true;
        }

        if (symbol_table.has_symbol(range->expression(1)->getText())) {
@@ -269,7 +270,6 @@ antlrcpp::Any qasm3_visitor::visitLoopStatement(
        b_value = exp_generator.current_value;
        if (b_value.getType().isa<mlir::MemRefType>()) {
          b_value = builder.create<mlir::LoadOp>(location, b_value);
          loop_bounds_are_memref = true;
        }
      }

@@ -278,10 +278,16 @@ antlrcpp::Any qasm3_visitor::visitLoopStatement(

      // HACK: Currently, we don't handle 'if', 'break', 'continue'
      // in the Affine for loop yet.
      if (!loop_bounds_are_memref &&
          program_block_str.find("if") == std::string::npos &&
      if (program_block_str.find("if") == std::string::npos &&
          program_block_str.find("break") == std::string::npos &&
          program_block_str.find("continue") == std::string::npos) {
          program_block_str.find("continue") == std::string::npos &&
          // This is equivalent to an "if"
          program_block_str.find("QCOR_EXPECT_TRUE") == std::string::npos &&
          // We can only handle nested for loops if the inner one is also an
          // affine one For now, don't do that since we're not sure.
          program_block_str.find("for") == std::string::npos &&
          // While loop is not converted to affine yet.
          program_block_str.find("while") == std::string::npos) {
        // Can use Affine for loop....
        affineLoopBuilder(
            a_value, b_value, c,