Commit a26ef774 authored by Mccaskey, Alex's avatar Mccaskey, Alex
Browse files

adding qubit deallocation to mlir, transforms, and qrt api

parent 6cc513c2
Loading
Loading
Loading
Loading
+0 −1
Original line number Diff line number Diff line
@@ -29,4 +29,3 @@ add_subdirectory(parsers)
add_subdirectory(transforms)
add_subdirectory(qir_qrt)
add_subdirectory(tools)
add_subdirectory(tests)
+217 −1
Original line number Diff line number Diff line
@@ -22,7 +22,7 @@ bool isOpaqueTypeWithName(mlir::Type type, std::string dialect, std::string type
}
QuantumDialect::QuantumDialect(mlir::MLIRContext *ctx)
    : mlir::Dialect(getDialectNamespace(), ctx, TypeID::get<QuantumDialect>()) {
  addOperations<InstOp, QallocOp, ExtractQubitOp>();
  addOperations<InstOp, QallocOp, ExtractQubitOp, DeallocOp, QRTInitOp>();
}
ExtractQubitOpAdaptor::ExtractQubitOpAdaptor(::mlir::ValueRange values, ::mlir::DictionaryAttr attrs)  : odsOperands(values), odsAttrs(attrs) {

@@ -393,6 +393,125 @@ void InstOp::build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir
  return ::mlir::success();
}

QRTInitOpAdaptor::QRTInitOpAdaptor(::mlir::ValueRange values, ::mlir::DictionaryAttr attrs)  : odsOperands(values), odsAttrs(attrs) {

}

QRTInitOpAdaptor::QRTInitOpAdaptor(QRTInitOp&op)  : odsOperands(op->getOperands()), odsAttrs(op->getAttrDictionary()) {

}

std::pair<unsigned, unsigned> QRTInitOpAdaptor::getODSOperandIndexAndLength(unsigned index) {
  return {index, 1};
}

::mlir::ValueRange QRTInitOpAdaptor::getODSOperands(unsigned index) {
  auto valueRange = getODSOperandIndexAndLength(index);
  return {std::next(odsOperands.begin(), valueRange.first),
           std::next(odsOperands.begin(), valueRange.first + valueRange.second)};
}

::mlir::Value QRTInitOpAdaptor::argc() {
  return *getODSOperands(0).begin();
}

::mlir::Value QRTInitOpAdaptor::argv() {
  return *getODSOperands(1).begin();
}

::mlir::LogicalResult QRTInitOpAdaptor::verify(::mlir::Location loc) {
  return ::mlir::success();
}

::llvm::StringRef QRTInitOp::getOperationName() {
  return "quantum.init";
}

std::pair<unsigned, unsigned> QRTInitOp::getODSOperandIndexAndLength(unsigned index) {
  return {index, 1};
}

::mlir::Operation::operand_range QRTInitOp::getODSOperands(unsigned index) {
  auto valueRange = getODSOperandIndexAndLength(index);
  return {std::next(getOperation()->operand_begin(), valueRange.first),
           std::next(getOperation()->operand_begin(), valueRange.first + valueRange.second)};
}

::mlir::Value QRTInitOp::argc() {
  return *getODSOperands(0).begin();
}

::mlir::Value QRTInitOp::argv() {
  return *getODSOperands(1).begin();
}

::mlir::MutableOperandRange QRTInitOp::argcMutable() {
  auto range = getODSOperandIndexAndLength(0);
  return ::mlir::MutableOperandRange(getOperation(), range.first, range.second);
}

::mlir::MutableOperandRange QRTInitOp::argvMutable() {
  auto range = getODSOperandIndexAndLength(1);
  return ::mlir::MutableOperandRange(getOperation(), range.first, range.second);
}

std::pair<unsigned, unsigned> QRTInitOp::getODSResultIndexAndLength(unsigned index) {
  return {index, 1};
}

