Commit 6620ec46 authored by Mccaskey, Alex's avatar Mccaskey, Alex
Browse files

separating transforms into separate directory

parent 30edc38d
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -24,6 +24,6 @@ link_directories(${LLVM_BUILD_LIBRARY_DIR})

add_subdirectory(dialect)
add_subdirectory(parsers)
#add_subdirectory(transforms)
add_subdirectory(transforms)

add_subdirectory(tests)
+2 −2
Original line number Diff line number Diff line
@@ -22,8 +22,8 @@ target_compile_features(QasmTester
target_include_directories(QasmTester PRIVATE . ../dialect)

llvm_update_compile_flags(QasmTester)
target_link_libraries(QasmTester PUBLIC ${LIBS} staq-mlir-visitor)
target_link_libraries(QasmTester PUBLIC quantum-to-llvm-lowering staq-mlir-visitor )

set_target_properties(QasmTester
                        PROPERTIES INSTALL_RPATH "${CMAKE_BINARY_DIR}/mlir/parsers/openqasm")
                        PROPERTIES INSTALL_RPATH "/home/cades/.mlir/lib:${CMAKE_BINARY_DIR}/mlir/parsers/openqasm:${CMAKE_BINARY_DIR}/lib")
mlir_check_all_link_libraries(QasmTester)
+6 −322
Original line number Diff line number Diff line
@@ -37,333 +37,17 @@
#include "mlir/Target/LLVMIR.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/Passes.h"

#include "optimization/simplify.hpp"
#include "quantum_dialect.hpp"
#include "staq_parser.hpp"
#include "transformations/desugar.hpp"
#include "transformations/inline.hpp"
#include "transformations/oracle_synthesizer.hpp"

#include "quantum_to_llvm.hpp"

using namespace mlir;
using namespace staq;
std::map<std::string, std::string> inst_map{{"cx", "cnot"}, {"measure", "mz"}};

class QallocOpLowering : public ConversionPattern {
 protected:
  std::string qir_qubit_array_allocate = "__quantum__rt__qubit_allocate_array";
  std::map<std::string, mlir::Value> &variables;

 public:
  explicit QallocOpLowering(MLIRContext *context,
                            std::map<std::string, mlir::Value> &vars)
      : ConversionPattern(mlir::quantum::QallocOp::getOperationName(), 1,
                          context),
        variables(vars) {}

  LogicalResult matchAndRewrite(
      Operation *op, ArrayRef<Value> operands,
      ConversionPatternRewriter &rewriter) const override {
    auto loc = op->getLoc();

    ModuleOp parentModule = op->getParentOfType<ModuleOp>();
    auto context = parentModule->getContext();

    FlatSymbolRefAttr symbol_ref;
    if (parentModule.lookupSymbol<LLVM::LLVMFuncOp>(qir_qubit_array_allocate)) {
      symbol_ref = SymbolRefAttr::get(qir_qubit_array_allocate, context);
    } else {
      auto qubit_type = LLVM::LLVMType::getInt64Ty(context);
      auto array_qbit_type = LLVM::LLVMType::getInt64Ty(context).getPointerTo();
      auto qalloc_ftype =
          LLVM::LLVMType::getFunctionTy(array_qbit_type, qubit_type, true);

      PatternRewriter::InsertionGuard insertGuard(rewriter);
      rewriter.setInsertionPointToStart(parentModule.getBody());
      rewriter.create<LLVM::LLVMFuncOp>(parentModule->getLoc(),
                                        qir_qubit_array_allocate, qalloc_ftype);

      symbol_ref = mlir::SymbolRefAttr::get(qir_qubit_array_allocate, context);
    }
    auto qallocOp = cast<mlir::quantum::QallocOp>(op);
    auto size = qallocOp.size();
    auto qreg_name = qallocOp.name().str();

    Value create_size_int = rewriter.create<LLVM::ConstantOp>(
        loc, LLVM::LLVMType::getInt64Ty(rewriter.getContext()),
        rewriter.getIntegerAttr(rewriter.getI64Type(), size));

    auto array_qbit_type = LLVM::LLVMType::getInt64Ty(context).getPointerTo();
    auto qalloc_qir_call = rewriter.create<mlir::CallOp>(
        loc, symbol_ref, array_qbit_type, ArrayRef<Value>({create_size_int}));

    auto qbit_array = qalloc_qir_call.getResult(0);

    rewriter.eraseOp(op);

    variables.insert({qreg_name, qbit_array});

    return success();
  }
};

