Commit 42f81c84 authored by Mccaskey, Alex's avatar Mccaskey, Alex
Browse files

adding set qreg op and qrt call. remove module attrs for function names.

parent 730ba97c
Loading
Loading
Loading
Loading
+101 −1
Original line number Diff line number Diff line
@@ -25,7 +25,7 @@ bool isOpaqueTypeWithName(mlir::Type type, std::string dialect,
QuantumDialect::QuantumDialect(mlir::MLIRContext *ctx)
    : mlir::Dialect(getDialectNamespace(), ctx, TypeID::get<QuantumDialect>()) {
  addOperations<InstOp, QallocOp, ExtractQubitOp, DeallocOp, QRTInitOp,
                QRTFinalizeOp>();
                QRTFinalizeOp, SetQregOp>();
}
QRTFinalizeOpAdaptor::QRTFinalizeOpAdaptor(::mlir::ValueRange values,
                                           ::mlir::DictionaryAttr attrs)
@@ -1000,5 +1000,105 @@ static mlir::ParseResult parseQallocOp(mlir::OpAsmParser &parser,
                                    ::mlir::OperationState &result) {
  return ::parseQallocOp(parser, result);
}

//===----------------------------------------------------------------------===//
// ::mlir::quantum::SetQregOp definitions
//===----------------------------------------------------------------------===//

SetQregOpAdaptor::SetQregOpAdaptor(::mlir::ValueRange values, ::mlir::DictionaryAttr attrs)  : odsOperands(values), odsAttrs(attrs) {

}

SetQregOpAdaptor::SetQregOpAdaptor(SetQregOp&op)  : odsOperands(op->getOperands()), odsAttrs(op->getAttrDictionary()) {

}

std::pair<unsigned, unsigned> SetQregOpAdaptor::getODSOperandIndexAndLength(unsigned index) {
  return {index, 1};
}

::mlir::ValueRange SetQregOpAdaptor::getODSOperands(unsigned index) {
  auto valueRange = getODSOperandIndexAndLength(index);
  return {std::next(odsOperands.begin(), valueRange.first),
           std::next(odsOperands.begin(), valueRange.first + valueRange.second)};
}

::mlir::Value SetQregOpAdaptor::qreg() {
  return *getODSOperands(0).begin();
}

::mlir::LogicalResult SetQregOpAdaptor::verify(::mlir::Location loc) {
  return ::mlir::success();
}

::llvm::StringRef SetQregOp::getOperationName() {
  return "quantum.set_qreg";
}

std::pair<unsigned, unsigned> SetQregOp::getODSOperandIndexAndLength(unsigned index) {
  return {index, 1};
}

::mlir::Operation::operand_range SetQregOp::getODSOperands(unsigned index) {
  auto valueRange = getODSOperandIndexAndLength(index);
  return {std::next(getOperation()->operand_begin(), valueRange.first),
           std::next(getOperation()->operand_begin(), valueRange.first + valueRange.second)};
}

::mlir::Value SetQregOp::qreg() {
  return *getODSOperands(0).begin();
}

::mlir::MutableOperandRange SetQregOp::qregMutable() {
  auto range = getODSOperandIndexAndLength(0);
  return ::mlir::MutableOperandRange(getOperation(), range.first, range.second);
}

std::pair<unsigned, unsigned> SetQregOp::getODSResultIndexAndLength(unsigned index) {
  return {index, 1};
}

::mlir::Operation::result_range SetQregOp::getODSResults(unsigned index) {
  auto valueRange = getODSResultIndexAndLength(index);
  return {std::next(getOperation()->result_begin(), valueRange.first),
           std::next(getOperation()->result_begin(), valueRange.first + valueRange.second)};
}

void SetQregOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value qreg) {
  odsState.addOperands(qreg);
}

void SetQregOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value qreg) {
  odsState.addOperands(qreg);
  assert(resultTypes.size() == 0u && "mismatched number of results");
  odsState.addTypes(resultTypes);
}

void SetQregOp::build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) {
  assert(operands.size() == 1u && "mismatched number of parameters");
  odsState.addOperands(operands);
  odsState.addAttributes(attributes);
  assert(resultTypes.size() == 0u && "mismatched number of return types");
  odsState.addTypes(resultTypes);
}

