Unverified Commit 2ec7bba7 authored by Florian Hahn's avatar Florian Hahn
Browse files

Recommit "[VPlan] Insert Trunc/Exts for reductions directly in VPlan."

This reverts commit e4ea0997.

The recommit fixes a reported crash by adding a missing check to make
sure the cast recipes are only introduced when vectorizing.

Test coverage added in 3cac608f.
Original commit message:
   Update the code to create Trunc/Ext recipes directly in
    adjustRecipesForReductions instead of fixing it up later in
    fixReductions.

    This explicitly models the required conversions and also makes sure they
    are generated at the right place (instead of after the exit condition),
    hence the changes in a few tests.
parent 75d6f508
Loading
Loading
Loading
Loading
+35 −32
Original line number Diff line number Diff line
@@ -3792,8 +3792,6 @@ void InnerLoopVectorizer::fixReduction(VPReductionPHIRecipe *PhiR,
    State.setDebugLocFrom(I->getDebugLoc());

  VPValue *LoopExitInstDef = PhiR->getBackedgeValue();
  // This is the vector-clone of the value that leaves the loop.
  Type *VecTy = State.get(LoopExitInstDef, 0)->getType();

  // Before each round, move the insertion point right between
  // the PHIs and the values we are going to write.
@@ -3805,10 +3803,6 @@ void InnerLoopVectorizer::fixReduction(VPReductionPHIRecipe *PhiR,
  State.setDebugLocFrom(LoopExitInst->getDebugLoc());

  Type *PhiTy = OrigPhi->getType();

  VPBasicBlock *LatchVPBB =
      PhiR->getParent()->getEnclosingLoopRegion()->getExitingBasicBlock();
  BasicBlock *VectorLoopLatch = State.CFG.VPBB2IRBB[LatchVPBB];
  // If tail is folded by masking, the vector value to leave the loop should be
  // a Select choosing between the vectorized LoopExitInst and vectorized Phi,
  // instead of the former. For an inloop reduction the reduction will already
@@ -3834,24 +3828,13 @@ void InnerLoopVectorizer::fixReduction(VPReductionPHIRecipe *PhiR,
  // then extend the loop exit value to enable InstCombine to evaluate the
  // entire expression in the smaller type.
  if (VF.isVector() && PhiTy != RdxDesc.getRecurrenceType()) {
    assert(!PhiR->isInLoop() && "Unexpected truncated inloop reduction!");
    Type *RdxVecTy = VectorType::get(RdxDesc.getRecurrenceType(), VF);
    Builder.SetInsertPoint(VectorLoopLatch->getTerminator());
    for (unsigned Part = 0; Part < UF; ++Part) {
      Value *Trunc = Builder.CreateTrunc(RdxParts[Part], RdxVecTy);
      Value *Extnd = RdxDesc.isSigned() ? Builder.CreateSExt(Trunc, VecTy)
                                        : Builder.CreateZExt(Trunc, VecTy);
      for (User *U : llvm::make_early_inc_range(RdxParts[Part]->users()))
        if (U != Trunc) {
          U->replaceUsesOfWith(RdxParts[Part], Extnd);
          RdxParts[Part] = Extnd;
        }
    }
    Builder.SetInsertPoint(LoopMiddleBlock,
                           LoopMiddleBlock->getFirstInsertionPt());
    for (unsigned Part = 0; Part < UF; ++Part)
    Type *RdxVecTy = VectorType::get(RdxDesc.getRecurrenceType(), VF);
    for (unsigned Part = 0; Part < UF; ++Part) {
      RdxParts[Part] = Builder.CreateTrunc(RdxParts[Part], RdxVecTy);
    }
  }

  // Reduce all of the unrolled parts into a single vector.
  Value *ReducedPartRdx = RdxParts[0];
@@ -9155,18 +9138,19 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
      PreviousLink = RedRecipe;
    }
  }

  // If tail is folded by masking, introduce selects between the phi
  // and the live-out instruction of each reduction, at the beginning of the
  // dedicated latch block.
  if (CM.foldTailByMasking()) {
    Builder.setInsertPoint(&*LatchVPBB->begin());
    for (VPRecipeBase &R :
         Plan->getVectorLoopRegion()->getEntryBasicBlock()->phis()) {
    VPReductionPHIRecipe *PhiR = dyn_cast<VPReductionPHIRecipe>(&R);
    if (!PhiR || PhiR->isInLoop())
      continue;

    const RecurrenceDescriptor &RdxDesc = PhiR->getRecurrenceDescriptor();
    auto *Result = PhiR->getBackedgeValue()->getDefiningRecipe();
    // If tail is folded by masking, introduce selects between the phi
    // and the live-out instruction of each reduction, at the beginning of the
    // dedicated latch block.
    if (CM.foldTailByMasking()) {
      VPValue *Cond =
          RecipeBuilder.createBlockInMask(OrigLoop->getHeader(), *Plan);
      VPValue *Red = PhiR->getBackedgeValue();
@@ -9174,16 +9158,35 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
             "reduction recipe must be defined before latch");
      FastMathFlags FMFs = RdxDesc.getFastMathFlags();
      Type *PhiTy = PhiR->getOperand(0)->getLiveInIRValue()->getType();
      auto *Select =
      Result =
          PhiTy->isFloatingPointTy()
              ? new VPInstruction(Instruction::Select, {Cond, Red, PhiR}, FMFs)
              : new VPInstruction(Instruction::Select, {Cond, Red, PhiR});
      Select->insertBefore(&*Builder.getInsertPoint());
      Result->insertBefore(&*Builder.getInsertPoint());
      if (PreferPredicatedReductionSelect ||
          TTI.preferPredicatedReductionSelect(
              PhiR->getRecurrenceDescriptor().getOpcode(), PhiTy,
              TargetTransformInfo::ReductionFlags()))
        PhiR->setOperand(1, Select);
        PhiR->setOperand(1, Result->getVPSingleValue());
    }
    // If the vector reduction can be performed in a smaller type, we truncate
    // then extend the loop exit value to enable InstCombine to evaluate the
    // entire expression in the smaller type.
    Type *PhiTy = PhiR->getStartValue()->getLiveInIRValue()->getType();
    if (MinVF.isVector() && PhiTy != RdxDesc.getRecurrenceType()) {
      assert(!PhiR->isInLoop() && "Unexpected truncated inloop reduction!");
      Type *RdxTy = RdxDesc.getRecurrenceType();
      auto *Trunc = new VPWidenCastRecipe(Instruction::Trunc,
                                          Result->getVPSingleValue(), RdxTy);
      auto *Extnd =
          RdxDesc.isSigned()
              ? new VPWidenCastRecipe(Instruction::SExt, Trunc, PhiTy)
              : new VPWidenCastRecipe(Instruction::ZExt, Trunc, PhiTy);

      Trunc->insertAfter(Result);
      Extnd->insertAfter(Trunc);
      Result->getVPSingleValue()->replaceAllUsesWith(Extnd);
      Trunc->setOperand(0, Result->getVPSingleValue());
    }
  }

+4 −4
Original line number Diff line number Diff line
@@ -207,10 +207,10 @@ define i16 @reduction_or_trunc(ptr noalias nocapture %ptr) {
; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <4 x i16>, ptr [[TMP3]], align 2
; CHECK-NEXT:    [[TMP4:%.*]] = zext <4 x i16> [[WIDE_LOAD]] to <4 x i32>
; CHECK-NEXT:    [[TMP5:%.*]] = or <4 x i32> [[TMP1]], [[TMP4]]
; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 4
; CHECK-NEXT:    [[TMP6:%.*]] = icmp eq i32 [[INDEX_NEXT]], 256
; CHECK-NEXT:    [[TMP7:%.*]] = trunc <4 x i32> [[TMP5]] to <4 x i16>
; CHECK-NEXT:    [[TMP8]] = zext <4 x i16> [[TMP7]] to <4 x i32>
; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 4
; CHECK-NEXT:    [[TMP6:%.*]] = icmp eq i32 [[INDEX_NEXT]], 256
; CHECK-NEXT:    br i1 [[TMP6]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP8:![0-9]+]]
; CHECK:       middle.block:
; CHECK-NEXT:    [[TMP9:%.*]] = trunc <4 x i32> [[TMP8]] to <4 x i16>
@@ -234,10 +234,10 @@ define i16 @reduction_or_trunc(ptr noalias nocapture %ptr) {
; CHECK-NEXT:    [[WIDE_LOAD4:%.*]] = load <4 x i16>, ptr [[TMP16]], align 2
; CHECK-NEXT:    [[TMP17:%.*]] = zext <4 x i16> [[WIDE_LOAD4]] to <4 x i32>
; CHECK-NEXT:    [[TMP18:%.*]] = or <4 x i32> [[TMP14]], [[TMP17]]
; CHECK-NEXT:    [[INDEX_NEXT5]] = add nuw i32 [[INDEX2]], 4
; CHECK-NEXT:    [[TMP19:%.*]] = icmp eq i32 [[INDEX_NEXT5]], 256
; CHECK-NEXT:    [[TMP20:%.*]] = trunc <4 x i32> [[TMP18]] to <4 x i16>
; CHECK-NEXT:    [[TMP21]] = zext <4 x i16> [[TMP20]] to <4 x i32>
; CHECK-NEXT:    [[INDEX_NEXT5]] = add nuw i32 [[INDEX2]], 4
; CHECK-NEXT:    [[TMP19:%.*]] = icmp eq i32 [[INDEX_NEXT5]], 256
; CHECK-NEXT:    br i1 [[TMP19]], label [[VEC_EPILOG_MIDDLE_BLOCK:%.*]], label [[VEC_EPILOG_VECTOR_BODY]], !llvm.loop [[LOOP9:![0-9]+]]
; CHECK:       vec.epilog.middle.block:
; CHECK-NEXT:    [[TMP22:%.*]] = trunc <4 x i32> [[TMP21]] to <4 x i16>
+0 −2
Original line number Diff line number Diff line
@@ -267,5 +267,3 @@ exit:
  %count.0 = trunc i32 %add.lcssa to i16
  ret i16 %count.0
}
;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
; DBG: {{.*}}
+4 −4
Original line number Diff line number Diff line
@@ -22,10 +22,10 @@ define i8 @PR34687(i1 %c, i32 %x, i32 %n) {
; CHECK-NEXT:    [[TMP0:%.*]] = select <4 x i1> [[BROADCAST_SPLAT]], <4 x i32> undef, <4 x i32> <i32 1, i32 1, i32 1, i32 1>
; CHECK-NEXT:    [[TMP1:%.*]] = and <4 x i32> [[VEC_PHI]], <i32 255, i32 255, i32 255, i32 255>
; CHECK-NEXT:    [[TMP2:%.*]] = add <4 x i32> [[TMP1]], [[BROADCAST_SPLAT2]]
; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 4
; CHECK-NEXT:    [[TMP3:%.*]] = icmp eq i32 [[INDEX_NEXT]], [[N_VEC]]
; CHECK-NEXT:    [[TMP4:%.*]] = trunc <4 x i32> [[TMP2]] to <4 x i8>
; CHECK-NEXT:    [[TMP5]] = zext <4 x i8> [[TMP4]] to <4 x i32>
; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 4
; CHECK-NEXT:    [[TMP3:%.*]] = icmp eq i32 [[INDEX_NEXT]], [[N_VEC]]
; CHECK-NEXT:    br i1 [[TMP3]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
; CHECK:       middle.block:
; CHECK-NEXT:    [[TMP6:%.*]] = trunc <4 x i32> [[TMP5]] to <4 x i8>
@@ -99,10 +99,10 @@ define i32 @PR35734(i32 %x, i32 %y) {
; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi <4 x i32> [ [[TMP2]], [[VECTOR_PH]] ], [ [[TMP7:%.*]], [[VECTOR_BODY]] ]
; CHECK-NEXT:    [[TMP3:%.*]] = and <4 x i32> [[VEC_PHI]], <i32 1, i32 1, i32 1, i32 1>
; CHECK-NEXT:    [[TMP4:%.*]] = add <4 x i32> [[TMP3]], <i32 -1, i32 -1, i32 -1, i32 -1>
; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 4
; CHECK-NEXT:    [[TMP5:%.*]] = icmp eq i32 [[INDEX_NEXT]], [[N_VEC]]
; CHECK-NEXT:    [[TMP6:%.*]] = trunc <4 x i32> [[TMP4]] to <4 x i1>
; CHECK-NEXT:    [[TMP7]] = sext <4 x i1> [[TMP6]] to <4 x i32>
; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 4
; CHECK-NEXT:    [[TMP5:%.*]] = icmp eq i32 [[INDEX_NEXT]], [[N_VEC]]
; CHECK-NEXT:    br i1 [[TMP5]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP4:![0-9]+]]
; CHECK:       middle.block:
; CHECK-NEXT:    [[TMP8:%.*]] = trunc <4 x i32> [[TMP7]] to <4 x i1>
+4 −4
Original line number Diff line number Diff line
@@ -17,14 +17,14 @@ define i8 @reduction_add_trunc(ptr noalias nocapture %A) {
; CHECK-NEXT:    [[TMP27:%.*]] = zext <vscale x 8 x i8> [[WIDE_LOAD2]] to <vscale x 8 x i32>
; CHECK-NEXT:    [[TMP28:%.*]] = add <vscale x 8 x i32> [[TMP14]], [[TMP26]]
; CHECK-NEXT:    [[TMP29:%.*]] = add <vscale x 8 x i32> [[TMP15]], [[TMP27]]
; CHECK-NEXT:    [[TMP33:%.*]] = trunc <vscale x 8 x i32> [[TMP28]] to <vscale x 8 x i8>
; CHECK-NEXT:    [[TMP35:%.*]] = trunc <vscale x 8 x i32> [[TMP29]] to <vscale x 8 x i8>
; CHECK-NEXT:    [[TMP34]] = zext <vscale x 8 x i8> [[TMP33]] to <vscale x 8 x i32>
; CHECK-NEXT:    [[TMP36]] = zext <vscale x 8 x i8> [[TMP35]] to <vscale x 8 x i32>
; CHECK-NEXT:    [[TMP30:%.*]] = call i32 @llvm.vscale.i32()
; CHECK-NEXT:    [[TMP31:%.*]] = mul i32 [[TMP30]], 16
; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i32 [[INDEX]], [[TMP31]]
; CHECK-NEXT:    [[TMP32:%.*]] = icmp eq i32 [[INDEX_NEXT]], {{%.*}}
; CHECK-NEXT:    [[TMP33:%.*]] = trunc <vscale x 8 x i32> [[TMP28]] to <vscale x 8 x i8>
; CHECK-NEXT:    [[TMP34]] = zext <vscale x 8 x i8> [[TMP33]] to <vscale x 8 x i32>
; CHECK-NEXT:    [[TMP35:%.*]] = trunc <vscale x 8 x i32> [[TMP29]] to <vscale x 8 x i8>
; CHECK-NEXT:    [[TMP36]] = zext <vscale x 8 x i8> [[TMP35]] to <vscale x 8 x i32>
; CHECK:       middle.block:
; CHECK-NEXT:    [[TMP37:%.*]] = trunc <vscale x 8 x i32> [[TMP34]] to <vscale x 8 x i8>
; CHECK-NEXT:    [[TMP38:%.*]] = trunc <vscale x 8 x i32> [[TMP36]] to <vscale x 8 x i8>