class InstOpLowering : public ConversionPattern {
 protected:
  std::string qir_get_qubit_from_array =
      "__quantum__rt__array_get_element_ptr_1d";
  std::map<std::string, mlir::Value> &variables;
  std::map<mlir::Operation *, std::string> &qubit_extract_map;

 public:
  explicit InstOpLowering(MLIRContext *context,
                          std::map<std::string, mlir::Value> &vars,
                          std::map<mlir::Operation *, std::string> &qem)
      : ConversionPattern(mlir::quantum::InstOp::getOperationName(), 1,
                          context),
        variables(vars),
        qubit_extract_map(qem) {}

  LogicalResult matchAndRewrite(
      Operation *op, ArrayRef<Value> operands,
      ConversionPatternRewriter &rewriter) const override {
    auto loc = op->getLoc();

    ModuleOp parentModule = op->getParentOfType<ModuleOp>();
    auto context = parentModule->getContext();

    // Now get Instruction name and the bits it operates on with qreg names
    auto instOp = cast<mlir::quantum::InstOp>(op);
    auto inst_name = instOp.name().str();
    inst_name = (inst_map.count(inst_name) ? inst_map[inst_name] : inst_name);

    std::vector<mlir::Value> qbit_results;
    for (auto operand : operands) {
      auto extract_op =
          operand.getDefiningOp<vector::ExtractElementOp>().getOperation();
      std::string get_qbit_call_qreg_key = qubit_extract_map[extract_op];
      mlir::Value qbit_result = variables[get_qbit_call_qreg_key];
      qbit_results.push_back(qbit_result);
    }

    // // Need to find the quantum instruction function
    // // Should be void __quantum__qis__INST(Qubit q) for example
    FlatSymbolRefAttr q_symbol_ref;
    std::string q_function_name =
        "__quantum__qis__" +
        (inst_map.count(inst_name) ? inst_map[inst_name] : inst_name);
    if (parentModule.lookupSymbol<LLVM::LLVMFuncOp>(q_function_name)) {
      q_symbol_ref = SymbolRefAttr::get(q_function_name, context);
    } else {
      LLVM::LLVMType ret_type = LLVM::LLVMType::getVoidTy(context);
      if (inst_name == "mz") {
        ret_type = LLVM::LLVMType::getInt64Ty(context);
      }

      std::vector<LLVM::LLVMType> tmp_arg_types;

      // FIXME loop over params too to add double types
      if (instOp.params()) {
        auto params = instOp.params().getValue();
        for (int i = 0; i < params.size(); i++) {
          auto param_type = LLVM::LLVMType::getDoubleTy(context);
          tmp_arg_types.push_back(param_type);
        }
      }

      // Need a Int64Type for each qubit argument
      for (int i = 0; i < operands.size(); i++) {
        auto qubit_index_type =
            LLVM::LLVMType::getInt64Ty(context).getPointerTo();
        tmp_arg_types.push_back(qubit_index_type);
      }

      // Create void (int, int) or void (int)
      auto get_ptr_qbit_ftype = LLVM::LLVMType::getFunctionTy(
          ret_type, llvm::makeArrayRef(tmp_arg_types), true);

      // Insert the function since it hasn't been seen yet
      PatternRewriter::InsertionGuard insertGuard(rewriter);
      rewriter.setInsertionPointToStart(parentModule.getBody());
      rewriter.create<LLVM::LLVMFuncOp>(parentModule->getLoc(), q_function_name,
                                        get_ptr_qbit_ftype);

      q_symbol_ref = mlir::SymbolRefAttr::get(q_function_name, context);
    }

    std::vector<mlir::Value> func_args;
    if (instOp.params()) {
      auto params = instOp.params().getValue();
      for (std::uint64_t i = 0; i < params.getNumElements(); i++) {
        auto param_double = params.template getValue<double>(llvm::makeArrayRef({i}));
        std::cout << "HELLO inst_name: " << inst_name << ", " << param_double
                  << "\n";
        auto double_attr =
            mlir::FloatAttr::get(rewriter.getF64Type(), param_double);

        Value const_double_op = rewriter.create<LLVM::ConstantOp>(
            loc, LLVM::LLVMType::getDoubleTy(rewriter.getContext()),
            double_attr);

        func_args.push_back(const_double_op);
      }
    }

    for (auto q : qbit_results) {
      func_args.push_back(q);
    }

    LLVM::LLVMType ret_type = LLVM::LLVMType::getVoidTy(context);
    if (inst_name == "mz") {
      ret_type = LLVM::LLVMType::getInt64Ty(context);
    }

    auto qinst_qir_call = rewriter.create<mlir::CallOp>(
        loc, q_symbol_ref, ret_type, llvm::makeArrayRef(func_args));

    // Notify the rewriter that this operation has been removed.
    rewriter.eraseOp(op);

    return success();
  }
};