::mlir::LogicalResult SetQregOp::verify() {
  if (failed(SetQregOpAdaptor(*this).verify(this->getLoc()))) return ::mlir::failure();
  {
    unsigned index = 0; (void)index;
    auto valueGroup0 = getODSOperands(0);
    for (::mlir::Value v : valueGroup0) {
      (void)v;
      if (!((isOpaqueTypeWithName(v.getType(), "quantum", "QregType")))) {
        return emitOpError("operand #") << index << " must be opaque qreg type, but got " << v.getType();
      }
      ++index;
    }
  }
  {
    unsigned index = 0; (void)index;
  }
  return ::mlir::success();
}
}  // namespace quantum
}  // namespace mlir
+32 −0
Original line number Diff line number Diff line
@@ -12,6 +12,7 @@ class DeallocOp;
class ExtractQubitOp;
class QRTInitOp;
class QRTFinalizeOp;
class SetQregOp;
}  // namespace quantum
}  // namespace mlir

@@ -325,6 +326,37 @@ class DeallocOp
                    ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
  ::mlir::LogicalResult verify();
};

class SetQregOpAdaptor {
public:
  SetQregOpAdaptor(::mlir::ValueRange values, ::mlir::DictionaryAttr attrs = nullptr);
  SetQregOpAdaptor(SetQregOp&op);
  std::pair<unsigned, unsigned> getODSOperandIndexAndLength(unsigned index);
  ::mlir::ValueRange getODSOperands(unsigned index);
  ::mlir::Value qreg();
  ::mlir::LogicalResult verify(::mlir::Location loc);

private:
  ::mlir::ValueRange odsOperands;
  ::mlir::DictionaryAttr odsAttrs;
};
class SetQregOp : public ::mlir::Op<SetQregOp, ::mlir::OpTrait::ZeroRegion, ::mlir::OpTrait::ZeroResult, ::mlir::OpTrait::ZeroSuccessor, ::mlir::OpTrait::OneOperand> {
public:
  using Op::Op;
  using Op::print;
  using Adaptor = SetQregOpAdaptor;
  static ::llvm::StringRef getOperationName();
  std::pair<unsigned, unsigned> getODSOperandIndexAndLength(unsigned index);
  ::mlir::Operation::operand_range getODSOperands(unsigned index);
  ::mlir::Value qreg();
  ::mlir::MutableOperandRange qregMutable();
  std::pair<unsigned, unsigned> getODSResultIndexAndLength(unsigned index);
  ::mlir::Operation::result_range getODSResults(unsigned index);
  static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value qreg);
  static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value qreg);
  static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
  ::mlir::LogicalResult verify();
};
}  // namespace quantum
}  // namespace mlir

