Commit 8c5f2361 authored by Sanjay Patel's avatar Sanjay Patel
Browse files

[InstCombine] enable (X <<nsw C1) >>s C2 --> X <<nsw (C1 - C2) for vectors with splat constants

llvm-svn: 293570
parent 6d76b7b4
Loading
Loading
Loading
Loading
+19 −54
Original line number Diff line number Diff line
@@ -341,53 +341,6 @@ foldShiftByConstOfShiftByConst(BinaryOperator &I, const APInt *COp1,
                                  ConstantInt::get(I.getType(), AmtSum));
  }

  // This is a constant shift of a constant shift. Be careful about hiding
  // shl instructions behind bit masks. They are used to represent multiplies
  // by a constant, and it is important that simple arithmetic expressions
  // are still recognizable by scalar evolution.
  //
  // The transforms applied to shl are very similar to the transforms applied
  // to mul by constant. We can be more aggressive about optimizing right
  // shifts.
  //
  // Combinations of right and left shifts will still be optimized in
  // DAGCombine where scalar evolution no longer applies.

  Value *X = ShiftOp->getOperand(0);
  unsigned ShiftAmt1 = ShAmt1->getLimitedValue();
  unsigned ShiftAmt2 = COp1->getLimitedValue();
  assert(ShiftAmt2 != 0 && "Should have been simplified earlier");
  if (ShiftAmt1 == 0)
    return nullptr; // Will be simplified in the future.

  if (ShiftAmt1 == ShiftAmt2)
    return nullptr;

  // FIXME: Everything under here should be extended to work with vector types.

  auto *ShiftAmt1C = dyn_cast<ConstantInt>(ShiftOp->getOperand(1));
  if (!ShiftAmt1C)
    return nullptr;

  IntegerType *Ty = cast<IntegerType>(I.getType());
  if (ShiftAmt2 < ShiftAmt1) {
    uint32_t ShiftDiff = ShiftAmt1 - ShiftAmt2;

    // We can't handle (X << C1) >>s C2, it shifts arbitrary bits in. However,
    // we can handle (X <<nsw C1) >>s C2 since it only shifts in sign bits.
    if (I.getOpcode() == Instruction::AShr &&
        ShiftOp->getOpcode() == Instruction::Shl) {
      if (ShiftOp->hasNoSignedWrap()) {
        // (X <<nsw C1) >>s C2 --> X <<nsw (C1-C2)
        ConstantInt *ShiftDiffCst = ConstantInt::get(Ty, ShiftDiff);
        BinaryOperator *NewShl =
            BinaryOperator::Create(Instruction::Shl, X, ShiftDiffCst);
        NewShl->setHasNoSignedWrap(true);
        return NewShl;
      }
    }
  }

  return nullptr;
}

@@ -640,6 +593,9 @@ Instruction *InstCombiner::visitShl(BinaryOperator &I) {
      return BinaryOperator::CreateAnd(X, ConstantInt::get(Ty, Mask));
    }

    // Be careful about hiding shl instructions behind bit masks. They are used
    // to represent multiplies by a constant, and it is important that simple
    // arithmetic expressions are still recognizable by scalar evolution.
    // The inexact versions are deferred to DAGCombine, so we don't hide shl
    // behind a bit mask.
    const APInt *ShrOp1;
@@ -792,14 +748,23 @@ Instruction *InstCombiner::visitAShr(BinaryOperator &I) {
    // We can't handle (X << C1) >>s C2. It shifts arbitrary bits in. However,
    // we can handle (X <<nsw C1) >>s C2 since it only shifts in sign bits.
    const APInt *ShlAmtAPInt;
    if (match(Op0, m_NSWShl(m_Value(X), m_APInt(ShlAmtAPInt))) &&
        ShlAmtAPInt->ult(*ShAmtAPInt)) {
    if (match(Op0, m_NSWShl(m_Value(X), m_APInt(ShlAmtAPInt)))) {
      unsigned ShlAmt = ShlAmtAPInt->getZExtValue();
      if (ShlAmt < ShAmt) {
        // (X <<nsw C1) >>s C2 --> X >>s (C2 - C1)
      Constant *ShiftDiff = ConstantInt::get(Ty, *ShAmtAPInt - *ShlAmtAPInt);
        Constant *ShiftDiff = ConstantInt::get(Ty, ShAmt - ShlAmt);
        auto *NewAShr = BinaryOperator::CreateAShr(X, ShiftDiff);
        NewAShr->setIsExact(I.isExact());
        return NewAShr;
      }
      if (ShlAmt > ShAmt) {
        // (X <<nsw C1) >>s C2 --> X <<nsw (C1 - C2)
        Constant *ShiftDiff = ConstantInt::get(Ty, ShlAmt - ShAmt);
        auto *NewShl = BinaryOperator::Create(Instruction::Shl, X, ShiftDiff);
        NewShl->setHasNoSignedWrap(true);
        return NewShl;
      }
    }

    // If the shifted-out value is known-zero, then this is an exact shift.
    if (!I.isExact() &&
+1 −2
Original line number Diff line number Diff line
@@ -1003,8 +1003,7 @@ define i32 @test52(i32 %x) {

define <2 x i32> @test52_splat_vec(<2 x i32> %x) {
; CHECK-LABEL: @test52_splat_vec(
; CHECK-NEXT:    [[A:%.*]] = shl nsw <2 x i32> %x, <i32 3, i32 3>
; CHECK-NEXT:    [[B:%.*]] = ashr exact <2 x i32> [[A]], <i32 1, i32 1>
; CHECK-NEXT:    [[B:%.*]] = shl nsw <2 x i32> %x, <i32 2, i32 2>
; CHECK-NEXT:    ret <2 x i32> [[B]]
;
  %A = shl nsw <2 x i32> %x, <i32 3, i32 3>