Unverified Commit f054947c authored by Valery Pykhtin's avatar Valery Pykhtin Committed by GitHub
Browse files

[SimplifyCFG] Prevent merging cbranch to cbranch if the branch probability...

[SimplifyCFG] Prevent merging cbranch to cbranch if the branch probability from the first to second is too low. (#69375)

AMDGPU target has faced the situation which can be illustrated with the
following testcase:

define void @dont_merge_cbranches(i32 %V) {
  %divergent_cond = icmp ne i32 %V, 0
  %uniform_cond = call i1 @uniform_result(i1 %divergent_cond)
  br i1 %uniform_cond, label %bb2, label %exit, !prof !0
bb2:
  br i1 %divergent_cond, label %bb3, label %exit
bb3:
  call void @bar( )
  br label %exit
exit:
  ret void
}
!0 = !{!"branch_weights", i32 1, i32 100000}

SimplifyCFG merges branches on %uniform_cond and %divergent_cond which is undesirable because the first branch to bb2 is taken extremely rare and the second branch is expensive. The merged branch becomes as expensive as the second.

This patch prevents such merging if the branch to the second branch is unlikely to happen.
parent dde85f86
Loading
Loading
Loading
Loading
+14 −0
Original line number Diff line number Diff line
@@ -4347,6 +4347,20 @@ static bool SimplifyCondBranchToCondBranch(BranchInst *PBI, BranchInst *BI,
  if (PBI->getSuccessor(PBIOp) == BB)
    return false;

  // If predecessor's branch probability to BB is too low don't merge branches.
  SmallVector<uint32_t, 2> PredWeights;
  if (!PBI->getMetadata(LLVMContext::MD_unpredictable) &&
      extractBranchWeights(*PBI, PredWeights) &&
      (PredWeights[0] + PredWeights[1]) != 0) {

    BranchProbability CommonDestProb = BranchProbability::getBranchProbability(
        PredWeights[PBIOp], PredWeights[0] + PredWeights[1]);

    BranchProbability Likely = TTI.getPredictableBranchThreshold();
    if (CommonDestProb >= Likely)
      return false;
  }

  // Do not perform this transformation if it would require
  // insertion of a large number of select instructions. For targets
  // without predication/cmovs, this is a big pessimization.
+84 −0
Original line number Diff line number Diff line
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
; RUN: opt < %s -passes=simplifycfg -S | FileCheck %s

declare void @bar()
declare i1 @uniform_result(i1 %c)

define void @dont_merge_cbranches1(i32 %V) {
; CHECK-LABEL: @dont_merge_cbranches1(
; CHECK-NEXT:    [[DIVERGENT_COND:%.*]] = icmp ne i32 [[V:%.*]], 0
; CHECK-NEXT:    [[UNIFORM_COND:%.*]] = call i1 @uniform_result(i1 [[DIVERGENT_COND]])
; CHECK-NEXT:    br i1 [[UNIFORM_COND]], label [[BB2:%.*]], label [[EXIT:%.*]], !prof [[PROF0:![0-9]+]]
; CHECK:       bb2:
; CHECK-NEXT:    br i1 [[DIVERGENT_COND]], label [[BB3:%.*]], label [[EXIT]]
; CHECK:       bb3:
; CHECK-NEXT:    call void @bar()
; CHECK-NEXT:    br label [[EXIT]]
; CHECK:       exit:
; CHECK-NEXT:    ret void
;
  %divergent_cond = icmp ne i32 %V, 0
  %uniform_cond = call i1 @uniform_result(i1 %divergent_cond)
  br i1 %uniform_cond, label %bb2, label %exit, !prof !0
bb2:
  br i1 %divergent_cond, label %bb3, label %exit
bb3:
  call void @bar( )
  br label %exit
exit:
  ret void
}

define void @dont_merge_cbranches2(i32 %V) {
; CHECK-LABEL: @dont_merge_cbranches2(
; CHECK-NEXT:    [[DIVERGENT_COND:%.*]] = icmp ne i32 [[V:%.*]], 0
; CHECK-NEXT:    [[UNIFORM_COND:%.*]] = call i1 @uniform_result(i1 [[DIVERGENT_COND]])
; CHECK-NEXT:    br i1 [[UNIFORM_COND]], label [[EXIT:%.*]], label [[BB2:%.*]], !prof [[PROF1:![0-9]+]]
; CHECK:       bb2:
; CHECK-NEXT:    br i1 [[DIVERGENT_COND]], label [[BB3:%.*]], label [[EXIT]]
; CHECK:       bb3:
; CHECK-NEXT:    call void @bar()
; CHECK-NEXT:    br label [[EXIT]]
; CHECK:       exit:
; CHECK-NEXT:    ret void
;
  %divergent_cond = icmp ne i32 %V, 0
  %uniform_cond = call i1 @uniform_result(i1 %divergent_cond)
  br i1 %uniform_cond, label %exit, label %bb2, !prof !1
bb2:
  br i1 %divergent_cond, label %bb3, label %exit
bb3:
  call void @bar( )
  br label %exit
exit:
  ret void
}

define void @merge_cbranches(i32 %V) {
; CHECK-LABEL: @merge_cbranches(
; CHECK-NEXT:    [[DIVERGENT_COND:%.*]] = icmp ne i32 [[V:%.*]], 0
; CHECK-NEXT:    [[UNIFORM_COND:%.*]] = call i1 @uniform_result(i1 [[DIVERGENT_COND]])
; CHECK-NEXT:    [[DIVERGENT_COND_NOT:%.*]] = xor i1 [[DIVERGENT_COND]], true
; CHECK-NEXT:    [[BRMERGE:%.*]] = select i1 [[UNIFORM_COND]], i1 true, i1 [[DIVERGENT_COND_NOT]]
; CHECK-NEXT:    br i1 [[BRMERGE]], label [[EXIT:%.*]], label [[BB3:%.*]], !prof [[PROF2:![0-9]+]]
; CHECK:       bb3:
; CHECK-NEXT:    call void @bar()
; CHECK-NEXT:    br label [[EXIT]]
; CHECK:       exit:
; CHECK-NEXT:    ret void
;
  %divergent_cond = icmp ne i32 %V, 0
  %uniform_cond = call i1 @uniform_result(i1 %divergent_cond)
  br i1 %uniform_cond, label %exit, label %bb2, !prof !2
bb2:
  br i1 %divergent_cond, label %bb3, label %exit
bb3:
  call void @bar( )
  br label %exit
exit:
  ret void
}

!0 = !{!"branch_weights", i32 1, i32 1000}
!1 = !{!"branch_weights", i32 1000, i32 1}
!2 = !{!"branch_weights", i32 3, i32 2}