Commit bab99345 authored by Jessica Paquette's avatar Jessica Paquette
Browse files

[AArch64][GlobalISel][NFC] Factor out TB(N)Z emission code into its own function

Factor it out into `emitTestBit` and add some asserts to the new function.

This will be useful for implementing TB(N)Z emission for SLT/SGT compares.

Differential Revision: https://reviews.llvm.org/D74080
parent 7212f657
Loading
Loading
Loading
Loading
+43 −20
Original line number Diff line number Diff line
@@ -171,6 +171,13 @@ private:
  MachineInstr *emitCSetForICMP(Register DefReg, unsigned Pred,
                                MachineIRBuilder &MIRBuilder) const;

  /// Emit a TB(N)Z instruction which tests \p Bit in \p TestReg.
  /// \p IsNegative is true if the test should be "not zero".
  /// This will also optimize the test bit instruction when possible.
  MachineInstr *emitTestBit(Register TestReg, uint64_t Bit, bool IsNegative,
                            MachineBasicBlock *DstMBB,
                            MachineIRBuilder &MIB) const;

  // Equivalent to the i32shift_a and friends from AArch64InstrInfo.td.
  // We use these manually instead of using the importer since it doesn't
  // support SDNodeXForm.
@@ -1114,6 +1121,40 @@ static Register getTestBitReg(Register Reg, uint64_t &Bit, bool &Invert,
  return Reg;
}

MachineInstr *AArch64InstructionSelector::emitTestBit(
    Register TestReg, uint64_t Bit, bool IsNegative, MachineBasicBlock *DstMBB,
    MachineIRBuilder &MIB) const {
  MachineRegisterInfo &MRI = *MIB.getMRI();
#ifndef NDEBUG
  assert(ProduceNonFlagSettingCondBr &&
         "Cannot emit TB(N)Z with speculation tracking!");
  assert(TestReg.isValid());
  LLT Ty = MRI.getType(TestReg);
  unsigned Size = Ty.getSizeInBits();
  assert(Bit < Size &&
         "Bit to test must be smaler than the size of a test register!");
  assert(Ty.isScalar() && "Expected a scalar!");
  assert(Size >= 32 && "Expected at least a 32-bit register!");
#endif

  // Attempt to optimize the test bit by walking over instructions.
  TestReg = getTestBitReg(TestReg, Bit, IsNegative, MRI);
  bool UseWReg = Bit < 32;

  // When the test register is a 64-bit register, we have to narrow to make
  // TBNZW work.
  if (UseWReg)
    TestReg = narrowExtendRegIfNeeded(TestReg, MIB);

  static const unsigned OpcTable[2][2] = {{AArch64::TBZX, AArch64::TBNZX},
                                          {AArch64::TBZW, AArch64::TBNZW}};
  unsigned Opc = OpcTable[UseWReg][IsNegative];
  auto TestBitMI =
      MIB.buildInstr(Opc).addReg(TestReg).addImm(Bit).addMBB(DstMBB);
  constrainSelectedInstRegOperands(*TestBitMI, TII, TRI, RBI);
  return &*TestBitMI;
}

bool AArch64InstructionSelector::tryOptAndIntoCompareBranch(
    MachineInstr *AndInst, int64_t CmpConstant, const CmpInst::Predicate &Pred,
    MachineBasicBlock *DstMBB, MachineIRBuilder &MIB) const {
@@ -1158,30 +1199,12 @@ bool AArch64InstructionSelector::tryOptAndIntoCompareBranch(
  if (!MaybeBit || !isPowerOf2_64(MaybeBit->Value))
    return false;

  // Try to optimize the TB(N)Z.
  uint64_t Bit = Log2_64(static_cast<uint64_t>(MaybeBit->Value));
  Register TestReg = AndInst->getOperand(1).getReg();
  bool Invert = Pred == CmpInst::Predicate::ICMP_NE;
  TestReg = getTestBitReg(TestReg, Bit, Invert, MRI);

  // Choose the correct TB(N)Z opcode to use.
  unsigned Opc = 0;
  if (Bit < 32) {
    // When the bit is less than 32, we have to use a TBZW even if we're on a 64
    // bit register.
    Opc = Invert ? AArch64::TBNZW : AArch64::TBZW;
    TestReg = narrowExtendRegIfNeeded(TestReg, MIB);
  } else {
    // Same idea for when Bit >= 32. We don't have to narrow here, because if
    // Bit > 32, then the G_CONSTANT must be outside the range of valid 32-bit
    // values. So, we must have a s64.
    Opc = Invert ? AArch64::TBNZX : AArch64::TBZX;
  }

  // Construct the branch.
  auto BranchMI =
      MIB.buildInstr(Opc).addReg(TestReg).addImm(Bit).addMBB(DstMBB);
  constrainSelectedInstRegOperands(*BranchMI, TII, TRI, RBI);
  // Emit a TB(N)Z.
  emitTestBit(TestReg, Bit, Invert, DstMBB, MIB);
  return true;
}