Unverified Commit 4f40fe10 authored by vporpo's avatar vporpo Committed by GitHub
Browse files

[SandboxIR][Tracker] Support nested checkpoints (#191097)

This patch implements nested checkpointing, i.e., you can now save the
IR state more than once and revert more than once.
For example, after two saves: save(1) and save(2), a revert() will bring
you back to the IR state of save(2), one more revert will bring you back
to the IR state of save(1).
parent 6e3ab87e
Loading
Loading
Loading
Loading
+5 −0
Original line number Diff line number Diff line
@@ -105,5 +105,10 @@ Internally this will go through the changes and run any finalization required.
Please note that after a call to `revert()` or `accept()` tracking will stop.
To start tracking again, the user needs to call `save()`.


Sandbox IR supports nested checkpoints, meaning that you can save more than once and revert more than once.
Conceptually each `save()` adds a new checkpoint to a stack and each `revert()` rolls back the IR state to that of the checkpoint at the top of the stack and pops the checkpoint off the stack.
A call to `accept()` will clear the stack.

## Users of Sandbox IR
- [The Sandbox Vectorizer](project:SandboxVectorizer.md)
+8 −9
Original line number Diff line number Diff line
@@ -452,10 +452,14 @@ private:
  SmallVector<std::unique_ptr<IRChangeBase>> Changes;
  /// The current state of the tracker.
  TrackerState State = TrackerState::Disabled;
  /// Nested snapshots require us to track the index of each snapshot in the
  /// `Changes` vector.
  SmallVector<unsigned, 8> Snapshots;
  Context &Ctx;

#ifndef NDEBUG
  IRSnapshotChecker SnapshotChecker;
  /// One checker per nested snapshot.
  SmallVector<IRSnapshotChecker> SnapshotChecker;
#endif

public:
@@ -465,14 +469,7 @@ public:
  bool InMiddleOfCreatingChange = false;
#endif // NDEBUG

  explicit Tracker(Context &Ctx)
      : Ctx(Ctx)
#ifndef NDEBUG
        ,
        SnapshotChecker(Ctx)
#endif
  {
  }
  explicit Tracker(Context &Ctx) : Ctx(Ctx) {}

  LLVM_ABI ~Tracker();
  Context &getContext() const { return Ctx; }
@@ -513,6 +510,8 @@ public:
  LLVM_ABI void accept();
  /// Stops tracking and reverts to saved state.
  LLVM_ABI void revert();
  /// \returns the number of nested (outstanding) checkpoints.
  unsigned nestingDepth() const { return Snapshots.size(); }

#ifndef NDEBUG
  void dump(raw_ostream &OS) const;
+23 −5
Original line number Diff line number Diff line
@@ -330,21 +330,32 @@ void CmpSwapOperands::dump() const {

void Tracker::save() {
  State = TrackerState::Record;
  // Record the last index in `Changes` that we will revert.
  Snapshots.push_back(Changes.size());
#if !defined(NDEBUG) && defined(EXPENSIVE_CHECKS)
  SnapshotChecker.save();
  SnapshotChecker.emplace_back(Ctx);
  SnapshotChecker.back().save();
#endif
}

void Tracker::revert() {
  assert(State == TrackerState::Record && "Forgot to save()!");
  State = TrackerState::Reverting;
  for (auto &Change : reverse(Changes))
  const unsigned ToRevert = Changes.size() - Snapshots.back();
  unsigned CntReverts = 0;
  for (auto &Change : reverse(Changes)) {
    // Stop reverting if we reach the index of the last snapshot.
    if (CntReverts++ == ToRevert)
      break;
    Change->revert(*this);
  Changes.clear();
  }
  Changes.erase(Changes.end() - ToRevert, Changes.end());
  Snapshots.pop_back();
#if !defined(NDEBUG) && defined(EXPENSIVE_CHECKS)
  SnapshotChecker.expectNoDiff();
  SnapshotChecker.back().expectNoDiff();
  SnapshotChecker.pop_back();
#endif
  State = TrackerState::Disabled;
  State = Snapshots.empty() ? TrackerState::Disabled : TrackerState::Record;
}

void Tracker::accept() {
@@ -353,13 +364,20 @@ void Tracker::accept() {
  for (auto &Change : Changes)
    Change->accept();
  Changes.clear();
  Snapshots.clear();
#if !defined(NDEBUG) && defined(EXPENSIVE_CHECKS)
  SnapshotChecker.clear();
#endif
}

#ifndef NDEBUG
void Tracker::dump(raw_ostream &OS) const {
  unsigned SnapshotCnt = 0;
  for (auto [Idx, ChangePtr] : enumerate(Changes)) {
    OS << Idx << ". ";
    ChangePtr->dump(OS);
    if (find(Snapshots, Idx) != Snapshots.end())
      OS << " [Snapshot " << SnapshotCnt++ << "]";
    OS << "\n";
  }
}
+67 −7
Original line number Diff line number Diff line
@@ -1970,11 +1970,11 @@ define i32 @foo(i32 %arg) {
  EXPECT_DEATH(Checker.expectNoDiff(), "Found IR difference");
}

TEST_F(TrackerTest, IRSnapshotCheckerSaveMultipleTimes) {
TEST_F(TrackerTest, NestedCheckpoints) {
  parseIR(C, R"IR(
define i32 @foo(i32 %arg) {
  %add0 = add i32 %arg, %arg
  %add1 = add i32 %add0, %arg
define i32 @foo(i32 %arg0, i32 %arg1) {
  %add0 = add i32 %arg0, %arg0
  %add1 = add i32 %add0, %add0
  ret i32 %add1
}
)IR");
@@ -1984,14 +1984,74 @@ define i32 @foo(i32 %arg) {
  auto *F = Ctx.createFunction(&LLVMF);
  auto *BB = &*F->begin();
  auto It = BB->begin();
  sandboxir::Argument *Arg0 = F->getArg(0);
  sandboxir::Argument *Arg1 = F->getArg(1);
  sandboxir::Instruction *Add0 = &*It++;
  sandboxir::Instruction *Add1 = &*It++;
  sandboxir::IRSnapshotChecker Checker(Ctx);
  Checker.save();
  Add1->setOperand(1, Add0);
  // Now IR differs from the last snapshot. Let's take a new snapshot.
  Ctx.save();
  // Check that revert() works even with no changes
  Ctx.revert();

  // Check multiple save(), revert().
  Ctx.save();
  Ctx.save();
  Ctx.revert();
  Ctx.revert();
  Checker.expectNoDiff();

  // Check nested checkpoint: save,save,revert,revert.
  EXPECT_EQ(Add1->getOperand(0), Add0);
  Ctx.save();
  EXPECT_EQ(Ctx.getTracker().nestingDepth(), 1u);
  Add1->setOperand(0, Arg0);
  Ctx.save();
  EXPECT_EQ(Ctx.getTracker().nestingDepth(), 2u);
  Add1->setOperand(0, Arg1);
  Ctx.revert();
  EXPECT_EQ(Ctx.getTracker().nestingDepth(), 1u);
  EXPECT_EQ(Add1->getOperand(0), Arg0);
  Ctx.revert();
  EXPECT_EQ(Ctx.getTracker().nestingDepth(), 0u);
  EXPECT_EQ(Add1->getOperand(0), Add0);

  Checker.expectNoDiff();

  // Check nested checkpoint: save,revert,save,revert
  EXPECT_EQ(Add1->getOperand(0), Add0);
  Ctx.save();
  EXPECT_EQ(Ctx.getTracker().nestingDepth(), 1u);
  Add1->setOperand(0, Arg0);
  Ctx.revert();
  EXPECT_EQ(Ctx.getTracker().nestingDepth(), 0u);
  EXPECT_EQ(Add1->getOperand(0), Add0);
  Checker.expectNoDiff();

  Ctx.save();
  EXPECT_EQ(Ctx.getTracker().nestingDepth(), 1u);
  Add1->setOperand(0, Arg1);
  Ctx.revert();
  EXPECT_EQ(Ctx.getTracker().nestingDepth(), 0u);
  EXPECT_EQ(Add1->getOperand(0), Add0);
  Checker.expectNoDiff();

  // Check nested checkpoint: save,accept,save,revert
  EXPECT_EQ(Add1->getOperand(0), Add0);
  Ctx.save();
  EXPECT_EQ(Ctx.getTracker().nestingDepth(), 1u);
  Add1->setOperand(0, Arg0);
  Ctx.accept();
  EXPECT_EQ(Ctx.getTracker().nestingDepth(), 0u);
  EXPECT_EQ(Add1->getOperand(0), Arg0);

  Checker.save();
  // The new snapshot should have replaced the old one, so this should succeed.
  Ctx.save();
  EXPECT_EQ(Ctx.getTracker().nestingDepth(), 1u);
  Add1->setOperand(0, Arg1);
  Ctx.revert();
  EXPECT_EQ(Ctx.getTracker().nestingDepth(), 0u);
  EXPECT_EQ(Add1->getOperand(0), Arg0);
  Checker.expectNoDiff();
}