Commit 02baa500 authored by Nguyen, Thien Minh's avatar Nguyen, Thien Minh
Browse files

more proper way to do op replacement during lowering



Signed-off-by: default avatarThien Nguyen <nguyentm@ornl.gov>
parent dc79abd1
Loading
Loading
Loading
Loading
+25 −30
Original line number Diff line number Diff line
@@ -75,13 +75,6 @@ LogicalResult PowURegionOpLowering::matchAndRewrite(
  mlir::Block &powBlock = casted.body().getBlocks().front();
  for (auto &subOp : powBlock.getOperations()) {
    rewriter.insert(subOp.clone());
    if (mlir::dyn_cast_or_null<mlir::quantum::ModifierEndOp>(&subOp)) {
      mlir::quantum::ModifierEndOp terminator = mlir::cast<mlir::quantum::ModifierEndOp>(&subOp);
      assert(terminator.qubits().size() == casted.qubits().size());
      for (size_t i = 0; i < terminator.qubits().size(); ++i) {
        casted.result()[i].replaceAllUsesWith(casted.qubits()[i]);
      }
    }
  }

  // End
@@ -113,7 +106,14 @@ LogicalResult PowURegionOpLowering::matchAndRewrite(
                                  qir_operands);
  }

  rewriter.eraseOp(op);
  {
    mlir::SmallVector<mlir::Value> chained_values;
    for (const auto &targetQubit : casted.qubits()) {
      chained_values.push_back(targetQubit);
    }
    rewriter.replaceOp(op, chained_values);
  }

  return success();
}

@@ -154,21 +154,11 @@ LogicalResult CtrlURegionOpLowering::matchAndRewrite(
  }

  // Inline
  {
  auto casted = cast<mlir::quantum::CtrlURegion>(op);
  {
    mlir::Block &ctrlBlock = casted.body().getBlocks().front();
    for (auto &subOp : ctrlBlock.getOperations()) {
      rewriter.insert(subOp.clone());
      if (mlir::dyn_cast_or_null<mlir::quantum::ModifierEndOp>(&subOp)) {
        mlir::quantum::ModifierEndOp terminator =
            mlir::cast<mlir::quantum::ModifierEndOp>(&subOp);
        assert(terminator.qubits().size() == casted.qubits().size() + 1);
        assert(terminator.qubits().size() == casted.result().size());
        casted.result()[0].replaceAllUsesWith(casted.ctrl_qubit());
        for (size_t i = 1; i < terminator.qubits().size(); ++i) {
          casted.result()[i].replaceAllUsesWith(casted.qubits()[i - 1]);
        }
      }
    }
  }
  // End
@@ -208,7 +198,13 @@ LogicalResult CtrlURegionOpLowering::matchAndRewrite(
        llvm::makeArrayRef(std::vector<mlir::Value>{ctrl_bit}));
  }

  rewriter.eraseOp(op);
  {
    mlir::SmallVector<mlir::Value> chained_values{casted.ctrl_qubit()};
    for (const auto &targetQubit : casted.qubits()) {
      chained_values.push_back(targetQubit);
    }
    rewriter.replaceOp(op, chained_values);
  }
  return success();
}

@@ -253,14 +249,6 @@ LogicalResult AdjURegionOpLowering::matchAndRewrite(
    mlir::Block &adjBlock = casted.body().getBlocks().front();
    for (auto &subOp : adjBlock.getOperations()) {
      rewriter.insert(subOp.clone());
      if (mlir::dyn_cast_or_null<mlir::quantum::ModifierEndOp>(&subOp)) {
        mlir::quantum::ModifierEndOp terminator =
            mlir::cast<mlir::quantum::ModifierEndOp>(&subOp);
        assert(terminator.qubits().size() == casted.qubits().size());
        for (size_t i = 0; i < terminator.qubits().size(); ++i) {
          casted.result()[i].replaceAllUsesWith(casted.qubits()[i]);
        }
      }
    }
  }
  // End
@@ -290,7 +278,14 @@ LogicalResult AdjURegionOpLowering::matchAndRewrite(
        llvm::makeArrayRef(std::vector<mlir::Value>{}));
  }

  rewriter.eraseOp(op);
  {
    mlir::SmallVector<mlir::Value> chained_values;
    for (const auto &targetQubit :
         cast<mlir::quantum::AdjURegion>(op).qubits()) {
      chained_values.push_back(targetQubit);
    }
    rewriter.replaceOp(op, chained_values);
  }

  return success();
}