Commit 75484499 authored by Nguyen, Thien's avatar Nguyen, Thien
Browse files

Fixed nested range-based for loops

The unnecessary casting ops between integer and index types cause the Affine for validation failed. This only occurs when the constant values are created inside another affine scope.

Fixing the const index value builder to create the index const directly and the loop statement handler in the range-based case to create index type.

Added the test case (https://github.com/ORNL-QCI/qcor/issues/240

)
Signed-off-by: Nguyen, Thien's avatarThien Nguyen <nguyentm@ornl.gov>
parent 356a3e2f
......@@ -55,6 +55,14 @@ for i in [0:4] {
}
QCOR_EXPECT_TRUE(loop_count == 12);
loop_count = 0;
for i in [0:4] {
for j in [0:3] {
print(i,j);
loop_count += 1;
}
}
QCOR_EXPECT_TRUE(loop_count == 12);
)#";
auto mlir = qcor::mlir_compile(for_stmt, "for_stmt",
......
......@@ -147,13 +147,18 @@ mlir::Value get_or_create_constant_integer_value(
mlir::Value get_or_create_constant_index_value(const std::size_t idx,
mlir::Location location,
int width,
ScopedSymbolTable& symbol_table,
mlir::OpBuilder& builder) {
auto type = mlir::IntegerType::get(builder.getContext(), width);
auto constant_int = get_or_create_constant_integer_value(
idx, location, type, symbol_table, builder);
return builder.create<mlir::IndexCastOp>(location, constant_int,
builder.getIndexType());
ScopedSymbolTable &symbol_table,
mlir::OpBuilder &builder) {
if (symbol_table.has_constant_integer(idx, width)) {
// If there is a cached constant integer value, cast and return it:
auto constant_int = symbol_table.get_constant_integer(idx, width);
return builder.create<mlir::IndexCastOp>(location, constant_int,
builder.getIndexType());
} else {
// Otherwise, create a new constant index value
auto integer_attr = mlir::IntegerAttr::get(builder.getIndexType(), idx);
return builder.create<mlir::ConstantOp>(location, integer_attr);
}
}
mlir::Type convertQasm3Type(qasm3::qasm3Parser::ClassicalTypeContext* ctx,
......
......@@ -156,43 +156,21 @@ void qasm3_visitor::createRangeBasedForLoop(
auto n_expr = range->expression().size();
int a, b, c;
// First question what type should we use?
mlir::Type int_type = builder.getI64Type();
if (symbol_table.has_symbol(range->expression(0)->getText())) {
int_type =
symbol_table.get_symbol(range->expression(0)->getText()).getType();
}
if (n_expr == 3) {
if (symbol_table.has_symbol(range->expression(1)->getText())) {
int_type =
symbol_table.get_symbol(range->expression(1)->getText()).getType();
} else if (symbol_table.has_symbol(range->expression(2)->getText())) {
int_type =
symbol_table.get_symbol(range->expression(2)->getText()).getType();
}
} else {
if (symbol_table.has_symbol(range->expression(1)->getText())) {
int_type =
symbol_table.get_symbol(range->expression(1)->getText()).getType();
}
}
if (int_type.isa<mlir::MemRefType>()) {
int_type = int_type.cast<mlir::MemRefType>().getElementType();
}
// For loop variables will be index type (casting will be done if needed)
mlir::Type index_type = builder.getIndexType();
c = 1;
mlir::Value a_value, b_value,
c_value = get_or_create_constant_integer_value(c, location, int_type,
symbol_table, builder);
c_value = get_or_create_constant_index_value(c, location, 64,
symbol_table, builder);
const auto const_eval_a_val =
symbol_table.try_evaluate_constant_integer_expression(
range->expression(0)->getText());
if (const_eval_a_val.has_value()) {
// std::cout << "A val = " << const_eval_a_val.value() << "\n";
a_value = get_or_create_constant_integer_value(
const_eval_a_val.value(), location, int_type, symbol_table, builder);
a_value = get_or_create_constant_index_value(
const_eval_a_val.value(), location, 64, symbol_table, builder);
} else {
qasm3_expression_generator exp_generator(builder, symbol_table, file_name);
exp_generator.visit(range->expression(0));
......@@ -208,8 +186,8 @@ void qasm3_visitor::createRangeBasedForLoop(
range->expression(2)->getText());
if (const_eval_b_val.has_value()) {
// std::cout << "B val = " << const_eval_b_val.value() << "\n";
b_value = get_or_create_constant_integer_value(
const_eval_b_val.value(), location, int_type, symbol_table, builder);
b_value = get_or_create_constant_index_value(
const_eval_b_val.value(), location, 64, symbol_table, builder);
} else {
qasm3_expression_generator exp_generator(builder, symbol_table,
file_name);
......@@ -231,8 +209,8 @@ void qasm3_visitor::createRangeBasedForLoop(
} else {
c = symbol_table.evaluate_constant_integer_expression(
range->expression(1)->getText());
c_value = get_or_create_constant_integer_value(
c, location, a_value.getType(), symbol_table, builder);
c_value = get_or_create_constant_index_value(
c, location, 64, symbol_table, builder);
}
} else {
......@@ -241,8 +219,8 @@ void qasm3_visitor::createRangeBasedForLoop(
range->expression(1)->getText());
if (const_eval_b_val.has_value()) {
// std::cout << "B val = " << const_eval_b_val.value() << "\n";
b_value = get_or_create_constant_integer_value(
const_eval_b_val.value(), location, int_type, symbol_table, builder);
b_value = get_or_create_constant_index_value(
const_eval_b_val.value(), location, 64, symbol_table, builder);
} else {
qasm3_expression_generator exp_generator(builder, symbol_table,
file_name);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment