Loading mlir/transforms/lowering/CallableLowering.cpp +42 −13 Original line number Diff line number Diff line Loading @@ -103,21 +103,51 @@ LogicalResult CreateCallableOpLowering::matchAndRewrite( false); FlatSymbolRefAttr symbol_ref = SymbolRefAttr::get(create_callable_op.functors(), context); // mlir::Value func_ptr = rewriter.create<LLVM::AddressOfOp>( // location, LLVM::LLVMPointerType::get(callable_entry_ftype), symbol_ref) auto callable_entry_fn_array_type = LLVM::LLVMArrayType::get( LLVM::LLVMPointerType::get(callable_entry_ftype), 4); auto callback_fn_array_type = LLVM::LLVMArrayType::get( LLVM::LLVMPointerType::get(capture_callback_ftype), 2); auto save_pt = rewriter.saveInsertionPoint(); rewriter.setInsertionPointToStart(&parentModule.getRegion().getBlocks().front()); const std::string functor_array_name = create_callable_op.functors().str() + "__Qops"; auto fPtr_array = ArrayAttr::get({symbol_ref, symbol_ref, symbol_ref, symbol_ref} , context); auto fPtr_array_const_global = rewriter.create<LLVM::GlobalOp>( location, callable_entry_fn_array_type, /*isConstant=*/true, LLVM::Linkage::Internal, functor_array_name.c_str(), fPtr_array); rewriter.restoreInsertionPoint(save_pt); mlir::Value value_1_const = rewriter.create<LLVM::ConstantOp>( location, IntegerType::get(rewriter.getContext(), 64), rewriter.getIntegerAttr(rewriter.getIndexType(), 1)); mlir::Value callable_entry_fn_array = rewriter.create<LLVM::AllocaOp>( location, LLVM::LLVMPointerType::get(callable_entry_fn_array_type), value_1_const, /*alignment=*/0); const std::vector<mlir::Value> functor_ptr_values{ // Base rewriter.create<LLVM::AddressOfOp>( location, LLVM::LLVMPointerType::get(callable_entry_ftype), symbol_ref), // Adjoint rewriter.create<LLVM::NullOp>( location, LLVM::LLVMPointerType::get(callable_entry_ftype)), // Controlled rewriter.create<LLVM::NullOp>( location, LLVM::LLVMPointerType::get(callable_entry_ftype)), // Controlled Adjoint rewriter.create<LLVM::NullOp>( location, LLVM::LLVMPointerType::get(callable_entry_ftype)), }; mlir::Value zero_index = rewriter.create<LLVM::ConstantOp>( location, IntegerType::get(rewriter.getContext(), 64), rewriter.getIntegerAttr(rewriter.getIndexType(), 0)); for (size_t func_idx = 0; func_idx < functor_ptr_values.size(); ++func_idx) { mlir::Value func_index_val = rewriter.create<LLVM::ConstantOp>( location, IntegerType::get(rewriter.getContext(), 64), rewriter.getIntegerAttr(rewriter.getIndexType(), func_idx)); mlir::Value func_in_array_ptr = rewriter.create<LLVM::GEPOp>( location, LLVM::LLVMPointerType::get( LLVM::LLVMPointerType::get(callable_entry_ftype)), callable_entry_fn_array, ArrayRef<Value>({zero_index, func_index_val})); rewriter.create<LLVM::StoreOp>(location, functor_ptr_values[func_idx], func_in_array_ptr); } auto callable_return_type = LLVM::LLVMPointerType::get(get_quantum_type("Callable", context)); Loading @@ -144,8 +174,7 @@ LogicalResult CreateCallableOpLowering::matchAndRewrite( qir_symbol_ref = mlir::SymbolRefAttr::get(qir_create_callable, context); } mlir::Value callable_entry_nullPtr = rewriter.create<LLVM::NullOp>( location, LLVM::LLVMPointerType::get(callable_entry_fn_array_type)); // Callbacks and captured tuple ==> null mlir::Value callbacks_nullPtr = rewriter.create<LLVM::NullOp>( location, LLVM::LLVMPointerType::get(callback_fn_array_type)); mlir::Value tuple_nullPtr = Loading @@ -153,7 +182,7 @@ LogicalResult CreateCallableOpLowering::matchAndRewrite( auto createCallableCallOp = rewriter.create<mlir::CallOp>( location, qir_symbol_ref, callable_return_type, ArrayRef<Value>( {callable_entry_nullPtr, callbacks_nullPtr, tuple_nullPtr})); {callable_entry_fn_array, callbacks_nullPtr, tuple_nullPtr})); rewriter.replaceOp(op, createCallableCallOp.getResult(0)); return success(); } Loading Loading
mlir/transforms/lowering/CallableLowering.cpp +42 −13 Original line number Diff line number Diff line Loading @@ -103,21 +103,51 @@ LogicalResult CreateCallableOpLowering::matchAndRewrite( false); FlatSymbolRefAttr symbol_ref = SymbolRefAttr::get(create_callable_op.functors(), context); // mlir::Value func_ptr = rewriter.create<LLVM::AddressOfOp>( // location, LLVM::LLVMPointerType::get(callable_entry_ftype), symbol_ref) auto callable_entry_fn_array_type = LLVM::LLVMArrayType::get( LLVM::LLVMPointerType::get(callable_entry_ftype), 4); auto callback_fn_array_type = LLVM::LLVMArrayType::get( LLVM::LLVMPointerType::get(capture_callback_ftype), 2); auto save_pt = rewriter.saveInsertionPoint(); rewriter.setInsertionPointToStart(&parentModule.getRegion().getBlocks().front()); const std::string functor_array_name = create_callable_op.functors().str() + "__Qops"; auto fPtr_array = ArrayAttr::get({symbol_ref, symbol_ref, symbol_ref, symbol_ref} , context); auto fPtr_array_const_global = rewriter.create<LLVM::GlobalOp>( location, callable_entry_fn_array_type, /*isConstant=*/true, LLVM::Linkage::Internal, functor_array_name.c_str(), fPtr_array); rewriter.restoreInsertionPoint(save_pt); mlir::Value value_1_const = rewriter.create<LLVM::ConstantOp>( location, IntegerType::get(rewriter.getContext(), 64), rewriter.getIntegerAttr(rewriter.getIndexType(), 1)); mlir::Value callable_entry_fn_array = rewriter.create<LLVM::AllocaOp>( location, LLVM::LLVMPointerType::get(callable_entry_fn_array_type), value_1_const, /*alignment=*/0); const std::vector<mlir::Value> functor_ptr_values{ // Base rewriter.create<LLVM::AddressOfOp>( location, LLVM::LLVMPointerType::get(callable_entry_ftype), symbol_ref), // Adjoint rewriter.create<LLVM::NullOp>( location, LLVM::LLVMPointerType::get(callable_entry_ftype)), // Controlled rewriter.create<LLVM::NullOp>( location, LLVM::LLVMPointerType::get(callable_entry_ftype)), // Controlled Adjoint rewriter.create<LLVM::NullOp>( location, LLVM::LLVMPointerType::get(callable_entry_ftype)), }; mlir::Value zero_index = rewriter.create<LLVM::ConstantOp>( location, IntegerType::get(rewriter.getContext(), 64), rewriter.getIntegerAttr(rewriter.getIndexType(), 0)); for (size_t func_idx = 0; func_idx < functor_ptr_values.size(); ++func_idx) { mlir::Value func_index_val = rewriter.create<LLVM::ConstantOp>( location, IntegerType::get(rewriter.getContext(), 64), rewriter.getIntegerAttr(rewriter.getIndexType(), func_idx)); mlir::Value func_in_array_ptr = rewriter.create<LLVM::GEPOp>( location, LLVM::LLVMPointerType::get( LLVM::LLVMPointerType::get(callable_entry_ftype)), callable_entry_fn_array, ArrayRef<Value>({zero_index, func_index_val})); rewriter.create<LLVM::StoreOp>(location, functor_ptr_values[func_idx], func_in_array_ptr); } auto callable_return_type = LLVM::LLVMPointerType::get(get_quantum_type("Callable", context)); Loading @@ -144,8 +174,7 @@ LogicalResult CreateCallableOpLowering::matchAndRewrite( qir_symbol_ref = mlir::SymbolRefAttr::get(qir_create_callable, context); } mlir::Value callable_entry_nullPtr = rewriter.create<LLVM::NullOp>( location, LLVM::LLVMPointerType::get(callable_entry_fn_array_type)); // Callbacks and captured tuple ==> null mlir::Value callbacks_nullPtr = rewriter.create<LLVM::NullOp>( location, LLVM::LLVMPointerType::get(callback_fn_array_type)); mlir::Value tuple_nullPtr = Loading @@ -153,7 +182,7 @@ LogicalResult CreateCallableOpLowering::matchAndRewrite( auto createCallableCallOp = rewriter.create<mlir::CallOp>( location, qir_symbol_ref, callable_return_type, ArrayRef<Value>( {callable_entry_nullPtr, callbacks_nullPtr, tuple_nullPtr})); {callable_entry_fn_array, callbacks_nullPtr, tuple_nullPtr})); rewriter.replaceOp(op, createCallableCallOp.getResult(0)); return success(); } Loading