Commit db231ebd authored by Sanjay Patel's avatar Sanjay Patel
Browse files

[InstCombine] fold fake vector extract to shift+trunc

We already handle more complicated cases like:
extelt (bitcast (inselt poison, X, 0)) --> trunc (lshr X)

But we missed this simpler pattern:
https://alive2.llvm.org/ce/z/D55h64 / https://alive2.llvm.org/ce/z/GKzzRq

This is part of solving:
https://llvm.org/PR52057

I made the transform depend on legal/desirable int type to avoid creating
a shift of an illegal type (for example i128). I'm not sure if that
restriction is actually necessary, but we can change that as a follow-up
if the backend can deal with integer ops on too-wide illegal types.

The pile of AVX512 test changes are all neutral AFAICT - the x86 backend
seems to know how to turn that into the expected "kmov" instructions.

Differential Revision: https://reviews.llvm.org/D111082
parent 02e690ba
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -317,6 +317,7 @@ private:

  Value *EmitGEPOffset(User *GEP);
  Instruction *scalarizePHI(ExtractElementInst &EI, PHINode *PN);
  Instruction *foldBitcastExtElt(ExtractElementInst &ExtElt);
  Instruction *foldCastedBitwiseLogic(BinaryOperator &I);
  Instruction *narrowBinOp(TruncInst &Trunc);
  Instruction *narrowMaskedBinOp(BinaryOperator &And);
+30 −8
Original line number Diff line number Diff line
@@ -178,24 +178,46 @@ Instruction *InstCombinerImpl::scalarizePHI(ExtractElementInst &EI,
  return &EI;
}

