Unverified Commit 3e58dd19 authored by Florian Hahn's avatar Florian Hahn
Browse files

[LV] Move reduction PHI node fixup to VPlan::execute (NFC).

All information to fix-up the reduction phi nodes in the vectorized loop
is available in VPlan now. This patch moves the code to do so, to make
this clearer. Fixing up the loop exit value still relies on other
information and remains outside of VPlan for now.

Reviewed By: Ayal

Differential Revision: https://reviews.llvm.org/D100113
parent 835cbfa8
Loading
Loading
Loading
Loading
+1 −18
Original line number Diff line number Diff line
@@ -594,8 +594,7 @@ protected:
  /// update their users.
  void fixFirstOrderRecurrence(VPWidenPHIRecipe *PhiR, VPTransformState &State);

  /// Fix a reduction cross-iteration phi. This is the second phase of
  /// vectorizing this phi node.
  /// Create code for the loop exit value of the reduction.
  void fixReduction(VPReductionPHIRecipe *Phi, VPTransformState &State);

  /// Clear NSW/NUW flags from reduction instructions if necessary.
@@ -4303,22 +4302,6 @@ void InnerLoopVectorizer::fixReduction(VPReductionPHIRecipe *PhiR,
  // Wrap flags are in general invalid after vectorization, clear them.
  clearReductionWrapFlags(RdxDesc, State);

  // Fix the vector-loop phi.

  // Reductions do not have to start at zero. They can start with
  // any loop invariant values.
  BasicBlock *VectorLoopLatch = LI->getLoopFor(LoopVectorBody)->getLoopLatch();

  unsigned LastPartForNewPhi = PhiR->isOrdered() ? 1 : UF;
  for (unsigned Part = 0; Part < LastPartForNewPhi; ++Part) {
    Value *VecRdxPhi = State.get(PhiR->getVPSingleValue(), Part);
    Value *Val = State.get(PhiR->getBackedgeValue(), Part);
    if (PhiR->isOrdered())
      Val = State.get(PhiR->getBackedgeValue(), UF - 1);

    cast<PHINode>(VecRdxPhi)->addIncoming(Val, VectorLoopLatch);
  }

  // Before each round, move the insertion point right between
  // the PHIs and the values we are going to write.
  // This allows us to write both PHINodes and the extractelement
+20 −8
Original line number Diff line number Diff line
@@ -815,16 +815,25 @@ void VPlan::execute(VPTransformState *State) {
  for (VPBlockBase *Block : depth_first(Entry))
    Block->execute(State);

  // Fix the latch value of the first-order recurrences in the vector loop. Only
  // a single part is generated, regardless of the UF.
  // Fix the latch value of reduction and first-order recurrences phis in the
  // vector loop.
  VPBasicBlock *Header = Entry->getEntryBasicBlock();
  for (VPRecipeBase &R : Header->phis()) {
    if (auto *FOR = dyn_cast<VPFirstOrderRecurrencePHIRecipe>(&R)) {
      auto *VecPhi = cast<PHINode>(State->get(FOR, 0));

      VPValue *PreviousDef = FOR->getBackedgeValue();
      Value *Incoming = State->get(PreviousDef, State->UF - 1);
      VecPhi->addIncoming(Incoming, VectorLatchBB);
    auto *PhiR = dyn_cast<VPWidenPHIRecipe>(&R);
    if (!PhiR || !(isa<VPFirstOrderRecurrencePHIRecipe>(&R) ||
                   isa<VPReductionPHIRecipe>(&R)))
      continue;
    // For first-order recurrences and in-order reduction phis, only a single
    // part is generated, which provides the last part from the previous
    // iteration. Otherwise all UF parts are generated.
    bool SinglePartNeeded = isa<VPFirstOrderRecurrencePHIRecipe>(&R) ||
                            cast<VPReductionPHIRecipe>(&R)->isOrdered();
    unsigned LastPartForNewPhi = SinglePartNeeded ? 1 : State->UF;
    for (unsigned Part = 0; Part < LastPartForNewPhi; ++Part) {
      Value *VecPhi = State->get(PhiR, Part);
      Value *Val = State->get(PhiR->getBackedgeValue(),
                              SinglePartNeeded ? State->UF - 1 : Part);
      cast<PHINode>(VecPhi)->addIncoming(Val, VectorLatchBB);
    }
  }

@@ -1319,6 +1328,9 @@ void VPReductionPHIRecipe::execute(VPTransformState &State) {
        PHINode::Create(VecTy, 2, "vec.phi", &*HeaderBB->getFirstInsertionPt());
    State.set(this, EntryPart, Part);
  }

  // Reductions do not have to start at zero. They can start with
  // any loop invariant values.
  VPValue *StartVPV = getStartValue();
  Value *StartV = StartVPV->getLiveInIRValue();