Commit 928a41e3 authored by Nguyen, Thien Minh's avatar Nguyen, Thien Minh
Browse files

Added Rz-CX permute pass and test



Also, rotation merging to drop zero-rotation.

Signed-off-by: default avatarThien Nguyen <nguyentm@ornl.gov>
parent 2288ba69
Loading
Loading
Loading
Loading
+28 −0
Original line number Diff line number Diff line
@@ -124,6 +124,34 @@ oracle q[0];
  EXPECT_EQ(countSubstring(llvm, "__quantum__rt__qubit_release_array"), 0);
}

TEST(qasm3PassManagerTester, checkPermuteAndCancel) {
  // Permute rz-cnot ==> gate cancellation
  const std::string src = R"#(OPENQASM 3;
include "qelib1.inc";

qubit q[2];

rz(0.123) q[0];
cx q[0], q[1];
rz(-0.123) q[0];
cx q[0], q[1];
)#";
  auto llvm =
      qcor::mlir_compile("qasm3", src, "test_kernel", qcor::OutputType::LLVMIR, false);
  std::cout << "LLVM:\n" << llvm << "\n";
  
  // Get the main kernel section only (there is the oracle LLVM section as well)
  llvm = llvm.substr(llvm.find("@test_kernel"));
  const auto last = llvm.find_first_of("}");
  llvm = llvm.substr(0, last + 1);
  std::cout << "LLVM:\n" << llvm << "\n";
  // Cancel all => No gates, extract, or alloc/dealloc:
  EXPECT_EQ(countSubstring(llvm, "__quantum__qis"), 0);
  EXPECT_EQ(countSubstring(llvm, "__quantum__rt__array_get_element_ptr_1d"), 0);
  EXPECT_EQ(countSubstring(llvm, "__quantum__rt__qubit_allocate_array"), 0);
  EXPECT_EQ(countSubstring(llvm, "__quantum__rt__qubit_release_array"), 0);
}