::mlir::Operation::result_range QRTInitOp::getODSResults(unsigned index) {
  auto valueRange = getODSResultIndexAndLength(index);
  return {std::next(getOperation()->result_begin(), valueRange.first),
           std::next(getOperation()->result_begin(), valueRange.first + valueRange.second)};
}

void QRTInitOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value argc, ::mlir::Value argv) {
  odsState.addOperands(argc);
  odsState.addOperands(argv);
}

void QRTInitOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value argc, ::mlir::Value argv) {
  odsState.addOperands(argc);
  odsState.addOperands(argv);
  assert(resultTypes.size() == 0u && "mismatched number of results");
  odsState.addTypes(resultTypes);
}

void QRTInitOp::build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) {
  assert(operands.size() == 2u && "mismatched number of parameters");
  odsState.addOperands(operands);
  odsState.addAttributes(attributes);
  assert(resultTypes.size() == 0u && "mismatched number of return types");
  odsState.addTypes(resultTypes);
}

::mlir::LogicalResult QRTInitOp::verify() {
  if (failed(QRTInitOpAdaptor(*this).verify(this->getLoc()))) return ::mlir::failure();
  {
    unsigned index = 0; (void)index;
    auto valueGroup0 = getODSOperands(0);
    for (::mlir::Value v : valueGroup0) {
      (void)v;
      if (!((v.getType().isInteger(32)))) {
        return emitOpError("operand #") << index << " must be 32-bit integer, but got " << v.getType();
      }
      ++index;
    }
    auto valueGroup1 = getODSOperands(1);
    for (::mlir::Value v : valueGroup1) {
      (void)v;
      if (!((isOpaqueTypeWithName(v.getType(), "quantum", "ArgvType")))) {
        return emitOpError("operand #") << index << " must be opaque argv type, but got " << v.getType();
      }
      ++index;
    }
  }
  {
    unsigned index = 0; (void)index;
  }
  return ::mlir::success();
}

} // namespace quantum
} // namespace mlir
namespace mlir {
@@ -553,6 +672,103 @@ void QallocOp::build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::ml
  return ::mlir::success();
}

DeallocOpAdaptor::DeallocOpAdaptor(::mlir::ValueRange values, ::mlir::DictionaryAttr attrs)  : odsOperands(values
), odsAttrs(attrs) {

}

DeallocOpAdaptor::DeallocOpAdaptor(DeallocOp&op)  : odsOperands(op->getOperands()), odsAttrs(op->getAttrDictionary()) {

}

std::pair<unsigned, unsigned> DeallocOpAdaptor::getODSOperandIndexAndLength(unsigned index) {
  return {index, 1};
}

::mlir::ValueRange DeallocOpAdaptor::getODSOperands(unsigned index) {
  auto valueRange = getODSOperandIndexAndLength(index);
  return {std::next(odsOperands.begin(), valueRange.first),
           std::next(odsOperands.begin(), valueRange.first + valueRange.second)};
}

::mlir::Value DeallocOpAdaptor::qubits() {
  return *getODSOperands(0).begin();
}

::mlir::LogicalResult DeallocOpAdaptor::verify(::mlir::Location loc) {
  return ::mlir::success();
}

::llvm::StringRef DeallocOp::getOperationName() {
  return "quantum.dealloc";
}

std::pair<unsigned, unsigned> DeallocOp::getODSOperandIndexAndLength(unsigned index) {
  return {index, 1};
}

::mlir::Operation::operand_range DeallocOp::getODSOperands(unsigned index) {
  auto valueRange = getODSOperandIndexAndLength(index);
  return {std::next(getOperation()->operand_begin(), valueRange.first),
           std::next(getOperation()->operand_begin(), valueRange.first + valueRange.second)};
}

::mlir::Value DeallocOp::qubits() {
  return *getODSOperands(0).begin();
}

::mlir::MutableOperandRange DeallocOp::qubitsMutable() {
  auto range = getODSOperandIndexAndLength(0);
  return ::mlir::MutableOperandRange(getOperation(), range.first, range.second);
}

