Commit b7d032a2 authored by Nguyen, Thien Minh's avatar Nguyen, Thien Minh
Browse files

Refactor MLIR lowering passes into separate files



Signed-off-by: default avatarThien Nguyen <nguyentm@ornl.gov>
parent aee2f861
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -8,7 +8,7 @@ link_directories(${XACC_ROOT}/lib)
#target_link_libraries(qasm3VisitorTester qcor-mlir-api gtest gtest_main)

add_executable(qasm3VisitorAliasTester test_alias_handler.cpp)
add_test(NAME qcor_qasm3_quantum_alias_decl_tester COMMAND test_alias_handler)
add_test(NAME qcor_qasm3_quantum_alias_decl_tester COMMAND qasm3VisitorAliasTester)
target_include_directories(qasm3VisitorAliasTester PRIVATE . ../../ ${XACC_ROOT}/include/gtest)
target_link_libraries(qasm3VisitorAliasTester qcor-mlir-api gtest gtest_main)

+121 −0
Original line number Diff line number Diff line
#pragma once

#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Verifier.h"
#include "mlir/InitAllDialects.h"
#include "quantum_to_llvm.hpp"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/IR/Module.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/ErrorOr.h"
#include "llvm/Support/MemoryBuffer.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Support/raw_ostream.h"
#include <iostream>

namespace qcor {
class AssignQubitOpConversion : public ConversionPattern {
protected:
  std::map<std::string, mlir::Value> &variables;

public:
  // CTor: store seen variables
  explicit AssignQubitOpConversion(MLIRContext *context,
                                   std::map<std::string, mlir::Value> &vars)
      : ConversionPattern(mlir::quantum::AssignQubitOp::getOperationName(), 1,
                          context),
        variables(vars) {}

  LogicalResult
  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                  ConversionPatternRewriter &rewriter) const override {
    // Local Declarations
    ModuleOp parentModule = op->getParentOfType<ModuleOp>();
    auto context = parentModule->getContext();
    auto location = parentModule->getLoc();
    // Unpack destination and source array and indices
    auto dest_array = operands[0];
    auto dest_idx = operands[1];
    auto src_array = operands[2];
    auto src_idx = operands[3];
    FlatSymbolRefAttr array_get_elem_fn_ptr = [&]() {
      static const std::string qir_get_qubit_from_array =
          "__quantum__rt__array_get_element_ptr_1d";
      if (parentModule.lookupSymbol<LLVM::LLVMFuncOp>(
              qir_get_qubit_from_array)) {
        return SymbolRefAttr::get(qir_get_qubit_from_array, context);
      } else {
        // prototype should be (int64* : qreg, int64 : element) -> int64* :
        // qubit
        auto qubit_array_type =
            LLVM::LLVMPointerType::get(get_quantum_type("Array", context));
        auto qubit_index_type = IntegerType::get(context, 64);
        // ret is i8*
        auto qbit_element_ptr_type =
            LLVM::LLVMPointerType::get(IntegerType::get(context, 8));

        auto get_ptr_qbit_ftype = LLVM::LLVMFunctionType::get(
            qbit_element_ptr_type,
            llvm::ArrayRef<Type>{qubit_array_type, qubit_index_type}, false);

        PatternRewriter::InsertionGuard insertGuard(rewriter);
        rewriter.setInsertionPointToStart(parentModule.getBody());
        rewriter.create<LLVM::LLVMFuncOp>(location, qir_get_qubit_from_array,
                                          get_ptr_qbit_ftype);

        return mlir::SymbolRefAttr::get(qir_get_qubit_from_array, context);
      }
    }();

    // Create the CallOp for the get element ptr 1d function
    auto get_dest_qbit_qir_call = rewriter.create<mlir::CallOp>(
        location, array_get_elem_fn_ptr,
        LLVM::LLVMPointerType::get(IntegerType::get(context, 8)),
        llvm::makeArrayRef(std::vector<mlir::Value>{dest_array, dest_idx}));

    auto get_src_qbit_qir_call = rewriter.create<mlir::CallOp>(
        location, array_get_elem_fn_ptr,
        LLVM::LLVMPointerType::get(IntegerType::get(context, 8)),
        llvm::makeArrayRef(std::vector<mlir::Value>{src_array, src_idx}));

    // Load source qubit
    auto src_bitcast = rewriter.create<LLVM::BitcastOp>(
        location,
        LLVM::LLVMPointerType::get(
            LLVM::LLVMPointerType::get(get_quantum_type("Qubit", context))),
        get_src_qbit_qir_call.getResult(0));

    auto real_casted_src_qubit = rewriter.create<LLVM::LoadOp>(
        location,
        LLVM::LLVMPointerType::get(get_quantum_type("Qubit", context)),
        src_bitcast.res());

    // Destination: just cast the raw ptr to Qubit** to store the source Qubit*
    // to. Get the destination raw ptr (int8) and cast to Qubit**
    auto dest_bitcast = rewriter.create<LLVM::BitcastOp>(
        location,
        LLVM::LLVMPointerType::get(
            LLVM::LLVMPointerType::get(get_quantum_type("Qubit", context))),
        get_dest_qbit_qir_call.getResult(0));

    // Store source (Qubit*) to destination (Qubit**)
    rewriter.create<LLVM::StoreOp>(location, real_casted_src_qubit,
                                   dest_bitcast);
    rewriter.eraseOp(op);
    // std::cout << "After assign:\n";
    // parentModule.dump();
    return success();
  }
};
} // namespace qcor
 No newline at end of file
