Unverified Commit 1fe64ee6 authored by Thien Nguyen's avatar Thien Nguyen Committed by GitHub
Browse files

Merge pull request #241 from tnguyen-ornl/tnguyen/fix-qcor-240

Fixed nested range-based for loops
parents 356a3e2f 75484499
Loading
Loading
Loading
Loading
Loading
+8 −0
Original line number Diff line number Diff line
@@ -55,6 +55,14 @@ for i in [0:4] {
}
QCOR_EXPECT_TRUE(loop_count == 12);

loop_count = 0;
for i in [0:4] {
 for j in [0:3] {
     print(i,j);
     loop_count += 1;
 }
}
QCOR_EXPECT_TRUE(loop_count == 12);

)#";
  auto mlir = qcor::mlir_compile(for_stmt, "for_stmt",
+12 −7
Original line number Diff line number Diff line
@@ -149,11 +149,16 @@ mlir::Value get_or_create_constant_index_value(const std::size_t idx,
                                               int width,
                                               ScopedSymbolTable &symbol_table,
                                               mlir::OpBuilder &builder) {
  auto type = mlir::IntegerType::get(builder.getContext(), width);
  auto constant_int = get_or_create_constant_integer_value(
      idx, location, type, symbol_table, builder);
  if (symbol_table.has_constant_integer(idx, width)) {
    // If there is a cached constant integer value, cast and return it:
    auto constant_int = symbol_table.get_constant_integer(idx, width);
    return builder.create<mlir::IndexCastOp>(location, constant_int,
                                             builder.getIndexType());
  } else {
    // Otherwise, create a new constant index value
    auto integer_attr = mlir::IntegerAttr::get(builder.getIndexType(), idx);
    return builder.create<mlir::ConstantOp>(location, integer_attr);
  }
}

mlir::Type convertQasm3Type(qasm3::qasm3Parser::ClassicalTypeContext* ctx,
+12 −34
Original line number Diff line number Diff line
@@ -156,34 +156,12 @@ void qasm3_visitor::createRangeBasedForLoop(
  auto n_expr = range->expression().size();
  int a, b, c;

  // First question what type should we use?
  mlir::Type int_type = builder.getI64Type();
  if (symbol_table.has_symbol(range->expression(0)->getText())) {
    int_type =
        symbol_table.get_symbol(range->expression(0)->getText()).getType();
  }
  if (n_expr == 3) {
    if (symbol_table.has_symbol(range->expression(1)->getText())) {
      int_type =
          symbol_table.get_symbol(range->expression(1)->getText()).getType();
    } else if (symbol_table.has_symbol(range->expression(2)->getText())) {
      int_type =
          symbol_table.get_symbol(range->expression(2)->getText()).getType();
    }
  } else {
    if (symbol_table.has_symbol(range->expression(1)->getText())) {
      int_type =
          symbol_table.get_symbol(range->expression(1)->getText()).getType();
    }
  }

  if (int_type.isa<mlir::MemRefType>()) {
    int_type = int_type.cast<mlir::MemRefType>().getElementType();
  }
  // For loop variables will be index type (casting will be done if needed)
  mlir::Type index_type = builder.getIndexType();

  c = 1;
  mlir::Value a_value, b_value,
      c_value = get_or_create_constant_integer_value(c, location, int_type,
      c_value = get_or_create_constant_index_value(c, location, 64,
                                                   symbol_table, builder);

  const auto const_eval_a_val =
@@ -191,8 +169,8 @@ void qasm3_visitor::createRangeBasedForLoop(
          range->expression(0)->getText());
  if (const_eval_a_val.has_value()) {
    // std::cout << "A val = " << const_eval_a_val.value() << "\n";
    a_value = get_or_create_constant_integer_value(
        const_eval_a_val.value(), location, int_type, symbol_table, builder);
    a_value = get_or_create_constant_index_value(
        const_eval_a_val.value(), location, 64, symbol_table, builder);
  } else {
    qasm3_expression_generator exp_generator(builder, symbol_table, file_name);
    exp_generator.visit(range->expression(0));
@@ -208,8 +186,8 @@ void qasm3_visitor::createRangeBasedForLoop(
            range->expression(2)->getText());
    if (const_eval_b_val.has_value()) {
      // std::cout << "B val = " << const_eval_b_val.value() << "\n";
      b_value = get_or_create_constant_integer_value(
          const_eval_b_val.value(), location, int_type, symbol_table, builder);
      b_value = get_or_create_constant_index_value(
          const_eval_b_val.value(), location, 64, symbol_table, builder);
    } else {
      qasm3_expression_generator exp_generator(builder, symbol_table,
                                               file_name);
@@ -231,8 +209,8 @@ void qasm3_visitor::createRangeBasedForLoop(
    } else {
      c = symbol_table.evaluate_constant_integer_expression(
          range->expression(1)->getText());
      c_value = get_or_create_constant_integer_value(
          c, location, a_value.getType(), symbol_table, builder);
      c_value = get_or_create_constant_index_value(
          c, location, 64, symbol_table, builder);
    }

  } else {
@@ -241,8 +219,8 @@ void qasm3_visitor::createRangeBasedForLoop(
            range->expression(1)->getText());
    if (const_eval_b_val.has_value()) {
      // std::cout << "B val = " << const_eval_b_val.value() << "\n";
      b_value = get_or_create_constant_integer_value(
          const_eval_b_val.value(), location, int_type, symbol_table, builder);
      b_value = get_or_create_constant_index_value(
          const_eval_b_val.value(), location, 64, symbol_table, builder);
    } else {
      qasm3_expression_generator exp_generator(builder, symbol_table,
                                               file_name);