Commit 75484499 authored by Nguyen, Thien Minh's avatar Nguyen, Thien Minh
Browse files

Fixed nested range-based for loops

The unnecessary casting ops between integer and index types cause the Affine for validation failed. This only occurs when the constant values are created inside another affine scope.

Fixing the const index value builder to create the index const directly and the loop statement handler in the range-based case to create index type.

Added the test case (https://github.com/ORNL-QCI/qcor/issues/240

)

Signed-off-by: default avatarThien Nguyen <nguyentm@ornl.gov>
parent 356a3e2f
Loading
Loading
Loading
Loading
+8 −0
Original line number Diff line number Diff line
@@ -55,6 +55,14 @@ for i in [0:4] {
}
QCOR_EXPECT_TRUE(loop_count == 12);

loop_count = 0;
for i in [0:4] {
 for j in [0:3] {
     print(i,j);
     loop_count += 1;
 }
}
QCOR_EXPECT_TRUE(loop_count == 12);

)#";
  auto mlir = qcor::mlir_compile(for_stmt, "for_stmt",
+12 −7
Original line number Diff line number Diff line
@@ -149,11 +149,16 @@ mlir::Value get_or_create_constant_index_value(const std::size_t idx,
                                               int width,
                                               ScopedSymbolTable &symbol_table,
                                               mlir::OpBuilder &builder) {
  auto type = mlir::IntegerType::get(builder.getContext(), width);
  auto constant_int = get_or_create_constant_integer_value(
      idx, location, type, symbol_table, builder);
  if (symbol_table.has_constant_integer(idx, width)) {
    // If there is a cached constant integer value, cast and return it:
    auto constant_int = symbol_table.get_constant_integer(idx, width);
    return builder.create<mlir::IndexCastOp>(location, constant_int,
                                             builder.getIndexType());
  } else {
    // Otherwise, create a new constant index value
    auto integer_attr = mlir::IntegerAttr::get(builder.getIndexType(), idx);
    return builder.create<mlir::ConstantOp>(location, integer_attr);
  }
}

mlir::Type convertQasm3Type(qasm3::qasm3Parser::ClassicalTypeContext* ctx,
+12 −34
Original line number Diff line number Diff line
@@ -156,34 +156,12 @@ void qasm3_visitor::createRangeBasedForLoop(
  auto n_expr = range->expression().size();
  int a, b, c;

  // First question what type should we use?
  mlir::Type int_type = builder.getI64Type();
  if (symbol_table.has_symbol(range->expression(0)->getText())) {
    int_type =
        symbol_table.get_symbol(range->expression(0)->getText()).getType();
  }
  if (n_expr == 3) {
    if (symbol_table.has_symbol(range->expression(1)->getText())) {
      int_type =
          symbol_table.get_symbol(range->expression(1)->getText()).getType();
    } else if (symbol_table.has_symbol(range->expression(2)->getText())) {
      int_type =
          symbol_table.get_symbol(range->expression(2)->getText()).getType();
    }
  } else {
    if (symbol_table.has_symbol(range->expression(1)->getText())) {
      int_type =
          symbol_table.get_symbol(range->expression(1)->getText()).getType();
    }
  }

  if (int_type.isa<mlir::MemRefType>()) {
    int_type = int_type.cast<mlir::MemRefType>().getElementType();
  }
  // For loop variables will be index type (casting will be done if needed)
  mlir::Type index_type = builder.getIndexType();

  c = 1;
  mlir::Value a_value, b_value,
      c_value = get_or_create_constant_integer_value(c, location, int_type,
      c_value = get_or_create_constant_index_value(c, location, 64,
                                                   symbol_table, builder);

  const auto const_eval_a_val =
@@ -191,8 +169,8 @@ void qasm3_visitor::createRangeBasedForLoop(
          range->expression(0)->getText());
  if (const_eval_a_val.has_value()) {
    // std::cout << "A val = " << const_eval_a_val.value() << "\n";
    a_value = get_or_create_constant_integer_value(
        const_eval_a_val.value(), location, int_type, symbol_table, builder);
    a_value = get_or_create_constant_index_value(
        const_eval_a_val.value(), location, 64, symbol_table, builder);
  } else {
    qasm3_expression_generator exp_generator(builder, symbol_table, file_name);
    exp_generator.visit(range->expression(0));
@@ -208,8 +186,8 @@ void qasm3_visitor::createRangeBasedForLoop(
            range->expression(2)->getText());
    if (const_eval_b_val.has_value()) {
      // std::cout << "B val = " << const_eval_b_val.value() << "\n";
      b_value = get_or_create_constant_integer_value(
          const_eval_b_val.value(), location, int_type, symbol_table, builder);
      b_value = get_or_create_constant_index_value(
          const_eval_b_val.value(), location, 64, symbol_table, builder);
    } else {
      qasm3_expression_generator exp_generator(builder, symbol_table,
                                               file_name);
@@ -231,8 +209,8 @@ void qasm3_visitor::createRangeBasedForLoop(
    } else {
      c = symbol_table.evaluate_constant_integer_expression(
          range->expression(1)->getText());
      c_value = get_or_create_constant_integer_value(
          c, location, a_value.getType(), symbol_table, builder);
      c_value = get_or_create_constant_index_value(
          c, location, 64, symbol_table, builder);
    }

  } else {
@@ -241,8 +219,8 @@ void qasm3_visitor::createRangeBasedForLoop(
            range->expression(1)->getText());
    if (const_eval_b_val.has_value()) {
      // std::cout << "B val = " << const_eval_b_val.value() << "\n";
      b_value = get_or_create_constant_integer_value(
          const_eval_b_val.value(), location, int_type, symbol_table, builder);
      b_value = get_or_create_constant_index_value(
          const_eval_b_val.value(), location, 64, symbol_table, builder);
    } else {
      qasm3_expression_generator exp_generator(builder, symbol_table,
                                               file_name);