Commit 2dd52b45 authored by Noah Goldstein's avatar Noah Goldstein
Browse files

[InstCombine] Improve logic for adding flags to shift instructions.

Instead of relying on constant operands, use known bits to do the
computation.

Proofs: https://alive2.llvm.org/ce/z/M-aBnw

Differential Revision: https://reviews.llvm.org/D157532
parent 968468af
Loading
Loading
Loading
Loading
+65 −28
Original line number Diff line number Diff line
@@ -941,6 +941,60 @@ Instruction *InstCombinerImpl::foldLShrOverflowBit(BinaryOperator &I) {
  return new ZExtInst(Overflow, Ty);
}

// Try to set nuw/nsw flags on shl or exact flag on lshr/ashr using knownbits.
static bool setShiftFlags(BinaryOperator &I, const SimplifyQuery &Q) {
  assert(I.isShift() && "Expected a shift as input");
  // We already have all the flags.
  if (I.getOpcode() == Instruction::Shl) {
    if (I.hasNoUnsignedWrap() && I.hasNoSignedWrap())
      return false;
  } else {
    if (I.isExact())
      return false;
  }

  // Compute what we know about shift count.
  KnownBits KnownCnt =
      computeKnownBits(I.getOperand(1), Q.DL, /*Depth*/ 0, Q.AC, Q.CxtI, Q.DT);
  // If we know nothing about shift count or its a poison shift, we won't be
  // able to prove anything so return before computing shift amount.
  if (KnownCnt.isUnknown())
    return false;
  unsigned BitWidth = KnownCnt.getBitWidth();
  APInt MaxCnt = KnownCnt.getMaxValue();
  if (MaxCnt.uge(BitWidth))
    return false;

  KnownBits KnownAmt =
      computeKnownBits(I.getOperand(0), Q.DL, /*Depth*/ 0, Q.AC, Q.CxtI, Q.DT);
  bool Changed = false;

  if (I.getOpcode() == Instruction::Shl) {
    // If we have as many leading zeros than maximum shift cnt we have nuw.
    if (!I.hasNoUnsignedWrap() && MaxCnt.ule(KnownAmt.countMinLeadingZeros())) {
      I.setHasNoUnsignedWrap();
      Changed = true;
    }
    // If we have more sign bits than maximum shift cnt we have nsw.
    if (!I.hasNoSignedWrap()) {
      if (MaxCnt.ult(KnownAmt.countMinSignBits()) ||
          MaxCnt.ult(ComputeNumSignBits(I.getOperand(0), Q.DL, /*Depth*/ 0,
                                        Q.AC, Q.CxtI, Q.DT))) {
        I.setHasNoSignedWrap();
        Changed = true;
      }
    }
    return Changed;
  }

  // If we have at least as many trailing zeros as maximum count then we have
  // exact.
  Changed = MaxCnt.ule(KnownAmt.countMinTrailingZeros());
  I.setIsExact(Changed);

  return Changed;
}

Instruction *InstCombinerImpl::visitShl(BinaryOperator &I) {
  const SimplifyQuery Q = SQ.getWithInstruction(&I);

@@ -1121,21 +1175,10 @@ Instruction *InstCombinerImpl::visitShl(BinaryOperator &I) {
      Value *NewShift = Builder.CreateShl(X, Op1);
      return BinaryOperator::CreateSub(NewLHS, NewShift);
    }

    // If the shifted-out value is known-zero, then this is a NUW shift.
    if (!I.hasNoUnsignedWrap() &&
        MaskedValueIsZero(Op0, APInt::getHighBitsSet(BitWidth, ShAmtC), 0,
                          &I)) {
      I.setHasNoUnsignedWrap();
      return &I;
  }

    // If the shifted-out value is all signbits, then this is a NSW shift.
    if (!I.hasNoSignedWrap() && ComputeNumSignBits(Op0, 0, &I) > ShAmtC) {
      I.setHasNoSignedWrap();
  if (setShiftFlags(I, Q))
    return &I;
    }
  }

  // Transform  (x >> y) << y  to  x & (-1 << y)
  // Valid for any type of right-shift.
