Commit f2fa6ad0 authored by Nikita Popov's avatar Nikita Popov
Browse files

[MergeICmps] Don't reorder unmerged comparisons

MergeICmps will currently sort (by offset) all comparisons in a chain,
including those that do not get merged. This is problematic in two ways:

 * We may end up moving the original first block into the middle of
   the chain, in which case the "extra work" instructions will also
   be in the middle of the chain, resulting in invalid IR
   (reported in https://reviews.llvm.org/D108782#3005583).
 * Reordering branches is generally not legal, because it may
   introduce branch on poison, which is UB (PR51845). The merging
   done by MergeICmps is legal as long as we assume that memcmp()
   works on frozen memory, but the reordering of unmerged comparisons
   is definitely incorrect (without inserting freeze instructions),
   so we should avoid it.

There are easier ways to fix the first issue, but I figured it was
worthwhile to do this properly to also fix the second one. What we
now do is to restore the original relative order of (potentially
merged) comparisons.

I took the liberty of dropping the MERGEICMPS_DOT_ON functionality,
because it would be more awkward to implement now (as the before and
after representation is different) and it doesn't seem terribly
useful nowadays.

Differential Revision: https://reviews.llvm.org/D110024
parent 40e971a2
Loading
Loading
Loading
Loading
+78 −94
Original line number Diff line number Diff line
@@ -229,6 +229,8 @@ class BCECmpBlock {
  InstructionSet BlockInsts;
  // The block requires splitting.
  bool RequireSplit = false;
  // Original order of this block in the chain.
  unsigned OrigOrder = 0;

private:
  BCECmp Cmp;
@@ -380,38 +382,82 @@ static inline void enqueueBlock(std::vector<BCECmpBlock> &Comparisons,
                    << Comparison.Rhs().BaseId << " + "
                    << Comparison.Rhs().Offset << "\n");
  LLVM_DEBUG(dbgs() << "\n");
  Comparison.OrigOrder = Comparisons.size();
  Comparisons.push_back(std::move(Comparison));
}

// A chain of comparisons.
class BCECmpChain {
public:
  using ContiguousBlocks = std::vector<BCECmpBlock>;

  BCECmpChain(const std::vector<BasicBlock *> &Blocks, PHINode &Phi,
              AliasAnalysis &AA);

   int size() const { return Comparisons_.size(); }

#ifdef MERGEICMPS_DOT_ON
  void dump() const;
#endif  // MERGEICMPS_DOT_ON

  bool simplify(const TargetLibraryInfo &TLI, AliasAnalysis &AA,
                DomTreeUpdater &DTU);

  bool atLeastOneMerged() const {
    return any_of(MergedBlocks_,
                  [](const auto &Blocks) { return Blocks.size() > 1; });
  }

private:
  static bool IsContiguous(const BCECmpBlock &First,
                           const BCECmpBlock &Second) {
  PHINode &Phi_;
  // The list of all blocks in the chain, grouped by contiguity.
  std::vector<ContiguousBlocks> MergedBlocks_;
  // The original entry block (before sorting);
  BasicBlock *EntryBlock_;
};

static bool areContiguous(const BCECmpBlock &First, const BCECmpBlock &Second) {
  return First.Lhs().BaseId == Second.Lhs().BaseId &&
         First.Rhs().BaseId == Second.Rhs().BaseId &&
         First.Lhs().Offset + First.SizeBits() / 8 == Second.Lhs().Offset &&
         First.Rhs().Offset + First.SizeBits() / 8 == Second.Rhs().Offset;
}

  PHINode &Phi_;
  std::vector<BCECmpBlock> Comparisons_;
  // The original entry block (before sorting);
  BasicBlock *EntryBlock_;
};
static unsigned getMinOrigOrder(const BCECmpChain::ContiguousBlocks &Blocks) {
  unsigned MinOrigOrder = std::numeric_limits<unsigned>::max();
  for (const BCECmpBlock &Block : Blocks)
    MinOrigOrder = std::min(MinOrigOrder, Block.OrigOrder);
  return MinOrigOrder;
}

/// Given a chain of comparison blocks, groups the blocks into contiguous
/// ranges that can be merged together into a single comparison.
static std::vector<BCECmpChain::ContiguousBlocks>
mergeBlocks(std::vector<BCECmpBlock> &&Blocks) {
  std::vector<BCECmpChain::ContiguousBlocks> MergedBlocks;

  // Sort to detect continuous offsets.
  llvm::sort(Blocks,
             [](const BCECmpBlock &LhsBlock, const BCECmpBlock &RhsBlock) {
               return std::tie(LhsBlock.Lhs(), LhsBlock.Rhs()) <
                      std::tie(RhsBlock.Lhs(), RhsBlock.Rhs());
             });

  BCECmpChain::ContiguousBlocks *LastMergedBlock = nullptr;
  for (BCECmpBlock &Block : Blocks) {
    if (!LastMergedBlock || !areContiguous(LastMergedBlock->back(), Block)) {
      MergedBlocks.emplace_back();
      LastMergedBlock = &MergedBlocks.back();
    } else {
      LLVM_DEBUG(dbgs() << "Merging block " << Block.BB->getName() << " into "
                        << LastMergedBlock->back().BB->getName() << "\n");
    }
    LastMergedBlock->push_back(std::move(Block));
  }

  // While we allow reordering for merging, do not reorder unmerged comparisons.
  // Doing so may introduce branch on poison.
  llvm::sort(MergedBlocks, [](const BCECmpChain::ContiguousBlocks &LhsBlocks,
                              const BCECmpChain::ContiguousBlocks &RhsBlocks) {
    return getMinOrigOrder(LhsBlocks) < getMinOrigOrder(RhsBlocks);
  });

  return MergedBlocks;
}

BCECmpChain::BCECmpChain(const std::vector<BasicBlock *> &Blocks, PHINode &Phi,
                         AliasAnalysis &AA)
@@ -492,46 +538,8 @@ BCECmpChain::BCECmpChain(const std::vector<BasicBlock *> &Blocks, PHINode &Phi,
    return;
  }
  EntryBlock_ = Comparisons[0].BB;
  Comparisons_ = std::move(Comparisons);
#ifdef MERGEICMPS_DOT_ON
  errs() << "BEFORE REORDERING:\n\n";
  dump();
#endif  // MERGEICMPS_DOT_ON
  // Reorder blocks by LHS. We can do that without changing the
  // semantics because we are only accessing dereferencable memory.
  llvm::sort(Comparisons_,
             [](const BCECmpBlock &LhsBlock, const BCECmpBlock &RhsBlock) {
               return std::tie(LhsBlock.Lhs(), LhsBlock.Rhs()) <
                      std::tie(RhsBlock.Lhs(), RhsBlock.Rhs());
             });
#ifdef MERGEICMPS_DOT_ON
  errs() << "AFTER REORDERING:\n\n";
  dump();
#endif  // MERGEICMPS_DOT_ON
}

