Unverified Commit 1e39575a authored by Min-Yih Hsu's avatar Min-Yih Hsu Committed by GitHub
Browse files

[RISCV] CSE by swapping conditional branches (#71111)

DAGCombiner, as well as InstCombine, tend to canonicalize GE/LE into
GT/LT, namely:
```
X >= C --> X > (C - 1)
```
Which sometime generates off-by-one constants that could have been CSE'd
with surrounding constants.
Instead of changing such canonicalization, this patch tries to swap
those branch conditions post-isel, in the hope of resurfacing more
constant CSE opportunities. More specifically, it performs the following
optimization:

For two constants C0 and C1 from
```
li Y, C0
li Z, C1
```
To remove redundnat `li Y, C0`,
 1. if C1 = C0 + 1 we can turn: 
    (a) blt Y, X -> bge X, Z
    (b) bge Y, X -> blt X, Z
 2. if C1 = C0 - 1 we can turn: 
    (a) blt X, Y -> bge Z, X
    (b) bge X, Y -> blt Z, X

This optimization will be done by PeepholeOptimizer through
RISCVInstrInfo::optimizeCondBranch.
parent 015c06ad
Loading
Loading
Loading
Loading
+119 −0
Original line number Diff line number Diff line
@@ -1159,6 +1159,125 @@ bool RISCVInstrInfo::reverseBranchCondition(
  return false;
}

bool RISCVInstrInfo::optimizeCondBranch(MachineInstr &MI) const {
  MachineBasicBlock *MBB = MI.getParent();
  MachineRegisterInfo &MRI = MBB->getParent()->getRegInfo();

  MachineBasicBlock *TBB, *FBB;
  SmallVector<MachineOperand, 3> Cond;
  if (analyzeBranch(*MBB, TBB, FBB, Cond, /*AllowModify=*/false))
    return false;
  (void)FBB;

  RISCVCC::CondCode CC = static_cast<RISCVCC::CondCode>(Cond[0].getImm());
  assert(CC != RISCVCC::COND_INVALID);

  if (CC == RISCVCC::COND_EQ || CC == RISCVCC::COND_NE)
    return false;

  // For two constants C0 and C1 from
  // ```
  // li Y, C0
  // li Z, C1
  // ```
  // 1. if C1 = C0 + 1
  // we can turn:
  //  (a) blt Y, X -> bge X, Z
  //  (b) bge Y, X -> blt X, Z
  //
  // 2. if C1 = C0 - 1
  // we can turn:
  //  (a) blt X, Y -> bge Z, X
  //  (b) bge X, Y -> blt Z, X
  //
  // To make sure this optimization is really beneficial, we only
  // optimize for cases where Y had only one use (i.e. only used by the branch).

  // Right now we only care about LI (i.e. ADDI x0, imm)
  auto isLoadImm = [](const MachineInstr *MI, int64_t &Imm) -> bool {
    if (MI->getOpcode() == RISCV::ADDI && MI->getOperand(1).isReg() &&
        MI->getOperand(1).getReg() == RISCV::X0) {
      Imm = MI->getOperand(2).getImm();
      return true;
    }
    return false;
  };
  // Either a load from immediate instruction or X0.
  auto isFromLoadImm = [&](const MachineOperand &Op, int64_t &Imm) -> bool {
    if (!Op.isReg())
      return false;
    Register Reg = Op.getReg();
    if (Reg == RISCV::X0) {
      Imm = 0;
      return true;
    }
    if (!Reg.isVirtual())
      return false;
    return isLoadImm(MRI.getVRegDef(Op.getReg()), Imm);
  };

  MachineOperand &LHS = MI.getOperand(0);
  MachineOperand &RHS = MI.getOperand(1);
  // Try to find the register for constant Z; return
  // invalid register otherwise.
  auto searchConst = [&](int64_t C1) -> Register {
    MachineBasicBlock::reverse_iterator II(&MI), E = MBB->rend();
    auto DefC1 = std::find_if(++II, E, [&](const MachineInstr &I) -> bool {
      int64_t Imm;
      return isLoadImm(&I, Imm) && Imm == C1;
    });
    if (DefC1 != E)
      return DefC1->getOperand(0).getReg();

    return Register();
  };

  bool Modify = false;
  int64_t C0;
  if (isFromLoadImm(LHS, C0) && MRI.hasOneUse(LHS.getReg())) {
    // Might be case 1.
    // Signed integer overflow is UB. (UINT64_MAX is bigger so we don't need
    // to worry about unsigned overflow here)
    if (C0 < INT64_MAX)
      if (Register RegZ = searchConst(C0 + 1)) {
        reverseBranchCondition(Cond);
        Cond[1] = MachineOperand::CreateReg(RHS.getReg(), /*isDef=*/false);
        Cond[2] = MachineOperand::CreateReg(RegZ, /*isDef=*/false);
        // We might extend the live range of Z, clear its kill flag to
        // account for this.
        MRI.clearKillFlags(RegZ);
        Modify = true;
      }
  } else if (isFromLoadImm(RHS, C0) && MRI.hasOneUse(RHS.getReg())) {
    // Might be case 2.
    // For unsigned cases, we don't want C1 to wrap back to UINT64_MAX
    // when C0 is zero.
    if ((CC == RISCVCC::COND_GE || CC == RISCVCC::COND_LT) || C0)
      if (Register RegZ = searchConst(C0 - 1)) {
        reverseBranchCondition(Cond);
        Cond[1] = MachineOperand::CreateReg(RegZ, /*isDef=*/false);
        Cond[2] = MachineOperand::CreateReg(LHS.getReg(), /*isDef=*/false);
        // We might extend the live range of Z, clear its kill flag to
        // account for this.
        MRI.clearKillFlags(RegZ);
        Modify = true;
      }
  }

  if (!Modify)
    return false;

  // Build the new branch and remove the old one.
  BuildMI(*MBB, MI, MI.getDebugLoc(),
          getBrCond(static_cast<RISCVCC::CondCode>(Cond[0].getImm())))
      .add(Cond[1])
      .add(Cond[2])
      .addMBB(TBB);
  MI.eraseFromParent();

  return true;
}

MachineBasicBlock *
RISCVInstrInfo::getBranchDestBlock(const MachineInstr &MI) const {
  assert(MI.getDesc().isBranch() && "Unexpected opcode!");
+2 −0
Original line number Diff line number Diff line
@@ -121,6 +121,8 @@ public:
  bool
  reverseBranchCondition(SmallVectorImpl<MachineOperand> &Cond) const override;

  bool optimizeCondBranch(MachineInstr &MI) const override;

  MachineBasicBlock *getBranchDestBlock(const MachineInstr &MI) const override;

  bool isBranchOffsetInRange(unsigned BranchOpc,
+119 −0
Original line number Diff line number Diff line
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 3
; RUN: llc -mtriple=riscv32 -O2 -verify-machineinstrs < %s | FileCheck %s
; RUN: llc -mtriple=riscv64 -O2 -verify-machineinstrs < %s | FileCheck %s

define void @u_case1_a(ptr %a, i32 signext %b, ptr %c, ptr %d) {
; CHECK-LABEL: u_case1_a:
; CHECK:       # %bb.0:
; CHECK-NEXT:    li a4, 32
; CHECK-NEXT:    sw a4, 0(a0)
; CHECK-NEXT:    bgeu a1, a4, .LBB0_2
; CHECK-NEXT:  # %bb.1: # %block1
; CHECK-NEXT:    sw a1, 0(a2)
; CHECK-NEXT:    ret
; CHECK-NEXT:  .LBB0_2: # %block2
; CHECK-NEXT:    li a0, 87
; CHECK-NEXT:    sw a0, 0(a3)
; CHECK-NEXT:    ret
  store i32 32, ptr %a
  %p = icmp ule i32 %b, 31
  br i1 %p, label %block1, label %block2

block1:                                           ; preds = %0
  store i32 %b, ptr %c
  br label %end_block

block2:                                           ; preds = %0
  store i32 87, ptr %d
  br label %end_block

end_block:                                        ; preds = %block2, %block1
  ret void
}

define void @case1_a(ptr %a, i32 signext %b, ptr %c, ptr %d) {
; CHECK-LABEL: case1_a:
; CHECK:       # %bb.0:
; CHECK-NEXT:    li a4, -1
; CHECK-NEXT:    sw a4, 0(a0)
; CHECK-NEXT:    bge a1, a4, .LBB1_2
; CHECK-NEXT:  # %bb.1: # %block1
; CHECK-NEXT:    sw a1, 0(a2)
; CHECK-NEXT:    ret
; CHECK-NEXT:  .LBB1_2: # %block2
; CHECK-NEXT:    li a0, 87
; CHECK-NEXT:    sw a0, 0(a3)
; CHECK-NEXT:    ret
  store i32 -1, ptr %a
  %p = icmp sle i32 %b, -2
  br i1 %p, label %block1, label %block2

block1:                                           ; preds = %0
  store i32 %b, ptr %c
  br label %end_block

block2:                                           ; preds = %0
  store i32 87, ptr %d
  br label %end_block

end_block:                                        ; preds = %block2, %block1
  ret void
}

define void @u_case2_a(ptr %a, i32 signext %b, ptr %c, ptr %d) {
; CHECK-LABEL: u_case2_a:
; CHECK:       # %bb.0:
; CHECK-NEXT:    li a4, 32
; CHECK-NEXT:    sw a4, 0(a0)
; CHECK-NEXT:    bgeu a4, a1, .LBB2_2
; CHECK-NEXT:  # %bb.1: # %block1
; CHECK-NEXT:    sw a1, 0(a2)
; CHECK-NEXT:    ret
; CHECK-NEXT:  .LBB2_2: # %block2
; CHECK-NEXT:    li a0, 87
; CHECK-NEXT:    sw a0, 0(a3)
; CHECK-NEXT:    ret
  store i32 32, ptr %a
  %p = icmp uge i32 %b, 33
  br i1 %p, label %block1, label %block2

block1:                                           ; preds = %0
  store i32 %b, ptr %c
  br label %end_block

block2:                                           ; preds = %0
  store i32 87, ptr %d
  br label %end_block

end_block:                                        ; preds = %block2, %block1
  ret void
}

define void @case2_a(ptr %a, i32 signext %b, ptr %c, ptr %d) {
; CHECK-LABEL: case2_a:
; CHECK:       # %bb.0:
; CHECK-NEXT:    li a4, -4
; CHECK-NEXT:    sw a4, 0(a0)
; CHECK-NEXT:    bge a4, a1, .LBB3_2
; CHECK-NEXT:  # %bb.1: # %block1
; CHECK-NEXT:    sw a1, 0(a2)
; CHECK-NEXT:    ret
; CHECK-NEXT:  .LBB3_2: # %block2
; CHECK-NEXT:    li a0, 87
; CHECK-NEXT:    sw a0, 0(a3)
; CHECK-NEXT:    ret
  store i32 -4, ptr %a
  %p = icmp sge i32 %b, -3
  br i1 %p, label %block1, label %block2

block1:                                           ; preds = %0
  store i32 %b, ptr %c
  br label %end_block

block2:                                           ; preds = %0
  store i32 87, ptr %d
  br label %end_block

end_block:                                        ; preds = %block2, %block1
  ret void
}