Commit 30edc38d authored by Mccaskey, Alex's avatar Mccaskey, Alex
Browse files

adding parameterized gates, lowering down to llvm complete

parent 08fa0e55
Loading
Loading
Loading
Loading
+133 −160
Original line number Diff line number Diff line
@@ -13,16 +13,30 @@ QuantumDialect::QuantumDialect(mlir::MLIRContext *ctx)
    : mlir::Dialect(getDialectNamespace(), ctx, TypeID::get<QuantumDialect>()) {
  addOperations<InstOp, QallocOp, ReturnOp>();
}
InstOpAdaptor::InstOpAdaptor(::mlir::ValueRange values,
                             ::mlir::DictionaryAttr attrs)
    : odsOperands(values), odsAttrs(attrs) {}
InstOpAdaptor::InstOpAdaptor(::mlir::ValueRange values, ::mlir::DictionaryAttr attrs)  : odsOperands(values), odsAttrs(attrs) {

InstOpAdaptor::InstOpAdaptor(InstOp &op)
    : odsOperands(op->getOperands()), odsAttrs(op->getAttrDictionary()) {}
}

InstOpAdaptor::InstOpAdaptor(InstOp&op)  : odsOperands(op->getOperands()), odsAttrs(op->getAttrDictionary()) {

std::pair<unsigned, unsigned> InstOpAdaptor::getODSOperandIndexAndLength(
    unsigned index) {
  return {index, 1};
}

std::pair<unsigned, unsigned> InstOpAdaptor::getODSOperandIndexAndLength(unsigned index) {
  bool isVariadic[] = {true};
  int prevVariadicCount = 0;
  for (unsigned i = 0; i < index; ++i)
    if (isVariadic[i]) ++prevVariadicCount;

  // Calculate how many dynamic values a static variadic operand corresponds to.
  // This assumes all static variadic operands have the same dynamic value count.
  int variadicSize = (odsOperands.size() - 0) / 1;
  // `index` passed in as the parameter is the static index which counts each
  // operand (variadic or not) as size 1. So here for each previous static variadic
  // operand, we need to offset by (variadicSize - 1) to get where the dynamic
  // value pack for this static operand starts.
  int start = index + (variadicSize - 1) * prevVariadicCount;
  int size = isVariadic[index] ? variadicSize : 1;
  return {start, size};
}

::mlir::ValueRange InstOpAdaptor::getODSOperands(unsigned index) {
@@ -31,117 +45,82 @@ std::pair<unsigned, unsigned> InstOpAdaptor::getODSOperandIndexAndLength(
           std::next(odsOperands.begin(), valueRange.first + valueRange.second)};
}

::mlir::StringAttr InstOpAdaptor::name() {
  assert(odsAttrs && "no attributes when constructing adapter");
  ::mlir::StringAttr attr = odsAttrs.get("name").cast<::mlir::StringAttr>();
  return attr;
}

::mlir::DenseElementsAttr InstOpAdaptor::qreg_names() {
  assert(odsAttrs && "no attributes when constructing adapter");
  ::mlir::DenseElementsAttr attr =
      odsAttrs.get("qreg_names").cast<::mlir::DenseElementsAttr>();
  return attr;
::mlir::ValueRange InstOpAdaptor::qubits() {
  return getODSOperands(0);
}

::mlir::DenseIntElementsAttr InstOpAdaptor::qubits() {
::mlir::StringAttr InstOpAdaptor::name() {
  assert(odsAttrs && "no attributes when constructing adapter");
  ::mlir::DenseIntElementsAttr attr =
      odsAttrs.get("qubits").cast<::mlir::DenseIntElementsAttr>();
  ::mlir::StringAttr attr = odsAttrs.get("name").cast<::mlir::StringAttr>();
  return attr;
}

::mlir::DenseElementsAttr InstOpAdaptor::params() {
  assert(odsAttrs && "no attributes when constructing adapter");
  ::mlir::DenseElementsAttr attr =
      odsAttrs.get("params").cast<::mlir::DenseElementsAttr>();
  ::mlir::DenseElementsAttr attr = odsAttrs.get("params").dyn_cast_or_null<::mlir::DenseElementsAttr>();
  return attr;
}

::mlir::LogicalResult InstOpAdaptor::verify(::mlir::Location loc) {
  {
  auto tblgen_name = odsAttrs.get("name");
    if (!tblgen_name)
      return emitError(loc,
                       "'quantum.inst' op "
                       "requires attribute 'name'");
    if (!((tblgen_name.isa<::mlir::StringAttr>())))
      return emitError(
          loc,
          "'quantum.inst' op "
          "attribute 'name' failed to satisfy constraint: string attribute");
  }
  {
    auto tblgen_qreg_names = odsAttrs.get("qreg_names");
    if (!tblgen_qreg_names)
      return emitError(loc,
                       "'quantum.inst' op "
                       "requires attribute 'qreg_names'");
    if (!((tblgen_qreg_names.isa<::mlir::DenseStringElementsAttr>())))
      return emitError(loc,
                       "'quantum.inst' op "
                       "attribute 'qreg_names' failed to satisfy constraint: "
                       "string elements attribute");
  }
  {
    auto tblgen_qubits = odsAttrs.get("qubits");
    if (!tblgen_qubits)
      return emitError(loc,
                       "'quantum.inst' op "
                       "requires attribute 'qubits'");
    if (!(((tblgen_qubits.isa<::mlir::DenseIntElementsAttr>())) &&
          ((tblgen_qubits.cast<::mlir::DenseIntElementsAttr>()
                .getType()
                .getElementType()
                .isIndex()))))
      return emitError(loc,
                       "'quantum.inst' op "
                       "attribute 'qubits' failed to satisfy constraint: index "
                       "elements attribute");
  if (!tblgen_name) return emitError(loc, "'quantum.inst' op ""requires attribute 'name'");
    if (!((tblgen_name.isa<::mlir::StringAttr>()))) return emitError(loc, "'quantum.inst' op ""attribute 'name' failed to satisfy constraint: string attribute");
  }
  {
  auto tblgen_params = odsAttrs.get("params");
    if (!tblgen_params)
      return emitError(loc,
                       "'quantum.inst' op "
                       "requires attribute 'params'");
    if (!((tblgen_params.isa<::mlir::DenseFPElementsAttr>() &&
           tblgen_params.cast<::mlir::DenseElementsAttr>()
               .getType()
               .getElementType()
               .isF64())))
      return emitError(loc,
                       "'quantum.inst' op "
                       "attribute 'params' failed to satisfy constraint: "
                       "64-bit float elements attribute");
  if (tblgen_params) {
    if (!((tblgen_params.isa<::mlir::DenseFPElementsAttr>() &&tblgen_params.cast<::mlir::DenseElementsAttr>().getType().getElementType().isF64()))) return emitError(loc, "'quantum.inst' op ""attribute 'params' failed to satisfy constraint: 64-bit float elements attribute");
  }
  }
  return ::mlir::success();
}

::llvm::StringRef InstOp::getOperationName() { return "quantum.inst"; }
::llvm::StringRef InstOp::getOperationName() {
  return "quantum.inst";
}

std::pair<unsigned, unsigned> InstOp::getODSOperandIndexAndLength(
    unsigned index) {
  return {index, 1};
std::pair<unsigned, unsigned> InstOp::getODSOperandIndexAndLength(unsigned index) {
  bool isVariadic[] = {true};
  int prevVariadicCount = 0;
  for (unsigned i = 0; i < index; ++i)
    if (isVariadic[i]) ++prevVariadicCount;

  // Calculate how many dynamic values a static variadic operand corresponds to.
  // This assumes all static variadic operands have the same dynamic value count.
  int variadicSize = (getOperation()->getNumOperands() - 0) / 1;
  // `index` passed in as the parameter is the static index which counts each
  // operand (variadic or not) as size 1. So here for each previous static variadic
  // operand, we need to offset by (variadicSize - 1) to get where the dynamic
  // value pack for this static operand starts.
  int start = index + (variadicSize - 1) * prevVariadicCount;
  int size = isVariadic[index] ? variadicSize : 1;
  return {start, size};
}

::mlir::Operation::operand_range InstOp::getODSOperands(unsigned index) {
  auto valueRange = getODSOperandIndexAndLength(index);
  return {std::next(getOperation()->operand_begin(), valueRange.first),
          std::next(getOperation()->operand_begin(),
                    valueRange.first + valueRange.second)};
           std::next(getOperation()->operand_begin(), valueRange.first + valueRange.second)};
}

::mlir::Operation::operand_range InstOp::qubits() {
  return getODSOperands(0);
}

std::pair<unsigned, unsigned> InstOp::getODSResultIndexAndLength(
    unsigned index) {
::mlir::MutableOperandRange InstOp::qubitsMutable() {
  auto range = getODSOperandIndexAndLength(0);
  return ::mlir::MutableOperandRange(getOperation(), range.first, range.second);
}

std::pair<unsigned, unsigned> InstOp::getODSResultIndexAndLength(unsigned index) {
  return {index, 1};
}

::mlir::Operation::result_range InstOp::getODSResults(unsigned index) {
  auto valueRange = getODSResultIndexAndLength(index);
  return {std::next(getOperation()->result_begin(), valueRange.first),
          std::next(getOperation()->result_begin(),
                    valueRange.first + valueRange.second)};
           std::next(getOperation()->result_begin(), valueRange.first + valueRange.second)};
}

::mlir::StringAttr InstOp::nameAttr() {
@@ -153,103 +132,60 @@ std::pair<unsigned, unsigned> InstOp::getODSResultIndexAndLength(
  return attr.getValue();
}

::mlir::DenseElementsAttr InstOp::qreg_namesAttr() {
  return this->getAttr("qreg_names").cast<::mlir::DenseElementsAttr>();
}

::mlir::DenseElementsAttr InstOp::qreg_names() {
  auto attr = qreg_namesAttr();
  return attr;
}

::mlir::DenseIntElementsAttr InstOp::qubitsAttr() {
  return this->getAttr("qubits").cast<::mlir::DenseIntElementsAttr>();
}

::mlir::DenseIntElementsAttr InstOp::qubits() {
  auto attr = qubitsAttr();
  return attr;
}

::mlir::DenseElementsAttr InstOp::paramsAttr() {
  return this->getAttr("params").cast<::mlir::DenseElementsAttr>();
  return this->getAttr("params").dyn_cast_or_null<::mlir::DenseElementsAttr>();
}

::mlir::DenseElementsAttr InstOp::params() {
::llvm::Optional< ::mlir::DenseElementsAttr > InstOp::params() {
  auto attr = paramsAttr();
  return attr;
  return attr ? ::llvm::Optional< ::mlir::DenseElementsAttr >(attr) : (::llvm::None);
}

void InstOp::nameAttr(::mlir::StringAttr attr) {
  (*this)->setAttr("name", attr);
}

void InstOp::qreg_namesAttr(::mlir::DenseElementsAttr attr) {
  (*this)->setAttr("qreg_names", attr);
}

void InstOp::qubitsAttr(::mlir::DenseIntElementsAttr attr) {
  (*this)->setAttr("qubits", attr);
}

void InstOp::paramsAttr(::mlir::DenseElementsAttr attr) {
  (*this)->setAttr("params", attr);
}

void InstOp::build(::mlir::OpBuilder &odsBuilder,
                   ::mlir::OperationState &odsState, ::mlir::StringAttr name,
                   ::mlir::DenseElementsAttr qreg_names,
                   ::mlir::DenseIntElementsAttr qubits,
                   ::mlir::DenseElementsAttr params) {
void InstOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::StringAttr name, ::mlir::ValueRange qubits, /*optional*/::mlir::DenseElementsAttr params) {
  odsState.addOperands(qubits);
  odsState.addAttribute("name", name);
  odsState.addAttribute("qreg_names", qreg_names);
  odsState.addAttribute("qubits", qubits);
  if (params) {
  odsState.addAttribute("params", params);
  }
}

void InstOp::build(::mlir::OpBuilder &odsBuilder,
                   ::mlir::OperationState &odsState,
                   ::mlir::TypeRange resultTypes, ::mlir::StringAttr name,
                   ::mlir::DenseElementsAttr qreg_names,
                   ::mlir::DenseIntElementsAttr qubits,
                   ::mlir::DenseElementsAttr params) {
void InstOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::StringAttr name, ::mlir::ValueRange qubits, /*optional*/::mlir::DenseElementsAttr params) {
  odsState.addOperands(qubits);
  odsState.addAttribute("name", name);
  odsState.addAttribute("qreg_names", qreg_names);
  odsState.addAttribute("qubits", qubits);
  if (params) {
  odsState.addAttribute("params", params);
  }
  assert(resultTypes.size() == 0u && "mismatched number of results");
  odsState.addTypes(resultTypes);
}

void InstOp::build(::mlir::OpBuilder &odsBuilder,
                   ::mlir::OperationState &odsState, ::llvm::StringRef name,
                   ::mlir::DenseElementsAttr qreg_names,
                   ::mlir::DenseIntElementsAttr qubits,
                   ::mlir::DenseElementsAttr params) {
void InstOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::llvm::StringRef name, ::mlir::ValueRange qubits, /*optional*/::mlir::DenseElementsAttr params) {
  odsState.addOperands(qubits);
  odsState.addAttribute("name", odsBuilder.getStringAttr(name));
  odsState.addAttribute("qreg_names", qreg_names);
  odsState.addAttribute("qubits", qubits);
  if (params) {
  odsState.addAttribute("params", params);
  }
}

void InstOp::build(::mlir::OpBuilder &odsBuilder,
                   ::mlir::OperationState &odsState,
                   ::mlir::TypeRange resultTypes, ::llvm::StringRef name,
                   ::mlir::DenseElementsAttr qreg_names,
                   ::mlir::DenseIntElementsAttr qubits,
                   ::mlir::DenseElementsAttr params) {
void InstOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::llvm::StringRef name, ::mlir::ValueRange qubits, /*optional*/::mlir::DenseElementsAttr params) {
  odsState.addOperands(qubits);
  odsState.addAttribute("name", odsBuilder.getStringAttr(name));
  odsState.addAttribute("qreg_names", qreg_names);
  odsState.addAttribute("qubits", qubits);
  if (params) {
  odsState.addAttribute("params", params);
  }
  assert(resultTypes.size() == 0u && "mismatched number of results");
  odsState.addTypes(resultTypes);
}

void InstOp::build(::mlir::OpBuilder &, ::mlir::OperationState &odsState,
                   ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands,
                   ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) {
  assert(operands.size() == 0u && "mismatched number of parameters");
void InstOp::build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) {
  odsState.addOperands(operands);
  odsState.addAttributes(attributes);
  assert(resultTypes.size() == 0u && "mismatched number of return types");
@@ -257,19 +193,32 @@ void InstOp::build(::mlir::OpBuilder &, ::mlir::OperationState &odsState,
}

::mlir::LogicalResult InstOp::verify() {
  if (failed(InstOpAdaptor(*this).verify(this->getLoc())))
    return ::mlir::failure();
  if (failed(InstOpAdaptor(*this).verify(this->getLoc()))) return ::mlir::failure();
  {
    unsigned index = 0;
    (void)index;
    unsigned index = 0; (void)index;
    auto valueGroup0 = getODSOperands(0);
    for (::mlir::Value v : valueGroup0) {
      (void)v;
      if (!((v.getType().isSignlessInteger(64)))) {
        return emitOpError("operand #") << index << " must be 64-bit signless integer, but got " << v.getType();
      }
      ++index;
    }
  }
  {
    unsigned index = 0;
    (void)index;
    unsigned index = 0; (void)index;
  }
  return ::mlir::success();
}

} // namespace quantum
} // namespace mlir
namespace mlir {
namespace quantum {

//===----------------------------------------------------------------------===//
// ::mlir::quantum::QallocOp definitions
//===----------------------------------------------------------------------===//

QallocOpAdaptor::QallocOpAdaptor(::mlir::ValueRange values, ::mlir::DictionaryAttr attrs)  : odsOperands(values), odsAttrs(attrs) {

@@ -339,6 +288,10 @@ std::pair<unsigned, unsigned> QallocOp::getODSResultIndexAndLength(unsigned inde
           std::next(getOperation()->result_begin(), valueRange.first + valueRange.second)};
}

::mlir::Value QallocOp::qubits() {
  return *getODSResults(0).begin();
}

::mlir::IntegerAttr QallocOp::sizeAttr() {
  return this->getAttr("size").cast<::mlir::IntegerAttr>();
}
@@ -365,27 +318,29 @@ void QallocOp::nameAttr(::mlir::StringAttr attr) {
  (*this)->setAttr("name", attr);
}

void QallocOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::IntegerAttr size, ::mlir::StringAttr name) {
void QallocOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type qubits, ::mlir::IntegerAttr size, ::mlir::StringAttr name) {
  odsState.addAttribute("size", size);
  odsState.addAttribute("name", name);
  odsState.addTypes(qubits);
}

void QallocOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::IntegerAttr size, ::mlir::StringAttr name) {
  odsState.addAttribute("size", size);
  odsState.addAttribute("name", name);
  assert(resultTypes.size() == 0u && "mismatched number of results");
  assert(resultTypes.size() == 1u && "mismatched number of results");
  odsState.addTypes(resultTypes);
}

void QallocOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::IntegerAttr size, ::llvm::StringRef name) {
void QallocOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type qubits, ::mlir::IntegerAttr size, ::llvm::StringRef name) {
  odsState.addAttribute("size", size);
  odsState.addAttribute("name", odsBuilder.getStringAttr(name));
  odsState.addTypes(qubits);
}

void QallocOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::IntegerAttr size, ::llvm::StringRef name) {
  odsState.addAttribute("size", size);
  odsState.addAttribute("name", odsBuilder.getStringAttr(name));
  assert(resultTypes.size() == 0u && "mismatched number of results");
  assert(resultTypes.size() == 1u && "mismatched number of results");
  odsState.addTypes(resultTypes);
}

@@ -393,7 +348,7 @@ void QallocOp::build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::ml
  assert(operands.size() == 0u && "mismatched number of parameters");
  odsState.addOperands(operands);
  odsState.addAttributes(attributes);
  assert(resultTypes.size() == 0u && "mismatched number of return types");
  assert(resultTypes.size() == 1u && "mismatched number of return types");
  odsState.addTypes(resultTypes);
}

@@ -404,10 +359,27 @@ void QallocOp::build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::ml
  }
  {
    unsigned index = 0; (void)index;
    auto valueGroup0 = getODSResults(0);
    for (::mlir::Value v : valueGroup0) {
      (void)v;
      if (!(((v.getType().isa<::mlir::VectorType>())) && ((v.getType().cast<::mlir::ShapedType>().getElementType().isSignlessInteger(64))))) {
        return emitOpError("result #") << index << " must be vector of 64-bit signless integer values, but got " << v.getType();
      }
      ++index;
    }
  }
  return ::mlir::success();
}

} // namespace quantum
} // namespace mlir
namespace mlir {
namespace quantum {

//===----------------------------------------------------------------------===//
// ::mlir::quantum::ReturnOp definitions
//===----------------------------------------------------------------------===//

ReturnOpAdaptor::ReturnOpAdaptor(::mlir::ValueRange values, ::mlir::DictionaryAttr attrs)  : odsOperands(values), odsAttrs(attrs) {

}
@@ -565,6 +537,7 @@ void ReturnOp::print(::mlir::OpAsmPrinter &p) {
}

void ReturnOp::getEffects(::mlir::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects) {

}

}  // namespace quantum
+36 −48

File changed.

Preview size limit exceeded, changes collapsed.

+1 −1
Original line number Diff line number Diff line
@@ -12,4 +12,4 @@ target_include_directories(
  ${LIBRARY_NAME}
  PUBLIC . ../../dialect ${XACC_ROOT}/include/staq/include ${XACC_ROOT}/include/staq/libs)

target_link_libraries(${LIBRARY_NAME} PUBLIC quantum-dialect)
target_link_libraries(${LIBRARY_NAME} PUBLIC quantum-dialect MLIRVector MLIRStandard)
+114 −90

File changed.

Preview size limit exceeded, changes collapsed.

+2 −0
Original line number Diff line number Diff line
@@ -9,6 +9,7 @@
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Verifier.h"
#include "parser/parser.hpp"
#include "quantum_dialect.hpp"

using namespace staq::ast;

@@ -18,6 +19,7 @@ class StaqToMLIR : public staq::ast::Visitor {
 protected:
  mlir::ModuleOp theModule;
  mlir::OpBuilder builder;
  std::map<std::string, mlir::quantum::QallocOp> qubit_allocations;

 public:
  StaqToMLIR(mlir::MLIRContext &context);
Loading