#ifdef MERGEICMPS_DOT_ON
void BCECmpChain::dump() const {
  errs() << "digraph dag {\n";
  errs() << " graph [bgcolor=transparent];\n";
  errs() << " node [color=black,style=filled,fillcolor=lightyellow];\n";
  errs() << " edge [color=black];\n";
  for (size_t I = 0; I < Comparisons_.size(); ++I) {
    const auto &Comparison = Comparisons_[I];
    errs() << " \"" << I << "\" [label=\"%"
           << Comparison.Lhs().Base()->getName() << " + "
           << Comparison.Lhs().Offset << " == %"
           << Comparison.Rhs().Base()->getName() << " + "
           << Comparison.Rhs().Offset << " (" << (Comparison.SizeBits() / 8)
           << " bytes)\"];\n";
    const Value *const Val = Phi_.getIncomingValueForBlock(Comparison.BB);
    if (I > 0) errs() << " \"" << (I - 1) << "\" -> \"" << I << "\";\n";
    errs() << " \"" << I << "\" -> \"Phi\" [label=\"" << *Val << "\"];\n";
  }
  errs() << " \"Phi\" [label=\"Phi\"];\n";
  errs() << "}\n\n";
}
#endif  // MERGEICMPS_DOT_ON
  MergedBlocks_ = mergeBlocks(std::move(Comparisons));
}

