Commit 079d4c4e authored by Mccaskey, Alex's avatar Mccaskey, Alex
Browse files

update quantum to llvm transform to support new variable rotation parameters

parent f1accb94
Loading
Loading
Loading
Loading
+38 −35
Original line number Diff line number Diff line
@@ -26,8 +26,7 @@ namespace {
using namespace mlir;
std::map<std::string, std::string> inst_map{{"cx", "cnot"}, {"measure", "mz"}};

mlir::Type get_quantum_type(std::string type,
                                      mlir::MLIRContext *context) {
mlir::Type get_quantum_type(std::string type, mlir::MLIRContext *context) {
  return LLVM::LLVMStructType::getOpaque(type, context);
}

@@ -420,6 +419,8 @@ class QuantumFuncArgConverter : public ConversionPattern {
        }
      } else if (type.isa<mlir::IntegerType>()) {
        return IntegerType::get(this->context, 32);
      } else if (type.isa<mlir::FloatType>()) {
        return FloatType::getF64(this->context);
      }
      return llvm::None;
    });
@@ -496,12 +497,18 @@ class QuantumFuncArgConverter : public ConversionPattern {
    if (ftype.getNumInputs() > 0) {
      std::vector<mlir::Type> tmp_arg_types;
      for (unsigned i = 0; i < ftype.getNumInputs(); i++) {
        auto input_type = ftype.getInput(i);
        if (input_type.isa<mlir::OpaqueType>()) {
          tmp_arg_types.push_back(
              LLVM::LLVMPointerType::get(get_quantum_type("Qubit", context)));
        } else {
          tmp_arg_types.push_back(input_type);
        }
      }

      auto new_func_signature = LLVM::LLVMFunctionType::get(
          LLVM::LLVMVoidType::get(context), llvm::makeArrayRef(tmp_arg_types), false);
      auto new_func_signature =
          LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(context),
                                      llvm::makeArrayRef(tmp_arg_types), false);

      auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(loc, funcOp.sym_name(),
                                                         new_func_signature);
@@ -576,7 +583,7 @@ class InstOpLowering : public ConversionPattern {
    // qubits the InstOp is operating on, so lets get tehm
    // as mlir::Values to be used in the creation of the CallOp for this
    // quantum runtime function
    std::vector<mlir::Value> qbit_results;
    std::vector<mlir::Value> qbit_results, param_results;
    for (auto operand : operands) {
      // The Operand points to the vector::ExtractElementOp that produces the
      // qubit Value, get that Operation
@@ -584,13 +591,17 @@ class InstOpLowering : public ConversionPattern {
      auto extract_op =
          operand.getDefiningOp<quantum::ExtractQubitOp>().getOperation();
      if (!extract_op) {
        if (operand.isa<BlockArgument>()) {
        if (operand.isa<BlockArgument>() && !operand.getType().isa<mlir::FloatType>()) {
          // only add qubit types
          qbit_results.push_back(operand);
        } else {
          std::cout << "Failure creating LLVM CallOp qubit value for instop "
                    << inst_name << "\n";
          return mlir::failure();
        } 
        // else {
        //   std::cout << "dumpy here\n";
        //   operand.dump();
        //   // std::cout << "Failure creating LLVM CallOp qubit value for instop "
        //   //           << inst_name << "\n";
        //   // return mlir::failure();
        // }
      } else {
        // Now get the corresponding qubit variable name (q_0 for q[0])
        std::string get_qbit_call_qreg_key = qubit_extract_map[extract_op];
@@ -620,16 +631,17 @@ class InstOpLowering : public ConversionPattern {
      // Create Types for all function arguments, start with
      // double parameters (if instOp has them)
      std::vector<Type> tmp_arg_types;
      if (instOp.params()) {
        auto params = instOp.params().getValue();
        for (int i = 0; i < params.size(); i++) {
      // if (instOp.params()) {
      // auto params = instOp.params().getValue();
      for (auto param : instOp.params()) {
        // // for (int i = 0; i < params.size(); i++) {
        auto param_type = FloatType::getF64(context);
        tmp_arg_types.push_back(param_type);
      }
      }
      // }

      // Now, we need a Int64Type for each qubit argument
      for (std::size_t i = 0; i < operands.size(); i++) {
      // Now, we need a QubitType for each qubit argument
      for (auto qbit : instOp.qubits()) {
        auto qubit_index_type =
            LLVM::LLVMPointerType::get(get_quantum_type("Qubit", context));
        // IntegerType::get(context, 64).getPointerTo();
@@ -651,19 +663,9 @@ class InstOpLowering : public ConversionPattern {
    // Now create the vector containing the function Values,
    // double parameters first if we have them...
    std::vector<mlir::Value> func_args;
    if (instOp.params()) {
      auto params = instOp.params().getValue();
      for (std::int64_t i = 0; i < params.size(); i++) {
        auto param_double = params.template getValue<double>(
            llvm::makeArrayRef({(std::uint64_t)i}));
        auto double_attr =
            mlir::FloatAttr::get(rewriter.getF64Type(), param_double);

        Value const_double_op = rewriter.create<LLVM::ConstantOp>(
            loc, FloatType::getF64(rewriter.getContext()), double_attr);

        func_args.push_back(const_double_op);
      }
    // if (instOp.params()) {
    for (auto param : instOp.params()) {
      func_args.push_back(param);
    }

    // Followed by qubit values
@@ -777,6 +779,7 @@ class ExtractQubitOpConversion : public ConversionPattern {
    // Create the CallOp for the get element ptr 1d function
    auto array_qbit_type =
        LLVM::LLVMPointerType::get(IntegerType::get(context, 8));

    auto get_qbit_qir_call = rewriter.create<mlir::CallOp>(
        location, symbol_ref, array_qbit_type,
        ArrayRef<Value>({vars[qreg_name], adaptor.idx()}));