Commit 74a45c51 authored by Mccaskey, Alex's avatar Mccaskey, Alex
Browse files

update to support gate functions

parent 5b610b82
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -27,5 +27,5 @@ get_filename_component(MLIR_INSTALL_DIR "${MLIR_DIR}/../../.." ABSOLUTE)
add_subdirectory(dialect)
add_subdirectory(parsers)
add_subdirectory(transforms)

add_subdirectory(tools)
add_subdirectory(tests)
+1 −0
Original line number Diff line number Diff line
@@ -12,6 +12,7 @@ namespace quantum {
QuantumDialect::QuantumDialect(mlir::MLIRContext *ctx)
    : mlir::Dialect(getDialectNamespace(), ctx, TypeID::get<QuantumDialect>()) {
  addOperations<InstOp, QallocOp, ReturnOp>();
  addTypes<QubitType>();
}
InstOpAdaptor::InstOpAdaptor(::mlir::ValueRange values,
                             ::mlir::DictionaryAttr attrs)
+177 −42
Original line number Diff line number Diff line
@@ -15,6 +15,93 @@ class ReturnOp;

namespace mlir {
namespace quantum {
// struct QubitTypeStorage : public TypeStorage {
//   QubitTypeStorage(std::int64_t _qubit_idx, std::string _enclosed_register)
//       : qubit_idx(_qubit_idx), enclosed_register(_enclosed_register) {}

//   /// The hash key for this storage is a pair of the integer and type params.
//   using KeyTy = std::pair<std::int64_t, std::string>;

//   /// Define the comparison function for the key type.
//   bool operator==(const KeyTy &key) const {
//     return key == KeyTy(qubit_idx, enclosed_register);
//   }

//   /// Define a hash function for the key type.
//   /// Note: This isn't necessary because std::pair, unsigned, and Type all have
//   /// hash functions already available.
//   static llvm::hash_code hashKey(const KeyTy &key) {
//     return llvm::hash_combine(key.first, key.second);
//   }

//   /// Define a construction function for the key type.
//   /// Note: This isn't necessary because KeyTy can be directly constructed with
//   /// the given parameters.
//   static KeyTy getKey(std::int64_t _qubit_idx, std::string enc_reg) {
//     return KeyTy(_qubit_idx, enc_reg);
//   }

//   /// Define a construction method for creating a new instance of this storage.
//   static QubitTypeStorage *construct(TypeStorageAllocator &allocator,
//                                      const KeyTy &key) {
//     return new (allocator.allocate<QubitTypeStorage>())
//         QubitTypeStorage(key.first, key.second);
//   }

//   /// The parametric data held by the storage class.
//   std::int64_t qubit_idx;
//   std::string enclosed_register;
// };

// class QubitType : public Type::TypeBase<QubitType, Type, QubitTypeStorage> {
//  public:
//   /// Inherit some necessary constructors from 'TypeBase'.
//   using Base::Base;

//   /// This method is used to get an instance of the 'ComplexType'. This method
//   /// asserts that all of the construction invariants were satisfied. To
//   /// gracefully handle failed construction, getChecked should be used instead.
//   static QubitType get(mlir::MLIRContext* ctx, int64_t qbit, std::string enc_reg) {
//     // Call into a helper 'get' method in 'TypeBase' to get a uniqued instance
//     // of this type. All parameters to the storage class are passed after the
//     // context.
//     return Base::get(ctx, qbit, enc_reg);
//   }

//   /// This method is used to get an instance of the 'ComplexType', defined at
//   /// the given location. If any of the construction invariants are invalid,
//   /// errors are emitted with the provided location and a null type is returned.
//   /// Note: This method is completely optional.
//   static QubitType getChecked(std::int64_t qbit, std::string enc_reg, Location location) {
//     // Call into a helper 'getChecked' method in 'TypeBase' to get a uniqued
//     // instance of this type. All parameters to the storage class are passed
//     // after the location.
//     return Base::getChecked(location, qbit, enc_reg);
//   }

//   /// This method is used to verify the construction invariants passed into the
//   /// 'get' and 'getChecked' methods. Note: This method is completely optional.
//   static LogicalResult verifyConstructionInvariants(Location loc,
//                                                     std::int64_t qbit, std::string enc_reg) {
//     // Our type only allows non-zero parameters.
//     if (qbit < 0)
//       return emitError(loc) << "non-zero parameter passed to 'QubitType'";
//     return success();
//   }

//   /// Return the parameter value.
//   std::int64_t getQubitIndex() {
//     // 'getImpl' returns a pointer to our internal storage instance.
//     return getImpl()->qubit_idx;
//   }

//   /// Return the integer parameter type.
//   std::string getEnclosedRegister() {
//     // 'getImpl' returns a pointer to our internal storage instance.
//     return getImpl()->enclosed_register;
//   }
// };

class QuantumDialect : public mlir::Dialect {
 public:
  explicit QuantumDialect(mlir::MLIRContext *ctx);
@@ -22,7 +109,8 @@ class QuantumDialect : public mlir::Dialect {
};
class InstOpAdaptor {
 public:
  InstOpAdaptor(::mlir::ValueRange values, ::mlir::DictionaryAttr attrs = nullptr);
  InstOpAdaptor(::mlir::ValueRange values,
                ::mlir::DictionaryAttr attrs = nullptr);
  InstOpAdaptor(InstOp &op);
  std::pair<unsigned, unsigned> getODSOperandIndexAndLength(unsigned index);
  ::mlir::ValueRange getODSOperands(unsigned index);
@@ -35,7 +123,10 @@ private:
  ::mlir::ValueRange odsOperands;
  ::mlir::DictionaryAttr odsAttrs;
};
class InstOp : public ::mlir::Op<InstOp, ::mlir::OpTrait::ZeroRegion, ::mlir::OpTrait::ZeroResult, ::mlir::OpTrait::ZeroSuccessor, ::mlir::OpTrait::VariadicOperands> {
class InstOp
    : public ::mlir::Op<
          InstOp, ::mlir::OpTrait::ZeroRegion, ::mlir::OpTrait::ZeroResult,
          ::mlir::OpTrait::ZeroSuccessor, ::mlir::OpTrait::VariadicOperands> {
 public:
  using Op::Op;
  using Op::print;
@@ -53,11 +144,27 @@ public:
  ::llvm::Optional<::mlir::DenseElementsAttr> params();
  void nameAttr(::mlir::StringAttr attr);
  void paramsAttr(::mlir::DenseElementsAttr attr);
  static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::StringAttr name, ::mlir::ValueRange qubits, /*optional*/::mlir::DenseElementsAttr params);
  static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::StringAttr name, ::mlir::ValueRange qubits, /*optional*/::mlir::DenseElementsAttr params);
  static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::llvm::StringRef name, ::mlir::ValueRange qubits, /*optional*/::mlir::DenseElementsAttr params);
  static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::llvm::StringRef name, ::mlir::ValueRange qubits, /*optional*/::mlir::DenseElementsAttr params);
  static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
  static void build(::mlir::OpBuilder &odsBuilder,
                    ::mlir::OperationState &odsState, ::mlir::StringAttr name,
                    ::mlir::ValueRange qubits,
                    /*optional*/ ::mlir::DenseElementsAttr params);
  static void build(::mlir::OpBuilder &odsBuilder,
                    ::mlir::OperationState &odsState,
                    ::mlir::TypeRange resultTypes, ::mlir::StringAttr name,
                    ::mlir::ValueRange qubits,
                    /*optional*/ ::mlir::DenseElementsAttr params);
  static void build(::mlir::OpBuilder &odsBuilder,
                    ::mlir::OperationState &odsState, ::llvm::StringRef name,
                    ::mlir::ValueRange qubits,
                    /*optional*/ ::mlir::DenseElementsAttr params);
  static void build(::mlir::OpBuilder &odsBuilder,
                    ::mlir::OperationState &odsState,
                    ::mlir::TypeRange resultTypes, ::llvm::StringRef name,
                    ::mlir::ValueRange qubits,
                    /*optional*/ ::mlir::DenseElementsAttr params);
  static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState,
                    ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands,
                    ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
  ::mlir::LogicalResult verify();
};
}  // namespace quantum
@@ -71,7 +178,8 @@ namespace quantum {

class QallocOpAdaptor {
 public:
  QallocOpAdaptor(::mlir::ValueRange values, ::mlir::DictionaryAttr attrs = nullptr);
  QallocOpAdaptor(::mlir::ValueRange values,
                  ::mlir::DictionaryAttr attrs = nullptr);
  QallocOpAdaptor(QallocOp &op);
  std::pair<unsigned, unsigned> getODSOperandIndexAndLength(unsigned index);
  ::mlir::ValueRange getODSOperands(unsigned index);
@@ -83,7 +191,10 @@ private:
  ::mlir::ValueRange odsOperands;
  ::mlir::DictionaryAttr odsAttrs;
};
class QallocOp : public ::mlir::Op<QallocOp, ::mlir::OpTrait::ZeroRegion, ::mlir::OpTrait::OneResult, ::mlir::OpTrait::ZeroSuccessor, ::mlir::OpTrait::ZeroOperands> {
class QallocOp
    : public ::mlir::Op<
          QallocOp, ::mlir::OpTrait::ZeroRegion, ::mlir::OpTrait::OneResult,
          ::mlir::OpTrait::ZeroSuccessor, ::mlir::OpTrait::ZeroOperands> {
 public:
  using Op::Op;
  using Op::print;
@@ -100,11 +211,23 @@ public:
  ::llvm::StringRef name();
  void sizeAttr(::mlir::IntegerAttr attr);
  void nameAttr(::mlir::StringAttr attr);
  static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type qubits, ::mlir::IntegerAttr size, ::mlir::StringAttr name);
  static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::IntegerAttr size, ::mlir::StringAttr name);
  static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type qubits, ::mlir::IntegerAttr size, ::llvm::StringRef name);
  static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::IntegerAttr size, ::llvm::StringRef name);
  static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
  static void build(::mlir::OpBuilder &odsBuilder,
                    ::mlir::OperationState &odsState, ::mlir::Type qubits,
                    ::mlir::IntegerAttr size, ::mlir::StringAttr name);
  static void build(::mlir::OpBuilder &odsBuilder,
                    ::mlir::OperationState &odsState,
                    ::mlir::TypeRange resultTypes, ::mlir::IntegerAttr size,
                    ::mlir::StringAttr name);
  static void build(::mlir::OpBuilder &odsBuilder,
                    ::mlir::OperationState &odsState, ::mlir::Type qubits,
                    ::mlir::IntegerAttr size, ::llvm::StringRef name);
  static void build(::mlir::OpBuilder &odsBuilder,
                    ::mlir::OperationState &odsState,
                    ::mlir::TypeRange resultTypes, ::mlir::IntegerAttr size,
                    ::llvm::StringRef name);
  static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState,
                    ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands,
                    ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
  ::mlir::LogicalResult verify();
};
}  // namespace quantum
@@ -118,7 +241,8 @@ namespace quantum {

class ReturnOpAdaptor {
 public:
  ReturnOpAdaptor(::mlir::ValueRange values, ::mlir::DictionaryAttr attrs = nullptr);
  ReturnOpAdaptor(::mlir::ValueRange values,
                  ::mlir::DictionaryAttr attrs = nullptr);
  ReturnOpAdaptor(ReturnOp &op);
  std::pair<unsigned, unsigned> getODSOperandIndexAndLength(unsigned index);
  ::mlir::ValueRange getODSOperands(unsigned index);
@@ -129,7 +253,13 @@ private:
  ::mlir::ValueRange odsOperands;
  ::mlir::DictionaryAttr odsAttrs;
};
class ReturnOp : public ::mlir::Op<ReturnOp, ::mlir::OpTrait::ZeroRegion, ::mlir::OpTrait::ZeroResult, ::mlir::OpTrait::ZeroSuccessor, ::mlir::OpTrait::VariadicOperands, ::mlir::MemoryEffectOpInterface::Trait, ::mlir::OpTrait::HasParent<FuncOp>::Impl, ::mlir::OpTrait::IsTerminator> {
class ReturnOp
    : public ::mlir::Op<
          ReturnOp, ::mlir::OpTrait::ZeroRegion, ::mlir::OpTrait::ZeroResult,
          ::mlir::OpTrait::ZeroSuccessor, ::mlir::OpTrait::VariadicOperands,
          ::mlir::MemoryEffectOpInterface::Trait,
          ::mlir::OpTrait::HasParent<FuncOp>::Impl,
          ::mlir::OpTrait::IsTerminator> {
 public:
  using Op::Op;
  using Op::print;
@@ -141,16 +271,21 @@ public:
  ::mlir::MutableOperandRange inputMutable();
  std::pair<unsigned, unsigned> getODSResultIndexAndLength(unsigned index);
  ::mlir::Operation::result_range getODSResults(unsigned index);
  static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState);
  static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange input);
  static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
  static void build(::mlir::OpBuilder &odsBuilder,
                    ::mlir::OperationState &odsState);
  static void build(::mlir::OpBuilder &odsBuilder,
                    ::mlir::OperationState &odsState, ::mlir::ValueRange input);
  static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState,
                    ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands,
                    ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
  ::mlir::LogicalResult verify();
  static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result);
  static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser,
                                   ::mlir::OperationState &result);
  void print(::mlir::OpAsmPrinter &p);
  void getEffects(::mlir::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects);
  void getEffects(::mlir::SmallVectorImpl<::mlir::SideEffects::EffectInstance<
                      ::mlir::MemoryEffects::Effect>> &effects);

  bool hasOperand() { return getNumOperands() != 0; }
  
};
}  // namespace quantum
}  // namespace mlir
@@ -176,7 +311,7 @@ def QuantumDialect : Dialect {


def InstOp : Op<QuantumDialect, "inst", []> {
   let arguments = (ins StrAttr:$name, StringElementsAttr:$qreg_names, IndexElementsAttr:$qubits, F64ElementsAttr:$params);
   let results = (outs);
   let arguments = (ins StrAttr:$name, StringElementsAttr:$qreg_names,
IndexElementsAttr:$qubits, F64ElementsAttr:$params); let results = (outs);
}
*/
 No newline at end of file
+72 −10
Original line number Diff line number Diff line
@@ -19,16 +19,72 @@ StaqToMLIR::StaqToMLIR(mlir::MLIRContext &context) : builder(&context) {
  auto func_type = builder.getFunctionType(arg_types, llvm::None);
  auto proto = mlir::FuncOp::create(builder.getUnknownLoc(), "main", func_type);
  mlir::FuncOp function(proto);
  auto &entryBlock = *function.addEntryBlock();
  main_entry_point = function.addEntryBlock();
  auto &entryBlock = *main_entry_point;
  builder.setInsertionPointToStart(&entryBlock);
  theModule.push_back(function);
  function_names.push_back("main");
}

void StaqToMLIR::addReturn() {
  builder.create<mlir::ReturnOp>(builder.getUnknownLoc());

  std::vector<llvm::StringRef> tmp(function_names.begin(), function_names.end());
  
  auto function_names_datatype =
      mlir::VectorType::get({function_names.size()}, builder.getI64Type());
  auto function_names_ref = llvm::makeArrayRef(tmp);
  auto attrs = mlir::DenseStringElementsAttr::get(function_names_datatype,
                                                  function_names_ref);

  mlir::Identifier id =
      mlir::Identifier::get("quantum.internal_functions", builder.getContext());

  theModule.setAttrs(
      llvm::makeArrayRef({mlir::NamedAttribute(std::make_pair(id, attrs))}));
}

void StaqToMLIR::visit(GateDecl &gate_function) {
  auto name = gate_function.id();
  static std::vector<std::string> builtins{
      "u3", "u2",   "u1",  "cx",  "id",  "u0",  "x",   "y",  "z",
      "h",  "s",    "sdg", "t",   "tdg", "rx",  "ry",  "rz", "cz",
      "cy", "swap", "ch",  "ccx", "crz", "cu1", "cu2", "cu3"};
  if (std::find(builtins.begin(), builtins.end(), name) == builtins.end()) {
    std::vector<mlir::Type> arg_types;

    function_names.push_back(name);

    auto n_args = gate_function.q_params().size();
    for (std::size_t i = 0; i < n_args; i++) {
      arg_types.push_back(builder.getI64Type());
    }

    auto func_type = builder.getFunctionType(arg_types, llvm::None);
    auto proto = mlir::FuncOp::create(builder.getUnknownLoc(), name, func_type);
    mlir::FuncOp function(proto);
    auto &entryBlock = *function.addEntryBlock();
    builder.setInsertionPointToStart(&entryBlock);

    auto arguments = entryBlock.getArguments();

    for (std::size_t i = 0; i < n_args; i++) {
      auto argument = arguments[i];
      auto arg_name = gate_function.q_params()[i];
      temporary_sub_kernel_args.insert({arg_name, argument});
    }
    in_sub_kernel = true;

    gate_function.foreach_stmt([this](Gate &g) { g.accept(*this); });

void StaqToMLIR::visit(GateDecl &) {}
    in_sub_kernel = false;
    temporary_sub_kernel_args.clear();
    builder.create<mlir::ReturnOp>(builder.getUnknownLoc());
    theModule.push_back(function);

    builder.setInsertionPointToStart(main_entry_point);
  }
}

void StaqToMLIR::visit(RegisterDecl &d) {
  if (d.is_quantum()) {
@@ -210,14 +266,20 @@ void StaqToMLIR::visit(DeclaredGate &g) {
      // throw an error
    }

    if (!in_sub_kernel) {
      auto qubits = qubit_allocations[qreg_var_name].qubits();

      std::uint64_t qidx = g.qarg(i).offset().value();
      auto integer_attr = mlir::IntegerAttr::get(builder.getI64Type(), qidx);
    mlir::Value pos = builder.create<mlir::ConstantOp>(location, integer_attr);
      mlir::Value pos =
          builder.create<mlir::ConstantOp>(location, integer_attr);
      mlir::Value qbit_value =
          builder.create<mlir::vector::ExtractElementOp>(location, qubits, pos);
      qubits_for_inst.push_back(qbit_value);
    } else {
      auto qubit_kernel_arg = temporary_sub_kernel_args[qreg_var_name];
      qubits_for_inst.push_back(qubit_kernel_arg);
    }
  }

  builder.create<mlir::quantum::InstOp>(location, str_attr,
+7 −1
Original line number Diff line number Diff line
#pragma once

// Turn off Staq warnings
#pragma GCC diagnostic ignored "-Wsuggest-override"
#pragma GCC diagnostic ignored "-Wdeprecated-copy"
#pragma GCC diagnostic ignored "-Wunused-function"
@@ -26,6 +28,10 @@ class StaqToMLIR : public staq::ast::Visitor {
  mlir::ModuleOp theModule;
  mlir::OpBuilder builder;
  std::map<std::string, mlir::quantum::QallocOp> qubit_allocations;
  bool in_sub_kernel = false;
  std::map<std::string, mlir::Value> temporary_sub_kernel_args;
  mlir::Block * main_entry_point;
  std::vector<std::string> function_names;
  
 public:
  StaqToMLIR(mlir::MLIRContext &context);
Loading