Commit f1accb94 authored by Mccaskey, Alex's avatar Mccaskey, Alex
Browse files

setup openqasm mlir gen to support variable rotation parameters with available...


setup openqasm mlir gen to support variable rotation parameters with available unary and binary arithmetic operations

Signed-off-by: Mccaskey, Alex's avatarAlex McCaskey <mccaskeyaj@ornl.gov>
parent 2c896a1a
Loading
Loading
Loading
Loading
+186 −61
Original line number Diff line number Diff line
#include "openqasm_mlir_generator.hpp"

#include "mlir/Dialect/StandardOps/IR/Ops.h"

namespace qcor {
@@ -24,6 +25,96 @@ void CountGateDecls::visit(GateDecl &g) {
  count++;
}

class MapParameterSubExpr : public staq::ast::Visitor {
 protected:
  mlir::OpBuilder &builder;
  mlir::Location &location;
  std::map<std::string, mlir::Value> &symbol_table;

  mlir::Value current_value;

 public:
  MapParameterSubExpr(mlir::OpBuilder &b, mlir::Location &l,
                      std::map<std::string, mlir::Value> &symbols)
      : builder(b), location(l), symbol_table(symbols) {}
  mlir::Value getValue() { return current_value; }
  void visit(VarAccess &) override {}
  void visit(BExpr &expr) override {
    Expr &left = expr.lexp();
    Expr &right = expr.rexp();
    mlir::Value left_value, right_value;
    if (left.constant_eval().has_value()) {
      auto constant_value = left.constant_eval().value();
      left_value = builder.create<mlir::ConstantOp>(
          location, mlir::FloatAttr::get(builder.getF64Type(), constant_value));
    } else {
      left.accept(*this);
      left_value = current_value;
    }

    if (right.constant_eval().has_value()) {
      auto constant_value = right.constant_eval().value();
      right_value = builder.create<mlir::ConstantOp>(
          location, mlir::FloatAttr::get(builder.getF64Type(), constant_value));
    } else {
      right.accept(*this);
      right_value = current_value;
    }

    if (expr.op() == BinaryOp::Divide) {
      current_value =
          builder.create<mlir::DivFOp>(location, left_value, right_value);
    } else if (expr.op() == BinaryOp::Plus) {
      current_value =
          builder.create<mlir::AddFOp>(location, left_value, right_value);
    } else if (expr.op() == BinaryOp::Minus) {
      current_value =
          builder.create<mlir::SubFOp>(location, left_value, right_value);
    } else if (expr.op() == BinaryOp::Times) {
      current_value =
          builder.create<mlir::MulFOp>(location, left_value, right_value);
    } else if (expr.op() == BinaryOp::Pow) {
      std::cout << "[OpenQASM MLIR Gen] pow(x,y) not supported yet.\n";
      exit(1);
    }
  }

  void visit(UExpr &expr) override {
    Expr &sub = expr.subexp();
    sub.accept(*this);
    if (expr.op() == UnaryOp::Neg) {
      current_value = builder.create<mlir::NegFOp>(location, current_value);
    } else {
      std::cout << "[OpenQASM MLIR Gen] no other unary ops supported.\n";
      exit(1);
    }
  }

