Commit 7f9f4bc2 authored by Mccaskey, Alex's avatar Mccaskey, Alex
Browse files

updates to add opaque qubit, result, and array types, cleaned up qir rt api signatures

parent 74a45c51
Loading
Loading
Loading
Loading
+276 −353

File changed.

Preview size limit exceeded, changes collapsed.

+68 −117
Original line number Diff line number Diff line
@@ -9,7 +9,7 @@ namespace quantum {
class InstOp;
class QallocOp;
class QInstOp;
class ReturnOp;
class ExtractQubitOp;
}  // namespace quantum
}  // namespace mlir

@@ -107,64 +107,38 @@ class QuantumDialect : public mlir::Dialect {
  explicit QuantumDialect(mlir::MLIRContext *ctx);
  static llvm::StringRef getDialectNamespace() { return "quantum"; }
};
class InstOpAdaptor {
class ExtractQubitOpAdaptor {
public:
  InstOpAdaptor(::mlir::ValueRange values,
                ::mlir::DictionaryAttr attrs = nullptr);
  InstOpAdaptor(InstOp &op);
  ExtractQubitOpAdaptor(::mlir::ValueRange values, ::mlir::DictionaryAttr attrs = nullptr);
  ExtractQubitOpAdaptor(ExtractQubitOp& op);
  std::pair<unsigned, unsigned> getODSOperandIndexAndLength(unsigned index);
  ::mlir::ValueRange getODSOperands(unsigned index);
  ::mlir::ValueRange qubits();
  ::mlir::StringAttr name();
  ::mlir::DenseElementsAttr params();
  ::mlir::Value qreg();
  ::mlir::Value idx();
  ::mlir::LogicalResult verify(::mlir::Location loc);

private:
  ::mlir::ValueRange odsOperands;
  ::mlir::DictionaryAttr odsAttrs;
};
class InstOp
    : public ::mlir::Op<
          InstOp, ::mlir::OpTrait::ZeroRegion, ::mlir::OpTrait::ZeroResult,
          ::mlir::OpTrait::ZeroSuccessor, ::mlir::OpTrait::VariadicOperands> {
class ExtractQubitOp : public ::mlir::Op<ExtractQubitOp, ::mlir::OpTrait::ZeroRegion, ::mlir::OpTrait::OneResult, ::mlir::OpTrait::ZeroSuccessor, ::mlir::OpTrait::NOperands<2>::Impl> {
public:
  using Op::Op;
  using Op::print;
  using Adaptor = InstOpAdaptor;
  using Adaptor = ExtractQubitOpAdaptor;
  static ::llvm::StringRef getOperationName();
  std::pair<unsigned, unsigned> getODSOperandIndexAndLength(unsigned index);
  ::mlir::Operation::operand_range getODSOperands(unsigned index);
  ::mlir::Operation::operand_range qubits();
  ::mlir::MutableOperandRange qubitsMutable();
  ::mlir::Value qreg();
  ::mlir::Value idx();
  ::mlir::MutableOperandRange qregMutable();
  ::mlir::MutableOperandRange idxMutable();
  std::pair<unsigned, unsigned> getODSResultIndexAndLength(unsigned index);
  ::mlir::Operation::result_range getODSResults(unsigned index);
  ::mlir::StringAttr nameAttr();
  ::llvm::StringRef name();
  ::mlir::DenseElementsAttr paramsAttr();
  ::llvm::Optional<::mlir::DenseElementsAttr> params();
  void nameAttr(::mlir::StringAttr attr);
  void paramsAttr(::mlir::DenseElementsAttr attr);
  static void build(::mlir::OpBuilder &odsBuilder,
                    ::mlir::OperationState &odsState, ::mlir::StringAttr name,
                    ::mlir::ValueRange qubits,
                    /*optional*/ ::mlir::DenseElementsAttr params);
  static void build(::mlir::OpBuilder &odsBuilder,
                    ::mlir::OperationState &odsState,
                    ::mlir::TypeRange resultTypes, ::mlir::StringAttr name,
                    ::mlir::ValueRange qubits,
                    /*optional*/ ::mlir::DenseElementsAttr params);
  static void build(::mlir::OpBuilder &odsBuilder,
                    ::mlir::OperationState &odsState, ::llvm::StringRef name,
                    ::mlir::ValueRange qubits,
                    /*optional*/ ::mlir::DenseElementsAttr params);
  static void build(::mlir::OpBuilder &odsBuilder,
                    ::mlir::OperationState &odsState,
                    ::mlir::TypeRange resultTypes, ::llvm::StringRef name,
                    ::mlir::ValueRange qubits,
                    /*optional*/ ::mlir::DenseElementsAttr params);
  static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState,
                    ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands,
                    ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
  ::mlir::Value qbit();
  static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type qbit, ::mlir::Value qreg, ::mlir::Value idx);
  static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value qreg, ::mlir::Value idx);
  static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
  ::mlir::LogicalResult verify();
};
} // namespace quantum
@@ -173,61 +147,48 @@ namespace mlir {
namespace quantum {

//===----------------------------------------------------------------------===//
// ::mlir::quantum::QallocOp declarations
// ::mlir::quantum::InstOp declarations
//===----------------------------------------------------------------------===//

class QallocOpAdaptor {
class InstOpAdaptor {
public:
  QallocOpAdaptor(::mlir::ValueRange values,
                  ::mlir::DictionaryAttr attrs = nullptr);
  QallocOpAdaptor(QallocOp &op);
  InstOpAdaptor(::mlir::ValueRange values, ::mlir::DictionaryAttr attrs = nullptr);
  InstOpAdaptor(InstOp&op);
  std::pair<unsigned, unsigned> getODSOperandIndexAndLength(unsigned index);
  ::mlir::ValueRange getODSOperands(unsigned index);
  ::mlir::IntegerAttr size();
  ::mlir::ValueRange qubits();
  ::mlir::StringAttr name();
  ::mlir::DenseElementsAttr params();
  ::mlir::LogicalResult verify(::mlir::Location loc);

private:
  ::mlir::ValueRange odsOperands;
  ::mlir::DictionaryAttr odsAttrs;
};
class QallocOp
    : public ::mlir::Op<
          QallocOp, ::mlir::OpTrait::ZeroRegion, ::mlir::OpTrait::OneResult,
          ::mlir::OpTrait::ZeroSuccessor, ::mlir::OpTrait::ZeroOperands> {
class InstOp : public ::mlir::Op<InstOp, ::mlir::OpTrait::ZeroRegion, ::mlir::OpTrait::VariadicResults, ::mlir::OpTrait::ZeroSuccessor, ::mlir::OpTrait::VariadicOperands> {
public:
  using Op::Op;
  using Op::print;
  using Adaptor = QallocOpAdaptor;
  using Adaptor = InstOpAdaptor;
  static ::llvm::StringRef getOperationName();
  std::pair<unsigned, unsigned> getODSOperandIndexAndLength(unsigned index);
  ::mlir::Operation::operand_range getODSOperands(unsigned index);
  ::mlir::Operation::operand_range qubits();
  ::mlir::MutableOperandRange qubitsMutable();
  std::pair<unsigned, unsigned> getODSResultIndexAndLength(unsigned index);
  ::mlir::Operation::result_range getODSResults(unsigned index);
  ::mlir::Value qubits();
  ::mlir::IntegerAttr sizeAttr();
  ::llvm::APInt size();
  ::mlir::Value bit();
  ::mlir::StringAttr nameAttr();
  ::llvm::StringRef name();
  void sizeAttr(::mlir::IntegerAttr attr);
  ::mlir::DenseElementsAttr paramsAttr();
  ::llvm::Optional< ::mlir::DenseElementsAttr > params();
  void nameAttr(::mlir::StringAttr attr);
  static void build(::mlir::OpBuilder &odsBuilder,
                    ::mlir::OperationState &odsState, ::mlir::Type qubits,
                    ::mlir::IntegerAttr size, ::mlir::StringAttr name);
  static void build(::mlir::OpBuilder &odsBuilder,
                    ::mlir::OperationState &odsState,
                    ::mlir::TypeRange resultTypes, ::mlir::IntegerAttr size,
                    ::mlir::StringAttr name);
  static void build(::mlir::OpBuilder &odsBuilder,
                    ::mlir::OperationState &odsState, ::mlir::Type qubits,
                    ::mlir::IntegerAttr size, ::llvm::StringRef name);
  static void build(::mlir::OpBuilder &odsBuilder,
                    ::mlir::OperationState &odsState,
                    ::mlir::TypeRange resultTypes, ::mlir::IntegerAttr size,
                    ::llvm::StringRef name);
  static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState,
                    ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands,
                    ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
  void paramsAttr(::mlir::DenseElementsAttr attr);
  static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, /*optional*/::mlir::Type bit, ::mlir::StringAttr name, ::mlir::ValueRange qubits, /*optional*/::mlir::DenseElementsAttr params);
  static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::StringAttr name, ::mlir::ValueRange qubits, /*optional*/::mlir::DenseElementsAttr params);
  static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, /*optional*/::mlir::Type bit, ::llvm::StringRef name, ::mlir::ValueRange qubits, /*optional*/::mlir::DenseElementsAttr params);
  static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::llvm::StringRef name, ::mlir::ValueRange qubits, /*optional*/::mlir::DenseElementsAttr params);
  static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
  ::mlir::LogicalResult verify();
};
} // namespace quantum
@@ -236,56 +197,46 @@ namespace mlir {
namespace quantum {

//===----------------------------------------------------------------------===//
// ::mlir::quantum::ReturnOp declarations
// ::mlir::quantum::QallocOp declarations
//===----------------------------------------------------------------------===//

class ReturnOpAdaptor {
class QallocOpAdaptor {
public:
  ReturnOpAdaptor(::mlir::ValueRange values,
                  ::mlir::DictionaryAttr attrs = nullptr);
  ReturnOpAdaptor(ReturnOp &op);
  QallocOpAdaptor(::mlir::ValueRange values, ::mlir::DictionaryAttr attrs = nullptr);
  QallocOpAdaptor(QallocOp&op);
  std::pair<unsigned, unsigned> getODSOperandIndexAndLength(unsigned index);
  ::mlir::ValueRange getODSOperands(unsigned index);
  ::mlir::ValueRange input();
  ::mlir::IntegerAttr size();
  ::mlir::StringAttr name();
  ::mlir::LogicalResult verify(::mlir::Location loc);

private:
  ::mlir::ValueRange odsOperands;
  ::mlir::DictionaryAttr odsAttrs;
};
class ReturnOp
    : public ::mlir::Op<
          ReturnOp, ::mlir::OpTrait::ZeroRegion, ::mlir::OpTrait::ZeroResult,
          ::mlir::OpTrait::ZeroSuccessor, ::mlir::OpTrait::VariadicOperands,
          ::mlir::MemoryEffectOpInterface::Trait,
          ::mlir::OpTrait::HasParent<FuncOp>::Impl,
          ::mlir::OpTrait::IsTerminator> {
class QallocOp : public ::mlir::Op<QallocOp, ::mlir::OpTrait::ZeroRegion, ::mlir::OpTrait::OneResult, ::mlir::OpTrait::ZeroSuccessor, ::mlir::OpTrait::ZeroOperands> {
public:
  using Op::Op;
  using Op::print;
  using Adaptor = ReturnOpAdaptor;
  using Adaptor = QallocOpAdaptor;
  static ::llvm::StringRef getOperationName();
  std::pair<unsigned, unsigned> getODSOperandIndexAndLength(unsigned index);
  ::mlir::Operation::operand_range getODSOperands(unsigned index);
  ::mlir::Operation::operand_range input();
  ::mlir::MutableOperandRange inputMutable();
  std::pair<unsigned, unsigned> getODSResultIndexAndLength(unsigned index);
  ::mlir::Operation::result_range getODSResults(unsigned index);
  static void build(::mlir::OpBuilder &odsBuilder,
                    ::mlir::OperationState &odsState);
  static void build(::mlir::OpBuilder &odsBuilder,
                    ::mlir::OperationState &odsState, ::mlir::ValueRange input);
  static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState,
                    ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands,
                    ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
  ::mlir::Value qubits();
  ::mlir::IntegerAttr sizeAttr();
  ::llvm::APInt size();
  ::mlir::StringAttr nameAttr();
  ::llvm::StringRef name();
  void sizeAttr(::mlir::IntegerAttr attr);
  void nameAttr(::mlir::StringAttr attr);
  static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type qubits, ::mlir::IntegerAttr size, ::mlir::StringAttr name);
  static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::IntegerAttr size, ::mlir::StringAttr name);
  static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type qubits, ::mlir::IntegerAttr size, ::llvm::StringRef name);
  static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::IntegerAttr size, ::llvm::StringRef name);
  static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
  ::mlir::LogicalResult verify();
  static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser,
                                   ::mlir::OperationState &result);
  void print(::mlir::OpAsmPrinter &p);
  void getEffects(::mlir::SmallVectorImpl<::mlir::SideEffects::EffectInstance<
                      ::mlir::MemoryEffects::Effect>> &effects);

  bool hasOperand() { return getNumOperands() != 0; }
};
}  // namespace quantum
}  // namespace mlir
+134 −54
Original line number Diff line number Diff line
@@ -24,12 +24,20 @@ StaqToMLIR::StaqToMLIR(mlir::MLIRContext &context) : builder(&context) {
  builder.setInsertionPointToStart(&entryBlock);
  theModule.push_back(function);
  function_names.push_back("main");

  llvm::StringRef qubit_type_name("Qubit"), array_type_name("Array"),
      result_type_name("Result");
  mlir::Identifier dialect = mlir::Identifier::get("quantum", &context);
  qubit_type = mlir::OpaqueType::get(dialect, qubit_type_name, &context);
  array_type = mlir::OpaqueType::get(dialect, array_type_name, &context);
  result_type = mlir::OpaqueType::get(dialect, result_type_name, &context);
}

