Unverified Commit fd311126 authored by Florian Hahn's avatar Florian Hahn
Browse files

[VPlan] Insert Trunc/Exts for reductions directly in VPlan.

Update the code to create Trunc/Ext recipes directly in
adjustRecipesForReductions instead of fixing it up later in
fixReductions.

This explicitly models the required conversions and also makes sure they
are generated at the right place (instead of after the exit condition),
hence the changes in a few tests.
parent dd64c82c
Loading
Loading
Loading
Loading
+35 −32
Original line number Diff line number Diff line
@@ -3792,8 +3792,6 @@ void InnerLoopVectorizer::fixReduction(VPReductionPHIRecipe *PhiR,
    State.setDebugLocFrom(I->getDebugLoc());

  VPValue *LoopExitInstDef = PhiR->getBackedgeValue();
  // This is the vector-clone of the value that leaves the loop.
  Type *VecTy = State.get(LoopExitInstDef, 0)->getType();

  // Before each round, move the insertion point right between
  // the PHIs and the values we are going to write.
@@ -3805,10 +3803,6 @@ void InnerLoopVectorizer::fixReduction(VPReductionPHIRecipe *PhiR,
  State.setDebugLocFrom(LoopExitInst->getDebugLoc());

  Type *PhiTy = OrigPhi->getType();

  VPBasicBlock *LatchVPBB =
      PhiR->getParent()->getEnclosingLoopRegion()->getExitingBasicBlock();
  BasicBlock *VectorLoopLatch = State.CFG.VPBB2IRBB[LatchVPBB];
  // If tail is folded by masking, the vector value to leave the loop should be
  // a Select choosing between the vectorized LoopExitInst and vectorized Phi,
  // instead of the former. For an inloop reduction the reduction will already
@@ -3834,24 +3828,13 @@ void InnerLoopVectorizer::fixReduction(VPReductionPHIRecipe *PhiR,
  // then extend the loop exit value to enable InstCombine to evaluate the
  // entire expression in the smaller type.
  if (VF.isVector() && PhiTy != RdxDesc.getRecurrenceType()) {
    assert(!PhiR->isInLoop() && "Unexpected truncated inloop reduction!");
    Type *RdxVecTy = VectorType::get(RdxDesc.getRecurrenceType(), VF);
    Builder.SetInsertPoint(VectorLoopLatch->getTerminator());
    for (unsigned Part = 0; Part < UF; ++Part) {
      Value *Trunc = Builder.CreateTrunc(RdxParts[Part], RdxVecTy);
      Value *Extnd = RdxDesc.isSigned() ? Builder.CreateSExt(Trunc, VecTy)
                                        : Builder.CreateZExt(Trunc, VecTy);
      for (User *U : llvm::make_early_inc_range(RdxParts[Part]->users()))
        if (U != Trunc) {
          U->replaceUsesOfWith(RdxParts[Part], Extnd);
          RdxParts[Part] = Extnd;
        }
    }
    Builder.SetInsertPoint(LoopMiddleBlock,
                           LoopMiddleBlock->getFirstInsertionPt());
    for (unsigned Part = 0; Part < UF; ++Part)
    Type *RdxVecTy = VectorType::get(RdxDesc.getRecurrenceType(), VF);
    for (unsigned Part = 0; Part < UF; ++Part) {
      RdxParts[Part] = Builder.CreateTrunc(RdxParts[Part], RdxVecTy);
    }
  }

  // Reduce all of the unrolled parts into a single vector.
  Value *ReducedPartRdx = RdxParts[0];
@@ -9155,18 +9138,19 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
      PreviousLink = RedRecipe;
    }
  }

  // If tail is folded by masking, introduce selects between the phi
  // and the live-out instruction of each reduction, at the beginning of the
  // dedicated latch block.
  if (CM.foldTailByMasking()) {
    Builder.setInsertPoint(&*LatchVPBB->begin());
    for (VPRecipeBase &R :
         Plan->getVectorLoopRegion()->getEntryBasicBlock()->phis()) {
    VPReductionPHIRecipe *PhiR = dyn_cast<VPReductionPHIRecipe>(&R);
    if (!PhiR || PhiR->isInLoop())
      continue;

    const RecurrenceDescriptor &RdxDesc = PhiR->getRecurrenceDescriptor();
    auto *Result = PhiR->getBackedgeValue()->getDefiningRecipe();
    // If tail is folded by masking, introduce selects between the phi
    // and the live-out instruction of each reduction, at the beginning of the
    // dedicated latch block.
    if (CM.foldTailByMasking()) {
      VPValue *Cond =
          RecipeBuilder.createBlockInMask(OrigLoop->getHeader(), *Plan);
      VPValue *Red = PhiR->getBackedgeValue();
@@ -9174,16 +9158,35 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
             "reduction recipe must be defined before latch");
      FastMathFlags FMFs = RdxDesc.getFastMathFlags();
      Type *PhiTy = PhiR->getOperand(0)->getLiveInIRValue()->getType();
      auto *Select =
      Result =
          PhiTy->isFloatingPointTy()
              ? new VPInstruction(Instruction::Select, {Cond, Red, PhiR}, FMFs)
              : new VPInstruction(Instruction::Select, {Cond, Red, PhiR});
      Select->insertBefore(&*Builder.getInsertPoint());
      Result->insertBefore(&*Builder.getInsertPoint());
      if (PreferPredicatedReductionSelect ||
          TTI.preferPredicatedReductionSelect(
              PhiR->getRecurrenceDescriptor().getOpcode(), PhiTy,
              TargetTransformInfo::ReductionFlags()))
        PhiR->setOperand(1, Select);
        PhiR->setOperand(1, Result->getVPSingleValue());
    }
    // If the vector reduction can be performed in a smaller type, we truncate
    // then extend the loop exit value to enable InstCombine to evaluate the
    // entire expression in the smaller type.
    Type *PhiTy = PhiR->getStartValue()->getLiveInIRValue()->getType();
    if (PhiTy != RdxDesc.getRecurrenceType()) {
      assert(!PhiR->isInLoop() && "Unexpected truncated inloop reduction!");
      Type *RdxTy = RdxDesc.getRecurrenceType();
      auto *Trunc = new VPWidenCastRecipe(Instruction::Trunc,
                                          Result->getVPSingleValue(), RdxTy);
      auto *Extnd =
          RdxDesc.isSigned()
              ? new VPWidenCastRecipe(Instruction::SExt, Trunc, PhiTy)
              : new VPWidenCastRecipe(Instruction::ZExt, Trunc, PhiTy);

      Trunc->insertAfter(Result);
      Extnd->insertAfter(Trunc);
      Result->getVPSingleValue()->replaceAllUsesWith(Extnd);
      Trunc->setOperand(0, Result->getVPSingleValue());
    }
  }

+4 −4
Original line number Diff line number Diff line
@@ -207,10 +207,10 @@ define i16 @reduction_or_trunc(ptr noalias nocapture %ptr) {
; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <4 x i16>, ptr [[TMP3]], align 2
; CHECK-NEXT:    [[TMP4:%.*]] = zext <4 x i16> [[WIDE_LOAD]] to <4 x i32>
; CHECK-NEXT:    [[TMP5:%.*]] = or <4 x i32> [[TMP1]], [[TMP4]]
; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 4
; CHECK-NEXT:    [[TMP6:%.*]] = icmp eq i32 [[INDEX_NEXT]], 256
; CHECK-NEXT:    [[TMP7:%.*]] = trunc <4 x i32> [[TMP5]] to <4 x i16>
; CHECK-NEXT:    [[TMP8]] = zext <4 x i16> [[TMP7]] to <4 x i32>
; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 4
; CHECK-NEXT:    [[TMP6:%.*]] = icmp eq i32 [[INDEX_NEXT]], 256
; CHECK-NEXT:    br i1 [[TMP6]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP8:![0-9]+]]
; CHECK:       middle.block:
; CHECK-NEXT:    [[TMP9:%.*]] = trunc <4 x i32> [[TMP8]] to <4 x i16>
@@ -234,10 +234,10 @@ define i16 @reduction_or_trunc(ptr noalias nocapture %ptr) {
; CHECK-NEXT:    [[WIDE_LOAD4:%.*]] = load <4 x i16>, ptr [[TMP16]], align 2
; CHECK-NEXT:    [[TMP17:%.*]] = zext <4 x i16> [[WIDE_LOAD4]] to <4 x i32>
; CHECK-NEXT:    [[TMP18:%.*]] = or <4 x i32> [[TMP14]], [[TMP17]]
; CHECK-NEXT:    [[INDEX_NEXT5]] = add nuw i32 [[INDEX2]], 4
; CHECK-NEXT:    [[TMP19:%.*]] = icmp eq i32 [[INDEX_NEXT5]], 256
; CHECK-NEXT:    [[TMP20:%.*]] = trunc <4 x i32> [[TMP18]] to <4 x i16>
; CHECK-NEXT:    [[TMP21]] = zext <4 x i16> [[TMP20]] to <4 x i32>
; CHECK-NEXT:    [[INDEX_NEXT5]] = add nuw i32 [[INDEX2]], 4
; CHECK-NEXT:    [[TMP19:%.*]] = icmp eq i32 [[INDEX_NEXT5]], 256
; CHECK-NEXT:    br i1 [[TMP19]], label [[VEC_EPILOG_MIDDLE_BLOCK:%.*]], label [[VEC_EPILOG_VECTOR_BODY]], !llvm.loop [[LOOP9:![0-9]+]]
; CHECK:       vec.epilog.middle.block:
; CHECK-NEXT:    [[TMP22:%.*]] = trunc <4 x i32> [[TMP21]] to <4 x i16>
+4 −4
Original line number Diff line number Diff line
@@ -22,10 +22,10 @@ define i8 @PR34687(i1 %c, i32 %x, i32 %n) {
; CHECK-NEXT:    [[TMP0:%.*]] = select <4 x i1> [[BROADCAST_SPLAT]], <4 x i32> undef, <4 x i32> <i32 1, i32 1, i32 1, i32 1>
; CHECK-NEXT:    [[TMP1:%.*]] = and <4 x i32> [[VEC_PHI]], <i32 255, i32 255, i32 255, i32 255>
; CHECK-NEXT:    [[TMP2:%.*]] = add <4 x i32> [[TMP1]], [[BROADCAST_SPLAT2]]
; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 4
; CHECK-NEXT:    [[TMP3:%.*]] = icmp eq i32 [[INDEX_NEXT]], [[N_VEC]]
; CHECK-NEXT:    [[TMP4:%.*]] = trunc <4 x i32> [[TMP2]] to <4 x i8>
; CHECK-NEXT:    [[TMP5]] = zext <4 x i8> [[TMP4]] to <4 x i32>
; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 4
; CHECK-NEXT:    [[TMP3:%.*]] = icmp eq i32 [[INDEX_NEXT]], [[N_VEC]]
; CHECK-NEXT:    br i1 [[TMP3]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
; CHECK:       middle.block:
; CHECK-NEXT:    [[TMP6:%.*]] = trunc <4 x i32> [[TMP5]] to <4 x i8>
@@ -99,10 +99,10 @@ define i32 @PR35734(i32 %x, i32 %y) {
; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi <4 x i32> [ [[TMP2]], [[VECTOR_PH]] ], [ [[TMP7:%.*]], [[VECTOR_BODY]] ]
; CHECK-NEXT:    [[TMP3:%.*]] = and <4 x i32> [[VEC_PHI]], <i32 1, i32 1, i32 1, i32 1>
; CHECK-NEXT:    [[TMP4:%.*]] = add <4 x i32> [[TMP3]], <i32 -1, i32 -1, i32 -1, i32 -1>
; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 4
; CHECK-NEXT:    [[TMP5:%.*]] = icmp eq i32 [[INDEX_NEXT]], [[N_VEC]]
; CHECK-NEXT:    [[TMP6:%.*]] = trunc <4 x i32> [[TMP4]] to <4 x i1>
; CHECK-NEXT:    [[TMP7]] = sext <4 x i1> [[TMP6]] to <4 x i32>
; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 4
; CHECK-NEXT:    [[TMP5:%.*]] = icmp eq i32 [[INDEX_NEXT]], [[N_VEC]]
; CHECK-NEXT:    br i1 [[TMP5]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP4:![0-9]+]]
; CHECK:       middle.block:
; CHECK-NEXT:    [[TMP8:%.*]] = trunc <4 x i32> [[TMP7]] to <4 x i1>
+4 −4
Original line number Diff line number Diff line
@@ -17,14 +17,14 @@ define i8 @reduction_add_trunc(ptr noalias nocapture %A) {
; CHECK-NEXT:    [[TMP27:%.*]] = zext <vscale x 8 x i8> [[WIDE_LOAD2]] to <vscale x 8 x i32>
; CHECK-NEXT:    [[TMP28:%.*]] = add <vscale x 8 x i32> [[TMP14]], [[TMP26]]
; CHECK-NEXT:    [[TMP29:%.*]] = add <vscale x 8 x i32> [[TMP15]], [[TMP27]]
; CHECK-NEXT:    [[TMP33:%.*]] = trunc <vscale x 8 x i32> [[TMP28]] to <vscale x 8 x i8>
; CHECK-NEXT:    [[TMP35:%.*]] = trunc <vscale x 8 x i32> [[TMP29]] to <vscale x 8 x i8>
; CHECK-NEXT:    [[TMP34]] = zext <vscale x 8 x i8> [[TMP33]] to <vscale x 8 x i32>
; CHECK-NEXT:    [[TMP36]] = zext <vscale x 8 x i8> [[TMP35]] to <vscale x 8 x i32>
; CHECK-NEXT:    [[TMP30:%.*]] = call i32 @llvm.vscale.i32()
; CHECK-NEXT:    [[TMP31:%.*]] = mul i32 [[TMP30]], 16
; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i32 [[INDEX]], [[TMP31]]
; CHECK-NEXT:    [[TMP32:%.*]] = icmp eq i32 [[INDEX_NEXT]], {{%.*}}
; CHECK-NEXT:    [[TMP33:%.*]] = trunc <vscale x 8 x i32> [[TMP28]] to <vscale x 8 x i8>
; CHECK-NEXT:    [[TMP34]] = zext <vscale x 8 x i8> [[TMP33]] to <vscale x 8 x i32>
; CHECK-NEXT:    [[TMP35:%.*]] = trunc <vscale x 8 x i32> [[TMP29]] to <vscale x 8 x i8>
; CHECK-NEXT:    [[TMP36]] = zext <vscale x 8 x i8> [[TMP35]] to <vscale x 8 x i32>
; CHECK:       middle.block:
; CHECK-NEXT:    [[TMP37:%.*]] = trunc <vscale x 8 x i32> [[TMP34]] to <vscale x 8 x i8>
; CHECK-NEXT:    [[TMP38:%.*]] = trunc <vscale x 8 x i32> [[TMP36]] to <vscale x 8 x i8>