Commit 83b19b0c authored by Mccaskey, Alex's avatar Mccaskey, Alex
Browse files

fully remove call op lowering and func arg converter, using custom llvm type converter now

parent 3cdb1636
Loading
Loading
Loading
Loading
Loading
+0 −248
Original line number Diff line number Diff line
@@ -238,18 +238,10 @@ class QRTInitOpLowering : public ConversionPattern {
      symbol_ref = mlir::SymbolRefAttr::get(qir_qrt_initialize, context);
    }

    // auto initOp = cast<mlir::quantum::QRTInitOp>(op);
    // auto parentFunc = op->getParentOfType<LLVM::LLVMFuncOp>();
    // parentFunc.dump();
    // auto args = parentFunc.body().getArguments();
    // std::cout << "HERE:\n";
    // args[0].dump();
    // args[1].dump();
    // create a CallOp for the new quantum runtime initialize
    // function.
    rewriter.create<mlir::CallOp>(
        loc, symbol_ref, IntegerType::get(context, 32), operands);
        // ArrayRef<Value>({variables["main_argc"], variables["main_argv"]}));

    // Remove the old QuantumDialect QallocOp
    rewriter.eraseOp(op);
@@ -305,13 +297,6 @@ class QRTFinalizeOpLowering : public ConversionPattern {
      symbol_ref = mlir::SymbolRefAttr::get(qir_qrt_finalize, context);
    }

    // auto initOp = cast<mlir::quantum::QRTInitOp>(op);
    // auto parentFunc = op->getParentOfType<LLVM::LLVMFuncOp>();
    // parentFunc.dump();
    // auto args = parentFunc.body().getArguments();
    // std::cout << "HERE:\n";
    // args[0].dump();
    // args[1].dump();
    // create a CallOp for the new quantum runtime initialize
    // function.
    rewriter.create<mlir::CallOp>(
@@ -373,18 +358,10 @@ class SetQregOpLowering : public ConversionPattern {
      symbol_ref = mlir::SymbolRefAttr::get(qir_qrt_finalize, context);
    }

    // auto initOp = cast<mlir::quantum::QRTInitOp>(op);
    // auto parentFunc = op->getParentOfType<LLVM::LLVMFuncOp>();
    // parentFunc.dump();
    // auto args = parentFunc.body().getArguments();
    // std::cout << "HERE:\n";
    // args[0].dump();
    // args[1].dump();
    // create a CallOp for the new quantum runtime initialize
    // function.
    rewriter.create<mlir::CallOp>(
        loc, symbol_ref, LLVM::LLVMVoidType::get(context), operands);
        // ArrayRef<Value>({variables["_incoming_qreg_variable"]}));

    // Remove the old QuantumDialect QallocOp
    rewriter.eraseOp(op);
@@ -393,227 +370,6 @@ class SetQregOpLowering : public ConversionPattern {
  }
};

class QuantumStdCallArgConverter : public ConversionPattern {
 protected:
  MLIRContext *context;
  std::vector<std::string> &module_function_names;