std::pair<unsigned, unsigned> DeallocOp::getODSResultIndexAndLength(unsigned index) {
  return {index, 1};
}

::mlir::Operation::result_range DeallocOp::getODSResults(unsigned index) {
  auto valueRange = getODSResultIndexAndLength(index);
  return {std::next(getOperation()->result_begin(), valueRange.first),
           std::next(getOperation()->result_begin(), valueRange.first + valueRange.second)};
}

void DeallocOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value qubits) {
  odsState.addOperands(qubits);
}

void DeallocOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value qubits) {
  odsState.addOperands(qubits);
  assert(resultTypes.size() == 0u && "mismatched number of results");
  odsState.addTypes(resultTypes);
}

void DeallocOp::build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) {
  assert(operands.size() == 1u && "mismatched number of parameters");
  odsState.addOperands(operands);
  odsState.addAttributes(attributes);
  assert(resultTypes.size() == 0u && "mismatched number of return types");
  odsState.addTypes(resultTypes);
}

::mlir::LogicalResult DeallocOp::verify() {
  if (failed(DeallocOpAdaptor(*this).verify(this->getLoc()))) return ::mlir::failure();
  {
    unsigned index = 0; (void)index;
    auto valueGroup0 = getODSOperands(0);
    for (::mlir::Value v : valueGroup0) {
      (void)v;
      if (!((isOpaqueTypeWithName(v.getType(), "quantum", "Array")))) {
        return emitOpError("operand #") << index << " must be opaque array type, but got " << v.getType();
      }
      ++index;
    }
  }
  {
    unsigned index = 0; (void)index;
  }
  return ::mlir::success();
}