void StaqToMLIR::addReturn() {
  builder.create<mlir::ReturnOp>(builder.getUnknownLoc());

  std::vector<llvm::StringRef> tmp(function_names.begin(), function_names.end());
  std::vector<llvm::StringRef> tmp(function_names.begin(),
                                   function_names.end());

  auto function_names_datatype =
      mlir::VectorType::get({function_names.size()}, builder.getI64Type());
@@ -57,7 +65,7 @@ void StaqToMLIR::visit(GateDecl &gate_function) {

    auto n_args = gate_function.q_params().size();
    for (std::size_t i = 0; i < n_args; i++) {
      arg_types.push_back(builder.getI64Type());
      arg_types.push_back(qubit_type);
    }

    auto func_type = builder.getFunctionType(arg_types, llvm::None);
@@ -103,11 +111,8 @@ void StaqToMLIR::visit(RegisterDecl &d) {
    auto integer_attr = mlir::IntegerAttr::get(integer_type, size);

    auto str_attr = builder.getStringAttr(name);

    auto returntype = mlir::VectorType::get({size}, integer_type);

    auto allocation = builder.create<mlir::quantum::QallocOp>(
        location, returntype, integer_attr, str_attr);
        location, array_type, integer_attr, str_attr);
    qubit_allocations.insert({name, allocation});
  }
}
@@ -132,16 +137,24 @@ void StaqToMLIR::visit(MeasureStmt &m) {
    // throw an error
  }

  std::uint64_t qidx = m.q_arg().offset().value();
  auto qbit_key = std::make_pair(qreg_var_name, qidx);
  mlir::Value qbit_value;
  if (extracted_qubits.count(qbit_key)) {
    qbit_value = extracted_qubits[qbit_key];
  } else {
    auto qubits = qubit_allocations[qreg_var_name].qubits();

  std::uint64_t qidx = m.q_arg().offset().value();
    auto integer_attr = mlir::IntegerAttr::get(builder.getI64Type(), qidx);
  mlir::Value pos2 = builder.create<mlir::ConstantOp>(location, integer_attr);
  mlir::Value qbit_value =
      builder.create<mlir::vector::ExtractElementOp>(location, qubits, pos2);
    mlir::Value pos = builder.create<mlir::ConstantOp>(location, integer_attr);
    qbit_value = builder.create<mlir::quantum::ExtractQubitOp>(
        location, qubit_type, qubits, pos);

    extracted_qubits.insert({std::make_pair(qreg_var_name, qidx), qbit_value});
  }
  qubits_for_inst.push_back(qbit_value);

  builder.create<mlir::quantum::InstOp>(location, str_attr,
  builder.create<mlir::quantum::InstOp>(location, result_type, str_attr,
                                        llvm::makeArrayRef(qubits_for_inst),
                                        params_dataAttribute);
}
@@ -170,19 +183,28 @@ void StaqToMLIR::visit(UGate &u) {
    // throw an error
  }

  std::uint64_t qidx = u.arg().offset().value();
  auto qbit_key = std::make_pair(qreg_var_name, qidx);
  mlir::Value qbit_value;
  if (extracted_qubits.count(qbit_key)) {
    qbit_value = extracted_qubits[qbit_key];
  } else {
    auto qubits = qubit_allocations[qreg_var_name].qubits();

  std::uint64_t qidx = u.arg().offset().value();
    auto integer_attr = mlir::IntegerAttr::get(builder.getI64Type(), qidx);
  mlir::Value pos2 = builder.create<mlir::ConstantOp>(location, integer_attr);
  mlir::Value ctrl_qbit_value =
      builder.create<mlir::vector::ExtractElementOp>(location, qubits, pos2);
  qubits_for_inst.push_back(ctrl_qbit_value);
    mlir::Value pos = builder.create<mlir::ConstantOp>(location, integer_attr);
    qbit_value = builder.create<mlir::quantum::ExtractQubitOp>(
        location, qubit_type, qubits, pos);

  builder.create<mlir::quantum::InstOp>(location, str_attr,
                                        llvm::makeArrayRef(qubits_for_inst),
                                        params_dataAttribute);
    extracted_qubits.insert({std::make_pair(qreg_var_name, qidx), qbit_value});
  }
  qubits_for_inst.push_back(qbit_value);

  builder.create<mlir::quantum::InstOp>(
      location, mlir::NoneType::get(builder.getContext()), str_attr,
      llvm::makeArrayRef(qubits_for_inst), params_dataAttribute);
}