class ExtractQubitOpConversion : public ConversionPattern {
 protected:
  LLVMTypeConverter &typeConverter;
  std::map<std::string, mlir::Value> &vars;
  std::map<mlir::Operation *, std::string> &qubit_extract_map;

 public:
  explicit ExtractQubitOpConversion(
      MLIRContext *context, LLVMTypeConverter &c,
      std::map<std::string, mlir::Value> &v,
      std::map<mlir::Operation *, std::string> &qem)
      : ConversionPattern(mlir::vector::ExtractElementOp::getOperationName(), 1,
                          context),
        typeConverter(c),
        vars(v),
        qubit_extract_map(qem) {}

  LogicalResult matchAndRewrite(
      Operation *op, ArrayRef<Value> operands,
      ConversionPatternRewriter &rewriter) const override {
    ModuleOp parentModule = op->getParentOfType<ModuleOp>();

    auto adaptor = vector::ExtractElementOpAdaptor(operands);

    auto vectorType = cast<vector::ExtractElementOp>(op).getVectorType();

    auto llvmType = typeConverter.convertType(vectorType.getElementType());

    // LLVM::LLVMType::getInt64Ty(context).getPointerTo();

    // Bail if result type cannot be lowered.
    if (!llvmType) {
      return failure();
    }

    mlir::Value v = operands[0];
    mlir::Value v1 = operands[1];

    auto qalloc_op = v.getDefiningOp<quantum::QallocOp>();
    auto qbit_constant_op = v1.getDefiningOp<LLVM::ConstantOp>();

    // Get info about what qreg we are extracting what qbit from
    std::string qreg_name = qalloc_op.name().str();
    mlir::Attribute unknown_attr = qbit_constant_op.value();
    auto int_attr = unknown_attr.cast<mlir::IntegerAttr>();
    auto int_value = int_attr.getInt();
    auto qubit_var_name = qreg_name + "_" + std::to_string(int_value);

    // Erase the old op
    rewriter.eraseOp(op);

    // Reuse the qubit if we've allocated it before.
    if (vars.count(qubit_var_name)) {
      qubit_extract_map.insert(
          {op, qreg_name + "_" + std::to_string(int_value)});
      return success();
    }

    auto context = parentModule->getContext();
    std::string qir_get_qubit_from_array =
        "__quantum__rt__array_get_element_ptr_1d";
    // First goal, get symbol for __quantum__rt__array_get_element_ptr_1d
    // function
    FlatSymbolRefAttr symbol_ref;
    if (parentModule.lookupSymbol<LLVM::LLVMFuncOp>(qir_get_qubit_from_array)) {
      symbol_ref = SymbolRefAttr::get(qir_get_qubit_from_array, context);
    } else {
      auto qubit_array_type =
          LLVM::LLVMType::getInt64Ty(context).getPointerTo();
      auto qubit_index_type = LLVM::LLVMType::getInt64Ty(context);

      auto qbit_element_ptr_type =
          LLVM::LLVMType::getInt64Ty(context).getPointerTo();
      auto get_ptr_qbit_ftype = LLVM::LLVMType::getFunctionTy(
          qbit_element_ptr_type,
          llvm::ArrayRef<LLVM::LLVMType>{qubit_array_type, qubit_index_type},
          true);

      PatternRewriter::InsertionGuard insertGuard(rewriter);
      rewriter.setInsertionPointToStart(parentModule.getBody());
      rewriter.create<LLVM::LLVMFuncOp>(
          parentModule->getLoc(), qir_get_qubit_from_array, get_ptr_qbit_ftype);

      symbol_ref = mlir::SymbolRefAttr::get(qir_get_qubit_from_array, context);
    }

    // Create the CallOp for the get element ptr 1d function
    auto array_qbit_type = LLVM::LLVMType::getInt64Ty(context).getPointerTo();
    auto get_qbit_qir_call = rewriter.create<mlir::CallOp>(
        parentModule->getLoc(), symbol_ref, array_qbit_type,
        ArrayRef<Value>({vars[qreg_name], adaptor.position()}));

    // Remember the variable name for this qubit
    vars.insert({qreg_name + "_" + std::to_string(int_value),
                 get_qbit_qir_call.getResult(0)});

    // STORE THAT THIS OP PRODUCES THIS QREG{IDX} VARIABLE NAME
    qubit_extract_map.insert({op, qreg_name + "_" + std::to_string(int_value)});

    return success();
  }
};

