Unverified Commit 849f963e authored by Igor Kirillov's avatar Igor Kirillov Committed by GitHub
Browse files

[CodeGen] Improve ExpandMemCmp for more efficient non-register aligned sizes handling (#70469)

* Enhanced the logic of ExpandMemCmp pass to merge contiguous
subsequences
  in LoadSequence, based on sizes allowed in `AllowedTailExpansions`.
* This enhancement seeks to minimize the number of basic blocks and
produce
  optimized code when using memcmp with non-register aligned sizes.
* Enable this feature for AArch64 with memcmp sizes modulo 8 equal to
  3, 5, and 6.

Reapplication of #69942 after fixing a bug
parent 89564f0b
Loading
Loading
Loading
Loading
+11 −0
Original line number Diff line number Diff line
@@ -907,6 +907,17 @@ public:
    // be done with two 4-byte compares instead of 4+2+1-byte compares. This
    // requires all loads in LoadSizes to be doable in an unaligned way.
    bool AllowOverlappingLoads = false;

    // Sometimes, the amount of data that needs to be compared is smaller than
    // the standard register size, but it cannot be loaded with just one load
    // instruction. For example, if the size of the memory comparison is 6
    // bytes, we can handle it more efficiently by loading all 6 bytes in a
    // single block and generating an 8-byte number, instead of generating two
    // separate blocks with conditional jumps for 4 and 2 byte loads. This
    // approach simplifies the process and produces the comparison result as
    // normal. This array lists the allowed sizes of memcmp tails that can be
    // merged into one block
    SmallVector<unsigned, 4> AllowedTailExpansions;
  };
  MemCmpExpansionOptions enableMemCmpExpansion(bool OptSize,
                                               bool IsZeroCmp) const;
+75 −20
Original line number Diff line number Diff line
@@ -117,8 +117,8 @@ class MemCmpExpansion {
    Value *Lhs = nullptr;
    Value *Rhs = nullptr;
  };
  LoadPair getLoadPair(Type *LoadSizeType, bool NeedsBSwap, Type *CmpSizeType,
                       unsigned OffsetBytes);
  LoadPair getLoadPair(Type *LoadSizeType, Type *BSwapSizeType,
                       Type *CmpSizeType, unsigned OffsetBytes);

  static LoadEntryVector
  computeGreedyLoadSequence(uint64_t Size, llvm::ArrayRef<unsigned> LoadSizes,
@@ -128,6 +128,11 @@ class MemCmpExpansion {
                                 unsigned MaxNumLoads,
                                 unsigned &NumLoadsNonOneByte);

  static void optimiseLoadSequence(
      LoadEntryVector &LoadSequence,
      const TargetTransformInfo::MemCmpExpansionOptions &Options,
      bool IsUsedForZeroCmp);

public:
  MemCmpExpansion(CallInst *CI, uint64_t Size,
                  const TargetTransformInfo::MemCmpExpansionOptions &Options,
@@ -210,6 +215,37 @@ MemCmpExpansion::computeOverlappingLoadSequence(uint64_t Size,
  return LoadSequence;
}

void MemCmpExpansion::optimiseLoadSequence(
    LoadEntryVector &LoadSequence,
    const TargetTransformInfo::MemCmpExpansionOptions &Options,
    bool IsUsedForZeroCmp) {
  // This part of code attempts to optimize the LoadSequence by merging allowed
  // subsequences into single loads of allowed sizes from
  // `MemCmpExpansionOptions::AllowedTailExpansions`. If it is for zero
  // comparison or if no allowed tail expansions are specified, we exit early.
  if (IsUsedForZeroCmp || Options.AllowedTailExpansions.empty())
    return;

  while (LoadSequence.size() >= 2) {
    auto Last = LoadSequence[LoadSequence.size() - 1];
    auto PreLast = LoadSequence[LoadSequence.size() - 2];

    // Exit the loop if the two sequences are not contiguous
    if (PreLast.Offset + PreLast.LoadSize != Last.Offset)
      break;

    auto LoadSize = Last.LoadSize + PreLast.LoadSize;
    if (find(Options.AllowedTailExpansions, LoadSize) ==
        Options.AllowedTailExpansions.end())
      break;

    // Remove the last two sequences and replace with the combined sequence
    LoadSequence.pop_back();
    LoadSequence.pop_back();
    LoadSequence.emplace_back(PreLast.Offset, LoadSize);
  }
}

// Initialize the basic block structure required for expansion of memcmp call
// with given maximum load size and memcmp size parameter.
// This structure includes:
@@ -255,6 +291,7 @@ MemCmpExpansion::MemCmpExpansion(
    }
  }
  assert(LoadSequence.size() <= Options.MaxNumLoads && "broken invariant");
  optimiseLoadSequence(LoadSequence, Options, IsUsedForZeroCmp);
}

