Loading mlir/dialect/lib/Quantum/QuantumOps.cpp +2 −1 Original line number Diff line number Diff line Loading @@ -8,7 +8,8 @@ bool isOpaqueTypeWithName(mlir::Type type, std::string dialect, if (type.isa<mlir::OpaqueType>() && dialect == "quantum") { if (type_name == "Qubit" || type_name == "Result" || type_name == "Array" || type_name == "ArgvType" || type_name == "QregType" || type_name == "StringType") { type_name == "StringType" || type_name == "Tuple" || type_name == "Callable") { return true; } } Loading mlir/transforms/lowering/CallableLowering.cpp 0 → 100644 +87 −0 Original line number Diff line number Diff line #include "CallableLowering.hpp" #include <iostream> #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" #include "mlir/Conversion/SCFToStandard/SCFToStandard.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/AsmState.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Verifier.h" #include "mlir/InitAllDialects.h" #include "llvm/ADT/Sequence.h" #include "llvm/ADT/StringRef.h" #include "llvm/IR/Module.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/ErrorOr.h" #include "llvm/Support/MemoryBuffer.h" #include "llvm/Support/SourceMgr.h" #include "llvm/Support/TargetSelect.h" #include "llvm/Support/raw_ostream.h" namespace qcor { LogicalResult TupleUnpackOpLowering::matchAndRewrite( Operation *op, ArrayRef<Value> operands, ConversionPatternRewriter &rewriter) const { assert(operands.size() == 1); ModuleOp parentModule = op->getParentOfType<ModuleOp>(); auto context = parentModule->getContext(); std::cout << "Before:\n"; parentModule.dump(); // Cast the tuple to a struct type: auto tuple_unpack_op = cast<mlir::quantum::TupleUnpackOp>(op); std::vector<Type> unpacked_type_list; std::vector<Type> tuple_struct_type_list; for (const auto &result : tuple_unpack_op.result()) { if (result.getType().isa<mlir::OpaqueType>() && result.getType().cast<mlir::OpaqueType>().getTypeData() == "Array") { auto array_type = LLVM::LLVMPointerType::get(get_quantum_type("Array", context)); unpacked_type_list.emplace_back(LLVM::LLVMPointerType::get(array_type)); tuple_struct_type_list.emplace_back(array_type); } else if (result.getType().isa<mlir::FloatType>()) { auto float_type = mlir::FloatType::getF64(context); unpacked_type_list.emplace_back(LLVM::LLVMPointerType::get(float_type)); tuple_struct_type_list.emplace_back(float_type); } } auto unpacked_type = LLVM::LLVMStructType::getLiteral( context, llvm::ArrayRef<Type>(tuple_struct_type_list)); auto location = parentModule->getLoc(); auto bitcast = rewriter.create<LLVM::BitcastOp>( location, LLVM::LLVMPointerType::get(unpacked_type), operands[0]); std::vector<mlir::Value> unpacked_vals; for (size_t idx = 0; idx < unpacked_type_list.size(); ++idx) { mlir::Value idx_cst = rewriter.create<LLVM::ConstantOp>( location, IntegerType::get(rewriter.getContext(), 64), rewriter.getIntegerAttr(rewriter.getIndexType(), idx)); auto getelementptr = rewriter.create<LLVM::GEPOp>( location, unpacked_type_list[idx], bitcast, idx_cst); auto load_op = rewriter.create<LLVM::LoadOp>( location, tuple_struct_type_list[idx], getelementptr.res()); unpacked_vals.emplace_back(load_op.res()); } for (size_t idx = 0; idx < unpacked_vals.size(); ++idx) { mlir::Value unpack_result = *(std::next(tuple_unpack_op.result_begin(), idx)); unpack_result.replaceAllUsesWith(unpacked_vals[idx]); } rewriter.eraseOp(op); std::cout << "After:\n"; parentModule.dump(); return success(); } LogicalResult CreateCallableOpLowering::matchAndRewrite( Operation *op, ArrayRef<Value> operands, ConversionPatternRewriter &rewriter) const { return success(); } } // namespace qcor No newline at end of file mlir/transforms/lowering/CallableLowering.hpp 0 → 100644 +28 −0 Original line number Diff line number Diff line #pragma once #include "quantum_to_llvm.hpp" namespace qcor { class TupleUnpackOpLowering : public ConversionPattern { protected: public: explicit TupleUnpackOpLowering(MLIRContext *context) : ConversionPattern(mlir::quantum::TupleUnpackOp::getOperationName(), 1, context) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands, ConversionPatternRewriter &rewriter) const override; }; class CreateCallableOpLowering : public ConversionPattern { protected: public: explicit CreateCallableOpLowering(MLIRContext *context) : ConversionPattern(mlir::quantum::CreateCallableOp::getOperationName(), 1, context) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands, ConversionPatternRewriter &rewriter) const override; }; } // namespace qcor No newline at end of file mlir/transforms/quantum_to_llvm.cpp +8 −1 Original line number Diff line number Diff line Loading @@ -21,6 +21,7 @@ #include "lowering/SetQregOpLowering.hpp" #include "lowering/StdAtanOpLowering.hpp" #include "lowering/ValueSemanticsInstOpLowering.hpp" #include "lowering/CallableLowering.hpp" namespace qcor { mlir::Type get_quantum_type(std::string type, mlir::MLIRContext *context) { Loading @@ -44,8 +45,12 @@ struct QuantumLLVMTypeConverter : public LLVMTypeConverter { return LLVM::LLVMPointerType::get(get_quantum_type("qreg", context)); } else if (type.getTypeData() == "Array") { return LLVM::LLVMPointerType::get(get_quantum_type("Array", context)); } else if (type.getTypeData() == "Callable") { return LLVM::LLVMPointerType::get(get_quantum_type("Callable", context)); } else if (type.getTypeData() == "Tuple") { return LLVM::LLVMPointerType::get(get_quantum_type("Tuple", context)); } std::cout << "ERROR WE DONT KNOW WAHT THIS TYPE IS\n"; std::cout << "ERROR WE DONT KNOW WHAT THIS TYPE IS\n"; exit(0); return mlir::IntegerType::get(context, 64); } Loading Loading @@ -105,6 +110,8 @@ void QuantumToLLVMLoweringPass::runOnOperation() { patterns.insert<EndAdjointURegionOpLowering>(&getContext()); patterns.insert<StartCtrlURegionOpLowering>(&getContext()); patterns.insert<EndCtrlURegionOpLowering>(&getContext()); patterns.insert<TupleUnpackOpLowering>(&getContext()); patterns.insert<CreateCallableOpLowering>(&getContext()); // We want to completely lower to LLVM, so we use a `FullConversion`. This // ensures that only legal operations will remain after the conversion. Loading Loading
mlir/dialect/lib/Quantum/QuantumOps.cpp +2 −1 Original line number Diff line number Diff line Loading @@ -8,7 +8,8 @@ bool isOpaqueTypeWithName(mlir::Type type, std::string dialect, if (type.isa<mlir::OpaqueType>() && dialect == "quantum") { if (type_name == "Qubit" || type_name == "Result" || type_name == "Array" || type_name == "ArgvType" || type_name == "QregType" || type_name == "StringType") { type_name == "StringType" || type_name == "Tuple" || type_name == "Callable") { return true; } } Loading
mlir/transforms/lowering/CallableLowering.cpp 0 → 100644 +87 −0 Original line number Diff line number Diff line #include "CallableLowering.hpp" #include <iostream> #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" #include "mlir/Conversion/SCFToStandard/SCFToStandard.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/AsmState.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Verifier.h" #include "mlir/InitAllDialects.h" #include "llvm/ADT/Sequence.h" #include "llvm/ADT/StringRef.h" #include "llvm/IR/Module.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/ErrorOr.h" #include "llvm/Support/MemoryBuffer.h" #include "llvm/Support/SourceMgr.h" #include "llvm/Support/TargetSelect.h" #include "llvm/Support/raw_ostream.h" namespace qcor { LogicalResult TupleUnpackOpLowering::matchAndRewrite( Operation *op, ArrayRef<Value> operands, ConversionPatternRewriter &rewriter) const { assert(operands.size() == 1); ModuleOp parentModule = op->getParentOfType<ModuleOp>(); auto context = parentModule->getContext(); std::cout << "Before:\n"; parentModule.dump(); // Cast the tuple to a struct type: auto tuple_unpack_op = cast<mlir::quantum::TupleUnpackOp>(op); std::vector<Type> unpacked_type_list; std::vector<Type> tuple_struct_type_list; for (const auto &result : tuple_unpack_op.result()) { if (result.getType().isa<mlir::OpaqueType>() && result.getType().cast<mlir::OpaqueType>().getTypeData() == "Array") { auto array_type = LLVM::LLVMPointerType::get(get_quantum_type("Array", context)); unpacked_type_list.emplace_back(LLVM::LLVMPointerType::get(array_type)); tuple_struct_type_list.emplace_back(array_type); } else if (result.getType().isa<mlir::FloatType>()) { auto float_type = mlir::FloatType::getF64(context); unpacked_type_list.emplace_back(LLVM::LLVMPointerType::get(float_type)); tuple_struct_type_list.emplace_back(float_type); } } auto unpacked_type = LLVM::LLVMStructType::getLiteral( context, llvm::ArrayRef<Type>(tuple_struct_type_list)); auto location = parentModule->getLoc(); auto bitcast = rewriter.create<LLVM::BitcastOp>( location, LLVM::LLVMPointerType::get(unpacked_type), operands[0]); std::vector<mlir::Value> unpacked_vals; for (size_t idx = 0; idx < unpacked_type_list.size(); ++idx) { mlir::Value idx_cst = rewriter.create<LLVM::ConstantOp>( location, IntegerType::get(rewriter.getContext(), 64), rewriter.getIntegerAttr(rewriter.getIndexType(), idx)); auto getelementptr = rewriter.create<LLVM::GEPOp>( location, unpacked_type_list[idx], bitcast, idx_cst); auto load_op = rewriter.create<LLVM::LoadOp>( location, tuple_struct_type_list[idx], getelementptr.res()); unpacked_vals.emplace_back(load_op.res()); } for (size_t idx = 0; idx < unpacked_vals.size(); ++idx) { mlir::Value unpack_result = *(std::next(tuple_unpack_op.result_begin(), idx)); unpack_result.replaceAllUsesWith(unpacked_vals[idx]); } rewriter.eraseOp(op); std::cout << "After:\n"; parentModule.dump(); return success(); } LogicalResult CreateCallableOpLowering::matchAndRewrite( Operation *op, ArrayRef<Value> operands, ConversionPatternRewriter &rewriter) const { return success(); } } // namespace qcor No newline at end of file
mlir/transforms/lowering/CallableLowering.hpp 0 → 100644 +28 −0 Original line number Diff line number Diff line #pragma once #include "quantum_to_llvm.hpp" namespace qcor { class TupleUnpackOpLowering : public ConversionPattern { protected: public: explicit TupleUnpackOpLowering(MLIRContext *context) : ConversionPattern(mlir::quantum::TupleUnpackOp::getOperationName(), 1, context) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands, ConversionPatternRewriter &rewriter) const override; }; class CreateCallableOpLowering : public ConversionPattern { protected: public: explicit CreateCallableOpLowering(MLIRContext *context) : ConversionPattern(mlir::quantum::CreateCallableOp::getOperationName(), 1, context) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands, ConversionPatternRewriter &rewriter) const override; }; } // namespace qcor No newline at end of file
mlir/transforms/quantum_to_llvm.cpp +8 −1 Original line number Diff line number Diff line Loading @@ -21,6 +21,7 @@ #include "lowering/SetQregOpLowering.hpp" #include "lowering/StdAtanOpLowering.hpp" #include "lowering/ValueSemanticsInstOpLowering.hpp" #include "lowering/CallableLowering.hpp" namespace qcor { mlir::Type get_quantum_type(std::string type, mlir::MLIRContext *context) { Loading @@ -44,8 +45,12 @@ struct QuantumLLVMTypeConverter : public LLVMTypeConverter { return LLVM::LLVMPointerType::get(get_quantum_type("qreg", context)); } else if (type.getTypeData() == "Array") { return LLVM::LLVMPointerType::get(get_quantum_type("Array", context)); } else if (type.getTypeData() == "Callable") { return LLVM::LLVMPointerType::get(get_quantum_type("Callable", context)); } else if (type.getTypeData() == "Tuple") { return LLVM::LLVMPointerType::get(get_quantum_type("Tuple", context)); } std::cout << "ERROR WE DONT KNOW WAHT THIS TYPE IS\n"; std::cout << "ERROR WE DONT KNOW WHAT THIS TYPE IS\n"; exit(0); return mlir::IntegerType::get(context, 64); } Loading Loading @@ -105,6 +110,8 @@ void QuantumToLLVMLoweringPass::runOnOperation() { patterns.insert<EndAdjointURegionOpLowering>(&getContext()); patterns.insert<StartCtrlURegionOpLowering>(&getContext()); patterns.insert<EndCtrlURegionOpLowering>(&getContext()); patterns.insert<TupleUnpackOpLowering>(&getContext()); patterns.insert<CreateCallableOpLowering>(&getContext()); // We want to completely lower to LLVM, so we use a `FullConversion`. This // ensures that only legal operations will remain after the conversion. Loading