@@ -1427,14 +1470,11 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) {
      Value *And = Builder.CreateAnd(BoolX, BoolY);
      return new ZExtInst(And, Ty);
    }
  }

    // If the shifted-out value is known-zero, then this is an exact shift.
    if (!I.isExact() &&
        MaskedValueIsZero(Op0, APInt::getLowBitsSet(BitWidth, ShAmtC), 0, &I)) {
      I.setIsExact();
  const SimplifyQuery Q = SQ.getWithInstruction(&I);
  if (setShiftFlags(I, Q))
    return &I;
    }
  }

  // Transform  (x << y) >> y  to  x & (-1 >> y)
  if (match(Op0, m_OneUse(m_Shl(m_Value(X), m_Specific(Op1))))) {
@@ -1594,14 +1634,11 @@ Instruction *InstCombinerImpl::visitAShr(BinaryOperator &I) {
      if (match(Op0, m_OneUse(m_NSWSub(m_Value(X), m_Value(Y)))))
        return new SExtInst(Builder.CreateICmpSLT(X, Y), Ty);
    }
  }

    // If the shifted-out value is known-zero, then this is an exact shift.
    if (!I.isExact() &&
        MaskedValueIsZero(Op0, APInt::getLowBitsSet(BitWidth, ShAmt), 0, &I)) {
      I.setIsExact();
  const SimplifyQuery Q = SQ.getWithInstruction(&I);
  if (setShiftFlags(I, Q))
    return &I;
    }
  }

  // Prefer `-(x & 1)` over `(x << (bitwidth(x)-1)) a>> (bitwidth(x)-1)`
  // as the pattern to splat the lowest bit.