+92 −0
Original line number Diff line number Diff line
#pragma once

#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Verifier.h"
#include "mlir/InitAllDialects.h"
#include "quantum_to_llvm.hpp"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/IR/Module.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/ErrorOr.h"
#include "llvm/Support/MemoryBuffer.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Support/raw_ostream.h"
#include <iostream>
namespace qcor {
class CreateStringLiteralOpLowering : public ConversionPattern {
private:
  std::map<std::string, mlir::Value> &variables;

  /// Return a value representing an access into a global string with the given
  /// name, creating the string if necessary.
  static Value getOrCreateGlobalString(Location loc, OpBuilder &builder,
                                       StringRef name, StringRef value,
                                       ModuleOp module) {
    // Create the global at the entry of the module.
    LLVM::GlobalOp global;
    if (!(global = module.lookupSymbol<LLVM::GlobalOp>(name))) {
      OpBuilder::InsertionGuard insertGuard(builder);
      builder.setInsertionPointToStart(module.getBody());
      auto type = LLVM::LLVMArrayType::get(
          IntegerType::get(builder.getContext(), 8), value.size());
      global = builder.create<LLVM::GlobalOp>(loc, type, /*isConstant=*/true,
                                              LLVM::Linkage::Internal, name,
                                              builder.getStringAttr(value));
    }

    // Get the pointer to the first character in the global string.
    Value globalPtr = builder.create<LLVM::AddressOfOp>(loc, global);
    Value cst0 = builder.create<LLVM::ConstantOp>(
        loc, IntegerType::get(builder.getContext(), 64),
        builder.getIntegerAttr(builder.getIndexType(), 0));
    return builder.create<LLVM::GEPOp>(
        loc,
        LLVM::LLVMPointerType::get(IntegerType::get(builder.getContext(), 8)),
        globalPtr, ArrayRef<Value>({cst0, cst0}));
  }

public:
  // Constructor, store seen variables
  explicit CreateStringLiteralOpLowering(MLIRContext *context,
                                         std::map<std::string, mlir::Value> &v)
      : ConversionPattern(
            mlir::quantum::CreateStringLiteralOp::getOperationName(), 1,
            context),
        variables(v) {}

  // Match any Operation that is the QallocOp
  LogicalResult
  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                  ConversionPatternRewriter &rewriter) const override {
    // Local Declarations, get location, parentModule
    // and the context
    auto loc = op->getLoc();
    ModuleOp parentModule = op->getParentOfType<ModuleOp>();
    auto slOp = cast<mlir::quantum::CreateStringLiteralOp>(op);
    auto slOpText = slOp.text();
    auto slVarName = slOp.varname();

    Value new_global_str = getOrCreateGlobalString(
        loc, rewriter, slVarName,
        StringRef(slOpText.str().c_str(), slOpText.str().length() + 1),
        parentModule);

    variables.insert({slVarName.str(), new_global_str});

    rewriter.eraseOp(op);

    return success();
  }
};
} // namespace qcor
 No newline at end of file
