Commit 4a8d6b3b authored by Simon Pilgrim's avatar Simon Pilgrim
Browse files

[X86][SSE] Use the general SMAX/SMIN/UMAX/UMIN pattern matching and remove the X86 implementation

Follow up to D10947 - D9746 added general SMAX/SMIN/UMAX/UMIN pattern matching to SelectionDAGBuilder::visitSelect.

This patch removes the X86 implementation and improves the AVX1/AVX2 support to correctly lower 256-bit integer vectors.

Differential Revision: http://reviews.llvm.org/D12006

llvm-svn: 244949
parent 361231ec
Loading
Loading
Loading
Loading
+24 −116
Original line number Diff line number Diff line
@@ -1198,6 +1198,19 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
      setOperationAction(ISD::MUL,             MVT::v8i32, Custom);
      setOperationAction(ISD::MUL,             MVT::v16i16, Custom);
      setOperationAction(ISD::MUL,             MVT::v32i8, Custom);
      setOperationAction(ISD::SMAX,            MVT::v32i8,  Custom);
      setOperationAction(ISD::SMAX,            MVT::v16i16, Custom);
      setOperationAction(ISD::SMAX,            MVT::v8i32,  Custom);
      setOperationAction(ISD::UMAX,            MVT::v32i8,  Custom);
      setOperationAction(ISD::UMAX,            MVT::v16i16, Custom);
      setOperationAction(ISD::UMAX,            MVT::v8i32,  Custom);
      setOperationAction(ISD::SMIN,            MVT::v32i8,  Custom);
      setOperationAction(ISD::SMIN,            MVT::v16i16, Custom);
      setOperationAction(ISD::SMIN,            MVT::v8i32,  Custom);
      setOperationAction(ISD::UMIN,            MVT::v32i8,  Custom);
      setOperationAction(ISD::UMIN,            MVT::v16i16, Custom);
      setOperationAction(ISD::UMIN,            MVT::v8i32,  Custom);
    }
    // In the customized shift lowering, the legal cases in AVX2 will be
