Loading mlir/parsers/qasm3/visitor_handlers/conditional_handler.cpp +28 −4 Original line number Diff line number Diff line Loading @@ -95,6 +95,21 @@ mlir::Value create_capture_callable_gen( wrapped_func, unpackOp.result()); builder.create<mlir::ReturnOp>(builder.getUnknownLoc()); moduleOp.push_back(function_op); // !! We only ever invoke the body functor, create dummy functors for adj/ctrl for (const auto &suffix : {"__adj__wrapper", "__ctl__wrapper", "__ctladj__wrapper"}) { builder.restoreInsertionPoint(main_block); const std::string temp_fn_name = func_name + suffix; mlir::FuncOp fn_op( mlir::FuncOp::create(builder.getUnknownLoc(), temp_fn_name, func_type)); fn_op.setVisibility(mlir::SymbolTable::Visibility::Private); auto &entryBlock = *fn_op.addEntryBlock(); builder.setInsertionPointToStart(&entryBlock); builder.create<mlir::ReturnOp>(builder.getUnknownLoc()); moduleOp.push_back(fn_op); } builder.restoreInsertionPoint(main_block); auto callable_create_op = builder.create<mlir::quantum::CreateCallableOp>( builder.getUnknownLoc(), callable_type, Loading Loading @@ -132,12 +147,17 @@ antlrcpp::Any qasm3_visitor::visitBranchingStatement( std::vector<mlir::Type> argument_types; std::vector<std::string> argument_names; std::vector<mlir::Value> argument_values; // Narrow the list of supported types for tuple unpack... // We don't support all types atm. for (auto &[k, v] : all_vars) { // QIR types and Float (rotation angles) if (v.getType().isa<mlir::OpaqueType>() || v.getType().isa<mlir::FloatType>()) { argument_names.emplace_back(k); argument_values.emplace_back(v); argument_types.emplace_back(v.getType()); } } // Use the ANTLR node ptr (hex) as id for this temp. function const auto toString = [](auto *antr_node) { Loading Loading @@ -165,6 +185,10 @@ antlrcpp::Any qasm3_visitor::visitBranchingStatement( symbol_table.exit_scope(); symbol_table.add_seen_function(tmp_func_name, function); symbol_table.set_last_created_block(nullptr); for (int i = 0; i < arguments.size(); ++i) { symbol_table.replace_symbol(symbol_table.get_symbol(argument_names[i]), argument_values[i]); } m_module.push_back(function); auto then_body_callable = create_capture_callable_gen( Loading mlir/transforms/lowering/CallableLowering.cpp +114 −11 Original line number Diff line number Diff line Loading @@ -52,6 +52,9 @@ LogicalResult TupleUnpackOpLowering::matchAndRewrite( tuple_struct_type_list.push_back(mlir::FloatType::getF64(context)); } else if (result.getType().isa<mlir::IntegerType>()) { tuple_struct_type_list.push_back(mlir::IntegerType::get(context, 64)); } else { std::cout << "WE DON'T SUPPORT TUPLE UNPACK FOR THE TYPE\n"; exit(0); } } Loading Loading @@ -126,12 +129,12 @@ LogicalResult CreateCallableOpLowering::matchAndRewrite( value_1_const, /*alignment=*/0); const std::string kernel_name = create_callable_op.functors().str(); const std::string BODY_WRAPPER_NAME = kernel_name + "__body__wrapper"; const std::string ADJOINT_WRAPPER_NAME = kernel_name + "__adj__wrapper"; const std::string CTRL_WRAPPER_NAME = kernel_name + "__ctl__wrapper"; const std::string CTRL_ADJOINT_WRAPPER_NAME = kernel_name + "__ctladj__wrapper"; const std::string CTRL_ADJOINT_WRAPPER_NAME = kernel_name + "__ctladj__wrapper"; const std::vector<mlir::Value> functor_ptr_values{ // Base Loading Loading @@ -180,9 +183,9 @@ LogicalResult CreateCallableOpLowering::matchAndRewrite( // TuplePtr capture) auto create_callable_ftype = LLVM::LLVMFunctionType::get( callable_return_type, llvm::ArrayRef<Type>{LLVM::LLVMPointerType::get(callable_entry_fn_array_type), LLVM::LLVMPointerType::get(callback_fn_array_type), tuple_type}, llvm::ArrayRef<Type>{ LLVM::LLVMPointerType::get(callable_entry_fn_array_type), LLVM::LLVMPointerType::get(callback_fn_array_type), tuple_type}, false); // Insert the function declaration Loading @@ -196,12 +199,112 @@ LogicalResult CreateCallableOpLowering::matchAndRewrite( // Callbacks and captured tuple ==> null mlir::Value callbacks_nullPtr = rewriter.create<LLVM::NullOp>( location, LLVM::LLVMPointerType::get(callback_fn_array_type)); mlir::Value tuple_nullPtr = rewriter.create<LLVM::NullOp>(location, tuple_type); mlir::Value capture_tuple_ptr = [&]() { if (create_callable_op.captures().empty()) { auto op = rewriter.create<LLVM::NullOp>(location, tuple_type); return op.res(); } else { mlir::SmallVector<mlir::Type> tuple_struct_type_list; size_t tuple_size_in_bytes = 0; for (const auto &captured_var : create_callable_op.captures()) { if (captured_var.getType().isa<mlir::OpaqueType>() && captured_var.getType().cast<mlir::OpaqueType>().getTypeData() == "Array") { tuple_struct_type_list.push_back( LLVM::LLVMPointerType::get(get_quantum_type("Array", context))); tuple_size_in_bytes += sizeof(void *); } else if (captured_var.getType().isa<mlir::OpaqueType>() && captured_var.getType() .cast<mlir::OpaqueType>() .getTypeData() == "Qubit") { tuple_struct_type_list.push_back( LLVM::LLVMPointerType::get(get_quantum_type("Qubit", context))); tuple_size_in_bytes += sizeof(void *); } else if (captured_var.getType().isa<mlir::OpaqueType>() && captured_var.getType() .cast<mlir::OpaqueType>() .getTypeData() == "Tuple") { tuple_struct_type_list.push_back( LLVM::LLVMPointerType::get(get_quantum_type("Tuple", context))); tuple_size_in_bytes += sizeof(void *); } else if (captured_var.getType().isa<mlir::FloatType>()) { tuple_struct_type_list.push_back(mlir::FloatType::getF64(context)); tuple_size_in_bytes += sizeof(double); } else if (captured_var.getType().isa<mlir::IntegerType>()) { tuple_struct_type_list.push_back(mlir::IntegerType::get(context, 64)); tuple_size_in_bytes += sizeof(int64_t); } else { std::cout << "WE DON'T SUPPORT TUPLE PACK FOR THE TYPE\n"; exit(0); } } mlir::Value tuple_size_value = rewriter.create<LLVM::ConstantOp>( location, mlir::IntegerType::get(rewriter.getContext(), 64), rewriter.getIntegerAttr( mlir::IntegerType::get(rewriter.getContext(), 64), tuple_size_in_bytes)); // Tuple create signature: TuplePtr __quantum__rt__tuple_create(int64_t // size) FlatSymbolRefAttr tuple_create_symbol_ref = [&]() { const std::string qir_tuple_create_fn_name = "__quantum__rt__tuple_create"; if (parentModule.lookupSymbol<LLVM::LLVMFuncOp>( qir_tuple_create_fn_name)) { return SymbolRefAttr::get(qir_tuple_create_fn_name, context); } else { auto ftype = LLVM::LLVMFunctionType::get( tuple_type, llvm::ArrayRef<Type>{mlir::IntegerType::get(context, 64)}, false); // Insert the function declaration PatternRewriter::InsertionGuard insertGuard(rewriter); rewriter.setInsertionPointToStart(parentModule.getBody()); rewriter.create<LLVM::LLVMFuncOp>(parentModule->getLoc(), qir_tuple_create_fn_name, ftype); return mlir::SymbolRefAttr::get(qir_tuple_create_fn_name, context); } }(); auto createTupleCallOp = rewriter.create<mlir::CallOp>( location, tuple_create_symbol_ref, tuple_type, ArrayRef<Value>({tuple_size_value})); mlir::Value tuplePtr = createTupleCallOp.getResult(0); // Store to tuple: auto tuple_struct_type = LLVM::LLVMStructType::getLiteral(context, tuple_struct_type_list); auto structPtr = rewriter .create<LLVM::BitcastOp>( location, LLVM::LLVMPointerType::get(tuple_struct_type), tuplePtr) .res(); mlir::Value zero_cst = rewriter.create<LLVM::ConstantOp>( location, IntegerType::get(rewriter.getContext(), 32), rewriter.getIntegerAttr(rewriter.getIndexType(), 0)); for (size_t idx = 0; idx < tuple_struct_type_list.size(); ++idx) { mlir::Value idx_cst = rewriter.create<LLVM::ConstantOp>( location, IntegerType::get(rewriter.getContext(), 32), rewriter.getIntegerAttr(rewriter.getIndexType(), idx)); auto field_ptr = rewriter .create<LLVM::GEPOp>( location, LLVM::LLVMPointerType::get(tuple_struct_type_list[idx]), structPtr, ArrayRef<Value>({zero_cst, idx_cst})) .res(); auto store_op = rewriter.create<LLVM::StoreOp>( location, create_callable_op.captures()[idx], field_ptr); } return tuplePtr; } }(); auto createCallableCallOp = rewriter.create<mlir::CallOp>( location, qir_symbol_ref, callable_return_type, ArrayRef<Value>( {callable_entry_fn_array, callbacks_nullPtr, tuple_nullPtr})); {callable_entry_fn_array, callbacks_nullPtr, capture_tuple_ptr})); rewriter.replaceOp(op, createCallableCallOp.getResult(0)); return success(); } Loading Loading
mlir/parsers/qasm3/visitor_handlers/conditional_handler.cpp +28 −4 Original line number Diff line number Diff line Loading @@ -95,6 +95,21 @@ mlir::Value create_capture_callable_gen( wrapped_func, unpackOp.result()); builder.create<mlir::ReturnOp>(builder.getUnknownLoc()); moduleOp.push_back(function_op); // !! We only ever invoke the body functor, create dummy functors for adj/ctrl for (const auto &suffix : {"__adj__wrapper", "__ctl__wrapper", "__ctladj__wrapper"}) { builder.restoreInsertionPoint(main_block); const std::string temp_fn_name = func_name + suffix; mlir::FuncOp fn_op( mlir::FuncOp::create(builder.getUnknownLoc(), temp_fn_name, func_type)); fn_op.setVisibility(mlir::SymbolTable::Visibility::Private); auto &entryBlock = *fn_op.addEntryBlock(); builder.setInsertionPointToStart(&entryBlock); builder.create<mlir::ReturnOp>(builder.getUnknownLoc()); moduleOp.push_back(fn_op); } builder.restoreInsertionPoint(main_block); auto callable_create_op = builder.create<mlir::quantum::CreateCallableOp>( builder.getUnknownLoc(), callable_type, Loading Loading @@ -132,12 +147,17 @@ antlrcpp::Any qasm3_visitor::visitBranchingStatement( std::vector<mlir::Type> argument_types; std::vector<std::string> argument_names; std::vector<mlir::Value> argument_values; // Narrow the list of supported types for tuple unpack... // We don't support all types atm. for (auto &[k, v] : all_vars) { // QIR types and Float (rotation angles) if (v.getType().isa<mlir::OpaqueType>() || v.getType().isa<mlir::FloatType>()) { argument_names.emplace_back(k); argument_values.emplace_back(v); argument_types.emplace_back(v.getType()); } } // Use the ANTLR node ptr (hex) as id for this temp. function const auto toString = [](auto *antr_node) { Loading Loading @@ -165,6 +185,10 @@ antlrcpp::Any qasm3_visitor::visitBranchingStatement( symbol_table.exit_scope(); symbol_table.add_seen_function(tmp_func_name, function); symbol_table.set_last_created_block(nullptr); for (int i = 0; i < arguments.size(); ++i) { symbol_table.replace_symbol(symbol_table.get_symbol(argument_names[i]), argument_values[i]); } m_module.push_back(function); auto then_body_callable = create_capture_callable_gen( Loading
mlir/transforms/lowering/CallableLowering.cpp +114 −11 Original line number Diff line number Diff line Loading @@ -52,6 +52,9 @@ LogicalResult TupleUnpackOpLowering::matchAndRewrite( tuple_struct_type_list.push_back(mlir::FloatType::getF64(context)); } else if (result.getType().isa<mlir::IntegerType>()) { tuple_struct_type_list.push_back(mlir::IntegerType::get(context, 64)); } else { std::cout << "WE DON'T SUPPORT TUPLE UNPACK FOR THE TYPE\n"; exit(0); } } Loading Loading @@ -126,12 +129,12 @@ LogicalResult CreateCallableOpLowering::matchAndRewrite( value_1_const, /*alignment=*/0); const std::string kernel_name = create_callable_op.functors().str(); const std::string BODY_WRAPPER_NAME = kernel_name + "__body__wrapper"; const std::string ADJOINT_WRAPPER_NAME = kernel_name + "__adj__wrapper"; const std::string CTRL_WRAPPER_NAME = kernel_name + "__ctl__wrapper"; const std::string CTRL_ADJOINT_WRAPPER_NAME = kernel_name + "__ctladj__wrapper"; const std::string CTRL_ADJOINT_WRAPPER_NAME = kernel_name + "__ctladj__wrapper"; const std::vector<mlir::Value> functor_ptr_values{ // Base Loading Loading @@ -180,9 +183,9 @@ LogicalResult CreateCallableOpLowering::matchAndRewrite( // TuplePtr capture) auto create_callable_ftype = LLVM::LLVMFunctionType::get( callable_return_type, llvm::ArrayRef<Type>{LLVM::LLVMPointerType::get(callable_entry_fn_array_type), LLVM::LLVMPointerType::get(callback_fn_array_type), tuple_type}, llvm::ArrayRef<Type>{ LLVM::LLVMPointerType::get(callable_entry_fn_array_type), LLVM::LLVMPointerType::get(callback_fn_array_type), tuple_type}, false); // Insert the function declaration Loading @@ -196,12 +199,112 @@ LogicalResult CreateCallableOpLowering::matchAndRewrite( // Callbacks and captured tuple ==> null mlir::Value callbacks_nullPtr = rewriter.create<LLVM::NullOp>( location, LLVM::LLVMPointerType::get(callback_fn_array_type)); mlir::Value tuple_nullPtr = rewriter.create<LLVM::NullOp>(location, tuple_type); mlir::Value capture_tuple_ptr = [&]() { if (create_callable_op.captures().empty()) { auto op = rewriter.create<LLVM::NullOp>(location, tuple_type); return op.res(); } else { mlir::SmallVector<mlir::Type> tuple_struct_type_list; size_t tuple_size_in_bytes = 0; for (const auto &captured_var : create_callable_op.captures()) { if (captured_var.getType().isa<mlir::OpaqueType>() && captured_var.getType().cast<mlir::OpaqueType>().getTypeData() == "Array") { tuple_struct_type_list.push_back( LLVM::LLVMPointerType::get(get_quantum_type("Array", context))); tuple_size_in_bytes += sizeof(void *); } else if (captured_var.getType().isa<mlir::OpaqueType>() && captured_var.getType() .cast<mlir::OpaqueType>() .getTypeData() == "Qubit") { tuple_struct_type_list.push_back( LLVM::LLVMPointerType::get(get_quantum_type("Qubit", context))); tuple_size_in_bytes += sizeof(void *); } else if (captured_var.getType().isa<mlir::OpaqueType>() && captured_var.getType() .cast<mlir::OpaqueType>() .getTypeData() == "Tuple") { tuple_struct_type_list.push_back( LLVM::LLVMPointerType::get(get_quantum_type("Tuple", context))); tuple_size_in_bytes += sizeof(void *); } else if (captured_var.getType().isa<mlir::FloatType>()) { tuple_struct_type_list.push_back(mlir::FloatType::getF64(context)); tuple_size_in_bytes += sizeof(double); } else if (captured_var.getType().isa<mlir::IntegerType>()) { tuple_struct_type_list.push_back(mlir::IntegerType::get(context, 64)); tuple_size_in_bytes += sizeof(int64_t); } else { std::cout << "WE DON'T SUPPORT TUPLE PACK FOR THE TYPE\n"; exit(0); } } mlir::Value tuple_size_value = rewriter.create<LLVM::ConstantOp>( location, mlir::IntegerType::get(rewriter.getContext(), 64), rewriter.getIntegerAttr( mlir::IntegerType::get(rewriter.getContext(), 64), tuple_size_in_bytes)); // Tuple create signature: TuplePtr __quantum__rt__tuple_create(int64_t // size) FlatSymbolRefAttr tuple_create_symbol_ref = [&]() { const std::string qir_tuple_create_fn_name = "__quantum__rt__tuple_create"; if (parentModule.lookupSymbol<LLVM::LLVMFuncOp>( qir_tuple_create_fn_name)) { return SymbolRefAttr::get(qir_tuple_create_fn_name, context); } else { auto ftype = LLVM::LLVMFunctionType::get( tuple_type, llvm::ArrayRef<Type>{mlir::IntegerType::get(context, 64)}, false); // Insert the function declaration PatternRewriter::InsertionGuard insertGuard(rewriter); rewriter.setInsertionPointToStart(parentModule.getBody()); rewriter.create<LLVM::LLVMFuncOp>(parentModule->getLoc(), qir_tuple_create_fn_name, ftype); return mlir::SymbolRefAttr::get(qir_tuple_create_fn_name, context); } }(); auto createTupleCallOp = rewriter.create<mlir::CallOp>( location, tuple_create_symbol_ref, tuple_type, ArrayRef<Value>({tuple_size_value})); mlir::Value tuplePtr = createTupleCallOp.getResult(0); // Store to tuple: auto tuple_struct_type = LLVM::LLVMStructType::getLiteral(context, tuple_struct_type_list); auto structPtr = rewriter .create<LLVM::BitcastOp>( location, LLVM::LLVMPointerType::get(tuple_struct_type), tuplePtr) .res(); mlir::Value zero_cst = rewriter.create<LLVM::ConstantOp>( location, IntegerType::get(rewriter.getContext(), 32), rewriter.getIntegerAttr(rewriter.getIndexType(), 0)); for (size_t idx = 0; idx < tuple_struct_type_list.size(); ++idx) { mlir::Value idx_cst = rewriter.create<LLVM::ConstantOp>( location, IntegerType::get(rewriter.getContext(), 32), rewriter.getIntegerAttr(rewriter.getIndexType(), idx)); auto field_ptr = rewriter .create<LLVM::GEPOp>( location, LLVM::LLVMPointerType::get(tuple_struct_type_list[idx]), structPtr, ArrayRef<Value>({zero_cst, idx_cst})) .res(); auto store_op = rewriter.create<LLVM::StoreOp>( location, create_callable_op.captures()[idx], field_ptr); } return tuplePtr; } }(); auto createCallableCallOp = rewriter.create<mlir::CallOp>( location, qir_symbol_ref, callable_return_type, ArrayRef<Value>( {callable_entry_fn_array, callbacks_nullPtr, tuple_nullPtr})); {callable_entry_fn_array, callbacks_nullPtr, capture_tuple_ptr})); rewriter.replaceOp(op, createCallableCallOp.getResult(0)); return success(); } Loading