static mlir::ParseResult parseQallocOp(mlir::OpAsmParser &parser,
                                       mlir::OperationState &result) {
  // SmallVector<mlir::OpAsmParser::OperandType, 2> operands;
+64 −87
Original line number Diff line number Diff line
@@ -8,99 +8,14 @@ namespace mlir {
namespace quantum {
class InstOp;
class QallocOp;
class QInstOp;
class DeallocOp;
class ExtractQubitOp;
class QRTInitOp;
}  // namespace quantum
}  // namespace mlir

namespace mlir {
namespace quantum {
// struct QubitTypeStorage : public TypeStorage {
//   QubitTypeStorage(std::int64_t _qubit_idx, std::string _enclosed_register)
//       : qubit_idx(_qubit_idx), enclosed_register(_enclosed_register) {}

//   /// The hash key for this storage is a pair of the integer and type params.
//   using KeyTy = std::pair<std::int64_t, std::string>;

//   /// Define the comparison function for the key type.
//   bool operator==(const KeyTy &key) const {
//     return key == KeyTy(qubit_idx, enclosed_register);
//   }

//   /// Define a hash function for the key type.
//   /// Note: This isn't necessary because std::pair, unsigned, and Type all have
//   /// hash functions already available.
//   static llvm::hash_code hashKey(const KeyTy &key) {
//     return llvm::hash_combine(key.first, key.second);
//   }

//   /// Define a construction function for the key type.
//   /// Note: This isn't necessary because KeyTy can be directly constructed with
//   /// the given parameters.
//   static KeyTy getKey(std::int64_t _qubit_idx, std::string enc_reg) {
//     return KeyTy(_qubit_idx, enc_reg);
//   }

//   /// Define a construction method for creating a new instance of this storage.
//   static QubitTypeStorage *construct(TypeStorageAllocator &allocator,
//                                      const KeyTy &key) {
//     return new (allocator.allocate<QubitTypeStorage>())
//         QubitTypeStorage(key.first, key.second);
//   }

//   /// The parametric data held by the storage class.
//   std::int64_t qubit_idx;
//   std::string enclosed_register;
// };

// class QubitType : public Type::TypeBase<QubitType, Type, QubitTypeStorage> {
//  public:
//   /// Inherit some necessary constructors from 'TypeBase'.
//   using Base::Base;

//   /// This method is used to get an instance of the 'ComplexType'. This method
//   /// asserts that all of the construction invariants were satisfied. To
//   /// gracefully handle failed construction, getChecked should be used instead.
//   static QubitType get(mlir::MLIRContext* ctx, int64_t qbit, std::string enc_reg) {
//     // Call into a helper 'get' method in 'TypeBase' to get a uniqued instance
//     // of this type. All parameters to the storage class are passed after the
//     // context.
//     return Base::get(ctx, qbit, enc_reg);
//   }

//   /// This method is used to get an instance of the 'ComplexType', defined at
//   /// the given location. If any of the construction invariants are invalid,
//   /// errors are emitted with the provided location and a null type is returned.
//   /// Note: This method is completely optional.
//   static QubitType getChecked(std::int64_t qbit, std::string enc_reg, Location location) {
//     // Call into a helper 'getChecked' method in 'TypeBase' to get a uniqued
//     // instance of this type. All parameters to the storage class are passed
//     // after the location.
//     return Base::getChecked(location, qbit, enc_reg);
//   }

//   /// This method is used to verify the construction invariants passed into the
//   /// 'get' and 'getChecked' methods. Note: This method is completely optional.
//   static LogicalResult verifyConstructionInvariants(Location loc,
//                                                     std::int64_t qbit, std::string enc_reg) {
//     // Our type only allows non-zero parameters.
//     if (qbit < 0)
//       return emitError(loc) << "non-zero parameter passed to 'QubitType'";
//     return success();
//   }

//   /// Return the parameter value.
//   std::int64_t getQubitIndex() {
//     // 'getImpl' returns a pointer to our internal storage instance.
//     return getImpl()->qubit_idx;
//   }

//   /// Return the integer parameter type.
//   std::string getEnclosedRegister() {
//     // 'getImpl' returns a pointer to our internal storage instance.
//     return getImpl()->enclosed_register;
//   }
// };

class QuantumDialect : public mlir::Dialect {
 public:
@@ -199,7 +114,39 @@ namespace quantum {
//===----------------------------------------------------------------------===//
// ::mlir::quantum::QallocOp declarations
//===----------------------------------------------------------------------===//
class QRTInitOpAdaptor {
public:
  QRTInitOpAdaptor(::mlir::ValueRange values, ::mlir::DictionaryAttr attrs = nullptr);
  QRTInitOpAdaptor(QRTInitOp&op);
  std::pair<unsigned, unsigned> getODSOperandIndexAndLength(unsigned index);
  ::mlir::ValueRange getODSOperands(unsigned index);
  ::mlir::Value argc();
  ::mlir::Value argv();
  ::mlir::LogicalResult verify(::mlir::Location loc);

private:
  ::mlir::ValueRange odsOperands;
  ::mlir::DictionaryAttr odsAttrs;
};
class QRTInitOp : public ::mlir::Op<QRTInitOp, ::mlir::OpTrait::ZeroRegion, ::mlir::OpTrait::ZeroResult, ::mlir::OpTrait::ZeroSuccessor, ::mlir::OpTrait::NOperands<2>::Impl> {
public:
  using Op::Op;
  using Op::print;
  using Adaptor = QRTInitOpAdaptor;
  static ::llvm::StringRef getOperationName();
  std::pair<unsigned, unsigned> getODSOperandIndexAndLength(unsigned index);
  ::mlir::Operation::operand_range getODSOperands(unsigned index);
  ::mlir::Value argc();
  ::mlir::Value argv();
  ::mlir::MutableOperandRange argcMutable();
  ::mlir::MutableOperandRange argvMutable();
  std::pair<unsigned, unsigned> getODSResultIndexAndLength(unsigned index);
  ::mlir::Operation::result_range getODSResults(unsigned index);
  static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value argc, ::mlir::Value argv);
  static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value argc, ::mlir::Value argv);
  static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
  ::mlir::LogicalResult verify();
};
class QallocOpAdaptor {
public:
  QallocOpAdaptor(::mlir::ValueRange values, ::mlir::DictionaryAttr attrs = nullptr);
@@ -239,6 +186,36 @@ public:
  ::mlir::LogicalResult verify();
  static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result);
};
class DeallocOpAdaptor {
public:
  DeallocOpAdaptor(::mlir::ValueRange values, ::mlir::DictionaryAttr attrs = nullptr);
  DeallocOpAdaptor(DeallocOp&op);
  std::pair<unsigned, unsigned> getODSOperandIndexAndLength(unsigned index);
  ::mlir::ValueRange getODSOperands(unsigned index);
  ::mlir::Value qubits();
  ::mlir::LogicalResult verify(::mlir::Location loc);

private:
  ::mlir::ValueRange odsOperands;
  ::mlir::DictionaryAttr odsAttrs;
};
class DeallocOp : public ::mlir::Op<DeallocOp, ::mlir::OpTrait::ZeroRegion, ::mlir::OpTrait::ZeroResult, ::mlir::OpTrait::ZeroSuccessor, ::mlir::OpTrait::OneOperand> {
public:
  using Op::Op;
  using Op::print;
  using Adaptor = DeallocOpAdaptor;
  static ::llvm::StringRef getOperationName();
  std::pair<unsigned, unsigned> getODSOperandIndexAndLength(unsigned index);
  ::mlir::Operation::operand_range getODSOperands(unsigned index);
  ::mlir::Value qubits();
  ::mlir::MutableOperandRange qubitsMutable();
  std::pair<unsigned, unsigned> getODSResultIndexAndLength(unsigned index);
  ::mlir::Operation::result_range getODSResults(unsigned index);
  static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value qubits);
  static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value qubits);
  static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
  ::mlir::LogicalResult verify();
};
}  // namespace quantum
}  // namespace mlir