 public:
  explicit QuantumStdCallArgConverter(MLIRContext *ctx,
                                      std::vector<std::string> &f_names)
      : ConversionPattern(mlir::CallOp::getOperationName(), 1, ctx),
        context(ctx),
        module_function_names(f_names) {}
  LogicalResult matchAndRewrite(
      Operation *op, ArrayRef<Value> operands,
      ConversionPatternRewriter &rewriter) const override {
    auto loc = op->getLoc();
    ModuleOp parentModule = op->getParentOfType<ModuleOp>();

    auto callOp = cast<mlir::CallOp>(op);
    auto name = callOp.callee().str();
    if (std::find(module_function_names.begin(), module_function_names.end(),
                  callOp.callee().str()) != std::end(module_function_names) &&
        callOp.getNumOperands() > 0) {
      auto res = callOp.getResultTypes()[0];

      std::vector<mlir::Type> tmp_arg_types;
      for (unsigned i = 0; i < callOp.getNumOperands(); i++) {
        auto input_type = callOp.getOperand(i).getType();
        if (input_type.isa<mlir::OpaqueType>()) {
          auto casted = input_type.cast<mlir::OpaqueType>();
          mlir::Type t;
          if (casted.getTypeData() == "Qubit") {
            t = LLVM::LLVMPointerType::get(
                get_quantum_type("Qubit", this->context));
          } else if (casted.getTypeData() == "Array") {
            t = LLVM::LLVMPointerType::get(
                get_quantum_type("Array", this->context));
          }
          tmp_arg_types.push_back(t);
        } else {
          mlir::LLVMTypeConverter converter(rewriter.getContext());
          if (auto mem_type = input_type.dyn_cast_or_null<mlir::MemRefType>()) {
            input_type = converter.convertType(mem_type);
          }
          tmp_arg_types.push_back(input_type);
        }
      }
      auto proto = mlir::FuncOp::create(
          loc, name,
          rewriter.getFunctionType(llvm::makeArrayRef(tmp_arg_types), res));
      mlir::FuncOp func(proto);

      rewriter.replaceOpWithNewOp<mlir::CallOp>(callOp, func, operands);

      return success();
    }
    return failure();
  }
};

class QuantumFuncArgConverter : public ConversionPattern {
 protected:
  std::unique_ptr<mlir::TypeConverter> my_tc;
  MLIRContext *context;
  std::map<std::string, mlir::Value> &variables;