+99 −0
Original line number Diff line number Diff line
#pragma once

#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Verifier.h"
#include "mlir/InitAllDialects.h"
#include "quantum_to_llvm.hpp"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/IR/Module.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/ErrorOr.h"
#include "llvm/Support/MemoryBuffer.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Support/raw_ostream.h"
#include <iostream>

namespace qcor {
// declare void @__quantum__rt__qubit_release_array(%Array*)
class DeallocOpLowering : public ConversionPattern {
protected:
  // Constant string for runtime function name
  inline static const std::string qir_qubit_array_deallocate =
      "__quantum__rt__qubit_release_array";
  // Rudimentary symbol table, seen variables
  std::map<std::string, mlir::Value> &variables;

  // %Array* @__quantum__rt__qubit_allocate_array(i64 %nQubits)
public:
  // Constructor, store seen variables
  explicit DeallocOpLowering(MLIRContext *context,
                             std::map<std::string, mlir::Value> &vars)
      : ConversionPattern(mlir::quantum::DeallocOp::getOperationName(), 1,
                          context),
        variables(vars) {}

  // Match any Operation that is the QallocOp
  LogicalResult
  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                  ConversionPatternRewriter &rewriter) const override {
    // Local Declarations, get location, parentModule
    // and the context
    auto loc = op->getLoc();
    ModuleOp parentModule = op->getParentOfType<ModuleOp>();
    auto context = parentModule->getContext();

    // First step is to get a reference to the Symbol Reference for the
    // qalloc QIR runtime function, this will only declare it once and reuse
    // each time it is seen
    FlatSymbolRefAttr symbol_ref;
    if (parentModule.lookupSymbol<LLVM::LLVMFuncOp>(
            qir_qubit_array_deallocate)) {
      symbol_ref = SymbolRefAttr::get(qir_qubit_array_deallocate, context);
    } else {
      // prototype is (Array*) -> void
      auto void_type = LLVM::LLVMVoidType::get(context);
      auto array_qbit_type =
          LLVM::LLVMPointerType::get(get_quantum_type("Array", context));
      auto dealloc_ftype =
          LLVM::LLVMFunctionType::get(void_type, array_qbit_type, false);

      // Insert the function declaration
      PatternRewriter::InsertionGuard insertGuard(rewriter);
      rewriter.setInsertionPointToStart(parentModule.getBody());
      rewriter.create<LLVM::LLVMFuncOp>(
          parentModule->getLoc(), qir_qubit_array_deallocate, dealloc_ftype);
      symbol_ref =
          mlir::SymbolRefAttr::get(qir_qubit_array_deallocate, context);
    }

    // Get as a QallocOp, get its allocatino size and qreg variable name
    auto deallocOp = cast<mlir::quantum::DeallocOp>(op);
    auto qubits_value = deallocOp.qubits();
    auto qreg_name_attr = qubits_value.getDefiningOp()->getAttr("name");
    auto name = qreg_name_attr.cast<::mlir::StringAttr>().getValue();
    auto qubits = variables[name.str()];

    // create a CallOp for the new quantum runtime de-allocation
    // function.
    rewriter.create<mlir::CallOp>(loc, symbol_ref,
                                  LLVM::LLVMVoidType::get(context),
                                  ArrayRef<Value>({qubits}));

    // Remove the old QuantumDialect QallocOp
    rewriter.eraseOp(op);

    return success();
  }
};
} // namespace qcor
 No newline at end of file
+112 −0
Original line number Diff line number Diff line
#pragma once

#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Verifier.h"
#include "mlir/InitAllDialects.h"
#include "quantum_to_llvm.hpp"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/IR/Module.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/ErrorOr.h"
#include "llvm/Support/MemoryBuffer.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Support/raw_ostream.h"
#include <iostream>