namespace {

@@ -655,47 +663,20 @@ static BasicBlock *mergeComparisons(ArrayRef<BCECmpBlock> Comparisons,

bool BCECmpChain::simplify(const TargetLibraryInfo &TLI, AliasAnalysis &AA,
                           DomTreeUpdater &DTU) {
  assert(Comparisons_.size() >= 2 && "simplifying trivial BCECmpChain");
  // First pass to check if there is at least one merge. If not, we don't do
  // anything and we keep analysis passes intact.
  const auto AtLeastOneMerged = [this]() {
    for (size_t I = 1; I < Comparisons_.size(); ++I) {
      if (IsContiguous(Comparisons_[I - 1], Comparisons_[I]))
        return true;
    }
    return false;
  };
  if (!AtLeastOneMerged())
    return false;

  assert(atLeastOneMerged() && "simplifying trivial BCECmpChain");
  LLVM_DEBUG(dbgs() << "Simplifying comparison chain starting at block "
                    << EntryBlock_->getName() << "\n");

  // Effectively merge blocks. We go in the reverse direction from the phi block
  // so that the next block is always available to branch to.
  const auto mergeRange = [this, &TLI, &AA, &DTU](int I, int Num,
                                                  BasicBlock *InsertBefore,
                                                  BasicBlock *Next) {
    return mergeComparisons(makeArrayRef(Comparisons_).slice(I, Num),
                            InsertBefore, Next, Phi_, TLI, AA, DTU);
  };
  int NumMerged = 1;
  BasicBlock *InsertBefore = EntryBlock_;
  BasicBlock *NextCmpBlock = Phi_.getParent();
  for (int I = static_cast<int>(Comparisons_.size()) - 2; I >= 0; --I) {
    if (IsContiguous(Comparisons_[I], Comparisons_[I + 1])) {
      LLVM_DEBUG(dbgs() << "Merging block " << Comparisons_[I].BB->getName()
                        << " into " << Comparisons_[I + 1].BB->getName()
                        << "\n");
      ++NumMerged;
    } else {
      NextCmpBlock = mergeRange(I + 1, NumMerged, NextCmpBlock, NextCmpBlock);
      NumMerged = 1;
    }
  for (const auto &Blocks : reverse(MergedBlocks_)) {
    NumMerged += Blocks.size() - 1;
    InsertBefore = NextCmpBlock = mergeComparisons(
        Blocks, InsertBefore, NextCmpBlock, Phi_, TLI, AA, DTU);
  }
  // Insert the entry block for the new chain before the old entry block.
  // If the old entry block was the function entry, this ensures that the new
  // entry can become the function entry.
  NextCmpBlock = mergeRange(0, NumMerged, EntryBlock_, NextCmpBlock);

  // Replace the original cmp chain with the new cmp chain by pointing all
  // predecessors of EntryBlock_ to NextCmpBlock instead. This makes all cmp
@@ -723,13 +704,16 @@ bool BCECmpChain::simplify(const TargetLibraryInfo &TLI, AliasAnalysis &AA,

  // Delete merged blocks. This also removes incoming values in phi.
  SmallVector<BasicBlock *, 16> DeadBlocks;
  for (auto &Cmp : Comparisons_) {
    LLVM_DEBUG(dbgs() << "Deleting merged block " << Cmp.BB->getName() << "\n");
    DeadBlocks.push_back(Cmp.BB);
  for (const auto &Blocks : MergedBlocks_) {
    for (const BCECmpBlock &Block : Blocks) {
      LLVM_DEBUG(dbgs() << "Deleting merged block " << Block.BB->getName()
                        << "\n");
      DeadBlocks.push_back(Block.BB);
    }
  }
  DeleteDeadBlocks(DeadBlocks, &DTU);

  Comparisons_.clear();
  MergedBlocks_.clear();
  return true;
}

@@ -829,8 +813,8 @@ bool processPhi(PHINode &Phi, const TargetLibraryInfo &TLI, AliasAnalysis &AA,
  if (Blocks.empty()) return false;
  BCECmpChain CmpChain(Blocks, Phi, AA);

  if (CmpChain.size() < 2) {
    LLVM_DEBUG(dbgs() << "skip: only one compare block\n");
  if (!CmpChain.atLeastOneMerged()) {
    LLVM_DEBUG(dbgs() << "skip: nothing merged\n");
    return false;
  }

+70 −0
Original line number Diff line number Diff line
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
; RUN: opt -S -mergeicmps < %s | FileCheck %s

target triple = "x86_64-unknown-linux-gnu"

%"struct.a::c" = type { i32, i32*, i8* }

; The entry block cannot be merged as the comparison is not continuous.
; While it compares the highest address, it should not be moved after the
; other comparisons, as that would make the allocas non-dominating.

define i1 @test() {
; CHECK-LABEL: @test(
; CHECK-NEXT:  "land.lhs.true+entry":
; CHECK-NEXT:    [[H:%.*]] = alloca %"struct.a::c", align 8
; CHECK-NEXT:    [[I:%.*]] = alloca %"struct.a::c", align 8
; CHECK-NEXT:    call void @init(%"struct.a::c"* [[H]])
; CHECK-NEXT:    call void @init(%"struct.a::c"* [[I]])
; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds %"struct.a::c", %"struct.a::c"* [[H]], i64 0, i32 1
; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr inbounds %"struct.a::c", %"struct.a::c"* [[I]], i64 0, i32 1
; CHECK-NEXT:    [[CSTR:%.*]] = bitcast i32** [[TMP0]] to i8*
; CHECK-NEXT:    [[CSTR2:%.*]] = bitcast i32** [[TMP1]] to i8*
; CHECK-NEXT:    [[MEMCMP:%.*]] = call i32 @memcmp(i8* [[CSTR]], i8* [[CSTR2]], i64 16)
; CHECK-NEXT:    [[TMP2:%.*]] = icmp eq i32 [[MEMCMP]], 0
; CHECK-NEXT:    br i1 [[TMP2]], label [[LAND_RHS1:%.*]], label [[LAND_END:%.*]]
; CHECK:       land.rhs1:
; CHECK-NEXT:    [[TMP3:%.*]] = getelementptr inbounds %"struct.a::c", %"struct.a::c"* [[H]], i64 0, i32 0
; CHECK-NEXT:    [[TMP4:%.*]] = getelementptr inbounds %"struct.a::c", %"struct.a::c"* [[I]], i64 0, i32 0
; CHECK-NEXT:    [[TMP5:%.*]] = load i32, i32* [[TMP3]], align 4
; CHECK-NEXT:    [[TMP6:%.*]] = load i32, i32* [[TMP4]], align 4
; CHECK-NEXT:    [[TMP7:%.*]] = icmp eq i32 [[TMP5]], [[TMP6]]
; CHECK-NEXT:    br label [[LAND_END]]
; CHECK:       land.end:
; CHECK-NEXT:    [[V9:%.*]] = phi i1 [ [[TMP7]], [[LAND_RHS1]] ], [ false, %"land.lhs.true+entry" ]
; CHECK-NEXT:    ret i1 [[V9]]
;
entry:
  %h = alloca %"struct.a::c", align 8
  %i = alloca %"struct.a::c", align 8
  call void @init(%"struct.a::c"* %h)
  call void @init(%"struct.a::c"* %i)
  %e = getelementptr inbounds %"struct.a::c", %"struct.a::c"* %h, i64 0, i32 2
  %v3 = load i8*, i8** %e, align 8
  %e2 = getelementptr inbounds %"struct.a::c", %"struct.a::c"* %i, i64 0, i32 2
  %v4 = load i8*, i8** %e2, align 8
  %cmp = icmp eq i8* %v3, %v4
  br i1 %cmp, label %land.lhs.true, label %land.end

land.lhs.true:                                    ; preds = %entry
  %d = getelementptr inbounds %"struct.a::c", %"struct.a::c"* %h, i64 0, i32 1
  %v5 = load i32*, i32** %d, align 8
  %d3 = getelementptr inbounds %"struct.a::c", %"struct.a::c"* %i, i64 0, i32 1
  %v6 = load i32*, i32** %d3, align 8
  %cmp4 = icmp eq i32* %v5, %v6
  br i1 %cmp4, label %land.rhs, label %land.end

land.rhs:                                         ; preds = %land.lhs.true
  %j = getelementptr inbounds %"struct.a::c", %"struct.a::c"* %h, i64 0, i32 0
  %v7 = load i32, i32* %j, align 8
  %j5 = getelementptr inbounds %"struct.a::c", %"struct.a::c"* %i, i64 0, i32 0
  %v8 = load i32, i32* %j5, align 8
  %cmp6 = icmp eq i32 %v7, %v8
  br label %land.end

land.end:                                         ; preds = %land.rhs, %land.lhs.true, %entry
  %v9 = phi i1 [ false, %land.lhs.true ], [ false, %entry ], [ %cmp6, %land.rhs ]
  ret i1 %v9
}

declare void @init(%"struct.a::c"*)
+15 −15
Original line number Diff line number Diff line
@@ -9,20 +9,20 @@

define zeroext i1 @opeq1(
; CHECK-LABEL: @opeq1(
; CHECK-NEXT:  "land.rhs.i+land.rhs.i.2":
; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds [[S:%.*]], %S* [[A:%.*]], i64 0, i32 0
; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr inbounds [[S]], %S* [[B:%.*]], i64 0, i32 0
; CHECK-NEXT:    [[CSTR:%.*]] = bitcast i32* [[TMP0]] to i8*
; CHECK-NEXT:    [[CSTR3:%.*]] = bitcast i32* [[TMP1]] to i8*
; CHECK-NEXT:    [[MEMCMP:%.*]] = call i32 @memcmp(i8* [[CSTR]], i8* [[CSTR3]], i64 8)
; CHECK-NEXT:    [[TMP2:%.*]] = icmp eq i32 [[MEMCMP]], 0
; CHECK-NEXT:    br i1 [[TMP2]], label [[ENTRY2:%.*]], label [[OPEQ1_EXIT:%.*]]
; CHECK:       entry2:
; CHECK-NEXT:    [[TMP3:%.*]] = getelementptr inbounds [[S]], %S* [[A]], i64 0, i32 3
; CHECK-NEXT:    [[TMP4:%.*]] = getelementptr inbounds [[S]], %S* [[B]], i64 0, i32 2
; CHECK-NEXT:    [[TMP5:%.*]] = load i32, i32* [[TMP3]], align 4
; CHECK-NEXT:    [[TMP6:%.*]] = load i32, i32* [[TMP4]], align 4
; CHECK-NEXT:    [[TMP7:%.*]] = icmp eq i32 [[TMP5]], [[TMP6]]
; CHECK-NEXT:  entry3:
; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds [[S:%.*]], %S* [[A:%.*]], i64 0, i32 3
; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr inbounds [[S]], %S* [[B:%.*]], i64 0, i32 2
; CHECK-NEXT:    [[TMP2:%.*]] = load i32, i32* [[TMP0]], align 4
; CHECK-NEXT:    [[TMP3:%.*]] = load i32, i32* [[TMP1]], align 4
; CHECK-NEXT:    [[TMP4:%.*]] = icmp eq i32 [[TMP2]], [[TMP3]]
; CHECK-NEXT:    br i1 [[TMP4]], label %"land.rhs.i+land.rhs.i.2", label [[OPEQ1_EXIT:%.*]]
; CHECK:       "land.rhs.i+land.rhs.i.2":
; CHECK-NEXT:    [[TMP5:%.*]] = getelementptr inbounds [[S]], %S* [[A]], i64 0, i32 0
; CHECK-NEXT:    [[TMP6:%.*]] = getelementptr inbounds [[S]], %S* [[B]], i64 0, i32 0
; CHECK-NEXT:    [[CSTR:%.*]] = bitcast i32* [[TMP5]] to i8*
; CHECK-NEXT:    [[CSTR2:%.*]] = bitcast i32* [[TMP6]] to i8*
; CHECK-NEXT:    [[MEMCMP:%.*]] = call i32 @memcmp(i8* [[CSTR]], i8* [[CSTR2]], i64 8)
; CHECK-NEXT:    [[TMP7:%.*]] = icmp eq i32 [[MEMCMP]], 0
; CHECK-NEXT:    br i1 [[TMP7]], label [[LAND_RHS_I_31:%.*]], label [[OPEQ1_EXIT]]
; CHECK:       land.rhs.i.31:
; CHECK-NEXT:    [[TMP8:%.*]] = getelementptr inbounds [[S]], %S* [[A]], i64 0, i32 3
@@ -32,7 +32,7 @@ define zeroext i1 @opeq1(
; CHECK-NEXT:    [[TMP12:%.*]] = icmp eq i32 [[TMP10]], [[TMP11]]
; CHECK-NEXT:    br label [[OPEQ1_EXIT]]
; CHECK:       opeq1.exit:
; CHECK-NEXT:    [[TMP13:%.*]] = phi i1 [ [[TMP12]], [[LAND_RHS_I_31]] ], [ false, [[ENTRY2]] ], [ false, %"land.rhs.i+land.rhs.i.2" ]
; CHECK-NEXT:    [[TMP13:%.*]] = phi i1 [ [[TMP12]], [[LAND_RHS_I_31]] ], [ false, %"land.rhs.i+land.rhs.i.2" ], [ false, [[ENTRY3:%.*]] ]
; CHECK-NEXT:    ret i1 [[TMP13]]
;
  %S* nocapture readonly dereferenceable(16) %a,