struct QuantumToLLVMLoweringPass
    : public PassWrapper<QuantumToLLVMLoweringPass, OperationPass<ModuleOp>> {
  void getDependentDialects(DialectRegistry &registry) const override {
    registry.insert<LLVM::LLVMDialect>();
  }
  void runOnOperation() final;

 public:
  QuantumToLLVMLoweringPass() = default;
};

void QuantumToLLVMLoweringPass::runOnOperation() {
  LLVMConversionTarget target(getContext());
  target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
  LLVMTypeConverter typeConverter(&getContext());

  OwningRewritePatternList patterns;
  populateStdToLLVMConversionPatterns(typeConverter, patterns);

  // Common variables to share across converteres
  std::map<std::string, mlir::Value> variables;
  std::map<mlir::Operation *, std::string> qubit_extract_map;

  // Add our custom conversion passes
  patterns.insert<QallocOpLowering>(&getContext(), variables);
  patterns.insert<InstOpLowering>(&getContext(), variables, qubit_extract_map);
  patterns.insert<ExtractQubitOpConversion>(&getContext(), typeConverter,
                                            variables, qubit_extract_map);

  // We want to completely lower to LLVM, so we use a `FullConversion`. This
  // ensures that only legal operations will remain after the conversion.
  auto module = getOperation();
  if (failed(applyFullConversion(module, target, std::move(patterns))))
    signalPassFailure();
}

int main(int argc, char **argv) {
  llvm::cl::ParseCommandLineOptions(argc, argv, "toy compiler\n");
@@ -386,7 +70,7 @@ measure q -> c;
  try {
    prog = parser::parse_string(lineText);
    transformations::desugar(*prog);
    transformations::synthesize_oracles(*prog);
    // transformations::synthesize_oracles(*prog);
  } catch (std::exception &e) {
    std::stringstream ss;
    std::cout << e.what() << "\n";
@@ -407,10 +91,10 @@ measure q -> c;

  // Create the PassManager for lowering to LLVM MLIR and run it
  mlir::PassManager pm(&context);
  pm.addPass(std::make_unique<QuantumToLLVMLoweringPass>());
  pm.addPass(std::make_unique<qcor::QuantumToLLVMLoweringPass>());
  auto module = visitor.module();
  auto module_op = module.getOperation();
  pm.run(module_op);
  auto result = pm.run(module_op);
  std::cout << "Lowered to LLVM MLIR Dialect:\n";
  module_op->dump();

+1 −0
Original line number Diff line number Diff line
@@ -11,6 +11,7 @@ set(LIBS
        quantum-dialect
        MLIROptLib
        MLIRTargetLLVMIR
        MLIRExecutionEngine
        )

add_mlir_library(${LIBRARY_NAME} SHARED ${SRC} LINK_LIBS PUBLIC ${LIBS})
+205 −163

File changed.

Preview size limit exceeded, changes collapsed.

Loading