void StaqToMLIR::visit(CNOTGate &g) {
  auto pos = g.pos();
  auto line = pos.get_linenum();
@@ -200,37 +222,86 @@ void StaqToMLIR::visit(CNOTGate &g) {
  // ctrl qbits
  std::vector<mlir::Value> qubits_for_inst;
  auto qreg_ctrl_var_name = g.ctrl().var();
  auto qreg_tgt_var_name = g.tgt().var();

  if (!qubit_allocations.count(qreg_ctrl_var_name)) {
    // throw an error
  }

  if (!qubit_allocations.count(qreg_tgt_var_name)) {
    // throw an error
  }

  // Get CTRL Qubit MLIR Value
  std::uint64_t ctrl_idx = g.ctrl().offset().value();
  auto ctrlqkey = std::make_pair(qreg_ctrl_var_name, ctrl_idx);

  mlir::Value ctl_qbit_value;
  if (extracted_qubits.count(ctrlqkey)) {
    ctl_qbit_value = extracted_qubits[ctrlqkey];
  } else {
    auto qubits = qubit_allocations[qreg_ctrl_var_name].qubits();

  std::uint64_t qidx = g.ctrl().offset().value();
  auto integer_attr = mlir::IntegerAttr::get(builder.getI64Type(), qidx);
  mlir::Value pos2 = builder.create<mlir::ConstantOp>(location, integer_attr);
  mlir::Value ctrl_qbit_value =
      builder.create<mlir::vector::ExtractElementOp>(location, qubits, pos2);
  qubits_for_inst.push_back(ctrl_qbit_value);
    auto integer_attr = mlir::IntegerAttr::get(builder.getI64Type(), ctrl_idx);
    mlir::Value pos = builder.create<mlir::ConstantOp>(location, integer_attr);
    ctl_qbit_value = builder.create<mlir::quantum::ExtractQubitOp>(
        location, qubit_type, qubits, pos);

  // tgt qubit
  auto qreg_tgt_var_name = g.tgt().var();
  if (!qubit_allocations.count(qreg_tgt_var_name)) {
    // throw an error
    extracted_qubits.insert(
        {std::make_pair(qreg_ctrl_var_name, ctrl_idx), ctl_qbit_value});
  }
  qubits_for_inst.push_back(ctl_qbit_value);

  // Get Target Qubit MLIR Value
  std::uint64_t tgt_idx = g.tgt().offset().value();
  auto tgtqkey = std::make_pair(qreg_tgt_var_name, tgt_idx);

  mlir::Value tgt_qbit_value;
  if (extracted_qubits.count(tgtqkey)) {
    tgt_qbit_value = extracted_qubits[tgtqkey];
  } else {
    auto qubits = qubit_allocations[qreg_tgt_var_name].qubits();

  auto tgt_qubits = qubit_allocations[qreg_tgt_var_name].qubits();
    auto integer_attr = mlir::IntegerAttr::get(builder.getI64Type(), tgt_idx);
    mlir::Value pos = builder.create<mlir::ConstantOp>(location, integer_attr);
    tgt_qbit_value = builder.create<mlir::quantum::ExtractQubitOp>(
        location, qubit_type, qubits, pos);

  std::uint64_t qidxt = g.tgt().offset().value();
  auto integer_attrt = mlir::IntegerAttr::get(builder.getI64Type(), qidxt);
  mlir::Value post = builder.create<mlir::ConstantOp>(location, integer_attrt);
  mlir::Value tgt_qbit_value = builder.create<mlir::vector::ExtractElementOp>(
      location, tgt_qubits, post);
    extracted_qubits.insert(
        {std::make_pair(qreg_tgt_var_name, tgt_idx), tgt_qbit_value});
  }
  qubits_for_inst.push_back(tgt_qbit_value);

  builder.create<mlir::quantum::InstOp>(location, str_attr,
                                        llvm::makeArrayRef(qubits_for_inst),
                                        params_dataAttribute);
  builder.create<mlir::quantum::InstOp>(
      location, mlir::NoneType::get(builder.getContext()), str_attr,
      llvm::makeArrayRef(qubits_for_inst), params_dataAttribute);

  // std::uint64_t qidx = g.ctrl().offset().value();
  // auto integer_attr = mlir::IntegerAttr::get(builder.getI64Type(), qidx);
  // mlir::Value pos2 = builder.create<mlir::ConstantOp>(location,
  // integer_attr); mlir::Value ctrl_qbit_value =
  //     builder.create<mlir::vector::ExtractElementOp>(location, qubits, pos2);
  // qubits_for_inst.push_back(ctrl_qbit_value);

  // // tgt qubit
  // auto qreg_tgt_var_name = g.tgt().var();
  // if (!qubit_allocations.count(qreg_tgt_var_name)) {
  //   // throw an error
  // }

  // auto tgt_qubits = qubit_allocations[qreg_tgt_var_name].qubits();

  // std::uint64_t qidxt = g.tgt().offset().value();
  // auto integer_attrt = mlir::IntegerAttr::get(builder.getI64Type(), qidxt);
  // mlir::Value post = builder.create<mlir::ConstantOp>(location,
  // integer_attrt); mlir::Value tgt_qbit_value =
  // builder.create<mlir::vector::ExtractElementOp>(
  //     location, tgt_qubits, post);
  // qubits_for_inst.push_back(tgt_qbit_value);

  //   builder.create<mlir::quantum::InstOp>(location, str_attr,
  //                                         llvm::makeArrayRef(qubits_for_inst),
  //                                         params_dataAttribute);
}
//   void visit(BarrierGate&) = 0;
void StaqToMLIR::visit(DeclaredGate &g) {
@@ -267,14 +338,23 @@ void StaqToMLIR::visit(DeclaredGate &g) {
    }

    if (!in_sub_kernel) {
      std::uint64_t qidx = g.qarg(i).offset().value();
      auto qbit_key = std::make_pair(qreg_var_name, qidx);
      mlir::Value qbit_value;
      if (extracted_qubits.count(qbit_key)) {
        qbit_value = extracted_qubits[qbit_key];
      } else {
        auto qubits = qubit_allocations[qreg_var_name].qubits();

      std::uint64_t qidx = g.qarg(i).offset().value();
        auto integer_attr = mlir::IntegerAttr::get(builder.getI64Type(), qidx);
        mlir::Value pos =
            builder.create<mlir::ConstantOp>(location, integer_attr);
      mlir::Value qbit_value =
          builder.create<mlir::vector::ExtractElementOp>(location, qubits, pos);
        qbit_value = builder.create<mlir::quantum::ExtractQubitOp>(
            location, qubit_type, qubits, pos);

        extracted_qubits.insert(
            {std::make_pair(qreg_var_name, qidx), qbit_value});
      }
      qubits_for_inst.push_back(qbit_value);
    } else {
      auto qubit_kernel_arg = temporary_sub_kernel_args[qreg_var_name];
@@ -282,9 +362,9 @@ void StaqToMLIR::visit(DeclaredGate &g) {
    }
  }

  builder.create<mlir::quantum::InstOp>(location, str_attr,
                                        llvm::makeArrayRef(qubits_for_inst),
                                        params_dataAttribute);
  builder.create<mlir::quantum::InstOp>(
      location, mlir::NoneType::get(builder.getContext()), str_attr,
      llvm::makeArrayRef(qubits_for_inst), params_dataAttribute);
}

}  // namespace qasm_parser
 No newline at end of file
+6 −0
Original line number Diff line number Diff line
@@ -33,6 +33,12 @@ class StaqToMLIR : public staq::ast::Visitor {
  mlir::Block * main_entry_point;
  std::vector<std::string> function_names;
  
  mlir::Type qubit_type;
  mlir::Type array_type;
  mlir::Type result_type;
  
  std::map<std::pair<std::string, std::uint64_t>, mlir::Value> extracted_qubits;

 public:
  StaqToMLIR(mlir::MLIRContext &context);
  mlir::ModuleOp module() {return theModule;}
+52 −41

File changed.

Preview size limit exceeded, changes collapsed.

Loading