  void visit(PiExpr &) override {}
  void visit(IntExpr &) override {}
  void visit(RealExpr &r) override {}
  void visit(VarExpr &v) override {
    if (symbol_table.count(v.var())) {
      current_value = symbol_table[v.var()];
    } else {
      std::cout << "[OpenQasm MLIR Gen] Error, " << v.var()
                << " is not a valid var in the symbol table.\n";
    }
  }
  void visit(ResetStmt &) override {}
  void visit(IfStmt &) override {}
  void visit(BarrierGate &) override {}
  void visit(GateDecl &) override {}
  void visit(OracleDecl &) override {}
  void visit(RegisterDecl &) override {}
  void visit(AncillaDecl &) override {}
  void visit(Program &prog) override {}
  void visit(MeasureStmt &m) override {}
  void visit(UGate &u) override {}
  void visit(CNOTGate &cx) override {}
  void visit(DeclaredGate &g) override {}
};

void OpenQasmMLIRGenerator::visit(Program &prog) {
  // How many statements are there (starts with 25)
  auto n_stmts = prog.body().size();
@@ -107,7 +198,8 @@ void OpenQasmMLIRGenerator::visit(Program &prog) {
                                             tmp->getArguments()[0]);
    builder.create<mlir::CallOp>(builder.getUnknownLoc(), function2);
    builder.create<mlir::quantum::QRTFinalizeOp>(builder.getUnknownLoc());
    builder.create<mlir::ReturnOp>(builder.getUnknownLoc(), llvm::ArrayRef<mlir::Value>());
    builder.create<mlir::ReturnOp>(builder.getUnknownLoc(),
                                   llvm::ArrayRef<mlir::Value>());
    builder.setInsertionPointToStart(save_main_entry_block);

    m_module.push_back(function2);
@@ -154,8 +246,8 @@ void OpenQasmMLIRGenerator::finalize_mlirgen() {

  if (add_main) {
    // builder.create<mlir::ReturnOp>(builder.getUnknownLoc(), llvm::None);
        builder.create<mlir::ReturnOp>(builder.getUnknownLoc(), llvm::ArrayRef<mlir::Value>());

    builder.create<mlir::ReturnOp>(builder.getUnknownLoc(),
                                   llvm::ArrayRef<mlir::Value>());
  }
}

@@ -167,6 +259,11 @@ void OpenQasmMLIRGenerator::visit(GateDecl &gate_function) {

    function_names.push_back(name);

    auto cn_args = gate_function.c_params().size();
    for (std::size_t i = 0; i < cn_args; i++) {
      arg_types.push_back(builder.getF64Type());
    }

    auto n_args = gate_function.q_params().size();
    for (std::size_t i = 0; i < n_args; i++) {
      arg_types.push_back(qubit_type);
@@ -180,8 +277,14 @@ void OpenQasmMLIRGenerator::visit(GateDecl &gate_function) {

    auto arguments = entryBlock.getArguments();

    for (std::size_t i = 0; i < n_args; i++) {
    for (std::size_t i = 0; i < cn_args; i++) {
      auto argument = arguments[i];
      auto arg_name = gate_function.c_params()[i];
      temporary_sub_kernel_args.insert({arg_name, argument});
    }

    for (std::size_t i = 0; i < n_args; i++) {
      auto argument = arguments[cn_args + i];
      auto arg_name = gate_function.q_params()[i];
      temporary_sub_kernel_args.insert({arg_name, argument});
    }
@@ -211,12 +314,6 @@ void OpenQasmMLIRGenerator::visit(RegisterDecl &d) {
    auto location =
        builder.getFileLineColLoc(builder.getIdentifier(fname), line, col);

    // if (is_first_inst && !in_sub_kernel) {
    //   auto main_args = main_entry_block->getArguments();
    //   builder.create<mlir::quantum::QRTInitOp>(location, main_args[0],
    //                                            main_args[1]);
    //   is_first_inst = false;
    // }
    auto integer_type = builder.getI64Type();
    auto integer_attr = mlir::IntegerAttr::get(integer_type, size);

@@ -238,9 +335,6 @@ void OpenQasmMLIRGenerator::visit(MeasureStmt &m) {

  auto str_attr = builder.getStringAttr("mz");

  // params
  mlir::DenseElementsAttr params_dataAttribute;

  std::vector<mlir::Value> qubits_for_inst;
  auto qreg_var_name = m.q_arg().var();
  if (!qubit_allocations.count(qreg_var_name)) {
@@ -264,10 +358,11 @@ void OpenQasmMLIRGenerator::visit(MeasureStmt &m) {
  }
  qubits_for_inst.push_back(qbit_value);

  builder.create<mlir::quantum::InstOp>(location, result_type, str_attr,
                                        llvm::makeArrayRef(qubits_for_inst),
                                        params_dataAttribute);
  builder.create<mlir::quantum::InstOp>(
      location, result_type, str_attr, llvm::makeArrayRef(qubits_for_inst),
      llvm::makeArrayRef(std::vector<mlir::Value>{}));
}

void OpenQasmMLIRGenerator::visit(UGate &u) {
  auto pos = u.pos();
  auto line = pos.get_linenum();
@@ -277,23 +372,46 @@ void OpenQasmMLIRGenerator::visit(UGate &u) {
  auto location =
      builder.getFileLineColLoc(builder.getIdentifier(fname), line, col);

  // if (is_first_inst && !in_sub_kernel) {
  //   auto main_args = main_entry_block->getArguments();
  auto str_attr = builder.getStringAttr("u3");

  //   builder.create<mlir::quantum::QRTInitOp>(location, main_args[0],
  //                                            main_args[1]);
  //   is_first_inst = false;
  // }
  std::vector<mlir::Value> params_for_inst;
  auto &theta_expr = u.theta();

  auto str_attr = builder.getStringAttr("u3");
  // params
  auto dataType = mlir::VectorType::get({3}, builder.getF64Type());
  std::vector<double> v{u.theta().constant_eval().value(),
                        u.phi().constant_eval().value(),
                        u.lambda().constant_eval().value()};
  auto params_arr_ref = llvm::makeArrayRef(v);
  auto params_dataAttribute =
      mlir::DenseElementsAttr::get(dataType, params_arr_ref);
  MapParameterSubExpr visitor(builder, location, temporary_sub_kernel_args);
  if (theta_expr.constant_eval().has_value()) {
    double val = theta_expr.constant_eval().value();
    auto float_attr = mlir::FloatAttr::get(builder.getF64Type(), val);
    mlir::Value val_val =
        builder.create<mlir::ConstantOp>(location, float_attr);
    params_for_inst.push_back(val_val);
  } else {
    theta_expr.accept(visitor);
    params_for_inst.push_back(visitor.getValue());
  }

  auto &phi_expr = u.phi();
  if (phi_expr.constant_eval().has_value()) {
    double val = phi_expr.constant_eval().value();
    auto float_attr = mlir::FloatAttr::get(builder.getF64Type(), val);
    mlir::Value val_val =
        builder.create<mlir::ConstantOp>(location, float_attr);
    params_for_inst.push_back(val_val);
  } else {
    phi_expr.accept(visitor);
    params_for_inst.push_back(visitor.getValue());
  }

  auto &lambda_expr = u.lambda();
  if (lambda_expr.constant_eval().has_value()) {
    double val = lambda_expr.constant_eval().value();
    auto float_attr = mlir::FloatAttr::get(builder.getF64Type(), val);
    mlir::Value val_val =
        builder.create<mlir::ConstantOp>(location, float_attr);
    params_for_inst.push_back(val_val);
  } else {
    lambda_expr.accept(visitor);
    params_for_inst.push_back(visitor.getValue());
  }

  std::vector<mlir::Value> qubits_for_inst;
  auto qreg_var_name = u.arg().var();
@@ -301,26 +419,33 @@ void OpenQasmMLIRGenerator::visit(UGate &u) {
    // throw an error
  }

  mlir::Value qbit_value;
  if (u.arg().offset().has_value()) {
    std::uint64_t qidx = u.arg().offset().value();
    auto qbit_key = std::make_pair(qreg_var_name, qidx);
  mlir::Value qbit_value;
    if (extracted_qubits.count(qbit_key)) {
      qbit_value = extracted_qubits[qbit_key];
    } else {
      auto qubits = qubit_allocations[qreg_var_name].qubits();

      auto integer_attr = mlir::IntegerAttr::get(builder.getI64Type(), qidx);
    mlir::Value pos = builder.create<mlir::ConstantOp>(location, integer_attr);
      mlir::Value pos =
          builder.create<mlir::ConstantOp>(location, integer_attr);
      qbit_value = builder.create<mlir::quantum::ExtractQubitOp>(
          location, qubit_type, qubits, pos);

    extracted_qubits.insert({std::make_pair(qreg_var_name, qidx), qbit_value});
      extracted_qubits.insert(
          {std::make_pair(qreg_var_name, qidx), qbit_value});
    }
  } else {
    auto var_name = u.arg().var();
    qbit_value = temporary_sub_kernel_args[var_name];
  }
  qubits_for_inst.push_back(qbit_value);

  builder.create<mlir::quantum::InstOp>(
      location, mlir::NoneType::get(builder.getContext()), str_attr,
      llvm::makeArrayRef(qubits_for_inst), params_dataAttribute);
      llvm::makeArrayRef(qubits_for_inst), llvm::makeArrayRef(params_for_inst));
}

void OpenQasmMLIRGenerator::visit(CNOTGate &g) {
@@ -399,7 +524,8 @@ void OpenQasmMLIRGenerator::visit(CNOTGate &g) {

  builder.create<mlir::quantum::InstOp>(
      location, mlir::NoneType::get(builder.getContext()), str_attr,
      llvm::makeArrayRef(qubits_for_inst), params_dataAttribute);
      llvm::makeArrayRef(qubits_for_inst),
      llvm::makeArrayRef(std::vector<mlir::Value>{}));
}

void OpenQasmMLIRGenerator::visit(DeclaredGate &g) {
@@ -411,28 +537,27 @@ void OpenQasmMLIRGenerator::visit(DeclaredGate &g) {
  auto location =
      builder.getFileLineColLoc(builder.getIdentifier(fname), line, col);

  // if (is_first_inst && !in_sub_kernel) {
  //   auto main_args = main_entry_block->getArguments();

  //   builder.create<mlir::quantum::QRTInitOp>(location, main_args[0],
  //                                            main_args[1]);
  //   is_first_inst = false;
  // }

  auto str_attr = builder.getStringAttr(g.name());

  // params
  mlir::DenseElementsAttr params_dataAttribute;
  std::vector<mlir::Value> params_for_inst;
  if (g.num_cargs()) {
    auto dataType =
        mlir::VectorType::get({g.num_cargs()}, builder.getF64Type());
    std::vector<double> v;  //{0.0};
    for (int i = 0; i < g.num_cargs(); i++) {
      v.push_back(g.carg(i).constant_eval().value());
      if (g.carg(i).constant_eval().has_value()) {
        double val = g.carg(i).constant_eval().value();

        auto float_attr = mlir::FloatAttr::get(builder.getF64Type(), val);
        mlir::Value val_val =
            builder.create<mlir::ConstantOp>(location, float_attr);
        params_for_inst.push_back(val_val);
      } else {
        MapParameterSubExpr visitor(builder, location,
                                    temporary_sub_kernel_args);
        Expr &theta_expr = g.carg(i);
        theta_expr.accept(visitor);
        params_for_inst.push_back(visitor.getValue());
      }
    }
    auto params_arr_ref = llvm::makeArrayRef(v);
    params_dataAttribute =
        mlir::DenseElementsAttr::get(dataType, params_arr_ref);
  }

  // qbits
@@ -470,7 +595,7 @@ void OpenQasmMLIRGenerator::visit(DeclaredGate &g) {

  builder.create<mlir::quantum::InstOp>(
      location, mlir::NoneType::get(builder.getContext()), str_attr,
      llvm::makeArrayRef(qubits_for_inst), params_dataAttribute);
      llvm::makeArrayRef(qubits_for_inst), llvm::makeArrayRef(params_for_inst));
}

}  // namespace qcor
 No newline at end of file