unsigned MemCmpExpansion::getNumBlocks() {
@@ -278,7 +315,7 @@ void MemCmpExpansion::createResultBlock() {
}

MemCmpExpansion::LoadPair MemCmpExpansion::getLoadPair(Type *LoadSizeType,
                                                       bool NeedsBSwap,
                                                       Type *BSwapSizeType,
                                                       Type *CmpSizeType,
                                                       unsigned OffsetBytes) {
  // Get the memory source at offset `OffsetBytes`.
@@ -307,16 +344,22 @@ MemCmpExpansion::LoadPair MemCmpExpansion::getLoadPair(Type *LoadSizeType,
  if (!Rhs)
    Rhs = Builder.CreateAlignedLoad(LoadSizeType, RhsSource, RhsAlign);

  // Zero extend if Byte Swap intrinsic has different type
  if (BSwapSizeType && LoadSizeType != BSwapSizeType) {
    Lhs = Builder.CreateZExt(Lhs, BSwapSizeType);
    Rhs = Builder.CreateZExt(Rhs, BSwapSizeType);
  }

  // Swap bytes if required.
  if (NeedsBSwap) {
    Function *Bswap = Intrinsic::getDeclaration(CI->getModule(),
                                                Intrinsic::bswap, LoadSizeType);
  if (BSwapSizeType) {
    Function *Bswap = Intrinsic::getDeclaration(
        CI->getModule(), Intrinsic::bswap, BSwapSizeType);
    Lhs = Builder.CreateCall(Bswap, Lhs);
    Rhs = Builder.CreateCall(Bswap, Rhs);
  }

  // Zero extend if required.
  if (CmpSizeType != nullptr && CmpSizeType != LoadSizeType) {
  if (CmpSizeType != nullptr && CmpSizeType != Lhs->getType()) {
    Lhs = Builder.CreateZExt(Lhs, CmpSizeType);
    Rhs = Builder.CreateZExt(Rhs, CmpSizeType);
  }
@@ -332,7 +375,7 @@ void MemCmpExpansion::emitLoadCompareByteBlock(unsigned BlockIndex,
  BasicBlock *BB = LoadCmpBlocks[BlockIndex];
  Builder.SetInsertPoint(BB);
  const LoadPair Loads =
      getLoadPair(Type::getInt8Ty(CI->getContext()), /*NeedsBSwap=*/false,
      getLoadPair(Type::getInt8Ty(CI->getContext()), nullptr,
                  Type::getInt32Ty(CI->getContext()), OffsetBytes);
  Value *Diff = Builder.CreateSub(Loads.Lhs, Loads.Rhs);

@@ -385,11 +428,12 @@ Value *MemCmpExpansion::getCompareLoadPairs(unsigned BlockIndex,
  IntegerType *const MaxLoadType =
      NumLoads == 1 ? nullptr
                    : IntegerType::get(CI->getContext(), MaxLoadSize * 8);

  for (unsigned i = 0; i < NumLoads; ++i, ++LoadIndex) {
    const LoadEntry &CurLoadEntry = LoadSequence[LoadIndex];
    const LoadPair Loads = getLoadPair(
        IntegerType::get(CI->getContext(), CurLoadEntry.LoadSize * 8),
        /*NeedsBSwap=*/false, MaxLoadType, CurLoadEntry.Offset);
        IntegerType::get(CI->getContext(), CurLoadEntry.LoadSize * 8), nullptr,
        MaxLoadType, CurLoadEntry.Offset);

    if (NumLoads != 1) {
      // If we have multiple loads per block, we need to generate a composite
@@ -475,13 +519,19 @@ void MemCmpExpansion::emitLoadCompareBlock(unsigned BlockIndex) {

  Type *LoadSizeType =
      IntegerType::get(CI->getContext(), CurLoadEntry.LoadSize * 8);
  Type *MaxLoadType = IntegerType::get(CI->getContext(), MaxLoadSize * 8);
  Type *BSwapSizeType =
      DL.isLittleEndian()
          ? IntegerType::get(CI->getContext(),
                             PowerOf2Ceil(CurLoadEntry.LoadSize * 8))
          : nullptr;
  Type *MaxLoadType = IntegerType::get(
      CI->getContext(),
      std::max(MaxLoadSize, (unsigned)PowerOf2Ceil(CurLoadEntry.LoadSize)) * 8);
  assert(CurLoadEntry.LoadSize <= MaxLoadSize && "Unexpected load type");

  Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]);

  const LoadPair Loads =
      getLoadPair(LoadSizeType, /*NeedsBSwap=*/DL.isLittleEndian(), MaxLoadType,
  const LoadPair Loads = getLoadPair(LoadSizeType, BSwapSizeType, MaxLoadType,
                                     CurLoadEntry.Offset);

  // Add the loaded values to the phi nodes for calculating memcmp result only
@@ -587,19 +637,24 @@ Value *MemCmpExpansion::getMemCmpEqZeroOneBlock() {
/// A memcmp expansion that only has one block of load and compare can bypass
/// the compare, branch, and phi IR that is required in the general case.
Value *MemCmpExpansion::getMemCmpOneBlock() {
  Type *LoadSizeType = IntegerType::get(CI->getContext(), Size * 8);
  bool NeedsBSwap = DL.isLittleEndian() && Size != 1;
  Type *LoadSizeType = IntegerType::get(CI->getContext(), Size * 8);
  Type *BSwapSizeType =
      NeedsBSwap ? IntegerType::get(CI->getContext(), PowerOf2Ceil(Size * 8))
                 : nullptr;
  Type *MaxLoadType =
      IntegerType::get(CI->getContext(),
                       std::max(MaxLoadSize, (unsigned)PowerOf2Ceil(Size)) * 8);

  // The i8 and i16 cases don't need compares. We zext the loaded values and
  // subtract them to get the suitable negative, zero, or positive i32 result.
  if (Size < 4) {
    const LoadPair Loads =
        getLoadPair(LoadSizeType, NeedsBSwap, Builder.getInt32Ty(),
                    /*Offset*/ 0);
  if (Size == 1 || Size == 2) {
    const LoadPair Loads = getLoadPair(LoadSizeType, BSwapSizeType,
                                       Builder.getInt32Ty(), /*Offset*/ 0);
    return Builder.CreateSub(Loads.Lhs, Loads.Rhs);
  }

  const LoadPair Loads = getLoadPair(LoadSizeType, NeedsBSwap, LoadSizeType,
  const LoadPair Loads = getLoadPair(LoadSizeType, BSwapSizeType, MaxLoadType,
                                     /*Offset*/ 0);
  // The result of memcmp is negative, zero, or positive, so produce that by
  // subtracting 2 extended compare bits: sub (ugt, ult).
+1 −0
Original line number Diff line number Diff line
@@ -2994,6 +2994,7 @@ AArch64TTIImpl::enableMemCmpExpansion(bool OptSize, bool IsZeroCmp) const {
  // they may wake up the FP unit, which raises the power consumption.  Perhaps
  // they could be used with no holds barred (-O3).
  Options.LoadSizes = {8, 4, 2, 1};
  Options.AllowedTailExpansions = {3, 5, 6};
  return Options;
}

+3005 −0

File added.

Preview size limit exceeded, changes collapsed.

+881 −0

File added.

Preview size limit exceeded, changes collapsed.