namespace qcor {
// The goal of this OpConversion is to map vector.extract on a
// qalloc qubit vector to the MSFT QIR __quantum__rt__array_get_element_ptr_1d()
// call
class ExtractQubitOpConversion : public ConversionPattern {
protected:
  LLVMTypeConverter &typeConverter;
  inline static const std::string qir_get_qubit_from_array =
      "__quantum__rt__array_get_element_ptr_1d";
  std::map<std::string, mlir::Value> &vars;
  std::map<mlir::Operation *, std::string> &qubit_extract_map;

public:
  explicit ExtractQubitOpConversion(
      MLIRContext *context, LLVMTypeConverter &c,
      std::map<std::string, mlir::Value> &v,
      std::map<mlir::Operation *, std::string> &qem)
      : ConversionPattern(mlir::quantum::ExtractQubitOp::getOperationName(), 1,
                          context),
        typeConverter(c), vars(v), qubit_extract_map(qem) {}

  LogicalResult
  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                  ConversionPatternRewriter &rewriter) const override {
    // Local Declarations
    ModuleOp parentModule = op->getParentOfType<ModuleOp>();
    auto context = parentModule->getContext();
    auto location = parentModule->getLoc();

    // First goal, get symbol for
    // %0 = call i8* @__quantum__rt__array_get_element_ptr_1d(%Array* %q, i64 0)
    // %1 = bitcast i8* %0 to %Qubit**
    // %.qb = load %Qubit*, %Qubit** %1
    FlatSymbolRefAttr symbol_ref;
    if (parentModule.lookupSymbol<LLVM::LLVMFuncOp>(qir_get_qubit_from_array)) {
      symbol_ref = SymbolRefAttr::get(qir_get_qubit_from_array, context);
    } else {
      // prototype should be (int64* : qreg, int64 : element) -> int64* : qubit
      auto qubit_array_type =
          LLVM::LLVMPointerType::get(get_quantum_type("Array", context));
      auto qubit_index_type = IntegerType::get(context, 64);
      // ret is i8*
      auto qbit_element_ptr_type =
          LLVM::LLVMPointerType::get(IntegerType::get(context, 8));

      auto get_ptr_qbit_ftype = LLVM::LLVMFunctionType::get(
          qbit_element_ptr_type,
          llvm::ArrayRef<Type>{qubit_array_type, qubit_index_type}, false);

      PatternRewriter::InsertionGuard insertGuard(rewriter);
      rewriter.setInsertionPointToStart(parentModule.getBody());
      rewriter.create<LLVM::LLVMFuncOp>(location, qir_get_qubit_from_array,
                                        get_ptr_qbit_ftype);

      symbol_ref = mlir::SymbolRefAttr::get(qir_get_qubit_from_array, context);
    }

    // Create the CallOp for the get element ptr 1d function
    auto array_qbit_type =
        LLVM::LLVMPointerType::get(IntegerType::get(context, 8));

    auto get_qbit_qir_call = rewriter.create<mlir::CallOp>(
        location, symbol_ref, array_qbit_type, operands);
    // ArrayRef<Value>({vars[qreg_name], adaptor.idx()}));

    auto bitcast = rewriter.create<LLVM::BitcastOp>(
        location,
        LLVM::LLVMPointerType::get(
            LLVM::LLVMPointerType::get(get_quantum_type("Qubit", context))),
        get_qbit_qir_call.getResult(0));
    auto real_casted_qubit = rewriter.create<LLVM::LoadOp>(
        location,
        LLVM::LLVMPointerType::get(get_quantum_type("Qubit", context)),
        bitcast.res());

    rewriter.replaceOp(op, real_casted_qubit.res());
    // Remember the variable name for this qubit
    // vars.insert({qubit_var_name, real_casted_qubit.res()});

    // STORE THAT THIS OP PRODUCES THIS QREG{IDX} VARIABLE NAME
    // qubit_extract_map.insert({op, qubit_var_name});

    return success();
  }
};
} // namespace qcor
 No newline at end of file
Loading