Loading mlir/tools/qcor-mlir-tool.cpp +4 −1 Original line number Diff line number Diff line Loading @@ -129,7 +129,10 @@ int main(int argc, char **argv) { // Now lower MLIR to LLVM IR llvm::LLVMContext llvmContext; auto llvmModule = mlir::translateModuleToLLVMIR(*module, llvmContext); if (!llvmModule) { llvm::errs() << "Failed to emit LLVM IR\n"; return -1; } // Optimize the LLVM IR llvm::InitializeNativeTarget(); llvm::InitializeNativeTargetAsmPrinter(); Loading mlir/transforms/lowering/CallableLowering.cpp +29 −27 Original line number Diff line number Diff line Loading @@ -30,52 +30,54 @@ LogicalResult TupleUnpackOpLowering::matchAndRewrite( assert(operands.size() == 1); ModuleOp parentModule = op->getParentOfType<ModuleOp>(); auto context = parentModule->getContext(); std::cout << "Before:\n"; parentModule.dump(); // Cast the tuple to a struct type: auto tuple_unpack_op = cast<mlir::quantum::TupleUnpackOp>(op); std::vector<Type> unpacked_type_list; std::vector<Type> tuple_struct_type_list; mlir::SmallVector<mlir::Type> tuple_struct_type_list; for (const auto &result : tuple_unpack_op.result()) { if (result.getType().isa<mlir::OpaqueType>() && result.getType().cast<mlir::OpaqueType>().getTypeData() == "Array") { auto array_type = LLVM::LLVMPointerType::get(get_quantum_type("Array", context)); unpacked_type_list.emplace_back(LLVM::LLVMPointerType::get(array_type)); tuple_struct_type_list.emplace_back(array_type); tuple_struct_type_list.push_back(array_type); } else if (result.getType().isa<mlir::FloatType>()) { auto float_type = mlir::FloatType::getF64(context); unpacked_type_list.emplace_back(LLVM::LLVMPointerType::get(float_type)); tuple_struct_type_list.emplace_back(float_type); tuple_struct_type_list.push_back(float_type); } } auto unpacked_type = LLVM::LLVMStructType::getLiteral( context, llvm::ArrayRef<Type>(tuple_struct_type_list)); auto unpacked_struct_type = LLVM::LLVMStructType::getLiteral(context, tuple_struct_type_list); auto location = parentModule->getLoc(); auto bitcast = rewriter.create<LLVM::BitcastOp>( location, LLVM::LLVMPointerType::get(unpacked_type), operands[0]); auto structPtr = rewriter .create<LLVM::BitcastOp>( location, LLVM::LLVMPointerType::get(unpacked_struct_type), operands[0]) .res(); std::vector<mlir::Value> unpacked_vals; for (size_t idx = 0; idx < unpacked_type_list.size(); ++idx) { mlir::SmallVector<mlir::Value> unpacked_vals; mlir::Value zero_cst = rewriter.create<LLVM::ConstantOp>( location, IntegerType::get(rewriter.getContext(), 64), rewriter.getIntegerAttr(rewriter.getIndexType(), 0)); for (size_t idx = 0; idx < tuple_struct_type_list.size(); ++idx) { mlir::Value idx_cst = rewriter.create<LLVM::ConstantOp>( location, IntegerType::get(rewriter.getContext(), 64), rewriter.getIntegerAttr(rewriter.getIndexType(), idx)); auto getelementptr = rewriter.create<LLVM::GEPOp>( location, unpacked_type_list[idx], bitcast, idx_cst); auto field_ptr = rewriter .create<LLVM::GEPOp>( location, LLVM::LLVMPointerType::get(tuple_struct_type_list[idx]), structPtr, mlir::ArrayRef<mlir::Value>({zero_cst, idx_cst})) .res(); auto load_op = rewriter.create<LLVM::LoadOp>( location, tuple_struct_type_list[idx], getelementptr.res()); unpacked_vals.emplace_back(load_op.res()); location, tuple_struct_type_list[idx], field_ptr); unpacked_vals.push_back(load_op.res()); } for (size_t idx = 0; idx < unpacked_vals.size(); ++idx) { mlir::Value unpack_result = *(std::next(tuple_unpack_op.result_begin(), idx)); unpack_result.replaceAllUsesWith(unpacked_vals[idx]); } rewriter.eraseOp(op); std::cout << "After:\n"; parentModule.dump(); rewriter.replaceOp(op, unpacked_vals); // std::cout << "After:\n"; // parentModule.dump(); return success(); } Loading Loading
mlir/tools/qcor-mlir-tool.cpp +4 −1 Original line number Diff line number Diff line Loading @@ -129,7 +129,10 @@ int main(int argc, char **argv) { // Now lower MLIR to LLVM IR llvm::LLVMContext llvmContext; auto llvmModule = mlir::translateModuleToLLVMIR(*module, llvmContext); if (!llvmModule) { llvm::errs() << "Failed to emit LLVM IR\n"; return -1; } // Optimize the LLVM IR llvm::InitializeNativeTarget(); llvm::InitializeNativeTargetAsmPrinter(); Loading
mlir/transforms/lowering/CallableLowering.cpp +29 −27 Original line number Diff line number Diff line Loading @@ -30,52 +30,54 @@ LogicalResult TupleUnpackOpLowering::matchAndRewrite( assert(operands.size() == 1); ModuleOp parentModule = op->getParentOfType<ModuleOp>(); auto context = parentModule->getContext(); std::cout << "Before:\n"; parentModule.dump(); // Cast the tuple to a struct type: auto tuple_unpack_op = cast<mlir::quantum::TupleUnpackOp>(op); std::vector<Type> unpacked_type_list; std::vector<Type> tuple_struct_type_list; mlir::SmallVector<mlir::Type> tuple_struct_type_list; for (const auto &result : tuple_unpack_op.result()) { if (result.getType().isa<mlir::OpaqueType>() && result.getType().cast<mlir::OpaqueType>().getTypeData() == "Array") { auto array_type = LLVM::LLVMPointerType::get(get_quantum_type("Array", context)); unpacked_type_list.emplace_back(LLVM::LLVMPointerType::get(array_type)); tuple_struct_type_list.emplace_back(array_type); tuple_struct_type_list.push_back(array_type); } else if (result.getType().isa<mlir::FloatType>()) { auto float_type = mlir::FloatType::getF64(context); unpacked_type_list.emplace_back(LLVM::LLVMPointerType::get(float_type)); tuple_struct_type_list.emplace_back(float_type); tuple_struct_type_list.push_back(float_type); } } auto unpacked_type = LLVM::LLVMStructType::getLiteral( context, llvm::ArrayRef<Type>(tuple_struct_type_list)); auto unpacked_struct_type = LLVM::LLVMStructType::getLiteral(context, tuple_struct_type_list); auto location = parentModule->getLoc(); auto bitcast = rewriter.create<LLVM::BitcastOp>( location, LLVM::LLVMPointerType::get(unpacked_type), operands[0]); auto structPtr = rewriter .create<LLVM::BitcastOp>( location, LLVM::LLVMPointerType::get(unpacked_struct_type), operands[0]) .res(); std::vector<mlir::Value> unpacked_vals; for (size_t idx = 0; idx < unpacked_type_list.size(); ++idx) { mlir::SmallVector<mlir::Value> unpacked_vals; mlir::Value zero_cst = rewriter.create<LLVM::ConstantOp>( location, IntegerType::get(rewriter.getContext(), 64), rewriter.getIntegerAttr(rewriter.getIndexType(), 0)); for (size_t idx = 0; idx < tuple_struct_type_list.size(); ++idx) { mlir::Value idx_cst = rewriter.create<LLVM::ConstantOp>( location, IntegerType::get(rewriter.getContext(), 64), rewriter.getIntegerAttr(rewriter.getIndexType(), idx)); auto getelementptr = rewriter.create<LLVM::GEPOp>( location, unpacked_type_list[idx], bitcast, idx_cst); auto field_ptr = rewriter .create<LLVM::GEPOp>( location, LLVM::LLVMPointerType::get(tuple_struct_type_list[idx]), structPtr, mlir::ArrayRef<mlir::Value>({zero_cst, idx_cst})) .res(); auto load_op = rewriter.create<LLVM::LoadOp>( location, tuple_struct_type_list[idx], getelementptr.res()); unpacked_vals.emplace_back(load_op.res()); location, tuple_struct_type_list[idx], field_ptr); unpacked_vals.push_back(load_op.res()); } for (size_t idx = 0; idx < unpacked_vals.size(); ++idx) { mlir::Value unpack_result = *(std::next(tuple_unpack_op.result_begin(), idx)); unpack_result.replaceAllUsesWith(unpacked_vals[idx]); } rewriter.eraseOp(op); std::cout << "After:\n"; parentModule.dump(); rewriter.replaceOp(op, unpacked_vals); // std::cout << "After:\n"; // parentModule.dump(); return success(); } Loading