int main(int argc, char **argv) {
  ::testing::InitGoogleTest(&argc, argv);
  auto ret = RUN_ALL_TESTS();
+66 −0
Original line number Diff line number Diff line
#include "PermuteGatePass.hpp"
#include "Quantum/QuantumOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Target/LLVMIR.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/Passes.h"

namespace qcor {
void PermuteGatePass::getDependentDialects(DialectRegistry &registry) const {
  registry.insert<LLVM::LLVMDialect>();
}

void PermuteGatePass::runOnOperation() {
  // Walk the operations within the function.
  std::vector<mlir::quantum::ValueSemanticsInstOp> deadOps;
  getOperation().walk([&](mlir::quantum::ValueSemanticsInstOp op) {
    auto inst_name = op.name();
    // Move "Rz" forward
    if (inst_name.str() == "rz") {
      auto return_value = *op.result().begin();
      if (return_value.hasOneUse()) {
        // get that one user
        auto user = *return_value.user_begin();
        // cast to a inst op
        if (auto next_inst =
                dyn_cast_or_null<mlir::quantum::ValueSemanticsInstOp>(user)) {
          if (next_inst.name() == "cx" || next_inst.name() == "cnot") {
            if (next_inst.getOperand(0) == op.result().front()) {
              mlir::OpBuilder rewriter(op);
              // rz connect to control bit (operand 0)
              // Permute rz:
              rewriter.setInsertionPointAfter(next_inst);
              mlir::Value cx_ctrl_out = next_inst.result().front();
              auto new_rz_inst =
                  rewriter.create<mlir::quantum::ValueSemanticsInstOp>(
                      op.getLoc(),
                      llvm::makeArrayRef({op.getOperand(0).getType()}), "rz",
                      llvm::makeArrayRef(cx_ctrl_out),
                      llvm::makeArrayRef(op.getOperand(1)));

              // Input to original rz => cnot
              next_inst.getOperand(0).replaceAllUsesWith(op.getOperand(0));
              // First output of cx (control line) to output of the new rz
              // except the new rz which is connected to the output of cx
              cx_ctrl_out.replaceAllUsesExcept(
                  new_rz_inst.result().front(),
                  mlir::SmallPtrSet<Operation *, 1>{new_rz_inst});
              deadOps.emplace_back(op);
            }
          }
        }
      }
    }
  });

  for (auto &op : deadOps) {
    op->dropAllUses();
    op.erase();
  }
}
} // namespace qcor
 No newline at end of file
+9 −61
Original line number Diff line number Diff line
#pragma once
#include "Quantum/QuantumOps.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Target/LLVMIR.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/Passes.h"
#include "utils/gate_matrix.hpp"

namespace qcor {
// Permute gate: to realize more gate merging opportunity:
// See Fig. 5 in https://arxiv.org/pdf/1710.07345.pdf
// e.g. move Rz on the **control** qubit of CNOT gate back and forth
// which could help to combine with other gates.
class PermuteGatePattern
    : public mlir::OpRewritePattern<mlir::quantum::ValueSemanticsInstOp> {
public:
  PermuteGatePattern(mlir::MLIRContext *context)
      : OpRewritePattern<mlir::quantum::ValueSemanticsInstOp>(context,
                                                              /*benefit=*/10) {}
  mlir::LogicalResult
  matchAndRewrite(mlir::quantum::ValueSemanticsInstOp op,
                  mlir::PatternRewriter &rewriter) const override {
    auto inst_name = op.name();
    // Move "Rz" forward
    if (inst_name.str() == "rz") {
      auto return_value = *op.result().begin();
      if (return_value.hasOneUse()) {
        // get that one user
        auto user = *return_value.user_begin();
        // cast to a inst op
        if (auto next_inst =
                dyn_cast_or_null<mlir::quantum::ValueSemanticsInstOp>(user)) {
          if (next_inst.name() == "cx" || next_inst.name() == "cnot") {
            if (next_inst.getOperand(0) == op.result().front()) {
              // rz connect to control bit (operand 0)
              // Permute rz:
              rewriter.setInsertionPointAfter(next_inst);
              mlir::Value cx_ctrl_out = next_inst.result().front();
              auto new_rz_inst =
                  rewriter.create<mlir::quantum::ValueSemanticsInstOp>(
                      op.getLoc(),
                      llvm::makeArrayRef({op.getOperand(0).getType()}), "rz",
                      llvm::makeArrayRef(cx_ctrl_out),
                      llvm::makeArrayRef(op.getOperand(1)));
              
              // Input to original rz => cnot
              next_inst.getOperand(0).replaceAllUsesWith(op.getOperand(0));
              // First output of cx (control line) to output of the new rz
              // except the new rz which is connected to the output of cx
              cx_ctrl_out.replaceAllUsesExcept(
                  new_rz_inst.result().front(),
                  mlir::SmallPtrSet<Operation *, 1>{new_rz_inst});
              // Erase the rz gate
              rewriter.eraseOp(op);
              std::cout << "AFTER PERMUTE gates:\n";
              auto parentModule = op->getParentOfType<mlir::ModuleOp>();
              parentModule->dump();
              return success();
            }
          }
        }
      }
    }
using namespace mlir;

    return failure();
  }
namespace qcor {
struct PermuteGatePass
    : public PassWrapper<PermuteGatePass, OperationPass<ModuleOp>> {
  void getDependentDialects(DialectRegistry &registry) const override;
  void runOnOperation() final;
  PermuteGatePass() {}
};
} // namespace qcor
 No newline at end of file
+18 −13
Original line number Diff line number Diff line
#include "RotationMergingPass.hpp"
#include "Quantum/QuantumOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
@@ -7,8 +9,6 @@
#include "mlir/Target/LLVMIR.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/Passes.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include <iostream>

namespace qcor {
@@ -99,6 +99,7 @@ void RotationMergingPass::runOnOperation() {
            // both angles are constant: pre-compute the total angle:
            const double totalAngle =
                first_angle_const.value() + second_angle_const.value();
            if (std::abs(totalAngle) > ZERO_ROTATION_TOLERANCE) {
              mlir::Value totalAngleVal = rewriter.create<mlir::ConstantOp>(
                  op.getLoc(),
                  mlir::FloatAttr::get(rewriter.getF64Type(), totalAngle));
@@ -110,6 +111,10 @@ void RotationMergingPass::runOnOperation() {
              // Input -> Output mapping (this instruction is to be removed)
              (*next_inst.result_begin())
                  .replaceAllUsesWith(*new_inst.result_begin());
            } else {
              // Zero rotation: just map from input -> output
              (*next_inst.result_begin()).replaceAllUsesWith(op.getOperand(0));
            }
          } else {
            // Need to create an AddFOp
            auto add_op = rewriter.create<mlir::AddFOp>(
+2 −0
Original line number Diff line number Diff line
@@ -25,6 +25,8 @@ private:
  static inline const std::vector<std::pair<std::string, std::string>> search_gates{
      {"rx", "rx"}, {"ry", "ry"}, {"rz", "rz"},
      {"x", "rx"},  {"y", "ry"},  {"z", "rz"}};
  // Angle that we considered zero:
  static constexpr double ZERO_ROTATION_TOLERANCE = 1e-9;
  bool should_combine(const std::string &name1, const std::string &name2) const;
};
} // namespace qcor
 No newline at end of file
Loading