+5 −1
Original line number Diff line number Diff line
@@ -15,6 +15,7 @@ class QuantumMLIRGenerator {
  mlir::ModuleOp m_module;
  mlir::OpBuilder builder;
  mlir::Block* main_entry_block;
  std::vector<std::string> function_names;

 public:
  QuantumMLIRGenerator(mlir::MLIRContext& ctx) : context(ctx), builder(&ctx) {}
@@ -24,7 +25,8 @@ class QuantumMLIRGenerator {
  // mlir using the quantum dialect. This may also be used for
  // introducing any initialization operations before
  // generation of the rest of the mlir code.
  virtual void initialize_mlirgen() = 0;
  virtual void initialize_mlirgen(bool add_entry_point = true,
                                  const std::string file_name = "") = 0;

  // This method can be implemented by subclasses to map a
  // quantum code in a subclass-specific source language to
@@ -36,6 +38,8 @@ class QuantumMLIRGenerator {
    return mlir::OwningModuleRef(mlir::OwningOpRef<mlir::ModuleOp>(m_module));
  }

  std::vector<std::string> seen_function_names() { return function_names; }

  // Finalize method, override to provide any end operations
  // to the module (like a return_op).
  virtual void finalize_mlirgen() = 0;
+95 −116
Original line number Diff line number Diff line
@@ -16,9 +16,8 @@ static std::vector<std::string> builtins{
    "cy", "swap", "ch",  "ccx", "crz", "cu1", "cu2", "cu3"};

static std::vector<std::string> search_for_inliner{
    "u3", "u2",   "u1",  "cx",  "id",  "u0",  "x",   "y",  "z",
    "h",  "s",    "sdg", "t",   "tdg", "rx",  "ry",  "rz", "cz",
    "cy", "swap"};
    "u3", "u2",  "u1", "cx",  "id", "u0", "x",  "y",  "z",  "h",
    "s",  "sdg", "t",  "tdg", "rx", "ry", "rz", "cz", "cy", "swap"};

void CountGateDecls::visit(GateDecl &g) {
  auto name = g.id();
@@ -59,8 +58,19 @@ void OpenQasmMLIRGenerator::visit(Program &prog) {
  auto int_type = builder.getI32Type();
  auto argv_type =
      mlir::OpaqueType::get(dialect, llvm::StringRef("ArgvType"), &context);
  auto qreg_type =
      mlir::OpaqueType::get(dialect, llvm::StringRef("qreg"), &context);

  if (add_main) {
    std::vector<mlir::Type> arg_types_vec2{};
    auto func_type2 =
        builder.getFunctionType(llvm::makeArrayRef(arg_types_vec2), llvm::None);
    auto proto2 = mlir::FuncOp::create(
        builder.getUnknownLoc(), "__internal_mlir_" + file_name, func_type2);
    mlir::FuncOp function2(proto2);
    auto save_main_entry_block = function2.addEntryBlock();

    if (add_entry_point) {
      std::vector<mlir::Type> arg_types_vec{int_type, argv_type};
      auto func_type =
          builder.getFunctionType(llvm::makeArrayRef(arg_types_vec), int_type);
@@ -70,13 +80,55 @@ void OpenQasmMLIRGenerator::visit(Program &prog) {
      main_entry_block = function.addEntryBlock();
      auto &entryBlock = *main_entry_block;
      builder.setInsertionPointToStart(&entryBlock);

      auto main_args = main_entry_block->getArguments();
      builder.create<mlir::quantum::QRTInitOp>(builder.getUnknownLoc(),
                                               main_args[0], main_args[1]);

      // call the function from main, run finalize, and return 0
      builder.create<mlir::CallOp>(builder.getUnknownLoc(), function2);
      builder.create<mlir::quantum::QRTFinalizeOp>(builder.getUnknownLoc());
      is_first_inst = false;
      auto integer_attr = mlir::IntegerAttr::get(builder.getI32Type(), 0);
      mlir::Value ret_zero = builder.create<mlir::ConstantOp>(
          builder.getUnknownLoc(), integer_attr);
      builder.create<mlir::ReturnOp>(builder.getUnknownLoc(), ret_zero);
      m_module.push_back(function);
      function_names.push_back("main");
    }

    std::vector<mlir::Type> arg_types_vec3{qreg_type};
    auto func_type3 =
        builder.getFunctionType(llvm::makeArrayRef(arg_types_vec3), llvm::None);
    auto proto3 =
        mlir::FuncOp::create(builder.getUnknownLoc(), file_name, func_type3);
    mlir::FuncOp function3(proto3);

    auto tmp = function3.addEntryBlock();
    builder.setInsertionPointToStart(tmp);
    builder.create<mlir::quantum::SetQregOp>(builder.getUnknownLoc(),
                                             tmp->getArguments()[0]);
    builder.create<mlir::CallOp>(builder.getUnknownLoc(), function2);
    builder.create<mlir::ReturnOp>(builder.getUnknownLoc(), llvm::None);
    builder.setInsertionPointToStart(save_main_entry_block);

    m_module.push_back(function2);
    m_module.push_back(function3);
    function_names.push_back("__internal_mlir_" + file_name);
    function_names.push_back(file_name);

    main_entry_block = save_main_entry_block;

    // Create function_name(opaque.qreg q), setup main to call it
  }
  prog.foreach_stmt([this](auto &stmt) { stmt.accept(*this); });
}

void OpenQasmMLIRGenerator::initialize_mlirgen() {}
void OpenQasmMLIRGenerator::initialize_mlirgen(bool _add_entry_point,
                                               const std::string function) {
  file_name = function;
  add_entry_point = _add_entry_point;
}

void OpenQasmMLIRGenerator::mlirgen(const std::string &src) {
  using namespace staq;
@@ -90,57 +142,6 @@ void OpenQasmMLIRGenerator::mlirgen(const std::string &src) {
    std::cout << e.what() << "\n";
  }

  // Replace standard controlled gates with expanded versions
  // First get mapping of gate name to composite gates
  // auto tmp_prog = parser::parse_string(R"(OPENQASM 2.0;
  // include "qelib1.inc";
  // )");

  // class CollectGateDecomps : public Traverse {
  //  public:
  //   std::map<std::string, std::list<std::unique_ptr<Gate>>> gate_decomps;

  //       void
  //       visit(GateDecl &gate) override {
  //     if (gate.id() == "ccx") {
  //       std::cout << "Found CCX\n";
  //       gate.body();
  //     }
  //     gate_decomps.insert({gate.id(), gate.body()});
  //   }
  // };
  // CollectGateDecomps collect;
  // tmp_prog->foreach_stmt([&](auto &stmt) { stmt.accept(collect); });

  // class BuildReplacerMap : public Traverse {
  //  protected:
  //   std::map<std::string, std::list<std::unique_ptr<Gate>>> &gate_decomps;

  //  public:
  //   std::unordered_map<int, std::list<std::unique_ptr<Gate>>> replacer_map;
  //   BuildReplacerMap(
  //       std::map<std::string, std::list<std::unique_ptr<Gate>>> &gd)
  //       : gate_decomps(gd) {}

  //   void visit(DeclaredGate &gate) {
  //     auto name = gate.name();

  //     if (name == "ccx") {
  //       std::cout << "adding to replacer map for ccx\n";
  //       auto uid = gate.uid();
  //       // auto gates = ;
  //       if (!replacer_map.count(uid)) {
  //       replacer_map.insert({uid, std::move(gate_decomps[name])});
  //       }
  //     }
  //   }
  // };

  // BuildReplacerMap replacer_builder(collect.gate_decomps);
  // prog->foreach_stmt([&](auto &stmt) { stmt.accept(replacer_builder); });

  // First, get uid of declared gate to replace

  visit(*prog);

  return;
@@ -154,29 +155,7 @@ void OpenQasmMLIRGenerator::finalize_mlirgen() {
  }

  if (add_main) {
    builder.create<mlir::quantum::QRTFinalizeOp>(builder.getUnknownLoc());

    auto integer_attr = mlir::IntegerAttr::get(builder.getI32Type(), 0);
    mlir::Value ret_zero =
        builder.create<mlir::ConstantOp>(builder.getUnknownLoc(), integer_attr);

    builder.create<mlir::ReturnOp>(builder.getUnknownLoc(), ret_zero);

    std::vector<llvm::StringRef> tmp(function_names.begin(),
                                     function_names.end());

    auto function_names_datatype = mlir::VectorType::get(
        {static_cast<std::int64_t>(function_names.size())},
        builder.getI64Type());
    auto function_names_ref = llvm::makeArrayRef(tmp);
    auto attrs = mlir::DenseStringElementsAttr::get(function_names_datatype,
                                                    function_names_ref);

    mlir::Identifier id = mlir::Identifier::get("quantum.internal_functions",
                                                builder.getContext());

    m_module.setAttrs(
        llvm::makeArrayRef({mlir::NamedAttribute(std::make_pair(id, attrs))}));
    builder.create<mlir::ReturnOp>(builder.getUnknownLoc(), llvm::None);
  }
}

@@ -232,12 +211,12 @@ void OpenQasmMLIRGenerator::visit(RegisterDecl &d) {
    auto location =
        builder.getFileLineColLoc(builder.getIdentifier(fname), line, col);

    if (is_first_inst && !in_sub_kernel) {
      auto main_args = main_entry_block->getArguments();
      builder.create<mlir::quantum::QRTInitOp>(location, main_args[0],
                                               main_args[1]);
      is_first_inst = false;
    }
    // if (is_first_inst && !in_sub_kernel) {
    //   auto main_args = main_entry_block->getArguments();
    //   builder.create<mlir::quantum::QRTInitOp>(location, main_args[0],
    //                                            main_args[1]);
    //   is_first_inst = false;
    // }
    auto integer_type = builder.getI64Type();
    auto integer_attr = mlir::IntegerAttr::get(integer_type, size);

@@ -298,13 +277,13 @@ void OpenQasmMLIRGenerator::visit(UGate &u) {
  auto location =
      builder.getFileLineColLoc(builder.getIdentifier(fname), line, col);

  if (is_first_inst && !in_sub_kernel) {
    auto main_args = main_entry_block->getArguments();
  // if (is_first_inst && !in_sub_kernel) {
  //   auto main_args = main_entry_block->getArguments();

    builder.create<mlir::quantum::QRTInitOp>(location, main_args[0],
                                             main_args[1]);
    is_first_inst = false;
  }
  //   builder.create<mlir::quantum::QRTInitOp>(location, main_args[0],
  //                                            main_args[1]);
  //   is_first_inst = false;
  // }

  auto str_attr = builder.getStringAttr("u3");
  // params
@@ -353,13 +332,13 @@ void OpenQasmMLIRGenerator::visit(CNOTGate &g) {
  auto location =
      builder.getFileLineColLoc(builder.getIdentifier(fname), line, col);

  if (is_first_inst && !in_sub_kernel) {
    auto main_args = main_entry_block->getArguments();
  // if (is_first_inst && !in_sub_kernel) {
  //   auto main_args = main_entry_block->getArguments();

    builder.create<mlir::quantum::QRTInitOp>(location, main_args[0],
                                             main_args[1]);
    is_first_inst = false;
  }
  //   builder.create<mlir::quantum::QRTInitOp>(location, main_args[0],
  //                                            main_args[1]);
  //   is_first_inst = false;
  // }
  auto str_attr = builder.getStringAttr("cx");

  // params
@@ -432,13 +411,13 @@ void OpenQasmMLIRGenerator::visit(DeclaredGate &g) {
  auto location =
      builder.getFileLineColLoc(builder.getIdentifier(fname), line, col);

  if (is_first_inst && !in_sub_kernel) {
    auto main_args = main_entry_block->getArguments();
  // if (is_first_inst && !in_sub_kernel) {
  //   auto main_args = main_entry_block->getArguments();

    builder.create<mlir::quantum::QRTInitOp>(location, main_args[0],
                                             main_args[1]);
    is_first_inst = false;
  }
  //   builder.create<mlir::quantum::QRTInitOp>(location, main_args[0],
  //                                            main_args[1]);
  //   is_first_inst = false;
  // }

  auto str_attr = builder.getStringAttr(g.name());

+20 −18
Original line number Diff line number Diff line
@@ -5,9 +5,9 @@
#pragma GCC diagnostic ignored "-Wdeprecated-copy"
#pragma GCC diagnostic ignored "-Wunused-function"

#include "mlir/IR/Region.h"
#include "ast/ast.hpp"
#include "ast/traversal.hpp"
#include "mlir/IR/Region.h"
#include "mlir_generator.hpp"
#include "optimization/simplify.hpp"
#include "parser/parser.hpp"
@@ -25,10 +25,10 @@ class OpenQasmMLIRGenerator : public qcor::QuantumMLIRGenerator,
  std::map<std::string, mlir::quantum::QallocOp> qubit_allocations;
  bool in_sub_kernel = false;
  std::map<std::string, mlir::Value> temporary_sub_kernel_args;
  std::vector<std::string> function_names;
  bool is_first_inst = true;
  bool add_main = true;

  std::string file_name = "main";
  bool add_entry_point = true;
  mlir::Type qubit_type;
  mlir::Type array_type;
  mlir::Type result_type;
@@ -36,8 +36,10 @@ class OpenQasmMLIRGenerator : public qcor::QuantumMLIRGenerator,
  std::map<std::pair<std::string, std::uint64_t>, mlir::Value> extracted_qubits;

 public:
  OpenQasmMLIRGenerator(mlir::MLIRContext &context) : QuantumMLIRGenerator(context){}
  void initialize_mlirgen() override;
  OpenQasmMLIRGenerator(mlir::MLIRContext &context)
      : QuantumMLIRGenerator(context) {}
  void initialize_mlirgen(bool add_entry_point = true,
                          const std::string file_name = "") override;
  void mlirgen(const std::string &src) override;
  void finalize_mlirgen() override;

@@ -65,6 +67,7 @@ class OpenQasmMLIRGenerator : public qcor::QuantumMLIRGenerator,
class CountGateDecls : public staq::ast::Visitor {
 private:
  std::size_t &count;

 public:
  std::vector<std::string> gates_to_inline;

@@ -88,6 +91,5 @@ public:
  void visit(UGate &u) override {}
  void visit(CNOTGate &cx) override {}
  void visit(DeclaredGate &g) override {}
  
};
}  // namespace qcor
 No newline at end of file
Loading