 public:
  explicit QuantumFuncArgConverter(MLIRContext *ctx,
                                   std::map<std::string, mlir::Value> &vars)
      : ConversionPattern(mlir::FuncOp::getOperationName(), 1, ctx),
        context(ctx),
        variables(vars) {
    my_tc = std::make_unique<mlir::TypeConverter>();
    my_tc->addConversion([this](mlir::Type type) -> mlir::Optional<mlir::Type> {
      if (type.isa<mlir::OpaqueType>()) {
        auto casted = type.cast<mlir::OpaqueType>();
        if (casted.getTypeData() == "Qubit") {
          return LLVM::LLVMPointerType::get(
              get_quantum_type("Qubit", this->context));
        } else if (casted.getTypeData() == "ArgvType") {
          return LLVM::LLVMPointerType::get(
              LLVM::LLVMPointerType::get(IntegerType::get(context, 8)));
        } else if (casted.getTypeData() == "qreg") {
          return LLVM::LLVMPointerType::get(
              get_quantum_type("qreg", this->context));
        } else if (casted.getTypeData() == "Array") {
          return LLVM::LLVMPointerType::get(
              get_quantum_type("Array", this->context));
        }
      } else if (type.isa<mlir::IntegerType>()) {
        return IntegerType::get(this->context, 32);
      } else if (type.isa<mlir::FloatType>()) {
        return FloatType::getF64(this->context);
      }
      std::cout << "HELLO WORLD HERE:\n";
      type.dump();
      return llvm::None;
    });
    typeConverter = my_tc.get();
  }
  LogicalResult matchAndRewrite(
      Operation *op, ArrayRef<Value> operands,
      ConversionPatternRewriter &rewriter) const override {
    auto loc = op->getLoc();
    ModuleOp parentModule = op->getParentOfType<ModuleOp>();
    auto context = parentModule->getContext();
    auto funcOp = cast<mlir::FuncOp>(op);
    auto ftype = funcOp.type().cast<FunctionType>();

    auto func_name = funcOp.getName().str();

    if (func_name == "main") {
      auto charstarstar = LLVM::LLVMPointerType::get(
          LLVM::LLVMPointerType::get(IntegerType::get(context, 8)));
      std::vector<Type> tmp_arg_types{IntegerType::get(context, 32),
                                      charstarstar};

      auto new_main_signature =
          LLVM::LLVMFunctionType::get(IntegerType::get(context, 32),
                                      llvm::makeArrayRef(tmp_arg_types), false);

      auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(loc, funcOp.sym_name(),
                                                         new_main_signature);
      rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
                                  newFuncOp.end());
      if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(),
                                             *typeConverter))) {
        return failure();
      }
      // rewriter.convertRegionTypes(&newFuncOp.getBody(), *typeConverter);

      auto args = newFuncOp.body().getArguments();
      variables.insert({"main_argc", args[0]});
      variables.insert({"main_argv", args[1]});

      rewriter.eraseOp(op);
      return success();
    }

    if (ftype.getNumInputs() == 1 &&
        ftype.getInput(0).isa<mlir::OpaqueType>() &&
        ftype.getInput(0).cast<mlir::OpaqueType>().getTypeData() == "qreg") {
      std::vector<mlir::Type> tmp_arg_types{
          LLVM::LLVMPointerType::get(get_quantum_type("qreg", context))};

      auto new_func_signature =
          LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(context),
                                      llvm::makeArrayRef(tmp_arg_types), false);

      auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(loc, funcOp.sym_name(),
                                                         new_func_signature);

      rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
                                  newFuncOp.end());

      if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(),
                                             *typeConverter))) {
        return failure();
      }

      rewriter.eraseOp(op);
      auto arg = newFuncOp.body().getArguments()[0];

      variables.insert({"_incoming_qreg_variable", arg});
      return success();
    }

    // Not main, sub quantum kernel, convert Qubit to Qubit*
    if (ftype.getNumInputs() > 0) {
      std::vector<mlir::Type> tmp_arg_types;
      for (unsigned i = 0; i < ftype.getNumInputs(); i++) {
        auto input_type = ftype.getInput(i);
        if (input_type.isa<mlir::OpaqueType>()) {
          auto casted = input_type.cast<mlir::OpaqueType>();
          mlir::Type t;
          if (casted.getTypeData() == "Qubit") {
            t = LLVM::LLVMPointerType::get(
                get_quantum_type("Qubit", this->context));
          } else if (casted.getTypeData() == "Array") {
            t = LLVM::LLVMPointerType::get(
                get_quantum_type("Array", this->context));
          }
          tmp_arg_types.push_back(t);
        } else {
          mlir::LLVMTypeConverter converter(rewriter.getContext());
          if (auto mem_type = input_type.dyn_cast_or_null<mlir::MemRefType>()) {
            input_type = converter.convertType(mem_type);
          }
          tmp_arg_types.push_back(input_type);
        }
      }

      mlir::Type res = LLVM::LLVMVoidType::get(context);
      if (ftype.getNumResults()) {
        res = ftype.getResult(0);
        mlir::LLVMTypeConverter converter(rewriter.getContext());
        if (auto mem_type = res.dyn_cast_or_null<mlir::MemRefType>()) {
          res = converter.convertType(mem_type);
        }
      }

      auto new_func_signature = LLVM::LLVMFunctionType::get(
          res, llvm::makeArrayRef(tmp_arg_types), false);

      auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(loc, funcOp.sym_name(),
                                                         new_func_signature);

      rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
                                  newFuncOp.end());

      if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(),
                                             *typeConverter))) {
        return failure();
      }
      rewriter.eraseOp(op);
      return success();
    }

    return failure();
  }
};

// The goal of InstOpLowering is to convert all QuantumDialect
// InstOp (quantum.inst) to the corresponding __quantum__qis__INST(int64*, ...)
// call
@@ -1126,7 +882,6 @@ void QuantumToLLVMLoweringPass::runOnOperation() {
  QuantumLLVMTypeConverter typeConverter(&getContext());

  OwningRewritePatternList patterns;
  // patterns.insert<QuantumStdCallArgConverter>(&getContext(), function_names);
  patterns.insert<StdAtanOpLowering>(&getContext());

  populateStdToLLVMConversionPatterns(typeConverter, patterns);
@@ -1135,9 +890,6 @@ void QuantumToLLVMLoweringPass::runOnOperation() {
  std::map<std::string, mlir::Value> variables;
  std::map<mlir::Operation *, std::string> qubit_extract_map;

  // Add our custom conversion passes
  // patterns.insert<QuantumFuncArgConverter>(&getContext(), variables);

  patterns.insert<CreateStringLiteralOpLowering>(&getContext(), variables);
  patterns.insert<PrintOpLowering>(&getContext(), variables);