Commit 236f98ca authored by Nguyen, Thien Minh's avatar Nguyen, Thien Minh
Browse files

Adding a CPhase angle combination pass



In case users write CPhase in a loop (rather than compute the total angle), this pass will consolidate the angle automatically.

Signed-off-by: default avatarThien Nguyen <nguyentm@ornl.gov>
parent 7c89e478
Loading
Loading
Loading
Loading
+10 −8
Original line number Diff line number Diff line
@@ -302,14 +302,13 @@ antlrcpp::Any qasm3_visitor::visitQuantumGateCall(
          symbol_table.array_qubit_symbol_name(qbit_var_name, idx_str);
      mlir::Value value;
      try {
        if (symbol_table.has_symbol(qubit_symbol_name)) {
          value = symbol_table.get_symbol(qubit_symbol_name);
        } else {
        // try catch is on this std::stoi(), if idx_str is not an integer,
        // then we drop out and try to evaluate the expression.
          value = get_or_extract_qubit(qbit_var_name, std::stoi(idx_str),
                                       location, symbol_table, builder);
        }
        const auto idx_val = std::stoi(idx_str);
        // Note: always use get_or_extract_qubit which has built-in qubit SSA
        // validation/adjust.
        value = get_or_extract_qubit(qbit_var_name, idx_val, location,
                                     symbol_table, builder);
      } catch (...) {
        if (symbol_table.has_symbol(idx_str)) {
          auto qubits = symbol_table.get_symbol(qbit_var_name);
@@ -321,6 +320,9 @@ antlrcpp::Any qasm3_visitor::visitQuantumGateCall(
            qbit = builder.create<mlir::IndexCastOp>(
                location, builder.getI64Type(), qbit);
          }

          // This is qubit extract by a variable index:
          symbol_table.invalidate_qubit_extracts(qbit_var_name);
          value = builder.create<mlir::quantum::ExtractQubitOp>(
              location, qubit_type, qubits, qbit);
          if (!symbol_table.has_symbol(qubit_symbol_name))
+162 −0
Original line number Diff line number Diff line
#include "CPhaseRotationMergingPass.hpp"
#include "Quantum/QuantumOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Target/LLVMIR.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/Passes.h"
#include <iostream>

namespace qcor {
void CPhaseRotationMergingPass::getDependentDialects(
    DialectRegistry &registry) const {
  registry.insert<LLVM::LLVMDialect>();
}

void CPhaseRotationMergingPass::runOnOperation() {
  // Walk the operations within the function.
  std::vector<mlir::quantum::ValueSemanticsInstOp> deadOps;
  getOperation().walk([&](mlir::quantum::ValueSemanticsInstOp op) {
    if (std::find(deadOps.begin(), deadOps.end(), op) != deadOps.end()) {
      // Skip this op since it was merged (forward search)
      return;
    }
    mlir::OpBuilder rewriter(op);
    auto inst_name = op.name();
    if (inst_name != "cphase") {
      return;
    }
    // Get the src ret qubit and the tgt ret qubit
    auto src_return_val = op.result().front();
    auto tgt_return_val = op.result().back();

    // Make sure they are used
    if (src_return_val.hasOneUse() && tgt_return_val.hasOneUse()) {
      // get the users of these values
      auto src_user = *src_return_val.user_begin();
      auto tgt_user = *tgt_return_val.user_begin();

      // Cast them to InstOps
      auto next_inst =
          dyn_cast_or_null<mlir::quantum::ValueSemanticsInstOp>(src_user);
      auto tmp_tgt =
          dyn_cast_or_null<mlir::quantum::ValueSemanticsInstOp>(tgt_user);

      if (!next_inst || !tmp_tgt) {
        // not inst ops
        return;
      }

      // We want the case where src_user and tgt_user are the same
      if (next_inst.getOperation() != tmp_tgt.getOperation()) {
        return;
      }

      // Ctrl -> ctrl; target -> target connections
      if (next_inst.getOperand(0) != src_return_val &&
          next_inst.getOperand(1) != tgt_return_val) {
        return;
      }

      if (next_inst.name() != "cphase") {
        return;
      }

      // They are the same operation, a cphase
      // so we have cphase src, tgt | cphase src, tgt
      // Combine the angles:
      const auto tryGetConstAngle =
          [](mlir::Value theta_var) -> std::optional<double> {
        if (!theta_var.getType().isa<mlir::FloatType>()) {
          return std::nullopt;
        }
        // Find the defining op:
        auto def_op = theta_var.getDefiningOp();
        if (def_op) {
          // Try cast:
          if (auto const_def_op =
                  dyn_cast_or_null<mlir::ConstantFloatOp>(def_op)) {
            llvm::APFloat theta_var_const_cal = const_def_op.getValue();
            return theta_var_const_cal.convertToDouble();
          }
        }
        return std::nullopt;
      };

      mlir::Value first_angle = op.getOperand(2);
      mlir::Value second_angle = next_inst.getOperand(2);
      rewriter.setInsertionPointAfter(next_inst);

      const auto first_angle_const = tryGetConstAngle(first_angle);
      const auto second_angle_const = tryGetConstAngle(second_angle);

      // Create a new instruction:
      // Return type: qubit; qubit
      std::vector<mlir::Type> ret_types{op.getOperand(0).getType(),
                                        op.getOperand(1).getType()};
      const std::string result_inst_name = "cphase";
      if (first_angle_const.has_value() && second_angle_const.has_value()) {
        // both angles are constant: pre-compute the total angle:
        const double totalAngle =
            first_angle_const.value() + second_angle_const.value();
        if (std::abs(totalAngle) > ZERO_ROTATION_TOLERANCE) {
          mlir::Value totalAngleVal = rewriter.create<mlir::ConstantOp>(
              op.getLoc(),
              mlir::FloatAttr::get(rewriter.getF64Type(), totalAngle));
          auto new_inst = rewriter.create<mlir::quantum::ValueSemanticsInstOp>(
              op.getLoc(), llvm::makeArrayRef(ret_types), result_inst_name,
              llvm::makeArrayRef({op.getOperand(0), op.getOperand(1)}),
              llvm::makeArrayRef({totalAngleVal}));
          // Input -> Output mapping (this instruction is to be removed)
          auto next_inst_result_0 = next_inst.result().front();
          auto next_inst_result_1 = next_inst.result().back();

          auto new_inst_result_0 = new_inst.result().front();
          auto new_inst_result_1 = new_inst.result().back();

          next_inst_result_0.replaceAllUsesWith(new_inst_result_0);
          next_inst_result_1.replaceAllUsesWith(new_inst_result_1);
        } else {
          // Zero rotation: just map from input -> output
          auto next_inst_result_0 = next_inst.result().front();
          auto next_inst_result_1 = next_inst.result().back();
          next_inst_result_0.replaceAllUsesWith(op.getOperand(0));
          next_inst_result_1.replaceAllUsesWith(op.getOperand(1));
        }
      } else {
        // Need to create an AddFOp
        auto add_op = rewriter.create<mlir::AddFOp>(op.getLoc(), first_angle,
                                                    second_angle);
        assert(add_op.result().getType().isa<mlir::FloatType>());

        auto new_inst = rewriter.create<mlir::quantum::ValueSemanticsInstOp>(
            op.getLoc(), llvm::makeArrayRef(ret_types), result_inst_name,
            llvm::makeArrayRef(op.getOperand(0)),
            llvm::makeArrayRef({add_op.result()}));
        // Input -> Output mapping (this instruction is to be removed)
        auto next_inst_result_0 = next_inst.result().front();
        auto next_inst_result_1 = next_inst.result().back();

        auto new_inst_result_0 = new_inst.result().front();
        auto new_inst_result_1 = new_inst.result().back();

        next_inst_result_0.replaceAllUsesWith(new_inst_result_0);
        next_inst_result_1.replaceAllUsesWith(new_inst_result_1);
      }

      // Cache instructions for delete.
      deadOps.emplace_back(op);
      deadOps.emplace_back(next_inst);
    }
  });

  for (auto &op : deadOps) {
    op->dropAllUses();
    op.erase();
  }
}
} // namespace qcor
+23 −0
Original line number Diff line number Diff line
#pragma once
#include "Quantum/QuantumOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Target/LLVMIR.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/Passes.h"

using namespace mlir;

namespace qcor {
// Merging 2 consecutive CPhase gates
// on the same control and target qubits.
struct CPhaseRotationMergingPass
    : public PassWrapper<CPhaseRotationMergingPass, OperationPass<ModuleOp>> {
  void getDependentDialects(DialectRegistry &registry) const override;
  void runOnOperation() final;
  CPhaseRotationMergingPass() {}

private:
  static constexpr double ZERO_ROTATION_TOLERANCE = 1e-9;
};
} // namespace qcor
 No newline at end of file
+3 −1
Original line number Diff line number Diff line
@@ -6,7 +6,7 @@
#include "optimizations/RotationMergingPass.hpp"
#include "optimizations/SimplifyQubitExtractPass.hpp"
#include "optimizations/SingleQubitGateMergingPass.hpp"

#include "optimizations/CphaseRotationMergingPass.hpp"
#include "quantum_to_llvm.hpp"
// Construct QCOR MLIR pass manager:
// Make sure we use the same set of passes and configs
@@ -34,6 +34,8 @@ void configureOptimizationPasses(mlir::PassManager &passManager) {
    
    // Rotation merging
    passManager.addPass(std::make_unique<RotationMergingPass>());
    passManager.addPass(std::make_unique<CPhaseRotationMergingPass>());

    // General gate sequence re-synthesize
    passManager.addPass(std::make_unique<SingleQubitGateMergingPass>());
    // Try permute gates to realize more merging opportunities