Loading mlir/parsers/qasm3/openqasmv3_mlir_generator.cpp +183 −141 Original line number Diff line number Diff line Loading @@ -7,13 +7,47 @@ namespace qcor { void OpenQasmV3MLIRGenerator::initialize_mlirgen( const std::string func_name, std::vector<mlir::Type> arg_types, std::vector<std::string> arg_var_names, mlir::Type return_type) { mlir::FunctionType func_type2; if (return_type) { func_type2 = builder.getFunctionType(llvm::makeArrayRef(arg_types), return_type); } else { func_type2 = builder.getFunctionType(llvm::makeArrayRef(arg_types), llvm::None); } auto proto2 = mlir::FuncOp::create( builder.getUnknownLoc(), "__internal_mlir_" + func_name, func_type2); mlir::FuncOp function2(proto2); std::string file_name = "internal_mlirgen_qcor_"; auto save_main_entry_block = function2.addEntryBlock(); builder.setInsertionPointToStart(save_main_entry_block); m_module.push_back(function2); main_entry_block = save_main_entry_block; // Configure block arguments visitor = std::make_shared<qasm3_visitor>(builder, m_module, file_name); auto symbol_table = visitor->getScopedSymbolTable(); auto arguments = main_entry_block->getArguments(); for (int i = 0; i < arg_var_names.size(); i++) { symbol_table->add_symbol(arg_var_names[i], arguments[i]); } add_main = false; if (!return_type) { add_custom_return = true; } return; } void OpenQasmV3MLIRGenerator::initialize_mlirgen(bool _add_entry_point, const std::string function) { file_name = function; add_entry_point = _add_entry_point; m_module = mlir::ModuleOp::create(builder.getUnknownLoc()); // Useful opaque type defs llvm::StringRef qubit_type_name("Qubit"), array_type_name("Array"), result_type_name("Result"); Loading @@ -38,8 +72,8 @@ void OpenQasmV3MLIRGenerator::initialize_mlirgen(bool _add_entry_point, if (add_entry_point) { std::vector<mlir::Type> arg_types_vec{int_type, argv_type}; auto func_type = builder.getFunctionType(llvm::makeArrayRef(arg_types_vec), int_type); auto func_type = builder.getFunctionType( llvm::makeArrayRef(arg_types_vec), int_type); auto proto = mlir::FuncOp::create(builder.getUnknownLoc(), "main", func_type); mlir::FuncOp function(proto); Loading Loading @@ -97,7 +131,9 @@ void OpenQasmV3MLIRGenerator::mlirgen(const std::string &src) { using namespace antlr4; using namespace qasm3; if (!visitor) { visitor = std::make_shared<qasm3_visitor>(builder, m_module, file_name); } ANTLRInputStream input(src); qasm3Lexer lexer(&input); Loading Loading @@ -139,34 +175,40 @@ void OpenQasmV3MLIRGenerator::mlirgen(const std::string &src) { void OpenQasmV3MLIRGenerator::finalize_mlirgen() { auto scoped_symbol_table = visitor->getScopedSymbolTable(); if (auto b = scoped_symbol_table.get_last_created_block()) { if (auto b = scoped_symbol_table->get_last_created_block()) { builder.setInsertionPointToEnd(b); } auto all_qalloc_ops = scoped_symbol_table.get_global_symbols_of_type<mlir::quantum::QallocOp>(); scoped_symbol_table ->get_global_symbols_of_type<mlir::quantum::QallocOp>(); for (auto op : all_qalloc_ops) { builder.create<mlir::quantum::DeallocOp>(builder.getUnknownLoc(), op); } // Add any function names that we created. auto fnames = scoped_symbol_table.get_seen_function_names(); auto fnames = scoped_symbol_table->get_seen_function_names(); for (auto f : fnames) { function_names.push_back(f); } if (add_main) { if (auto b = scoped_symbol_table.get_last_created_block()) { if (auto b = scoped_symbol_table->get_last_created_block()) { builder.setInsertionPointToEnd(b); } else { builder.setInsertionPointToEnd(main_entry_block); } auto integer_attr = mlir::IntegerAttr::get(builder.getI32Type(), 0); auto ret = builder.create<mlir::ConstantOp>(builder.getUnknownLoc(), integer_attr); auto ret = builder.create<mlir::ConstantOp>(builder.getUnknownLoc(), integer_attr); builder.create<mlir::ReturnOp>(builder.getUnknownLoc(), llvm::ArrayRef<mlir::Value>(ret)); } if (add_custom_return) { builder.create<mlir::ReturnOp>(builder.getUnknownLoc(), llvm::ArrayRef<mlir::Value>()); } } } // namespace qcor No newline at end of file mlir/parsers/qasm3/openqasmv3_mlir_generator.hpp +16 −1 Original line number Diff line number Diff line Loading @@ -10,6 +10,10 @@ class OpenQasmV3MLIRGenerator : public qcor::QuantumMLIRGenerator { protected: std::string file_name = "main"; bool add_entry_point = true; bool add_custom_return = false; mlir::Type return_type; mlir::Type qubit_type; mlir::Type array_type; mlir::Type result_type; Loading @@ -20,7 +24,18 @@ class OpenQasmV3MLIRGenerator : public qcor::QuantumMLIRGenerator { public: OpenQasmV3MLIRGenerator(mlir::MLIRContext &context) : QuantumMLIRGenerator(context) {} : QuantumMLIRGenerator(context) { m_module = mlir::ModuleOp::create(builder.getUnknownLoc()); } OpenQasmV3MLIRGenerator(mlir::OpBuilder b, mlir::MLIRContext &ctx) : QuantumMLIRGenerator(b, ctx) { m_module = mlir::ModuleOp::create(builder.getUnknownLoc()); } void initialize_mlirgen(const std::string func_name, std::vector<mlir::Type> arg_types, std::vector<std::string> arg_var_names, mlir::Type return_type); void initialize_mlirgen(bool add_entry_point = true, const std::string file_name = "") override; void mlirgen(const std::string &src) override; Loading mlir/parsers/qasm3/qasm3_visitor.hpp +1 −1 Original line number Diff line number Diff line Loading @@ -27,7 +27,7 @@ namespace qcor { class qasm3_visitor : public qasm3::qasm3BaseVisitor { public: // Return the symbol table. ScopedSymbolTable& getScopedSymbolTable() { return symbol_table; } ScopedSymbolTable* getScopedSymbolTable() { return &symbol_table; } // The constructor, instantiates commonly used opaque types qasm3_visitor(mlir::OpBuilder b, mlir::ModuleOp m, std::string& fname) Loading mlir/parsers/qasm3/visitor_handlers/subroutine_handler.cpp +21 −12 Original line number Diff line number Diff line Loading @@ -166,7 +166,8 @@ antlrcpp::Any qasm3_visitor::visitReturnStatement( value = symbol_table.get_symbol(ret_stmt); // Actually return value if it is a bit[], // load and return if it is a bit // printErrorMessage("Putting this here til I fix this"); if (current_function_return_type) { // this means it is a subroutine if (!current_function_return_type.isa<mlir::MemRefType>()) { if (current_function_return_type.isa<mlir::IntegerType>() && current_function_return_type.getIntOrFloatBitWidth() == 1) { Loading @@ -176,10 +177,19 @@ antlrcpp::Any qasm3_visitor::visitReturnStatement( llvm::ArrayRef<mlir::Value> zero_index(tmp); value = builder.create<mlir::LoadOp>(location, value, zero_index); } else { value = builder.create<mlir::LoadOp>(location, value); //, zero_index); value = builder.create<mlir::LoadOp>(location, value); //, zero_index); } } else { printErrorMessage("We do not return memrefs from subroutines.", context); } } else { printErrorMessage("We do not return memrefs from subroutines.", context); if (auto t = value.getType().dyn_cast_or_null<mlir::MemRefType>()) { if (t.getRank() == 0) { value = builder.create<mlir::LoadOp>(location, value); } } } } else { Loading @@ -192,7 +202,6 @@ antlrcpp::Any qasm3_visitor::visitReturnStatement( visitChildren(context->statement()); value = symbol_table.get_last_value_added(); } } is_return_stmt = false; Loading mlir/transforms/lowering/PrintOpLowering.cpp +13 −0 Original line number Diff line number Diff line Loading @@ -114,6 +114,11 @@ LogicalResult PrintOpLowering::matchAndRewrite( frmt_spec += "%d"; } ss << "_bit_array_b_" << dim; } else if (mem_ref_type.getElementType().isa<mlir::IntegerType>() && mem_ref_type.getRank() == 0 && mem_ref_type.getElementType().getIntOrFloatBitWidth() == 1) { frmt_spec += "%d"; ss << "_bit_array_b_0"; } } else { std::cout << "Currently invalid type to print.\n"; Loading Loading @@ -155,6 +160,7 @@ LogicalResult PrintOpLowering::matchAndRewrite( mem_ref_type.getRank() > 0 && mem_ref_type.getElementType().getIntOrFloatBitWidth() == 1) { // This is a bit array... auto dim = mem_ref_type.getShape()[0]; for (int i = 0; i < dim; i++) { auto attr = mlir::IntegerAttr::get(rewriter.getIndexType(), i); Loading @@ -163,6 +169,13 @@ LogicalResult PrintOpLowering::matchAndRewrite( loc, o, llvm::makeArrayRef(std::vector<mlir::Value>{ii})); args.push_back(z); } continue; } else if (mem_ref_type.getElementType().isa<mlir::IntegerType>() && mem_ref_type.getRank() == 0 && mem_ref_type.getElementType().getIntOrFloatBitWidth() == 1) { auto z = rewriter.create<mlir::LoadOp>(loc, o); args.push_back(z); continue; } } Loading Loading
mlir/parsers/qasm3/openqasmv3_mlir_generator.cpp +183 −141 Original line number Diff line number Diff line Loading @@ -7,13 +7,47 @@ namespace qcor { void OpenQasmV3MLIRGenerator::initialize_mlirgen( const std::string func_name, std::vector<mlir::Type> arg_types, std::vector<std::string> arg_var_names, mlir::Type return_type) { mlir::FunctionType func_type2; if (return_type) { func_type2 = builder.getFunctionType(llvm::makeArrayRef(arg_types), return_type); } else { func_type2 = builder.getFunctionType(llvm::makeArrayRef(arg_types), llvm::None); } auto proto2 = mlir::FuncOp::create( builder.getUnknownLoc(), "__internal_mlir_" + func_name, func_type2); mlir::FuncOp function2(proto2); std::string file_name = "internal_mlirgen_qcor_"; auto save_main_entry_block = function2.addEntryBlock(); builder.setInsertionPointToStart(save_main_entry_block); m_module.push_back(function2); main_entry_block = save_main_entry_block; // Configure block arguments visitor = std::make_shared<qasm3_visitor>(builder, m_module, file_name); auto symbol_table = visitor->getScopedSymbolTable(); auto arguments = main_entry_block->getArguments(); for (int i = 0; i < arg_var_names.size(); i++) { symbol_table->add_symbol(arg_var_names[i], arguments[i]); } add_main = false; if (!return_type) { add_custom_return = true; } return; } void OpenQasmV3MLIRGenerator::initialize_mlirgen(bool _add_entry_point, const std::string function) { file_name = function; add_entry_point = _add_entry_point; m_module = mlir::ModuleOp::create(builder.getUnknownLoc()); // Useful opaque type defs llvm::StringRef qubit_type_name("Qubit"), array_type_name("Array"), result_type_name("Result"); Loading @@ -38,8 +72,8 @@ void OpenQasmV3MLIRGenerator::initialize_mlirgen(bool _add_entry_point, if (add_entry_point) { std::vector<mlir::Type> arg_types_vec{int_type, argv_type}; auto func_type = builder.getFunctionType(llvm::makeArrayRef(arg_types_vec), int_type); auto func_type = builder.getFunctionType( llvm::makeArrayRef(arg_types_vec), int_type); auto proto = mlir::FuncOp::create(builder.getUnknownLoc(), "main", func_type); mlir::FuncOp function(proto); Loading Loading @@ -97,7 +131,9 @@ void OpenQasmV3MLIRGenerator::mlirgen(const std::string &src) { using namespace antlr4; using namespace qasm3; if (!visitor) { visitor = std::make_shared<qasm3_visitor>(builder, m_module, file_name); } ANTLRInputStream input(src); qasm3Lexer lexer(&input); Loading Loading @@ -139,34 +175,40 @@ void OpenQasmV3MLIRGenerator::mlirgen(const std::string &src) { void OpenQasmV3MLIRGenerator::finalize_mlirgen() { auto scoped_symbol_table = visitor->getScopedSymbolTable(); if (auto b = scoped_symbol_table.get_last_created_block()) { if (auto b = scoped_symbol_table->get_last_created_block()) { builder.setInsertionPointToEnd(b); } auto all_qalloc_ops = scoped_symbol_table.get_global_symbols_of_type<mlir::quantum::QallocOp>(); scoped_symbol_table ->get_global_symbols_of_type<mlir::quantum::QallocOp>(); for (auto op : all_qalloc_ops) { builder.create<mlir::quantum::DeallocOp>(builder.getUnknownLoc(), op); } // Add any function names that we created. auto fnames = scoped_symbol_table.get_seen_function_names(); auto fnames = scoped_symbol_table->get_seen_function_names(); for (auto f : fnames) { function_names.push_back(f); } if (add_main) { if (auto b = scoped_symbol_table.get_last_created_block()) { if (auto b = scoped_symbol_table->get_last_created_block()) { builder.setInsertionPointToEnd(b); } else { builder.setInsertionPointToEnd(main_entry_block); } auto integer_attr = mlir::IntegerAttr::get(builder.getI32Type(), 0); auto ret = builder.create<mlir::ConstantOp>(builder.getUnknownLoc(), integer_attr); auto ret = builder.create<mlir::ConstantOp>(builder.getUnknownLoc(), integer_attr); builder.create<mlir::ReturnOp>(builder.getUnknownLoc(), llvm::ArrayRef<mlir::Value>(ret)); } if (add_custom_return) { builder.create<mlir::ReturnOp>(builder.getUnknownLoc(), llvm::ArrayRef<mlir::Value>()); } } } // namespace qcor No newline at end of file
mlir/parsers/qasm3/openqasmv3_mlir_generator.hpp +16 −1 Original line number Diff line number Diff line Loading @@ -10,6 +10,10 @@ class OpenQasmV3MLIRGenerator : public qcor::QuantumMLIRGenerator { protected: std::string file_name = "main"; bool add_entry_point = true; bool add_custom_return = false; mlir::Type return_type; mlir::Type qubit_type; mlir::Type array_type; mlir::Type result_type; Loading @@ -20,7 +24,18 @@ class OpenQasmV3MLIRGenerator : public qcor::QuantumMLIRGenerator { public: OpenQasmV3MLIRGenerator(mlir::MLIRContext &context) : QuantumMLIRGenerator(context) {} : QuantumMLIRGenerator(context) { m_module = mlir::ModuleOp::create(builder.getUnknownLoc()); } OpenQasmV3MLIRGenerator(mlir::OpBuilder b, mlir::MLIRContext &ctx) : QuantumMLIRGenerator(b, ctx) { m_module = mlir::ModuleOp::create(builder.getUnknownLoc()); } void initialize_mlirgen(const std::string func_name, std::vector<mlir::Type> arg_types, std::vector<std::string> arg_var_names, mlir::Type return_type); void initialize_mlirgen(bool add_entry_point = true, const std::string file_name = "") override; void mlirgen(const std::string &src) override; Loading
mlir/parsers/qasm3/qasm3_visitor.hpp +1 −1 Original line number Diff line number Diff line Loading @@ -27,7 +27,7 @@ namespace qcor { class qasm3_visitor : public qasm3::qasm3BaseVisitor { public: // Return the symbol table. ScopedSymbolTable& getScopedSymbolTable() { return symbol_table; } ScopedSymbolTable* getScopedSymbolTable() { return &symbol_table; } // The constructor, instantiates commonly used opaque types qasm3_visitor(mlir::OpBuilder b, mlir::ModuleOp m, std::string& fname) Loading
mlir/parsers/qasm3/visitor_handlers/subroutine_handler.cpp +21 −12 Original line number Diff line number Diff line Loading @@ -166,7 +166,8 @@ antlrcpp::Any qasm3_visitor::visitReturnStatement( value = symbol_table.get_symbol(ret_stmt); // Actually return value if it is a bit[], // load and return if it is a bit // printErrorMessage("Putting this here til I fix this"); if (current_function_return_type) { // this means it is a subroutine if (!current_function_return_type.isa<mlir::MemRefType>()) { if (current_function_return_type.isa<mlir::IntegerType>() && current_function_return_type.getIntOrFloatBitWidth() == 1) { Loading @@ -176,10 +177,19 @@ antlrcpp::Any qasm3_visitor::visitReturnStatement( llvm::ArrayRef<mlir::Value> zero_index(tmp); value = builder.create<mlir::LoadOp>(location, value, zero_index); } else { value = builder.create<mlir::LoadOp>(location, value); //, zero_index); value = builder.create<mlir::LoadOp>(location, value); //, zero_index); } } else { printErrorMessage("We do not return memrefs from subroutines.", context); } } else { printErrorMessage("We do not return memrefs from subroutines.", context); if (auto t = value.getType().dyn_cast_or_null<mlir::MemRefType>()) { if (t.getRank() == 0) { value = builder.create<mlir::LoadOp>(location, value); } } } } else { Loading @@ -192,7 +202,6 @@ antlrcpp::Any qasm3_visitor::visitReturnStatement( visitChildren(context->statement()); value = symbol_table.get_last_value_added(); } } is_return_stmt = false; Loading
mlir/transforms/lowering/PrintOpLowering.cpp +13 −0 Original line number Diff line number Diff line Loading @@ -114,6 +114,11 @@ LogicalResult PrintOpLowering::matchAndRewrite( frmt_spec += "%d"; } ss << "_bit_array_b_" << dim; } else if (mem_ref_type.getElementType().isa<mlir::IntegerType>() && mem_ref_type.getRank() == 0 && mem_ref_type.getElementType().getIntOrFloatBitWidth() == 1) { frmt_spec += "%d"; ss << "_bit_array_b_0"; } } else { std::cout << "Currently invalid type to print.\n"; Loading Loading @@ -155,6 +160,7 @@ LogicalResult PrintOpLowering::matchAndRewrite( mem_ref_type.getRank() > 0 && mem_ref_type.getElementType().getIntOrFloatBitWidth() == 1) { // This is a bit array... auto dim = mem_ref_type.getShape()[0]; for (int i = 0; i < dim; i++) { auto attr = mlir::IntegerAttr::get(rewriter.getIndexType(), i); Loading @@ -163,6 +169,13 @@ LogicalResult PrintOpLowering::matchAndRewrite( loc, o, llvm::makeArrayRef(std::vector<mlir::Value>{ii})); args.push_back(z); } continue; } else if (mem_ref_type.getElementType().isa<mlir::IntegerType>() && mem_ref_type.getRank() == 0 && mem_ref_type.getElementType().getIntOrFloatBitWidth() == 1) { auto z = rewriter.create<mlir::LoadOp>(loc, o); args.push_back(z); continue; } } Loading