static Instruction *foldBitcastExtElt(ExtractElementInst &Ext,
                                      InstCombiner::BuilderTy &Builder,
                                      bool IsBigEndian) {
Instruction *InstCombinerImpl::foldBitcastExtElt(ExtractElementInst &Ext) {
  Value *X;
  uint64_t ExtIndexC;
  if (!match(Ext.getVectorOperand(), m_BitCast(m_Value(X))) ||
      !X->getType()->isVectorTy() ||
      !match(Ext.getIndexOperand(), m_ConstantInt(ExtIndexC)))
    return nullptr;

  ElementCount NumElts =
      cast<VectorType>(Ext.getVectorOperandType())->getElementCount();
  Type *DestTy = Ext.getType();
  bool IsBigEndian = DL.isBigEndian();

  // If we are casting an integer to vector and extracting a portion, that is
  // a shift-right and truncate.
  // TODO: If no shift is needed, allow extra use?
  // TODO: Allow FP dest type by casting the trunc to FP?
  if (X->getType()->isIntegerTy() && DestTy->isIntegerTy() &&
      isDesirableIntType(X->getType()->getPrimitiveSizeInBits()) &&
      Ext.getVectorOperand()->hasOneUse()) {
    assert(isa<FixedVectorType>(Ext.getVectorOperand()->getType()) &&
           "Expected fixed vector type for bitcast from scalar integer");

    // Big endian requires adjusting the extract index since MSB is at index 0.
    // LittleEndian: extelt (bitcast i32 X to v4i8), 0 -> trunc i32 X to i8
    // BigEndian: extelt (bitcast i32 X to v4i8), 0 -> trunc i32 (X >> 24) to i8
    if (IsBigEndian)
      ExtIndexC = NumElts.getKnownMinValue() - 1 - ExtIndexC;
    unsigned ShiftAmountC = ExtIndexC * DestTy->getPrimitiveSizeInBits();
    Value *Lshr = Builder.CreateLShr(X, ShiftAmountC, "extelt.offset");
    return new TruncInst(Lshr, DestTy);
  }

  if (!X->getType()->isVectorTy())
    return nullptr;

  // If this extractelement is using a bitcast from a vector of the same number
  // of elements, see if we can find the source element from the source vector:
  // extelt (bitcast VecX), IndexC --> bitcast X[IndexC]
  auto *SrcTy = cast<VectorType>(X->getType());
  Type *DestTy = Ext.getType();
  ElementCount NumSrcElts = SrcTy->getElementCount();
  ElementCount NumElts =
      cast<VectorType>(Ext.getVectorOperandType())->getElementCount();
  if (NumSrcElts == NumElts)
    if (Value *Elt = findScalarElement(X, ExtIndexC))
      return new BitCastInst(Elt, DestTy);
@@ -410,7 +432,7 @@ Instruction *InstCombinerImpl::visitExtractElementInst(ExtractElementInst &EI) {
      }
    }

    if (Instruction *I = foldBitcastExtElt(EI, Builder, DL.isBigEndian()))
    if (Instruction *I = foldBitcastExtElt(EI))
      return I;

    // If there's a vector PHI feeding a scalar use through this extractelement
+138 −138

File changed.

Preview size limit exceeded, changes collapsed.

+138 −138

File changed.

Preview size limit exceeded, changes collapsed.

+40 −12
Original line number Diff line number Diff line
@@ -330,11 +330,17 @@ define <4 x double> @invalid_extractelement(<2 x double> %a, <4 x double> %b, do
  ret <4 x double> %r
}

; i32 is a desirable/supported type independent of data layout.

define i8 @bitcast_scalar_supported_type_index0(i32 %x) {
; ANY-LABEL: @bitcast_scalar_supported_type_index0(
; ANY-NEXT:    [[V:%.*]] = bitcast i32 [[X:%.*]] to <4 x i8>
; ANY-NEXT:    [[R:%.*]] = extractelement <4 x i8> [[V]], i8 0
; ANY-NEXT:    ret i8 [[R]]
; LE-LABEL: @bitcast_scalar_supported_type_index0(
; LE-NEXT:    [[R:%.*]] = trunc i32 [[X:%.*]] to i8
; LE-NEXT:    ret i8 [[R]]
;
; BE-LABEL: @bitcast_scalar_supported_type_index0(
; BE-NEXT:    [[EXTELT_OFFSET:%.*]] = lshr i32 [[X:%.*]], 24
; BE-NEXT:    [[R:%.*]] = trunc i32 [[EXTELT_OFFSET]] to i8
; BE-NEXT:    ret i8 [[R]]
;
  %v = bitcast i32 %x to <4 x i8>
  %r = extractelement <4 x i8> %v, i8 0
@@ -342,27 +348,41 @@ define i8 @bitcast_scalar_supported_type_index0(i32 %x) {
}

define i8 @bitcast_scalar_supported_type_index2(i32 %x) {
; ANY-LABEL: @bitcast_scalar_supported_type_index2(
; ANY-NEXT:    [[V:%.*]] = bitcast i32 [[X:%.*]] to <4 x i8>
; ANY-NEXT:    [[R:%.*]] = extractelement <4 x i8> [[V]], i64 2
; ANY-NEXT:    ret i8 [[R]]
; LE-LABEL: @bitcast_scalar_supported_type_index2(
; LE-NEXT:    [[EXTELT_OFFSET:%.*]] = lshr i32 [[X:%.*]], 16
; LE-NEXT:    [[R:%.*]] = trunc i32 [[EXTELT_OFFSET]] to i8
; LE-NEXT:    ret i8 [[R]]
;
; BE-LABEL: @bitcast_scalar_supported_type_index2(
; BE-NEXT:    [[EXTELT_OFFSET:%.*]] = lshr i32 [[X:%.*]], 8
; BE-NEXT:    [[R:%.*]] = trunc i32 [[EXTELT_OFFSET]] to i8
; BE-NEXT:    ret i8 [[R]]
;
  %v = bitcast i32 %x to <4 x i8>
  %r = extractelement <4 x i8> %v, i64 2
  ret i8 %r
}

; i64 is legal based on data layout.

define i4 @bitcast_scalar_legal_type_index3(i64 %x) {
; ANY-LABEL: @bitcast_scalar_legal_type_index3(
; ANY-NEXT:    [[V:%.*]] = bitcast i64 [[X:%.*]] to <16 x i4>
; ANY-NEXT:    [[R:%.*]] = extractelement <16 x i4> [[V]], i64 3
; ANY-NEXT:    ret i4 [[R]]
; LE-LABEL: @bitcast_scalar_legal_type_index3(
; LE-NEXT:    [[EXTELT_OFFSET:%.*]] = lshr i64 [[X:%.*]], 12
; LE-NEXT:    [[R:%.*]] = trunc i64 [[EXTELT_OFFSET]] to i4
; LE-NEXT:    ret i4 [[R]]
;
; BE-LABEL: @bitcast_scalar_legal_type_index3(
; BE-NEXT:    [[EXTELT_OFFSET:%.*]] = lshr i64 [[X:%.*]], 48
; BE-NEXT:    [[R:%.*]] = trunc i64 [[EXTELT_OFFSET]] to i4
; BE-NEXT:    ret i4 [[R]]
;
  %v = bitcast i64 %x to <16 x i4>
  %r = extractelement <16 x i4> %v, i64 3
  ret i4 %r
}

; negative test - don't create a shift for an illegal type.

define i8 @bitcast_scalar_illegal_type_index1(i128 %x) {
; ANY-LABEL: @bitcast_scalar_illegal_type_index1(
; ANY-NEXT:    [[V:%.*]] = bitcast i128 [[X:%.*]] to <16 x i8>
@@ -374,6 +394,8 @@ define i8 @bitcast_scalar_illegal_type_index1(i128 %x) {
  ret i8 %r
}

; negative test - can't use shift/trunc on FP

define i8 @bitcast_fp_index0(float %x) {
; ANY-LABEL: @bitcast_fp_index0(
; ANY-NEXT:    [[V:%.*]] = bitcast float [[X:%.*]] to <4 x i8>
@@ -385,6 +407,8 @@ define i8 @bitcast_fp_index0(float %x) {
  ret i8 %r
}

; negative test - can't have FP dest type without a cast

define half @bitcast_fpvec_index0(i32 %x) {
; ANY-LABEL: @bitcast_fpvec_index0(
; ANY-NEXT:    [[V:%.*]] = bitcast i32 [[X:%.*]] to <2 x half>
@@ -396,6 +420,8 @@ define half @bitcast_fpvec_index0(i32 %x) {
  ret half %r
}

; negative test - need constant index

define i8 @bitcast_scalar_index_variable(i32 %x, i64 %y) {
; ANY-LABEL: @bitcast_scalar_index_variable(
; ANY-NEXT:    [[V:%.*]] = bitcast i32 [[X:%.*]] to <4 x i8>
@@ -407,6 +433,8 @@ define i8 @bitcast_scalar_index_variable(i32 %x, i64 %y) {
  ret i8 %r
}

; negative test - no extra uses

define i8 @bitcast_scalar_index0_use(i64 %x) {
; ANY-LABEL: @bitcast_scalar_index0_use(
; ANY-NEXT:    [[V:%.*]] = bitcast i64 [[X:%.*]] to <8 x i8>