+15 −15
Original line number Diff line number Diff line
@@ -413,11 +413,11 @@ define i1 @mul_is_pow2(i16 %x, i16 %y, i16 %z) {
; CHECK-SAME: (i16 [[X:%.*]], i16 [[Y:%.*]], i16 [[Z:%.*]]) {
; CHECK-NEXT:    [[XSMALL:%.*]] = and i16 [[X]], 3
; CHECK-NEXT:    [[ZSMALL:%.*]] = and i16 [[Z]], 3
; CHECK-NEXT:    [[XP2:%.*]] = shl i16 4, [[XSMALL]]
; CHECK-NEXT:    [[ZP2:%.*]] = shl i16 2, [[ZSMALL]]
; CHECK-NEXT:    [[XX:%.*]] = mul nuw nsw i16 [[XP2]], [[ZP2]]
; CHECK-NEXT:    [[ZP2:%.*]] = shl nuw nsw i16 2, [[ZSMALL]]
; CHECK-NEXT:    [[TMP1:%.*]] = add nuw nsw i16 [[XSMALL]], 2
; CHECK-NEXT:    [[XX:%.*]] = shl nuw nsw i16 [[ZP2]], [[TMP1]]
; CHECK-NEXT:    [[AND:%.*]] = and i16 [[XX]], [[Y]]
; CHECK-NEXT:    [[R:%.*]] = icmp eq i16 [[AND]], [[XX]]
; CHECK-NEXT:    [[R:%.*]] = icmp ne i16 [[AND]], 0
; CHECK-NEXT:    ret i1 [[R]]
;
  %xsmall = and i16 %x, 3
@@ -436,9 +436,9 @@ define i1 @mul_is_pow2_fail(i16 %x, i16 %y, i16 %z) {
; CHECK-SAME: (i16 [[X:%.*]], i16 [[Y:%.*]], i16 [[Z:%.*]]) {
; CHECK-NEXT:    [[XSMALL:%.*]] = and i16 [[X]], 7
; CHECK-NEXT:    [[ZSMALL:%.*]] = and i16 [[Z]], 7
; CHECK-NEXT:    [[XP2:%.*]] = shl i16 4, [[XSMALL]]
; CHECK-NEXT:    [[ZP2:%.*]] = shl i16 2, [[ZSMALL]]
; CHECK-NEXT:    [[XX:%.*]] = mul i16 [[XP2]], [[ZP2]]
; CHECK-NEXT:    [[ZP2:%.*]] = shl nuw nsw i16 2, [[ZSMALL]]
; CHECK-NEXT:    [[TMP1:%.*]] = add nuw nsw i16 [[XSMALL]], 2
; CHECK-NEXT:    [[XX:%.*]] = shl i16 [[ZP2]], [[TMP1]]
; CHECK-NEXT:    [[AND:%.*]] = and i16 [[XX]], [[Y]]
; CHECK-NEXT:    [[R:%.*]] = icmp eq i16 [[AND]], [[XX]]
; CHECK-NEXT:    ret i1 [[R]]
@@ -459,9 +459,9 @@ define i1 @mul_is_pow2_fail2(i16 %x, i16 %y, i16 %z) {
; CHECK-SAME: (i16 [[X:%.*]], i16 [[Y:%.*]], i16 [[Z:%.*]]) {
; CHECK-NEXT:    [[XSMALL:%.*]] = and i16 [[X]], 3
; CHECK-NEXT:    [[ZSMALL:%.*]] = and i16 [[Z]], 3
; CHECK-NEXT:    [[XP2:%.*]] = shl i16 3, [[XSMALL]]
; CHECK-NEXT:    [[ZP2:%.*]] = shl i16 2, [[ZSMALL]]
; CHECK-NEXT:    [[XX:%.*]] = mul nuw nsw i16 [[XP2]], [[ZP2]]
; CHECK-NEXT:    [[XP2:%.*]] = shl nuw nsw i16 3, [[XSMALL]]
; CHECK-NEXT:    [[TMP1:%.*]] = add nuw nsw i16 [[ZSMALL]], 1
; CHECK-NEXT:    [[XX:%.*]] = shl nuw nsw i16 [[XP2]], [[TMP1]]
; CHECK-NEXT:    [[AND:%.*]] = and i16 [[XX]], [[Y]]
; CHECK-NEXT:    [[R:%.*]] = icmp eq i16 [[AND]], [[XX]]
; CHECK-NEXT:    ret i1 [[R]]
@@ -481,9 +481,9 @@ define i1 @shl_is_pow2(i16 %x, i16 %y) {
; CHECK-LABEL: define i1 @shl_is_pow2
; CHECK-SAME: (i16 [[X:%.*]], i16 [[Y:%.*]]) {
; CHECK-NEXT:    [[XSMALL:%.*]] = and i16 [[X]], 7
; CHECK-NEXT:    [[XX:%.*]] = shl i16 4, [[XSMALL]]
; CHECK-NEXT:    [[XX:%.*]] = shl nuw nsw i16 4, [[XSMALL]]
; CHECK-NEXT:    [[AND:%.*]] = and i16 [[XX]], [[Y]]
; CHECK-NEXT:    [[R:%.*]] = icmp eq i16 [[AND]], [[XX]]
; CHECK-NEXT:    [[R:%.*]] = icmp ne i16 [[AND]], 0
; CHECK-NEXT:    ret i1 [[R]]
;
  %xsmall = and i16 %x, 7
@@ -515,7 +515,7 @@ define i1 @shl_is_pow2_fail2(i16 %x, i16 %y) {
; CHECK-LABEL: define i1 @shl_is_pow2_fail2
; CHECK-SAME: (i16 [[X:%.*]], i16 [[Y:%.*]]) {
; CHECK-NEXT:    [[XSMALL:%.*]] = and i16 [[X]], 7
; CHECK-NEXT:    [[XX:%.*]] = shl i16 5, [[XSMALL]]
; CHECK-NEXT:    [[XX:%.*]] = shl nuw nsw i16 5, [[XSMALL]]
; CHECK-NEXT:    [[AND:%.*]] = and i16 [[XX]], [[Y]]
; CHECK-NEXT:    [[R:%.*]] = icmp eq i16 [[AND]], [[XX]]
; CHECK-NEXT:    ret i1 [[R]]
@@ -532,9 +532,9 @@ define i1 @lshr_is_pow2(i16 %x, i16 %y) {
; CHECK-LABEL: define i1 @lshr_is_pow2
; CHECK-SAME: (i16 [[X:%.*]], i16 [[Y:%.*]]) {
; CHECK-NEXT:    [[XSMALL:%.*]] = and i16 [[X]], 7
; CHECK-NEXT:    [[XX:%.*]] = lshr i16 512, [[XSMALL]]
; CHECK-NEXT:    [[XX:%.*]] = lshr exact i16 512, [[XSMALL]]
; CHECK-NEXT:    [[AND:%.*]] = and i16 [[XX]], [[Y]]
; CHECK-NEXT:    [[R:%.*]] = icmp eq i16 [[AND]], [[XX]]
; CHECK-NEXT:    [[R:%.*]] = icmp ne i16 [[AND]], 0
; CHECK-NEXT:    ret i1 [[R]]
;
  %xsmall = and i16 %x, 7
+1 −1
Original line number Diff line number Diff line
@@ -29,7 +29,7 @@ define i8 @and_not_shl(i8 %x) {
; CHECK-SAME: (i8 [[X:%.*]]) {
; CHECK-NEXT:    [[OP1_P2:%.*]] = icmp ult i8 [[X]], 6
; CHECK-NEXT:    call void @llvm.assume(i1 [[OP1_P2]])
; CHECK-NEXT:    [[SHIFT:%.*]] = shl i8 -1, [[X]]
; CHECK-NEXT:    [[SHIFT:%.*]] = shl nsw i8 -1, [[X]]
; CHECK-NEXT:    [[NOT:%.*]] = and i8 [[SHIFT]], 32
; CHECK-NEXT:    [[R:%.*]] = xor i8 [[NOT]], 32
; CHECK-NEXT:    ret i8 [[R]]
+2 −2
Original line number Diff line number Diff line
@@ -5,10 +5,10 @@
define i32 @src(i1 %x2) {
; CHECK-LABEL: @src(
; CHECK-NEXT:    [[X13:%.*]] = zext i1 [[X2:%.*]] to i32
; CHECK-NEXT:    [[_7:%.*]] = shl i32 -1, [[X13]]
; CHECK-NEXT:    [[_7:%.*]] = shl nsw i32 -1, [[X13]]
; CHECK-NEXT:    [[MASK:%.*]] = xor i32 [[_7]], -1
; CHECK-NEXT:    [[_8:%.*]] = and i32 [[MASK]], [[X13]]
; CHECK-NEXT:    [[_9:%.*]] = shl i32 [[_8]], [[X13]]
; CHECK-NEXT:    [[_9:%.*]] = shl nuw nsw i32 [[_8]], [[X13]]
; CHECK-NEXT:    ret i32 [[_9]]
;
  %x13 = zext i1 %x2 to i32
+1 −1
Original line number Diff line number Diff line
@@ -705,7 +705,7 @@ define i9 @rotateleft_9_neg_mask_wide_amount_commute(i9 %v, i33 %shamt) {
; CHECK-NEXT:    [[LSHAMT:%.*]] = and i33 [[SHAMT]], 8
; CHECK-NEXT:    [[RSHAMT:%.*]] = and i33 [[NEG]], 8
; CHECK-NEXT:    [[CONV:%.*]] = zext i9 [[V:%.*]] to i33
; CHECK-NEXT:    [[SHL:%.*]] = shl i33 [[CONV]], [[LSHAMT]]
; CHECK-NEXT:    [[SHL:%.*]] = shl nuw nsw i33 [[CONV]], [[LSHAMT]]
; CHECK-NEXT:    [[SHR:%.*]] = lshr i33 [[CONV]], [[RSHAMT]]
; CHECK-NEXT:    [[OR:%.*]] = or i33 [[SHL]], [[SHR]]
; CHECK-NEXT:    [[RET:%.*]] = trunc i33 [[OR]] to i9
Loading