Unverified Commit be57381a authored by Dhruv Chawla's avatar Dhruv Chawla Committed by GitHub
Browse files

[InstCombine] Create a class to lazily track computed known bits (#66611)

This patch adds a new class "WithCache" which stores a pointer to
any type passable to computeKnownBits along with KnownBits
information which is computed on-demand when getKnownBits()
is called. This allows reusing the known bits information when it is
passed as an argument to multiple functions.

It also changes a few functions to accept WithCache(s) so that
known bits information computed in some callees can be propagated to
others from the top level visitAddSub caller.

This gives a speedup of 0.14%:
https://llvm-compile-time-tracker.com/compare.php?from=499d41cef2e7bbb65804f6a815b9fa8b27efce0f&to=fbea87f1f1e6d5552e2bc309f8e201a3af6d28ec&stat=instructions:u
parent 7b1e6851
Loading
Loading
Loading
Loading
+15 −4
Original line number Diff line number Diff line
@@ -17,6 +17,7 @@
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/Analysis/SimplifyQuery.h"
#include "llvm/Analysis/WithCache.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/FMF.h"
@@ -90,6 +91,12 @@ KnownBits computeKnownBits(const Value *V, const APInt &DemandedElts,
                           const DominatorTree *DT = nullptr,
                           bool UseInstrInfo = true);

KnownBits computeKnownBits(const Value *V, const APInt &DemandedElts,
                           unsigned Depth, const SimplifyQuery &Q);

KnownBits computeKnownBits(const Value *V, unsigned Depth,
                           const SimplifyQuery &Q);

/// Compute known bits from the range metadata.
/// \p KnownZero the set of bits that are known to be zero
/// \p KnownOne the set of bits that are known to be one
@@ -107,7 +114,8 @@ KnownBits analyzeKnownBitsFromAndXorOr(
    bool UseInstrInfo = true);

/// Return true if LHS and RHS have no common bits set.
bool haveNoCommonBitsSet(const Value *LHS, const Value *RHS,
bool haveNoCommonBitsSet(const WithCache<const Value *> &LHSCache,
                         const WithCache<const Value *> &RHSCache,
                         const SimplifyQuery &SQ);

/// Return true if the given value is known to have exactly one bit set when
@@ -847,9 +855,12 @@ OverflowResult computeOverflowForUnsignedMul(const Value *LHS, const Value *RHS,
                                             const SimplifyQuery &SQ);
OverflowResult computeOverflowForSignedMul(const Value *LHS, const Value *RHS,
                                           const SimplifyQuery &SQ);
OverflowResult computeOverflowForUnsignedAdd(const Value *LHS, const Value *RHS,
OverflowResult
computeOverflowForUnsignedAdd(const WithCache<const Value *> &LHS,
                              const WithCache<const Value *> &RHS,
                              const SimplifyQuery &SQ);
OverflowResult computeOverflowForSignedAdd(const Value *LHS, const Value *RHS,
OverflowResult computeOverflowForSignedAdd(const WithCache<const Value *> &LHS,
                                           const WithCache<const Value *> &RHS,
                                           const SimplifyQuery &SQ);
/// This version also leverages the sign bit of Add if known.
OverflowResult computeOverflowForSignedAdd(const AddOperator *Add,
+71 −0
Original line number Diff line number Diff line
//===- llvm/Analysis/WithCache.h - KnownBits cache for pointers -*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Store a pointer to any type along with the KnownBits information for it
// that is computed lazily (if required).
//
//===----------------------------------------------------------------------===//

#ifndef LLVM_ANALYSIS_WITHCACHE_H
#define LLVM_ANALYSIS_WITHCACHE_H

#include "llvm/IR/Value.h"
#include "llvm/Support/KnownBits.h"
#include <type_traits>

namespace llvm {
struct SimplifyQuery;
KnownBits computeKnownBits(const Value *V, unsigned Depth,
                           const SimplifyQuery &Q);

template <typename Arg> class WithCache {
  static_assert(std::is_pointer_v<Arg>, "WithCache requires a pointer type!");

  using UnderlyingType = std::remove_pointer_t<Arg>;
  constexpr static bool IsConst = std::is_const_v<Arg>;

  template <typename T, bool Const>
  using conditionally_const_t = std::conditional_t<Const, const T, T>;

  using PointerType = conditionally_const_t<UnderlyingType *, IsConst>;
  using ReferenceType = conditionally_const_t<UnderlyingType &, IsConst>;

  // Store the presence of the KnownBits information in one of the bits of
  // Pointer.
  // true  -> present
  // false -> absent
  mutable PointerIntPair<PointerType, 1, bool> Pointer;
  mutable KnownBits Known;

  void calculateKnownBits(const SimplifyQuery &Q) const {
    Known = computeKnownBits(Pointer.getPointer(), 0, Q);
    Pointer.setInt(true);
  }

public:
  WithCache(PointerType Pointer) : Pointer(Pointer, false) {}
  WithCache(PointerType Pointer, const KnownBits &Known)
      : Pointer(Pointer, true), Known(Known) {}

  [[nodiscard]] PointerType getValue() const { return Pointer.getPointer(); }

  [[nodiscard]] const KnownBits &getKnownBits(const SimplifyQuery &Q) const {
    if (!hasKnownBits())
      calculateKnownBits(Q);
    return Known;
  }

  [[nodiscard]] bool hasKnownBits() const { return Pointer.getInt(); }

  operator PointerType() const { return Pointer.getPointer(); }
  PointerType operator->() const { return Pointer.getPointer(); }
  ReferenceType operator*() const { return *Pointer.getPointer(); }
};
} // namespace llvm

#endif
+8 −5
Original line number Diff line number Diff line
@@ -510,14 +510,17 @@ public:
                                             SQ.getWithInstruction(CxtI));
  }

  OverflowResult computeOverflowForUnsignedAdd(const Value *LHS,
                                               const Value *RHS,
  OverflowResult
  computeOverflowForUnsignedAdd(const WithCache<const Value *> &LHS,
                                const WithCache<const Value *> &RHS,
                                const Instruction *CxtI) const {
    return llvm::computeOverflowForUnsignedAdd(LHS, RHS,
                                               SQ.getWithInstruction(CxtI));
  }

  OverflowResult computeOverflowForSignedAdd(const Value *LHS, const Value *RHS,
  OverflowResult
  computeOverflowForSignedAdd(const WithCache<const Value *> &LHS,
                              const WithCache<const Value *> &RHS,
                              const Instruction *CxtI) const {
    return llvm::computeOverflowForSignedAdd(LHS, RHS,
                                             SQ.getWithInstruction(CxtI));
+37 −38
Original line number Diff line number Diff line
@@ -33,6 +33,7 @@
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/Analysis/VectorUtils.h"
#include "llvm/Analysis/WithCache.h"
#include "llvm/IR/Argument.h"
#include "llvm/IR/Attributes.h"
#include "llvm/IR/BasicBlock.h"
@@ -178,17 +179,11 @@ void llvm::computeKnownBits(const Value *V, const APInt &DemandedElts,
      SimplifyQuery(DL, DT, AC, safeCxtI(V, CxtI), UseInstrInfo));
}

static KnownBits computeKnownBits(const Value *V, const APInt &DemandedElts,
                                  unsigned Depth, const SimplifyQuery &Q);

static KnownBits computeKnownBits(const Value *V, unsigned Depth,
                                  const SimplifyQuery &Q);

KnownBits llvm::computeKnownBits(const Value *V, const DataLayout &DL,
                                 unsigned Depth, AssumptionCache *AC,
                                 const Instruction *CxtI,
                                 const DominatorTree *DT, bool UseInstrInfo) {
  return ::computeKnownBits(
  return computeKnownBits(
      V, Depth, SimplifyQuery(DL, DT, AC, safeCxtI(V, CxtI), UseInstrInfo));
}

@@ -196,13 +191,17 @@ KnownBits llvm::computeKnownBits(const Value *V, const APInt &DemandedElts,
                                 const DataLayout &DL, unsigned Depth,
                                 AssumptionCache *AC, const Instruction *CxtI,
                                 const DominatorTree *DT, bool UseInstrInfo) {
  return ::computeKnownBits(
  return computeKnownBits(
      V, DemandedElts, Depth,
      SimplifyQuery(DL, DT, AC, safeCxtI(V, CxtI), UseInstrInfo));
}

bool llvm::haveNoCommonBitsSet(const Value *LHS, const Value *RHS,
bool llvm::haveNoCommonBitsSet(const WithCache<const Value *> &LHSCache,
                               const WithCache<const Value *> &RHSCache,
                               const SimplifyQuery &SQ) {
  const Value *LHS = LHSCache.getValue();
  const Value *RHS = RHSCache.getValue();

  assert(LHS->getType() == RHS->getType() &&
         "LHS and RHS should have the same type");
  assert(LHS->getType()->isIntOrIntVectorTy() &&
@@ -250,12 +249,9 @@ bool llvm::haveNoCommonBitsSet(const Value *LHS, const Value *RHS,
        match(LHS, m_Not(m_c_Or(m_Specific(A), m_Specific(B)))))
      return true;
  }
  IntegerType *IT = cast<IntegerType>(LHS->getType()->getScalarType());
  KnownBits LHSKnown(IT->getBitWidth());
  KnownBits RHSKnown(IT->getBitWidth());
  ::computeKnownBits(LHS, LHSKnown, 0, SQ);
  ::computeKnownBits(RHS, RHSKnown, 0, SQ);
  return KnownBits::haveNoCommonBitsSet(LHSKnown, RHSKnown);

  return KnownBits::haveNoCommonBitsSet(LHSCache.getKnownBits(SQ),
                                        RHSCache.getKnownBits(SQ));
}

bool llvm::isOnlyUsedInZeroEqualityComparison(const Instruction *I) {
@@ -1784,19 +1780,19 @@ static void computeKnownBitsFromOperator(const Operator *I,

/// Determine which bits of V are known to be either zero or one and return
/// them.
KnownBits computeKnownBits(const Value *V, const APInt &DemandedElts,
KnownBits llvm::computeKnownBits(const Value *V, const APInt &DemandedElts,
                                 unsigned Depth, const SimplifyQuery &Q) {
  KnownBits Known(getBitWidth(V->getType(), Q.DL));
  computeKnownBits(V, DemandedElts, Known, Depth, Q);
  ::computeKnownBits(V, DemandedElts, Known, Depth, Q);
  return Known;
}

/// Determine which bits of V are known to be either zero or one and return
/// them.
KnownBits computeKnownBits(const Value *V, unsigned Depth,
KnownBits llvm::computeKnownBits(const Value *V, unsigned Depth,
                                 const SimplifyQuery &Q) {
  KnownBits Known(getBitWidth(V->getType(), Q.DL));
  computeKnownBits(V, Known, Depth, Q);
  ::computeKnownBits(V, Known, Depth, Q);
  return Known;
}

@@ -6256,10 +6252,11 @@ static OverflowResult mapOverflowResult(ConstantRange::OverflowResult OR) {

/// Combine constant ranges from computeConstantRange() and computeKnownBits().
static ConstantRange
computeConstantRangeIncludingKnownBits(const Value *V, bool ForSigned,
computeConstantRangeIncludingKnownBits(const WithCache<const Value *> &V,
                                       bool ForSigned,
                                       const SimplifyQuery &SQ) {
  KnownBits Known = ::computeKnownBits(V, /*Depth=*/0, SQ);
  ConstantRange CR1 = ConstantRange::fromKnownBits(Known, ForSigned);
  ConstantRange CR1 =
      ConstantRange::fromKnownBits(V.getKnownBits(SQ), ForSigned);
  ConstantRange CR2 = computeConstantRange(V, ForSigned, SQ.IIQ.UseInstrInfo);
  ConstantRange::PreferredRangeType RangeType =
      ForSigned ? ConstantRange::Signed : ConstantRange::Unsigned;
@@ -6269,8 +6266,8 @@ computeConstantRangeIncludingKnownBits(const Value *V, bool ForSigned,
OverflowResult llvm::computeOverflowForUnsignedMul(const Value *LHS,
                                                   const Value *RHS,
                                                   const SimplifyQuery &SQ) {
  KnownBits LHSKnown = ::computeKnownBits(LHS, /*Depth=*/0, SQ);
  KnownBits RHSKnown = ::computeKnownBits(RHS, /*Depth=*/0, SQ);
  KnownBits LHSKnown = computeKnownBits(LHS, /*Depth=*/0, SQ);
  KnownBits RHSKnown = computeKnownBits(RHS, /*Depth=*/0, SQ);
  ConstantRange LHSRange = ConstantRange::fromKnownBits(LHSKnown, false);
  ConstantRange RHSRange = ConstantRange::fromKnownBits(RHSKnown, false);
  return mapOverflowResult(LHSRange.unsignedMulMayOverflow(RHSRange));
@@ -6307,16 +6304,17 @@ OverflowResult llvm::computeOverflowForSignedMul(const Value *LHS,
    // product is exactly the minimum negative number.
    // E.g. mul i16 with 17 sign bits: 0xff00 * 0xff80 = 0x8000
    // For simplicity we just check if at least one side is not negative.
    KnownBits LHSKnown = ::computeKnownBits(LHS, /*Depth=*/0, SQ);
    KnownBits RHSKnown = ::computeKnownBits(RHS, /*Depth=*/0, SQ);
    KnownBits LHSKnown = computeKnownBits(LHS, /*Depth=*/0, SQ);
    KnownBits RHSKnown = computeKnownBits(RHS, /*Depth=*/0, SQ);
    if (LHSKnown.isNonNegative() || RHSKnown.isNonNegative())
      return OverflowResult::NeverOverflows;
  }
  return OverflowResult::MayOverflow;
}

OverflowResult llvm::computeOverflowForUnsignedAdd(const Value *LHS,
                                                   const Value *RHS,
OverflowResult
llvm::computeOverflowForUnsignedAdd(const WithCache<const Value *> &LHS,
                                    const WithCache<const Value *> &RHS,
                                    const SimplifyQuery &SQ) {
  ConstantRange LHSRange =
      computeConstantRangeIncludingKnownBits(LHS, /*ForSigned=*/false, SQ);
@@ -6325,10 +6323,10 @@ OverflowResult llvm::computeOverflowForUnsignedAdd(const Value *LHS,
  return mapOverflowResult(LHSRange.unsignedAddMayOverflow(RHSRange));
}

static OverflowResult computeOverflowForSignedAdd(const Value *LHS,
                                                  const Value *RHS,
                                                  const AddOperator *Add,
                                                  const SimplifyQuery &SQ) {
static OverflowResult
computeOverflowForSignedAdd(const WithCache<const Value *> &LHS,
                            const WithCache<const Value *> &RHS,
                            const AddOperator *Add, const SimplifyQuery &SQ) {
  if (Add && Add->hasNoSignedWrap()) {
    return OverflowResult::NeverOverflows;
  }
@@ -6944,8 +6942,9 @@ OverflowResult llvm::computeOverflowForSignedAdd(const AddOperator *Add,
                                       Add, SQ);
}

OverflowResult llvm::computeOverflowForSignedAdd(const Value *LHS,
                                                 const Value *RHS,
OverflowResult
llvm::computeOverflowForSignedAdd(const WithCache<const Value *> &LHS,
                                  const WithCache<const Value *> &RHS,
                                  const SimplifyQuery &SQ) {
  return ::computeOverflowForSignedAdd(LHS, RHS, nullptr, SQ);
}
+5 −3
Original line number Diff line number Diff line
@@ -1566,7 +1566,8 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
    return replaceInstUsesWith(I, Constant::getNullValue(I.getType()));

  // A+B --> A|B iff A and B have no bits set in common.
  if (haveNoCommonBitsSet(LHS, RHS, SQ.getWithInstruction(&I)))
  WithCache<const Value *> LHSCache(LHS), RHSCache(RHS);
  if (haveNoCommonBitsSet(LHSCache, RHSCache, SQ.getWithInstruction(&I)))
    return BinaryOperator::CreateOr(LHS, RHS);

  if (Instruction *Ext = narrowMathIfNoOverflow(I))
@@ -1661,11 +1662,12 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
  // willNotOverflowUnsignedAdd to reduce the number of invocations of
  // computeKnownBits.
  bool Changed = false;
  if (!I.hasNoSignedWrap() && willNotOverflowSignedAdd(LHS, RHS, I)) {
  if (!I.hasNoSignedWrap() && willNotOverflowSignedAdd(LHSCache, RHSCache, I)) {
    Changed = true;
    I.setHasNoSignedWrap(true);
  }
  if (!I.hasNoUnsignedWrap() && willNotOverflowUnsignedAdd(LHS, RHS, I)) {
  if (!I.hasNoUnsignedWrap() &&
      willNotOverflowUnsignedAdd(LHSCache, RHSCache, I)) {
    Changed = true;
    I.setHasNoUnsignedWrap(true);
  }
Loading