@@ -16881,6 +16894,13 @@ static SDValue LowerSUB(SDValue Op, SelectionDAG &DAG) {
  return Lower256IntArith(Op, DAG);
}
static SDValue LowerMINMAX(SDValue Op, SelectionDAG &DAG) {
  assert(Op.getSimpleValueType().is256BitVector() &&
         Op.getSimpleValueType().isInteger() &&
         "Only handle AVX 256-bit vector integer operation");
  return Lower256IntArith(Op, DAG);
}
static SDValue LowerMUL(SDValue Op, const X86Subtarget *Subtarget,
                        SelectionDAG &DAG) {
  SDLoc dl(Op);
@@ -18772,6 +18792,10 @@ SDValue X86TargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
  case ISD::SUBE:               return LowerADDC_ADDE_SUBC_SUBE(Op, DAG);
  case ISD::ADD:                return LowerADD(Op, DAG);
  case ISD::SUB:                return LowerSUB(Op, DAG);
  case ISD::SMAX:
  case ISD::SMIN:
  case ISD::UMAX:
  case ISD::UMIN:               return LowerMINMAX(Op, DAG);
  case ISD::FSINCOS:            return LowerFSINCOS(Op, Subtarget, DAG);
  case ISD::MGATHER:            return LowerMGATHER(Op, Subtarget, DAG);
  case ISD::MSCATTER:           return LowerMSCATTER(Op, Subtarget, DAG);
@@ -22365,96 +22389,6 @@ static SDValue PerformEXTRACT_VECTOR_ELTCombine(SDNode *N, SelectionDAG &DAG,
  return SDValue();
}
/// \brief Matches a VSELECT onto min/max or return 0 if the node doesn't match.
static std::pair<unsigned, bool>
matchIntegerMINMAX(SDValue Cond, EVT VT, SDValue LHS, SDValue RHS,
                   SelectionDAG &DAG, const X86Subtarget *Subtarget) {
  if (!VT.isVector())
    return std::make_pair(0, false);
  bool NeedSplit = false;
  switch (VT.getSimpleVT().SimpleTy) {
  default: return std::make_pair(0, false);
  case MVT::v4i64:
  case MVT::v2i64:
    if (!Subtarget->hasVLX())
      return std::make_pair(0, false);
    break;
  case MVT::v64i8:
  case MVT::v32i16:
    if (!Subtarget->hasBWI())
      return std::make_pair(0, false);
    break;
  case MVT::v16i32:
  case MVT::v8i64:
    if (!Subtarget->hasAVX512())
      return std::make_pair(0, false);
    break;
  case MVT::v32i8:
  case MVT::v16i16:
  case MVT::v8i32:
    if (!Subtarget->hasAVX2())
      NeedSplit = true;
    if (!Subtarget->hasAVX())
      return std::make_pair(0, false);
    break;
  case MVT::v16i8:
  case MVT::v8i16:
  case MVT::v4i32:
    if (!Subtarget->hasSSE2())
      return std::make_pair(0, false);
  }
  // SSE2 has only a small subset of the operations.
  bool hasUnsigned = Subtarget->hasSSE41() ||
                     (Subtarget->hasSSE2() && VT == MVT::v16i8);
  bool hasSigned = Subtarget->hasSSE41() ||
                   (Subtarget->hasSSE2() && VT == MVT::v8i16);
  ISD::CondCode CC = cast<CondCodeSDNode>(Cond.getOperand(2))->get();
  unsigned Opc = 0;
  // Check for x CC y ? x : y.
  if (DAG.isEqualTo(LHS, Cond.getOperand(0)) &&
      DAG.isEqualTo(RHS, Cond.getOperand(1))) {
    switch (CC) {
    default: break;
    case ISD::SETULT:
    case ISD::SETULE:
      Opc = hasUnsigned ? ISD::UMIN : 0; break;
    case ISD::SETUGT:
    case ISD::SETUGE:
      Opc = hasUnsigned ? ISD::UMAX : 0; break;
    case ISD::SETLT:
    case ISD::SETLE:
      Opc = hasSigned ? ISD::SMIN : 0; break;
    case ISD::SETGT:
    case ISD::SETGE:
      Opc = hasSigned ? ISD::SMAX : 0; break;
    }
  // Check for x CC y ? y : x -- a min/max with reversed arms.
  } else if (DAG.isEqualTo(LHS, Cond.getOperand(1)) &&
             DAG.isEqualTo(RHS, Cond.getOperand(0))) {
    switch (CC) {
    default: break;
    case ISD::SETULT:
    case ISD::SETULE:
      Opc = hasUnsigned ? ISD::UMAX : 0; break;
    case ISD::SETUGT:
    case ISD::SETUGE:
      Opc = hasUnsigned ? ISD::UMIN : 0; break;
    case ISD::SETLT:
    case ISD::SETLE:
      Opc = hasSigned ? ISD::SMAX : 0; break;
    case ISD::SETGT:
    case ISD::SETGE:
      Opc = hasSigned ? ISD::SMIN : 0; break;
    }
  }
  return std::make_pair(Opc, NeedSplit);
}
static SDValue
transformVSELECTtoBlendVECTOR_SHUFFLE(SDNode *N, SelectionDAG &DAG,
                                      const X86Subtarget *Subtarget) {
@@ -22864,32 +22798,6 @@ static SDValue PerformSELECTCombine(SDNode *N, SelectionDAG &DAG,
    }
  }
  // Try to match a min/max vector operation.
  if (N->getOpcode() == ISD::VSELECT && Cond.getOpcode() == ISD::SETCC) {
    std::pair<unsigned, bool> ret = matchIntegerMINMAX(Cond, VT, LHS, RHS, DAG, Subtarget);
    unsigned Opc = ret.first;
    bool NeedSplit = ret.second;
    if (Opc && NeedSplit) {
      unsigned NumElems = VT.getVectorNumElements();
      // Extract the LHS vectors
      SDValue LHS1 = Extract128BitVector(LHS, 0, DAG, DL);
      SDValue LHS2 = Extract128BitVector(LHS, NumElems/2, DAG, DL);
      // Extract the RHS vectors
      SDValue RHS1 = Extract128BitVector(RHS, 0, DAG, DL);
      SDValue RHS2 = Extract128BitVector(RHS, NumElems/2, DAG, DL);
      // Create min/max for each subvector
      LHS = DAG.getNode(Opc, DL, LHS1.getValueType(), LHS1, RHS1);
      RHS = DAG.getNode(Opc, DL, LHS2.getValueType(), LHS2, RHS2);
      // Merge the result
      return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, LHS, RHS);
    } else if (Opc)
      return DAG.getNode(Opc, DL, VT, LHS, RHS);
  }
  // Simplify vector selection if condition value type matches vselect
  // operand type
  if (N->getOpcode() == ISD::VSELECT && CondVT == VT) {