Loading mlir/dialect/include/Quantum/QuantumOps.td +6 −2 Original line number Diff line number Diff line Loading @@ -229,10 +229,14 @@ def TupleUnpackOp : QuantumOp<"tupleUnpack", []> { } def CreateCallableOp : QuantumOp<"createCallable", []> { let arguments = (ins FlatSymbolRefAttr:$functors); let arguments = (ins FlatSymbolRefAttr:$functors, Variadic<AnyType>:$captures); let results = (outs CallableType:$callable); let printer = [{ auto op = *this; p << "q.createCallable" << "(" << op.functors() << ") : " << op.callable().getType(); }]; p << "q.createCallable" << "(" << op.functors() << ") "; if (!op.captures().empty()) { p << "capture " << op.captures(); } }]; } def YieldOp : QuantumOp<"yield", [NoSideEffect, Terminator]> { Loading mlir/parsers/qasm3/visitor_handlers/conditional_handler.cpp +57 −4 Original line number Diff line number Diff line Loading @@ -57,6 +57,57 @@ std::optional<BitComparisonExpression> tryParseSimpleBooleanExpression( return std::nullopt; } // Callable running-off captured vars... mlir::Value create_capture_callable_gen( mlir::OpBuilder &builder, const std::string &func_name, mlir::ModuleOp &moduleOp, mlir::FuncOp &wrapped_func, std::vector<mlir::Value> &captured_vars) { auto context = builder.getContext(); auto main_block = builder.saveInsertionPoint(); mlir::Identifier dialect = mlir::Identifier::get("quantum", context); llvm::StringRef tuple_type_name("Tuple"); auto tuple_type = mlir::OpaqueType::get(context, dialect, tuple_type_name); llvm::StringRef array_type_name("Array"); auto array_type = mlir::OpaqueType::get(context, dialect, array_type_name); llvm::StringRef callable_type_name("Callable"); auto callable_type = mlir::OpaqueType::get(context, dialect, callable_type_name); llvm::StringRef qubit_type_name("Qubit"); auto qubit_type = mlir::OpaqueType::get(context, dialect, qubit_type_name); const std::vector<mlir::Type> argument_types{tuple_type, tuple_type, tuple_type}; auto func_type = builder.getFunctionType(argument_types, llvm::None); const std::string BODY_WRAPPER_SUFFIX = "__body__wrapper"; std::vector<mlir::FuncOp> all_wrapper_funcs; // Body wrapper: const std::string wrapper_fn_name = func_name + BODY_WRAPPER_SUFFIX; mlir::FuncOp function_op(mlir::FuncOp::create(builder.getUnknownLoc(), wrapper_fn_name, func_type)); function_op.setVisibility(mlir::SymbolTable::Visibility::Private); auto &entryBlock = *function_op.addEntryBlock(); builder.setInsertionPointToStart(&entryBlock); auto arguments = entryBlock.getArguments(); assert(arguments.size() == 3); // Unpack from **captured** vars (not input args...) // i.e., Tuple # 0 mlir::Value arg_tuple = arguments[0]; auto fn_type = wrapped_func.getType().cast<mlir::FunctionType>(); mlir::TypeRange arg_types(fn_type.getInputs()); auto unpackOp = builder.create<mlir::quantum::TupleUnpackOp>( builder.getUnknownLoc(), arg_types, arg_tuple); auto call_op = builder.create<mlir::CallOp>(builder.getUnknownLoc(), wrapped_func, unpackOp.result()); builder.create<mlir::ReturnOp>(builder.getUnknownLoc()); moduleOp.push_back(function_op); builder.restoreInsertionPoint(main_block); auto callable_create_op = builder.create<mlir::quantum::CreateCallableOp>( builder.getUnknownLoc(), callable_type, builder.getSymbolRefAttr(wrapped_func), /*captures*/ llvm::makeArrayRef(captured_vars)); return callable_create_op; } } // namespace namespace qcor { Loading @@ -70,7 +121,6 @@ antlrcpp::Any qasm3_visitor::visitBranchingStatement( auto bit_check_conditional = tryParseSimpleBooleanExpression(*conditional_expr); // Currently, we're only support If (not else yet) if (bit_check_conditional.has_value() && context->programBlock().size() == 1 && symbol_table.try_lookup_meas_result(bit_check_conditional->var_name) Loading Loading @@ -104,12 +154,13 @@ antlrcpp::Any qasm3_visitor::visitBranchingStatement( ss << (void *)antr_node; return ss.str(); }; const std::string tmp_fun_name = const std::string tmp_func_name = "if_body_" + toString(context->programBlock(0)); auto func_type = builder.getFunctionType(argument_types, llvm::None); auto proto = mlir::FuncOp::create(builder.getUnknownLoc(), tmp_fun_name, func_type); mlir::FuncOp::create(builder.getUnknownLoc(), tmp_func_name, func_type); mlir::FuncOp function(proto); function.setVisibility(mlir::SymbolTable::Visibility::Private); auto &entryBlock = *function.addEntryBlock(); builder.setInsertionPointToStart(&entryBlock); symbol_table.enter_new_scope(); Loading @@ -121,10 +172,12 @@ antlrcpp::Any qasm3_visitor::visitBranchingStatement( builder.create<mlir::ReturnOp>(builder.getUnknownLoc()); builder.restoreInsertionPoint(main_block); symbol_table.exit_scope(); symbol_table.add_seen_function(tmp_fun_name, function); symbol_table.add_seen_function(tmp_func_name, function); symbol_table.set_last_created_block(nullptr); m_module.push_back(function); auto then_body_callable = create_capture_callable_gen( builder, tmp_func_name, m_module, function, argument_values); // Create a call to the function: // FIXME: this should be wrapped as a callable... auto then_body_builder = ifOp.getThenBodyBuilder(); Loading mlir/parsers/qasm3/visitor_handlers/subroutine_handler.cpp +2 −1 Original line number Diff line number Diff line Loading @@ -181,7 +181,8 @@ void add_callable_gen(mlir::OpBuilder &builder, const std::string &func_name, builder.setInsertionPointToStart(&create_callable_entryBlock); auto callable_create_op = builder.create<mlir::quantum::CreateCallableOp>( builder.getUnknownLoc(), callable_type, builder.getSymbolRefAttr(wrapped_func)); builder.getSymbolRefAttr(wrapped_func), /*captures*/ llvm::makeArrayRef(std::vector<mlir::Value>{})); builder.create<mlir::ReturnOp>(builder.getUnknownLoc(), callable_create_op.callable()); moduleOp.push_back(create_callable_function_op); Loading Loading
mlir/dialect/include/Quantum/QuantumOps.td +6 −2 Original line number Diff line number Diff line Loading @@ -229,10 +229,14 @@ def TupleUnpackOp : QuantumOp<"tupleUnpack", []> { } def CreateCallableOp : QuantumOp<"createCallable", []> { let arguments = (ins FlatSymbolRefAttr:$functors); let arguments = (ins FlatSymbolRefAttr:$functors, Variadic<AnyType>:$captures); let results = (outs CallableType:$callable); let printer = [{ auto op = *this; p << "q.createCallable" << "(" << op.functors() << ") : " << op.callable().getType(); }]; p << "q.createCallable" << "(" << op.functors() << ") "; if (!op.captures().empty()) { p << "capture " << op.captures(); } }]; } def YieldOp : QuantumOp<"yield", [NoSideEffect, Terminator]> { Loading
mlir/parsers/qasm3/visitor_handlers/conditional_handler.cpp +57 −4 Original line number Diff line number Diff line Loading @@ -57,6 +57,57 @@ std::optional<BitComparisonExpression> tryParseSimpleBooleanExpression( return std::nullopt; } // Callable running-off captured vars... mlir::Value create_capture_callable_gen( mlir::OpBuilder &builder, const std::string &func_name, mlir::ModuleOp &moduleOp, mlir::FuncOp &wrapped_func, std::vector<mlir::Value> &captured_vars) { auto context = builder.getContext(); auto main_block = builder.saveInsertionPoint(); mlir::Identifier dialect = mlir::Identifier::get("quantum", context); llvm::StringRef tuple_type_name("Tuple"); auto tuple_type = mlir::OpaqueType::get(context, dialect, tuple_type_name); llvm::StringRef array_type_name("Array"); auto array_type = mlir::OpaqueType::get(context, dialect, array_type_name); llvm::StringRef callable_type_name("Callable"); auto callable_type = mlir::OpaqueType::get(context, dialect, callable_type_name); llvm::StringRef qubit_type_name("Qubit"); auto qubit_type = mlir::OpaqueType::get(context, dialect, qubit_type_name); const std::vector<mlir::Type> argument_types{tuple_type, tuple_type, tuple_type}; auto func_type = builder.getFunctionType(argument_types, llvm::None); const std::string BODY_WRAPPER_SUFFIX = "__body__wrapper"; std::vector<mlir::FuncOp> all_wrapper_funcs; // Body wrapper: const std::string wrapper_fn_name = func_name + BODY_WRAPPER_SUFFIX; mlir::FuncOp function_op(mlir::FuncOp::create(builder.getUnknownLoc(), wrapper_fn_name, func_type)); function_op.setVisibility(mlir::SymbolTable::Visibility::Private); auto &entryBlock = *function_op.addEntryBlock(); builder.setInsertionPointToStart(&entryBlock); auto arguments = entryBlock.getArguments(); assert(arguments.size() == 3); // Unpack from **captured** vars (not input args...) // i.e., Tuple # 0 mlir::Value arg_tuple = arguments[0]; auto fn_type = wrapped_func.getType().cast<mlir::FunctionType>(); mlir::TypeRange arg_types(fn_type.getInputs()); auto unpackOp = builder.create<mlir::quantum::TupleUnpackOp>( builder.getUnknownLoc(), arg_types, arg_tuple); auto call_op = builder.create<mlir::CallOp>(builder.getUnknownLoc(), wrapped_func, unpackOp.result()); builder.create<mlir::ReturnOp>(builder.getUnknownLoc()); moduleOp.push_back(function_op); builder.restoreInsertionPoint(main_block); auto callable_create_op = builder.create<mlir::quantum::CreateCallableOp>( builder.getUnknownLoc(), callable_type, builder.getSymbolRefAttr(wrapped_func), /*captures*/ llvm::makeArrayRef(captured_vars)); return callable_create_op; } } // namespace namespace qcor { Loading @@ -70,7 +121,6 @@ antlrcpp::Any qasm3_visitor::visitBranchingStatement( auto bit_check_conditional = tryParseSimpleBooleanExpression(*conditional_expr); // Currently, we're only support If (not else yet) if (bit_check_conditional.has_value() && context->programBlock().size() == 1 && symbol_table.try_lookup_meas_result(bit_check_conditional->var_name) Loading Loading @@ -104,12 +154,13 @@ antlrcpp::Any qasm3_visitor::visitBranchingStatement( ss << (void *)antr_node; return ss.str(); }; const std::string tmp_fun_name = const std::string tmp_func_name = "if_body_" + toString(context->programBlock(0)); auto func_type = builder.getFunctionType(argument_types, llvm::None); auto proto = mlir::FuncOp::create(builder.getUnknownLoc(), tmp_fun_name, func_type); mlir::FuncOp::create(builder.getUnknownLoc(), tmp_func_name, func_type); mlir::FuncOp function(proto); function.setVisibility(mlir::SymbolTable::Visibility::Private); auto &entryBlock = *function.addEntryBlock(); builder.setInsertionPointToStart(&entryBlock); symbol_table.enter_new_scope(); Loading @@ -121,10 +172,12 @@ antlrcpp::Any qasm3_visitor::visitBranchingStatement( builder.create<mlir::ReturnOp>(builder.getUnknownLoc()); builder.restoreInsertionPoint(main_block); symbol_table.exit_scope(); symbol_table.add_seen_function(tmp_fun_name, function); symbol_table.add_seen_function(tmp_func_name, function); symbol_table.set_last_created_block(nullptr); m_module.push_back(function); auto then_body_callable = create_capture_callable_gen( builder, tmp_func_name, m_module, function, argument_values); // Create a call to the function: // FIXME: this should be wrapped as a callable... auto then_body_builder = ifOp.getThenBodyBuilder(); Loading
mlir/parsers/qasm3/visitor_handlers/subroutine_handler.cpp +2 −1 Original line number Diff line number Diff line Loading @@ -181,7 +181,8 @@ void add_callable_gen(mlir::OpBuilder &builder, const std::string &func_name, builder.setInsertionPointToStart(&create_callable_entryBlock); auto callable_create_op = builder.create<mlir::quantum::CreateCallableOp>( builder.getUnknownLoc(), callable_type, builder.getSymbolRefAttr(wrapped_func)); builder.getSymbolRefAttr(wrapped_func), /*captures*/ llvm::makeArrayRef(std::vector<mlir::Value>{})); builder.create<mlir::ReturnOp>(builder.getUnknownLoc(), callable_create_op.callable()); moduleOp.push_back(create_callable_function_op); Loading