Commit 619d7dc3 authored by Sanjay Patel's avatar Sanjay Patel
Browse files

[DAGCombiner] recognize shuffle (shuffle X, Mask0), Mask --> splat X

We get the simple cases of this via demanded elements and other folds,
but that doesn't work if the values have >1 use, so add a dedicated
match for the pattern.

We already have this transform in IR, but it doesn't help the
motivating x86 tests (based on PR42024) because the shuffles don't
exist until after legalization and other combines have happened.
The AArch64 test shows a minimal IR example of the problem.

Differential Revision: https://reviews.llvm.org/D75348
parent 624dbfcc
Loading
Loading
Loading
Loading
+54 −0
Original line number Diff line number Diff line
@@ -30,6 +30,7 @@
#include "llvm/ADT/Statistic.h"
#include "llvm/Analysis/AliasAnalysis.h"
#include "llvm/Analysis/MemoryLocation.h"
#include "llvm/Analysis/VectorUtils.h"
#include "llvm/CodeGen/DAGCombine.h"
#include "llvm/CodeGen/ISDOpcodes.h"
#include "llvm/CodeGen/MachineFrameInfo.h"
@@ -19259,6 +19260,56 @@ static SDValue combineShuffleOfSplatVal(ShuffleVectorSDNode *Shuf,
                              NewMask);
}
/// Combine shuffle of shuffle of the form:
/// shuf (shuf X, undef, InnerMask), undef, OuterMask --> splat X
static SDValue formSplatFromShuffles(ShuffleVectorSDNode *OuterShuf,
                                     SelectionDAG &DAG) {
  if (!OuterShuf->getOperand(1).isUndef())
    return SDValue();
  auto *InnerShuf = dyn_cast<ShuffleVectorSDNode>(OuterShuf->getOperand(0));
  if (!InnerShuf || !InnerShuf->getOperand(1).isUndef())
    return SDValue();
  ArrayRef<int> OuterMask = OuterShuf->getMask();
  ArrayRef<int> InnerMask = InnerShuf->getMask();
  unsigned NumElts = OuterMask.size();
  assert(NumElts == InnerMask.size() && "Mask length mismatch");
  SmallVector<int, 32> CombinedMask(NumElts, -1);
  int SplatIndex = -1;
  for (unsigned i = 0; i != NumElts; ++i) {
    // Undef lanes remain undef.
    int OuterMaskElt = OuterMask[i];
    if (OuterMaskElt == -1)
      continue;
    // Peek through the shuffle masks to get the underlying source element.
    int InnerMaskElt = InnerMask[OuterMaskElt];
    if (InnerMaskElt == -1)
      continue;
    // Initialize the splatted element.
    if (SplatIndex == -1)
      SplatIndex = InnerMaskElt;
    // Non-matching index - this is not a splat.
    if (SplatIndex != InnerMaskElt)
      return SDValue();
    CombinedMask[i] = InnerMaskElt;
  }
  assert(all_of(CombinedMask, [](int M) { return M == -1; }) ||
         getSplatIndex(CombinedMask) != -1 && "Expected a splat mask");
  // TODO: The transform may be a win even if the mask is not legal.
  EVT VT = OuterShuf->getValueType(0);
  assert(VT == InnerShuf->getValueType(0) && "Expected matching shuffle types");
  if (!DAG.getTargetLoweringInfo().isShuffleMaskLegal(CombinedMask, VT))
    return SDValue();
  return DAG.getVectorShuffle(VT, SDLoc(OuterShuf), InnerShuf->getOperand(0),
                              InnerShuf->getOperand(1), CombinedMask);
}
/// If the shuffle mask is taking exactly one element from the first vector
/// operand and passing through all other elements from the second vector
/// operand, return the index of the mask element that is choosing an element
@@ -19417,6 +19468,9 @@ SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) {
  if (SDValue V = combineShuffleOfSplatVal(SVN, DAG))
    return V;
  if (SDValue V = formSplatFromShuffles(SVN, DAG))
    return V;
  // If it is a splat, check if the argument vector is another splat or a
  // build_vector.
  if (SVN->isSplat() && SVN->getSplatIndex() < (int)NumElts) {
+0 −2
Original line number Diff line number Diff line
@@ -449,8 +449,6 @@ define void @disguised_dup(<4 x float> %x, <4 x float>* %p1, <4 x float>* %p2) {
; CHECK-NEXT:    dup.4s v1, v0[0]
; CHECK-NEXT:    ext.16b v0, v0, v0, #12
; CHECK-NEXT:    ext.16b v0, v0, v1, #8
; CHECK-NEXT:    zip2.4s v1, v0, v0
; CHECK-NEXT:    ext.16b v1, v0, v1, #12
; CHECK-NEXT:    str q0, [x0]
; CHECK-NEXT:    str q1, [x1]
; CHECK-NEXT:    ret
+4 −4
Original line number Diff line number Diff line
@@ -811,7 +811,7 @@ define i32 @test_v4i32(<4 x i32> %a0) {
; SSE2-LABEL: test_v4i32:
; SSE2:       # %bb.0:
; SSE2-NEXT:    pshufd {{.*#+}} xmm1 = xmm0[2,3,0,1]
; SSE2-NEXT:    pshufd {{.*#+}} xmm2 = xmm0[3,3,1,1]
; SSE2-NEXT:    pshufd {{.*#+}} xmm2 = xmm0[3,1,2,3]
; SSE2-NEXT:    pshufd {{.*#+}} xmm3 = xmm0[1,1,3,3]
; SSE2-NEXT:    pmuludq %xmm2, %xmm3
; SSE2-NEXT:    pmuludq %xmm0, %xmm1
@@ -858,7 +858,7 @@ define i32 @test_v8i32(<8 x i32> %a0) {
; SSE2-NEXT:    pmuludq %xmm1, %xmm0
; SSE2-NEXT:    pshufd {{.*#+}} xmm1 = xmm0[2,3,0,1]
; SSE2-NEXT:    pmuludq %xmm0, %xmm1
; SSE2-NEXT:    pshufd {{.*#+}} xmm0 = xmm3[2,2,0,0]
; SSE2-NEXT:    pshufd {{.*#+}} xmm0 = xmm3[2,0,2,2]
; SSE2-NEXT:    pmuludq %xmm3, %xmm0
; SSE2-NEXT:    pmuludq %xmm1, %xmm0
; SSE2-NEXT:    movd %xmm0, %eax
@@ -928,7 +928,7 @@ define i32 @test_v16i32(<16 x i32> %a0) {
; SSE2-NEXT:    pmuludq %xmm1, %xmm2
; SSE2-NEXT:    pshufd {{.*#+}} xmm1 = xmm0[2,3,0,1]
; SSE2-NEXT:    pmuludq %xmm0, %xmm1
; SSE2-NEXT:    pshufd {{.*#+}} xmm0 = xmm2[2,2,0,0]
; SSE2-NEXT:    pshufd {{.*#+}} xmm0 = xmm2[2,0,2,2]
; SSE2-NEXT:    pmuludq %xmm2, %xmm0
; SSE2-NEXT:    pmuludq %xmm1, %xmm0
; SSE2-NEXT:    movd %xmm0, %eax
@@ -1018,7 +1018,7 @@ define i32 @test_v32i32(<32 x i32> %a0) {
; SSE2-NEXT:    pmuludq %xmm0, %xmm1
; SSE2-NEXT:    pshufd {{.*#+}} xmm0 = xmm1[2,3,0,1]
; SSE2-NEXT:    pmuludq %xmm1, %xmm0
; SSE2-NEXT:    pshufd {{.*#+}} xmm1 = xmm11[2,2,0,0]
; SSE2-NEXT:    pshufd {{.*#+}} xmm1 = xmm11[2,0,2,2]
; SSE2-NEXT:    pmuludq %xmm11, %xmm1
; SSE2-NEXT:    pmuludq %xmm0, %xmm1
; SSE2-NEXT:    movd %xmm1, %eax
+46 −56
Original line number Diff line number Diff line
@@ -1826,45 +1826,40 @@ define void @splat4_v8i32_load_store(<8 x i32>* %s, <32 x i32>* %d) {
define void @splat4_v4f64_load_store(<4 x double>* %s, <16 x double>* %d) {
; AVX1-LABEL: splat4_v4f64_load_store:
; AVX1:       # %bb.0:
; AVX1-NEXT:    vmovupd (%rdi), %ymm0
; AVX1-NEXT:    vperm2f128 {{.*#+}} ymm1 = ymm0[0,1,0,1]
; AVX1-NEXT:    vperm2f128 {{.*#+}} ymm0 = ymm0[2,3,2,3]
; AVX1-NEXT:    vmovddup {{.*#+}} ymm2 = ymm1[0,0,2,2]
; AVX1-NEXT:    vmovddup {{.*#+}} ymm3 = ymm0[0,0,2,2]
; AVX1-NEXT:    vpermilpd {{.*#+}} ymm1 = ymm1[1,1,3,3]
; AVX1-NEXT:    vpermilpd {{.*#+}} ymm0 = ymm0[1,1,3,3]
; AVX1-NEXT:    vmovupd %ymm0, 96(%rsi)
; AVX1-NEXT:    vmovupd %ymm3, 64(%rsi)
; AVX1-NEXT:    vmovupd %ymm1, 32(%rsi)
; AVX1-NEXT:    vmovupd %ymm2, (%rsi)
; AVX1-NEXT:    vbroadcastsd (%rdi), %ymm0
; AVX1-NEXT:    vbroadcastsd 16(%rdi), %ymm1
; AVX1-NEXT:    vbroadcastsd 8(%rdi), %ymm2
; AVX1-NEXT:    vbroadcastsd 24(%rdi), %ymm3
; AVX1-NEXT:    vmovups %ymm3, 96(%rsi)
; AVX1-NEXT:    vmovups %ymm1, 64(%rsi)
; AVX1-NEXT:    vmovups %ymm2, 32(%rsi)
; AVX1-NEXT:    vmovups %ymm0, (%rsi)
; AVX1-NEXT:    vzeroupper
; AVX1-NEXT:    retq
;
; AVX2-LABEL: splat4_v4f64_load_store:
; AVX2:       # %bb.0:
; AVX2-NEXT:    vmovups (%rdi), %ymm0
; AVX2-NEXT:    vbroadcastsd (%rdi), %ymm1
; AVX2-NEXT:    vpermpd {{.*#+}} ymm2 = ymm0[2,2,2,2]
; AVX2-NEXT:    vpermpd {{.*#+}} ymm3 = ymm0[1,1,1,1]
; AVX2-NEXT:    vpermpd {{.*#+}} ymm0 = ymm0[3,3,3,3]
; AVX2-NEXT:    vmovups %ymm0, 96(%rsi)
; AVX2-NEXT:    vmovups %ymm2, 64(%rsi)
; AVX2-NEXT:    vmovups %ymm3, 32(%rsi)
; AVX2-NEXT:    vmovups %ymm1, (%rsi)
; AVX2-NEXT:    vbroadcastsd (%rdi), %ymm0
; AVX2-NEXT:    vbroadcastsd 16(%rdi), %ymm1
; AVX2-NEXT:    vbroadcastsd 8(%rdi), %ymm2
; AVX2-NEXT:    vbroadcastsd 24(%rdi), %ymm3
; AVX2-NEXT:    vmovups %ymm3, 96(%rsi)
; AVX2-NEXT:    vmovups %ymm1, 64(%rsi)
; AVX2-NEXT:    vmovups %ymm2, 32(%rsi)
; AVX2-NEXT:    vmovups %ymm0, (%rsi)
; AVX2-NEXT:    vzeroupper
; AVX2-NEXT:    retq
;
; AVX512-LABEL: splat4_v4f64_load_store:
; AVX512:       # %bb.0:
; AVX512-NEXT:    vmovups (%rdi), %ymm0
; AVX512-NEXT:    vbroadcastsd (%rdi), %ymm1
; AVX512-NEXT:    vpermpd {{.*#+}} ymm2 = ymm0[2,2,2,2]
; AVX512-NEXT:    vpermpd {{.*#+}} ymm3 = ymm0[1,1,1,1]
; AVX512-NEXT:    vpermpd {{.*#+}} ymm0 = ymm0[3,3,3,3]
; AVX512-NEXT:    vbroadcastsd (%rdi), %ymm0
; AVX512-NEXT:    vbroadcastsd 16(%rdi), %ymm1
; AVX512-NEXT:    vbroadcastsd 8(%rdi), %ymm2
; AVX512-NEXT:    vbroadcastsd 24(%rdi), %ymm3
; AVX512-NEXT:    vinsertf64x4 $1, %ymm2, %zmm0, %zmm0
; AVX512-NEXT:    vinsertf64x4 $1, %ymm3, %zmm1, %zmm1
; AVX512-NEXT:    vinsertf64x4 $1, %ymm0, %zmm2, %zmm0
; AVX512-NEXT:    vmovups %zmm0, 64(%rsi)
; AVX512-NEXT:    vmovups %zmm1, (%rsi)
; AVX512-NEXT:    vmovups %zmm1, 64(%rsi)
; AVX512-NEXT:    vmovups %zmm0, (%rsi)
; AVX512-NEXT:    vzeroupper
; AVX512-NEXT:    retq
  %x = load <4 x double>, <4 x double>* %s, align 8
@@ -1878,45 +1873,40 @@ define void @splat4_v4f64_load_store(<4 x double>* %s, <16 x double>* %d) {
define void @splat4_v4i64_load_store(<4 x i64>* %s, <16 x i64>* %d) {
; AVX1-LABEL: splat4_v4i64_load_store:
; AVX1:       # %bb.0:
; AVX1-NEXT:    vmovupd (%rdi), %ymm0
; AVX1-NEXT:    vperm2f128 {{.*#+}} ymm1 = ymm0[0,1,0,1]
; AVX1-NEXT:    vperm2f128 {{.*#+}} ymm0 = ymm0[2,3,2,3]
; AVX1-NEXT:    vmovddup {{.*#+}} ymm2 = ymm1[0,0,2,2]
; AVX1-NEXT:    vmovddup {{.*#+}} ymm3 = ymm0[0,0,2,2]
; AVX1-NEXT:    vpermilpd {{.*#+}} ymm1 = ymm1[1,1,3,3]
; AVX1-NEXT:    vpermilpd {{.*#+}} ymm0 = ymm0[1,1,3,3]
; AVX1-NEXT:    vmovupd %ymm0, 96(%rsi)
; AVX1-NEXT:    vmovupd %ymm3, 64(%rsi)
; AVX1-NEXT:    vmovupd %ymm1, 32(%rsi)
; AVX1-NEXT:    vmovupd %ymm2, (%rsi)
; AVX1-NEXT:    vbroadcastsd (%rdi), %ymm0
; AVX1-NEXT:    vbroadcastsd 16(%rdi), %ymm1
; AVX1-NEXT:    vbroadcastsd 8(%rdi), %ymm2
; AVX1-NEXT:    vbroadcastsd 24(%rdi), %ymm3
; AVX1-NEXT:    vmovups %ymm3, 96(%rsi)
; AVX1-NEXT:    vmovups %ymm1, 64(%rsi)
; AVX1-NEXT:    vmovups %ymm2, 32(%rsi)
; AVX1-NEXT:    vmovups %ymm0, (%rsi)
; AVX1-NEXT:    vzeroupper
; AVX1-NEXT:    retq
;
; AVX2-LABEL: splat4_v4i64_load_store:
; AVX2:       # %bb.0:
; AVX2-NEXT:    vmovups (%rdi), %ymm0
; AVX2-NEXT:    vbroadcastsd (%rdi), %ymm1
; AVX2-NEXT:    vpermpd {{.*#+}} ymm2 = ymm0[2,2,2,2]
; AVX2-NEXT:    vpermpd {{.*#+}} ymm3 = ymm0[1,1,1,1]
; AVX2-NEXT:    vpermpd {{.*#+}} ymm0 = ymm0[3,3,3,3]
; AVX2-NEXT:    vmovups %ymm0, 96(%rsi)
; AVX2-NEXT:    vmovups %ymm2, 64(%rsi)
; AVX2-NEXT:    vmovups %ymm3, 32(%rsi)
; AVX2-NEXT:    vmovups %ymm1, (%rsi)
; AVX2-NEXT:    vbroadcastsd (%rdi), %ymm0
; AVX2-NEXT:    vbroadcastsd 16(%rdi), %ymm1
; AVX2-NEXT:    vbroadcastsd 8(%rdi), %ymm2
; AVX2-NEXT:    vbroadcastsd 24(%rdi), %ymm3
; AVX2-NEXT:    vmovups %ymm3, 96(%rsi)
; AVX2-NEXT:    vmovups %ymm1, 64(%rsi)
; AVX2-NEXT:    vmovups %ymm2, 32(%rsi)
; AVX2-NEXT:    vmovups %ymm0, (%rsi)
; AVX2-NEXT:    vzeroupper
; AVX2-NEXT:    retq
;
; AVX512-LABEL: splat4_v4i64_load_store:
; AVX512:       # %bb.0:
; AVX512-NEXT:    vmovups (%rdi), %ymm0
; AVX512-NEXT:    vbroadcastsd (%rdi), %ymm1
; AVX512-NEXT:    vpermpd {{.*#+}} ymm2 = ymm0[2,2,2,2]
; AVX512-NEXT:    vpermpd {{.*#+}} ymm3 = ymm0[1,1,1,1]
; AVX512-NEXT:    vpermpd {{.*#+}} ymm0 = ymm0[3,3,3,3]
; AVX512-NEXT:    vbroadcastsd (%rdi), %ymm0
; AVX512-NEXT:    vbroadcastsd 16(%rdi), %ymm1
; AVX512-NEXT:    vbroadcastsd 8(%rdi), %ymm2
; AVX512-NEXT:    vbroadcastsd 24(%rdi), %ymm3
; AVX512-NEXT:    vinsertf64x4 $1, %ymm2, %zmm0, %zmm0
; AVX512-NEXT:    vinsertf64x4 $1, %ymm3, %zmm1, %zmm1
; AVX512-NEXT:    vinsertf64x4 $1, %ymm0, %zmm2, %zmm0
; AVX512-NEXT:    vmovups %zmm0, 64(%rsi)
; AVX512-NEXT:    vmovups %zmm1, (%rsi)
; AVX512-NEXT:    vmovups %zmm1, 64(%rsi)
; AVX512-NEXT:    vmovups %zmm0, (%rsi)
; AVX512-NEXT:    vzeroupper
; AVX512-NEXT:    retq
  %x = load <4 x i64>, <4 x i64>* %s, align 8