Unverified Commit 1fe64ee6 authored by Thien Nguyen's avatar Thien Nguyen Committed by GitHub
Browse files

Merge pull request #241 from tnguyen-ornl/tnguyen/fix-qcor-240

Fixed nested range-based for loops
parents 356a3e2f 75484499
Pipeline #174881 passed with stage
in 60 minutes and 57 seconds
......@@ -55,6 +55,14 @@ for i in [0:4] {
}
QCOR_EXPECT_TRUE(loop_count == 12);
loop_count = 0;
for i in [0:4] {
for j in [0:3] {
print(i,j);
loop_count += 1;
}
}
QCOR_EXPECT_TRUE(loop_count == 12);
)#";
auto mlir = qcor::mlir_compile(for_stmt, "for_stmt",
......
......@@ -147,13 +147,18 @@ mlir::Value get_or_create_constant_integer_value(
mlir::Value get_or_create_constant_index_value(const std::size_t idx,
mlir::Location location,
int width,
ScopedSymbolTable& symbol_table,
mlir::OpBuilder& builder) {
auto type = mlir::IntegerType::get(builder.getContext(), width);
auto constant_int = get_or_create_constant_integer_value(
idx, location, type, symbol_table, builder);
return builder.create<mlir::IndexCastOp>(location, constant_int,
builder.getIndexType());
ScopedSymbolTable &symbol_table,
mlir::OpBuilder &builder) {
if (symbol_table.has_constant_integer(idx, width)) {
// If there is a cached constant integer value, cast and return it:
auto constant_int = symbol_table.get_constant_integer(idx, width);
return builder.create<mlir::IndexCastOp>(location, constant_int,
builder.getIndexType());
} else {
// Otherwise, create a new constant index value
auto integer_attr = mlir::IntegerAttr::get(builder.getIndexType(), idx);
return builder.create<mlir::ConstantOp>(location, integer_attr);
}
}
mlir::Type convertQasm3Type(qasm3::qasm3Parser::ClassicalTypeContext* ctx,
......
......@@ -156,43 +156,21 @@ void qasm3_visitor::createRangeBasedForLoop(
auto n_expr = range->expression().size();
int a, b, c;
// First question what type should we use?
mlir::Type int_type = builder.getI64Type();
if (symbol_table.has_symbol(range->expression(0)->getText())) {
int_type =
symbol_table.get_symbol(range->expression(0)->getText()).getType();
}
if (n_expr == 3) {
if (symbol_table.has_symbol(range->expression(1)->getText())) {
int_type =
symbol_table.get_symbol(range->expression(1)->getText()).getType();
} else if (symbol_table.has_symbol(range->expression(2)->getText())) {
int_type =
symbol_table.get_symbol(range->expression(2)->getText()).getType();
}
} else {
if (symbol_table.has_symbol(range->expression(1)->getText())) {
int_type =
symbol_table.get_symbol(range->expression(1)->getText()).getType();
}
}
if (int_type.isa<mlir::MemRefType>()) {
int_type = int_type.cast<mlir::MemRefType>().getElementType();
}
// For loop variables will be index type (casting will be done if needed)
mlir::Type index_type = builder.getIndexType();
c = 1;
mlir::Value a_value, b_value,
c_value = get_or_create_constant_integer_value(c, location, int_type,
symbol_table, builder);
c_value = get_or_create_constant_index_value(c, location, 64,
symbol_table, builder);
const auto const_eval_a_val =
symbol_table.try_evaluate_constant_integer_expression(
range->expression(0)->getText());
if (const_eval_a_val.has_value()) {
// std::cout << "A val = " << const_eval_a_val.value() << "\n";
a_value = get_or_create_constant_integer_value(
const_eval_a_val.value(), location, int_type, symbol_table, builder);
a_value = get_or_create_constant_index_value(
const_eval_a_val.value(), location, 64, symbol_table, builder);
} else {
qasm3_expression_generator exp_generator(builder, symbol_table, file_name);
exp_generator.visit(range->expression(0));
......@@ -208,8 +186,8 @@ void qasm3_visitor::createRangeBasedForLoop(
range->expression(2)->getText());
if (const_eval_b_val.has_value()) {
// std::cout << "B val = " << const_eval_b_val.value() << "\n";
b_value = get_or_create_constant_integer_value(
const_eval_b_val.value(), location, int_type, symbol_table, builder);
b_value = get_or_create_constant_index_value(
const_eval_b_val.value(), location, 64, symbol_table, builder);
} else {
qasm3_expression_generator exp_generator(builder, symbol_table,
file_name);
......@@ -231,8 +209,8 @@ void qasm3_visitor::createRangeBasedForLoop(
} else {
c = symbol_table.evaluate_constant_integer_expression(
range->expression(1)->getText());
c_value = get_or_create_constant_integer_value(
c, location, a_value.getType(), symbol_table, builder);
c_value = get_or_create_constant_index_value(
c, location, 64, symbol_table, builder);
}
} else {
......@@ -241,8 +219,8 @@ void qasm3_visitor::createRangeBasedForLoop(
range->expression(1)->getText());
if (const_eval_b_val.has_value()) {
// std::cout << "B val = " << const_eval_b_val.value() << "\n";
b_value = get_or_create_constant_integer_value(
const_eval_b_val.value(), location, int_type, symbol_table, builder);
b_value = get_or_create_constant_index_value(
const_eval_b_val.value(), location, 64, symbol_table, builder);
} else {
qasm3_expression_generator exp_generator(builder, symbol_table,
file_name);
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment