Commit ebb181cf authored by Simon Pilgrim's avatar Simon Pilgrim
Browse files

[X86] matchScalarReduction - add support for partial reductions

Add optional support for opt-in partial reduction cases by providing an optional partial mask to indicate which elements have been extracted for the scalar reduction.
parent 2e773626
Loading
Loading
Loading
Loading
+32 −16
Original line number Diff line number Diff line
@@ -20964,9 +20964,12 @@ static SDValue getSETCC(X86::CondCode Cond, SDValue EFLAGS, const SDLoc &dl,
}
/// Helper for matching OR(EXTRACTELT(X,0),OR(EXTRACTELT(X,1),...))
/// style scalarized (associative) reduction patterns.
/// style scalarized (associative) reduction patterns. Partial reductions
/// are supported when the pointer SrcMask is non-null.
/// TODO - move this to SelectionDAG?
static bool matchScalarReduction(SDValue Op, ISD::NodeType BinOp,
                                 SmallVectorImpl<SDValue> &SrcOps) {
                                 SmallVectorImpl<SDValue> &SrcOps,
                                 SmallVectorImpl<APInt> *SrcMask = nullptr) {
  SmallVector<SDValue, 8> Opnds;
  DenseMap<SDValue, APInt> SrcOpMap;
  EVT VT = MVT::Other;
@@ -21018,6 +21021,11 @@ static bool matchScalarReduction(SDValue Op, ISD::NodeType BinOp,
    M->second.setBit(CIdx);
  }
  if (SrcMask) {
    // Collect the source partial masks.
    for (SDValue &SrcOp : SrcOps)
      SrcMask->push_back(SrcOpMap[SrcOp]);
  } else {
    // Quit if not all elements are used.
    for (DenseMap<SDValue, APInt>::const_iterator I = SrcOpMap.begin(),
                                                  E = SrcOpMap.end();
@@ -21025,6 +21033,7 @@ static bool matchScalarReduction(SDValue Op, ISD::NodeType BinOp,
      if (!I->second.isAllOnesValue())
        return false;
    }
  }
  return true;
}
@@ -41210,7 +41219,8 @@ static SDValue combineAnd(SDNode *N, SelectionDAG &DAG,
  // TODO: Support multiple SrcOps.
  if (VT == MVT::i1) {
    SmallVector<SDValue, 2> SrcOps;
    if (matchScalarReduction(SDValue(N, 0), ISD::AND, SrcOps) &&
    SmallVector<APInt, 2> SrcPartials;
    if (matchScalarReduction(SDValue(N, 0), ISD::AND, SrcOps, &SrcPartials) &&
        SrcOps.size() == 1) {
      SDLoc dl(N);
      const TargetLowering &TLI = DAG.getTargetLoweringInfo();
@@ -41220,9 +41230,11 @@ static SDValue combineAnd(SDNode *N, SelectionDAG &DAG,
      if (!Mask && TLI.isTypeLegal(SrcOps[0].getValueType()))
        Mask = DAG.getBitcast(MaskVT, SrcOps[0]);
      if (Mask) {
        APInt AllBits = APInt::getAllOnesValue(NumElts);
        return DAG.getSetCC(dl, MVT::i1, Mask,
                            DAG.getConstant(AllBits, dl, MaskVT), ISD::SETEQ);
        assert(SrcPartials[0].getBitWidth() == NumElts &&
               "Unexpected partial reduction mask");
        SDValue PartialBits = DAG.getConstant(SrcPartials[0], dl, MaskVT);
        Mask = DAG.getNode(ISD::AND, dl, MaskVT, Mask, PartialBits);
        return DAG.getSetCC(dl, MVT::i1, Mask, PartialBits, ISD::SETEQ);
      }
    }
  }
@@ -41685,7 +41697,8 @@ static SDValue combineOr(SDNode *N, SelectionDAG &DAG,
  // TODO: Support multiple SrcOps.
  if (VT == MVT::i1) {
    SmallVector<SDValue, 2> SrcOps;
    if (matchScalarReduction(SDValue(N, 0), ISD::OR, SrcOps) &&
    SmallVector<APInt, 2> SrcPartials;
    if (matchScalarReduction(SDValue(N, 0), ISD::OR, SrcOps, &SrcPartials) &&
        SrcOps.size() == 1) {
      SDLoc dl(N);
      const TargetLowering &TLI = DAG.getTargetLoweringInfo();
@@ -41695,9 +41708,12 @@ static SDValue combineOr(SDNode *N, SelectionDAG &DAG,
      if (!Mask && TLI.isTypeLegal(SrcOps[0].getValueType()))
        Mask = DAG.getBitcast(MaskVT, SrcOps[0]);
      if (Mask) {
        APInt AllBits = APInt::getNullValue(NumElts);
        return DAG.getSetCC(dl, MVT::i1, Mask,
                            DAG.getConstant(AllBits, dl, MaskVT), ISD::SETNE);
        assert(SrcPartials[0].getBitWidth() == NumElts &&
               "Unexpected partial reduction mask");
        SDValue ZeroBits = DAG.getConstant(0, dl, MaskVT);
        SDValue PartialBits = DAG.getConstant(SrcPartials[0], dl, MaskVT);
        Mask = DAG.getNode(ISD::AND, dl, MaskVT, Mask, PartialBits);
        return DAG.getSetCC(dl, MVT::i1, Mask, ZeroBits, ISD::SETNE);
      }
    }
  }
+22 −63
Original line number Diff line number Diff line
@@ -4225,40 +4225,25 @@ define i1 @movmsk_v16i8(<16 x i8> %x, <16 x i8> %y) {
  ret i1 %u2
}

; TODO: Replace shift+mask chain with NOT+TEST+SETE
define i1 @movmsk_v8i16(<8 x i16> %x, <8 x i16> %y) {
; SSE2-LABEL: movmsk_v8i16:
; SSE2:       # %bb.0:
; SSE2-NEXT:    pcmpgtw %xmm1, %xmm0
; SSE2-NEXT:    packsswb %xmm0, %xmm0
; SSE2-NEXT:    pmovmskb %xmm0, %ecx
; SSE2-NEXT:    movl %ecx, %eax
; SSE2-NEXT:    shrb $7, %al
; SSE2-NEXT:    movl %ecx, %edx
; SSE2-NEXT:    andb $16, %dl
; SSE2-NEXT:    shrb $4, %dl
; SSE2-NEXT:    andb %al, %dl
; SSE2-NEXT:    movl %ecx, %eax
; SSE2-NEXT:    shrb %al
; SSE2-NEXT:    andb %dl, %al
; SSE2-NEXT:    andb %cl, %al
; SSE2-NEXT:    pmovmskb %xmm0, %eax
; SSE2-NEXT:    andb $-109, %al
; SSE2-NEXT:    cmpb $-109, %al
; SSE2-NEXT:    sete %al
; SSE2-NEXT:    retq
;
; AVX-LABEL: movmsk_v8i16:
; AVX:       # %bb.0:
; AVX-NEXT:    vpcmpgtw %xmm1, %xmm0, %xmm0
; AVX-NEXT:    vpacksswb %xmm0, %xmm0, %xmm0
; AVX-NEXT:    vpmovmskb %xmm0, %ecx
; AVX-NEXT:    movl %ecx, %eax
; AVX-NEXT:    shrb $7, %al
; AVX-NEXT:    movl %ecx, %edx
; AVX-NEXT:    andb $16, %dl
; AVX-NEXT:    shrb $4, %dl
; AVX-NEXT:    andb %al, %dl
; AVX-NEXT:    movl %ecx, %eax
; AVX-NEXT:    shrb %al
; AVX-NEXT:    andb %dl, %al
; AVX-NEXT:    andb %cl, %al
; AVX-NEXT:    vpmovmskb %xmm0, %eax
; AVX-NEXT:    andb $-109, %al
; AVX-NEXT:    cmpb $-109, %al
; AVX-NEXT:    sete %al
; AVX-NEXT:    retq
;
; KNL-LABEL: movmsk_v8i16:
@@ -4266,34 +4251,20 @@ define i1 @movmsk_v8i16(<8 x i16> %x, <8 x i16> %y) {
; KNL-NEXT:    vpcmpgtw %xmm1, %xmm0, %xmm0
; KNL-NEXT:    vpmovsxwq %xmm0, %zmm0
; KNL-NEXT:    vptestmq %zmm0, %zmm0, %k0
; KNL-NEXT:    kshiftrw $4, %k0, %k1
; KNL-NEXT:    kmovw %k1, %ecx
; KNL-NEXT:    kshiftrw $7, %k0, %k1
; KNL-NEXT:    kmovw %k1, %eax
; KNL-NEXT:    kshiftrw $1, %k0, %k1
; KNL-NEXT:    kmovw %k1, %edx
; KNL-NEXT:    kmovw %k0, %esi
; KNL-NEXT:    andb %cl, %al
; KNL-NEXT:    andb %dl, %al
; KNL-NEXT:    andb %sil, %al
; KNL-NEXT:    # kill: def $al killed $al killed $eax
; KNL-NEXT:    kmovw %k0, %eax
; KNL-NEXT:    andb $-109, %al
; KNL-NEXT:    cmpb $-109, %al
; KNL-NEXT:    sete %al
; KNL-NEXT:    vzeroupper
; KNL-NEXT:    retq
;
; SKX-LABEL: movmsk_v8i16:
; SKX:       # %bb.0:
; SKX-NEXT:    vpcmpgtw %xmm1, %xmm0, %k0
; SKX-NEXT:    kshiftrb $4, %k0, %k1
; SKX-NEXT:    kmovd %k1, %ecx
; SKX-NEXT:    kshiftrb $7, %k0, %k1
; SKX-NEXT:    kmovd %k1, %eax
; SKX-NEXT:    kshiftrb $1, %k0, %k1
; SKX-NEXT:    kmovd %k1, %edx
; SKX-NEXT:    kmovd %k0, %esi
; SKX-NEXT:    andb %cl, %al
; SKX-NEXT:    andb %dl, %al
; SKX-NEXT:    andb %sil, %al
; SKX-NEXT:    # kill: def $al killed $al killed $eax
; SKX-NEXT:    kmovd %k0, %eax
; SKX-NEXT:    andb $-109, %al
; SKX-NEXT:    cmpb $-109, %al
; SKX-NEXT:    sete %al
; SKX-NEXT:    retq
  %cmp = icmp sgt <8 x i16> %x, %y
  %e1 = extractelement <8 x i1> %cmp, i32 0
@@ -4478,30 +4449,18 @@ define i1 @movmsk_v4f32(<4 x float> %x, <4 x float> %y) {
; KNL-NEXT:    # kill: def $xmm1 killed $xmm1 def $zmm1
; KNL-NEXT:    # kill: def $xmm0 killed $xmm0 def $zmm0
; KNL-NEXT:    vcmpeq_uqps %zmm1, %zmm0, %k0
; KNL-NEXT:    kshiftrw $3, %k0, %k1
; KNL-NEXT:    kmovw %k1, %ecx
; KNL-NEXT:    kshiftrw $2, %k0, %k1
; KNL-NEXT:    kmovw %k1, %eax
; KNL-NEXT:    kshiftrw $1, %k0, %k0
; KNL-NEXT:    kmovw %k0, %edx
; KNL-NEXT:    orb %cl, %al
; KNL-NEXT:    orb %dl, %al
; KNL-NEXT:    # kill: def $al killed $al killed $eax
; KNL-NEXT:    kmovw %k0, %eax
; KNL-NEXT:    testb $14, %al
; KNL-NEXT:    setne %al
; KNL-NEXT:    vzeroupper
; KNL-NEXT:    retq
;
; SKX-LABEL: movmsk_v4f32:
; SKX:       # %bb.0:
; SKX-NEXT:    vcmpeq_uqps %xmm1, %xmm0, %k0
; SKX-NEXT:    kshiftrb $3, %k0, %k1
; SKX-NEXT:    kmovd %k1, %ecx
; SKX-NEXT:    kshiftrb $2, %k0, %k1
; SKX-NEXT:    kmovd %k1, %eax
; SKX-NEXT:    kshiftrb $1, %k0, %k0
; SKX-NEXT:    kmovd %k0, %edx
; SKX-NEXT:    orb %cl, %al
; SKX-NEXT:    orb %dl, %al
; SKX-NEXT:    # kill: def $al killed $al killed $eax
; SKX-NEXT:    kmovd %k0, %eax
; SKX-NEXT:    testb $14, %al
; SKX-NEXT:    setne %al
; SKX-NEXT:    retq
  %cmp = fcmp ueq <4 x float> %x, %y
  %e1 = extractelement <4 x i1> %cmp, i32 1