+43 −0
Original line number Diff line number Diff line
#pragma once
#include <string>

#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Verifier.h"

namespace qcor {
class QuantumMLIRGenerator {
 protected:
  mlir::MLIRContext& context;
  mlir::ModuleOp m_module;
  mlir::OpBuilder builder;
  mlir::Block* main_entry_block;

 public:
  QuantumMLIRGenerator(mlir::MLIRContext& ctx) : context(ctx), builder(&ctx) {}

  // This method can be implemented by subclasses to
  // introduce any initialization steps required for constructing
  // mlir using the quantum dialect. This may also be used for
  // introducing any initialization operations before
  // generation of the rest of the mlir code.
  virtual void initialize_mlirgen() = 0;

  // This method can be implemented by subclasses to map a
  // quantum code in a subclass-specific source language to
  // the internal generated MLIR ModuleOp instance
  virtual void mlirgen(const std::string& src) = 0;

  // Return the generated ModuleOp
  mlir::OwningModuleRef get_module() {
    return mlir::OwningModuleRef(mlir::OwningOpRef<mlir::ModuleOp>(m_module));
  }

  // Finalize method, override to provide any end operations
  // to the module (like a return_op).
  virtual void finalize_mlirgen() = 0;
};
}  // namespace qcor
 No newline at end of file
+2 −2
Original line number Diff line number Diff line

set(LIBRARY_NAME staq-mlir-visitor)
set(LIBRARY_NAME openqasm-mlir-generator)

file(GLOB SRC *.cpp generated/*.cpp)

@@ -10,6 +10,6 @@ target_compile_features(${LIBRARY_NAME}
target_compile_options(${LIBRARY_NAME} PUBLIC "-Wno-attributes -Wno-suggest-override")
target_include_directories(
  ${LIBRARY_NAME}
  PUBLIC . ../../dialect ${XACC_ROOT}/include/staq/include ${XACC_ROOT}/include/staq/libs)
  PUBLIC . .. ../../dialect ${XACC_ROOT}/include/staq/include ${XACC_ROOT}/include/staq/libs)

target_link_libraries(${LIBRARY_NAME} PUBLIC quantum-dialect MLIRVector MLIRStandard)
Loading