Loading mlir/parsers/qasm3/tests/test_optimization.cpp +52 −0 Original line number Diff line number Diff line Loading @@ -385,6 +385,58 @@ cx first_and_last_qubit[0], first_and_last_qubit[1]; } } TEST(qasm3PassManagerTester, checkConstEvalLoopUnroll) { { // Unroll the loop with const vars as loop bounds const std::string src = R"#(OPENQASM 3; include "stdgates.inc"; const n = 3; qubit qb; for i in [0:2*n] { x qb; } )#"; auto llvm = qcor::mlir_compile(src, "test_kernel", qcor::OutputType::LLVMIR, false); std::cout << "LLVM:\n" << llvm << "\n"; // Get the main kernel section only llvm = llvm.substr(llvm.find("@__internal_mlir_test_kernel")); const auto last = llvm.find_first_of("}"); llvm = llvm.substr(0, last + 1); std::cout << "LLVM:\n" << llvm << "\n"; // Cancel all EXPECT_EQ(countSubstring(llvm, "__quantum__qis"), 0); } { // Unroll the loop with const vars as loop bounds const std::string src = R"#(OPENQASM 3; include "stdgates.inc"; const n = 3; qubit qb; for i in [0:2*n + 1] { x qb; } )#"; auto llvm = qcor::mlir_compile(src, "test_kernel1", qcor::OutputType::LLVMIR, false); std::cout << "LLVM:\n" << llvm << "\n"; // Get the main kernel section only llvm = llvm.substr(llvm.find("@__internal_mlir_test_kernel1")); const auto last = llvm.find_first_of("}"); llvm = llvm.substr(0, last + 1); std::cout << "LLVM:\n" << llvm << "\n"; // One X gate left EXPECT_EQ(countSubstring(llvm, "__quantum__qis__x"), 1); } } int main(int argc, char **argv) { ::testing::InitGoogleTest(&argc, argv); auto ret = RUN_ALL_TESTS(); Loading mlir/parsers/qasm3/utils/symbol_table.cpp +19 −3 Original line number Diff line number Diff line Loading @@ -26,7 +26,8 @@ void ScopedSymbolTable::replace_symbol(mlir::Value old_value, } } int64_t ScopedSymbolTable::evaluate_constant_integer_expression( std::optional<int64_t> ScopedSymbolTable::try_evaluate_constant_integer_expression( const std::string expr_str) { auto all_constants = get_constant_integer_variables(); std::vector<std::string> variable_names; Loading Loading @@ -58,7 +59,22 @@ int64_t ScopedSymbolTable::evaluate_constant_integer_expression( if (parser.compile(expr_str, expr)) { ref = expr.value(); } else { printErrorMessage("Failed to evaluate cnostant integer expression: " + // Cannot eval to a const int. return std::nullopt; } return (int64_t)ref; } int64_t ScopedSymbolTable::evaluate_constant_integer_expression( const std::string expr_str) { const std::optional<int64_t> try_eval = try_evaluate_constant_integer_expression(expr_str); double ref = 0.0; if (try_eval.has_value()) { ref = try_eval.value(); } else { printErrorMessage("Failed to evaluate constant integer expression: " + expr_str + ". Must be a constant integer type."); } Loading mlir/parsers/qasm3/utils/symbol_table.hpp +4 −0 Original line number Diff line number Diff line Loading @@ -221,7 +221,11 @@ class ScopedSymbolTable { return ret; } // Eval a const int expression (throw if failed) int64_t evaluate_constant_integer_expression(const std::string expr); // Returns null if this expression cannot be const-eval to an integer value. std::optional<int64_t> try_evaluate_constant_integer_expression(const std::string expr); mlir::FuncOp get_seen_function(const std::string name) { if (!seen_functions.count(name)) { Loading mlir/parsers/qasm3/visitor_handlers/loop_stmt_handler.cpp +48 −19 Original line number Diff line number Diff line Loading @@ -230,6 +230,15 @@ antlrcpp::Any qasm3_visitor::visitLoopStatement( c_value = get_or_create_constant_integer_value(c, location, int_type, symbol_table, builder); const auto const_eval_a_val = symbol_table.try_evaluate_constant_integer_expression( range->expression(0)->getText()); if (const_eval_a_val.has_value()) { // std::cout << "A val = " << const_eval_a_val.value() << "\n"; a_value = get_or_create_constant_integer_value(const_eval_a_val.value(), location, int_type, symbol_table, builder); } else { qasm3_expression_generator exp_generator(builder, symbol_table, file_name); exp_generator.visit(range->expression(0)); Loading @@ -237,8 +246,18 @@ antlrcpp::Any qasm3_visitor::visitLoopStatement( if (a_value.getType().isa<mlir::MemRefType>()) { a_value = builder.create<mlir::LoadOp>(location, a_value); } } if (n_expr == 3) { const auto const_eval_b_val = symbol_table.try_evaluate_constant_integer_expression( range->expression(2)->getText()); if (const_eval_b_val.has_value()) { // std::cout << "B val = " << const_eval_b_val.value() << "\n"; b_value = get_or_create_constant_integer_value( const_eval_b_val.value(), location, int_type, symbol_table, builder); } else { qasm3_expression_generator exp_generator(builder, symbol_table, file_name); exp_generator.visit(range->expression(2)); Loading @@ -246,7 +265,7 @@ antlrcpp::Any qasm3_visitor::visitLoopStatement( if (b_value.getType().isa<mlir::MemRefType>()) { b_value = builder.create<mlir::LoadOp>(location, b_value); } } if (symbol_table.has_symbol(range->expression(1)->getText())) { printErrorMessage("You must provide loop step as a constant value.", context); Loading @@ -263,6 +282,15 @@ antlrcpp::Any qasm3_visitor::visitLoopStatement( c, location, a_value.getType(), symbol_table, builder); } } else { const auto const_eval_b_val = symbol_table.try_evaluate_constant_integer_expression( range->expression(1)->getText()); if (const_eval_b_val.has_value()) { // std::cout << "B val = " << const_eval_b_val.value() << "\n"; b_value = get_or_create_constant_integer_value( const_eval_b_val.value(), location, int_type, symbol_table, builder); } else { qasm3_expression_generator exp_generator(builder, symbol_table, file_name); Loading @@ -272,6 +300,7 @@ antlrcpp::Any qasm3_visitor::visitLoopStatement( b_value = builder.create<mlir::LoadOp>(location, b_value); } } } const std::string program_block_str = program_block->getText(); // std::cout << "HOWDY:\n" << program_block_str << "\n"; Loading Loading
mlir/parsers/qasm3/tests/test_optimization.cpp +52 −0 Original line number Diff line number Diff line Loading @@ -385,6 +385,58 @@ cx first_and_last_qubit[0], first_and_last_qubit[1]; } } TEST(qasm3PassManagerTester, checkConstEvalLoopUnroll) { { // Unroll the loop with const vars as loop bounds const std::string src = R"#(OPENQASM 3; include "stdgates.inc"; const n = 3; qubit qb; for i in [0:2*n] { x qb; } )#"; auto llvm = qcor::mlir_compile(src, "test_kernel", qcor::OutputType::LLVMIR, false); std::cout << "LLVM:\n" << llvm << "\n"; // Get the main kernel section only llvm = llvm.substr(llvm.find("@__internal_mlir_test_kernel")); const auto last = llvm.find_first_of("}"); llvm = llvm.substr(0, last + 1); std::cout << "LLVM:\n" << llvm << "\n"; // Cancel all EXPECT_EQ(countSubstring(llvm, "__quantum__qis"), 0); } { // Unroll the loop with const vars as loop bounds const std::string src = R"#(OPENQASM 3; include "stdgates.inc"; const n = 3; qubit qb; for i in [0:2*n + 1] { x qb; } )#"; auto llvm = qcor::mlir_compile(src, "test_kernel1", qcor::OutputType::LLVMIR, false); std::cout << "LLVM:\n" << llvm << "\n"; // Get the main kernel section only llvm = llvm.substr(llvm.find("@__internal_mlir_test_kernel1")); const auto last = llvm.find_first_of("}"); llvm = llvm.substr(0, last + 1); std::cout << "LLVM:\n" << llvm << "\n"; // One X gate left EXPECT_EQ(countSubstring(llvm, "__quantum__qis__x"), 1); } } int main(int argc, char **argv) { ::testing::InitGoogleTest(&argc, argv); auto ret = RUN_ALL_TESTS(); Loading
mlir/parsers/qasm3/utils/symbol_table.cpp +19 −3 Original line number Diff line number Diff line Loading @@ -26,7 +26,8 @@ void ScopedSymbolTable::replace_symbol(mlir::Value old_value, } } int64_t ScopedSymbolTable::evaluate_constant_integer_expression( std::optional<int64_t> ScopedSymbolTable::try_evaluate_constant_integer_expression( const std::string expr_str) { auto all_constants = get_constant_integer_variables(); std::vector<std::string> variable_names; Loading Loading @@ -58,7 +59,22 @@ int64_t ScopedSymbolTable::evaluate_constant_integer_expression( if (parser.compile(expr_str, expr)) { ref = expr.value(); } else { printErrorMessage("Failed to evaluate cnostant integer expression: " + // Cannot eval to a const int. return std::nullopt; } return (int64_t)ref; } int64_t ScopedSymbolTable::evaluate_constant_integer_expression( const std::string expr_str) { const std::optional<int64_t> try_eval = try_evaluate_constant_integer_expression(expr_str); double ref = 0.0; if (try_eval.has_value()) { ref = try_eval.value(); } else { printErrorMessage("Failed to evaluate constant integer expression: " + expr_str + ". Must be a constant integer type."); } Loading
mlir/parsers/qasm3/utils/symbol_table.hpp +4 −0 Original line number Diff line number Diff line Loading @@ -221,7 +221,11 @@ class ScopedSymbolTable { return ret; } // Eval a const int expression (throw if failed) int64_t evaluate_constant_integer_expression(const std::string expr); // Returns null if this expression cannot be const-eval to an integer value. std::optional<int64_t> try_evaluate_constant_integer_expression(const std::string expr); mlir::FuncOp get_seen_function(const std::string name) { if (!seen_functions.count(name)) { Loading
mlir/parsers/qasm3/visitor_handlers/loop_stmt_handler.cpp +48 −19 Original line number Diff line number Diff line Loading @@ -230,6 +230,15 @@ antlrcpp::Any qasm3_visitor::visitLoopStatement( c_value = get_or_create_constant_integer_value(c, location, int_type, symbol_table, builder); const auto const_eval_a_val = symbol_table.try_evaluate_constant_integer_expression( range->expression(0)->getText()); if (const_eval_a_val.has_value()) { // std::cout << "A val = " << const_eval_a_val.value() << "\n"; a_value = get_or_create_constant_integer_value(const_eval_a_val.value(), location, int_type, symbol_table, builder); } else { qasm3_expression_generator exp_generator(builder, symbol_table, file_name); exp_generator.visit(range->expression(0)); Loading @@ -237,8 +246,18 @@ antlrcpp::Any qasm3_visitor::visitLoopStatement( if (a_value.getType().isa<mlir::MemRefType>()) { a_value = builder.create<mlir::LoadOp>(location, a_value); } } if (n_expr == 3) { const auto const_eval_b_val = symbol_table.try_evaluate_constant_integer_expression( range->expression(2)->getText()); if (const_eval_b_val.has_value()) { // std::cout << "B val = " << const_eval_b_val.value() << "\n"; b_value = get_or_create_constant_integer_value( const_eval_b_val.value(), location, int_type, symbol_table, builder); } else { qasm3_expression_generator exp_generator(builder, symbol_table, file_name); exp_generator.visit(range->expression(2)); Loading @@ -246,7 +265,7 @@ antlrcpp::Any qasm3_visitor::visitLoopStatement( if (b_value.getType().isa<mlir::MemRefType>()) { b_value = builder.create<mlir::LoadOp>(location, b_value); } } if (symbol_table.has_symbol(range->expression(1)->getText())) { printErrorMessage("You must provide loop step as a constant value.", context); Loading @@ -263,6 +282,15 @@ antlrcpp::Any qasm3_visitor::visitLoopStatement( c, location, a_value.getType(), symbol_table, builder); } } else { const auto const_eval_b_val = symbol_table.try_evaluate_constant_integer_expression( range->expression(1)->getText()); if (const_eval_b_val.has_value()) { // std::cout << "B val = " << const_eval_b_val.value() << "\n"; b_value = get_or_create_constant_integer_value( const_eval_b_val.value(), location, int_type, symbol_table, builder); } else { qasm3_expression_generator exp_generator(builder, symbol_table, file_name); Loading @@ -272,6 +300,7 @@ antlrcpp::Any qasm3_visitor::visitLoopStatement( b_value = builder.create<mlir::LoadOp>(location, b_value); } } } const std::string program_block_str = program_block->getText(); // std::cout << "HOWDY:\n" << program_block_str << "\n"; Loading