Commit 599d0709 authored by Craig Topper's avatar Craig Topper
Browse files

[X86] Remove dyn_casts to ConstantSDNode for operand 1 of...

[X86] Remove dyn_casts to ConstantSDNode for operand 1 of X86ISD::VSRLI/VSRAI/VSRLI. Use getConstantOperandVal and APInt operations.

These nodes should only ever be formed with an i8 TargetConstant
so we don't need to check for it to be a constant. It's also
always 8-bits so we don't need to use APInt compare functions.
parent 5edb40c0
Loading
Loading
Loading
Loading
+99 −108
Original line number Diff line number Diff line
@@ -32369,14 +32369,13 @@ void X86TargetLowering::computeKnownBitsForTargetNode(const SDValue Op,
  case X86ISD::VSRAI:
  case X86ISD::VSHLI:
  case X86ISD::VSRLI: {
    if (auto *ShiftImm = dyn_cast<ConstantSDNode>(Op.getOperand(1))) {
      if (ShiftImm->getAPIntValue().uge(VT.getScalarSizeInBits())) {
    unsigned ShAmt = Op.getConstantOperandVal(1);
    if (ShAmt >= VT.getScalarSizeInBits()) {
      Known.setAllZero();
      break;
    }
    Known = DAG.computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
      unsigned ShAmt = ShiftImm->getZExtValue();
    if (Opc == X86ISD::VSHLI) {
      Known.Zero <<= ShAmt;
      Known.One <<= ShAmt;
@@ -32391,7 +32390,6 @@ void X86TargetLowering::computeKnownBitsForTargetNode(const SDValue Op,
      Known.Zero.ashrInPlace(ShAmt);
      Known.One.ashrInPlace(ShAmt);
    }
    }
    break;
  }
  case X86ISD::PACKUS: {
@@ -35656,13 +35654,11 @@ bool X86TargetLowering::SimplifyDemandedBitsForTargetNode(
  }
  case X86ISD::VSHLI: {
    SDValue Op0 = Op.getOperand(0);
    SDValue Op1 = Op.getOperand(1);
    if (auto *ShiftImm = dyn_cast<ConstantSDNode>(Op1)) {
      if (ShiftImm->getAPIntValue().uge(BitWidth))
    unsigned ShAmt = Op.getConstantOperandVal(1);
    if (ShAmt >= BitWidth)
      break;
      unsigned ShAmt = ShiftImm->getZExtValue();
    APInt DemandedMask = OriginalDemandedBits.lshr(ShAmt);
    // If this is ((X >>u C1) << ShAmt), see if we can simplify this into a
@@ -35670,9 +35666,9 @@ bool X86TargetLowering::SimplifyDemandedBitsForTargetNode(
    // out) are never demanded.
    if (Op0.getOpcode() == X86ISD::VSRLI &&
        OriginalDemandedBits.countTrailingZeros() >= ShAmt) {
        if (auto *Shift2Imm = dyn_cast<ConstantSDNode>(Op0.getOperand(1))) {
          if (Shift2Imm->getAPIntValue().ult(BitWidth)) {
            int Diff = ShAmt - Shift2Imm->getZExtValue();
      unsigned Shift2Amt = Op0.getConstantOperandVal(1);
      if (Shift2Amt < BitWidth) {
        int Diff = ShAmt - Shift2Amt;
        if (Diff == 0)
          return TLO.CombineTo(Op, Op0.getOperand(0));
@@ -35683,7 +35679,6 @@ bool X86TargetLowering::SimplifyDemandedBitsForTargetNode(
        return TLO.CombineTo(Op, NewShift);
      }
    }
      }
    if (SimplifyDemandedBits(Op0, DemandedMask, OriginalDemandedElts, Known,
                             TLO, Depth + 1))
@@ -35695,15 +35690,13 @@ bool X86TargetLowering::SimplifyDemandedBitsForTargetNode(
    // Low bits known zero.
    Known.Zero.setLowBits(ShAmt);
    }
    break;
  }
  case X86ISD::VSRLI: {
    if (auto *ShiftImm = dyn_cast<ConstantSDNode>(Op.getOperand(1))) {
      if (ShiftImm->getAPIntValue().uge(BitWidth))
    unsigned ShAmt = Op.getConstantOperandVal(1);
    if (ShAmt >= BitWidth)
      break;
      unsigned ShAmt = ShiftImm->getZExtValue();
    APInt DemandedMask = OriginalDemandedBits << ShAmt;
    if (SimplifyDemandedBits(Op.getOperand(0), DemandedMask,
@@ -35716,18 +35709,16 @@ bool X86TargetLowering::SimplifyDemandedBitsForTargetNode(
    // High bits known zero.
    Known.Zero.setHighBits(ShAmt);
    }
    break;
  }
  case X86ISD::VSRAI: {
    SDValue Op0 = Op.getOperand(0);
    SDValue Op1 = Op.getOperand(1);
    if (auto *ShiftImm = dyn_cast<ConstantSDNode>(Op1)) {
      if (ShiftImm->getAPIntValue().uge(BitWidth))
    unsigned ShAmt = cast<ConstantSDNode>(Op1)->getZExtValue();
    if (ShAmt >= BitWidth)
      break;
      unsigned ShAmt = ShiftImm->getZExtValue();
    APInt DemandedMask = OriginalDemandedBits << ShAmt;
    // If we just want the sign bit then we don't need to shift it.
@@ -35735,7 +35726,8 @@ bool X86TargetLowering::SimplifyDemandedBitsForTargetNode(
      return TLO.CombineTo(Op, Op0);
    // fold (VSRAI (VSHLI X, C1), C1) --> X iff NumSignBits(X) > C1
      if (Op0.getOpcode() == X86ISD::VSHLI && Op1 == Op0.getOperand(1)) {
    if (Op0.getOpcode() == X86ISD::VSHLI &&
        Op.getOperand(1) == Op0.getOperand(1)) {
      SDValue Op00 = Op0.getOperand(0);
      unsigned NumSignBits =
          TLO.DAG.ComputeNumSignBits(Op00, OriginalDemandedElts);
@@ -35766,7 +35758,6 @@ bool X86TargetLowering::SimplifyDemandedBitsForTargetNode(
    // High bits are known one.
    if (Known.One[BitWidth - ShAmt - 1])
      Known.One.setHighBits(ShAmt);
    }
    break;
  }
  case X86ISD::PEXTRB:
@@ -39347,15 +39338,15 @@ static SDValue combineVectorShiftImm(SDNode *N, SelectionDAG &DAG,
  bool LogicalShift = X86ISD::VSHLI == Opcode || X86ISD::VSRLI == Opcode;
  EVT VT = N->getValueType(0);
  SDValue N0 = N->getOperand(0);
  SDValue N1 = N->getOperand(1);
  unsigned NumBitsPerElt = VT.getScalarSizeInBits();
  assert(VT == N0.getValueType() && (NumBitsPerElt % 8) == 0 &&
         "Unexpected value type");
  assert(N1.getValueType() == MVT::i8 && "Unexpected shift amount type");
  assert(N->getOperand(1).getValueType() == MVT::i8 &&
         "Unexpected shift amount type");
  // Out of range logical bit shifts are guaranteed to be zero.
  // Out of range arithmetic bit shifts splat the sign bit.
  unsigned ShiftVal = cast<ConstantSDNode>(N1)->getZExtValue();
  unsigned ShiftVal = N->getConstantOperandVal(1);
  if (ShiftVal >= NumBitsPerElt) {
    if (LogicalShift)
      return DAG.getConstant(0, SDLoc(N), VT);