Commit b567ff2f authored by Sjoerd Meijer's avatar Sjoerd Meijer
Browse files

[ARM][MVE] Tail-predication: support constant trip count

We had support for runtime trip count values, but not constants, and this adds
supports for that.

And added a minor optimisation while I was add it: don't invoke Cleanup when
there's nothing to clean up.

Differential Revision: https://reviews.llvm.org/D73198
parent 6c2df5d1
Loading
Loading
Loading
Loading
+171 −92
Original line number Diff line number Diff line
@@ -55,6 +55,27 @@ DisableTailPredication("disable-mve-tail-predication", cl::Hidden,
                       cl::desc("Disable MVE Tail Predication"));
namespace {

// Bookkeeping for pattern matching the loop trip count and the number of
// elements processed by the loop.
struct TripCountPattern {
  // The Predicate used by the masked loads/stores, i.e. an icmp instruction
  // which calculates active/inactive lanes
  Instruction *Predicate = nullptr;

  // The add instruction that increments the IV
  Value *TripCount = nullptr;

  // The number of elements processed by the vector loop.
  Value *NumElements = nullptr;

  VectorType *VecTy = nullptr;
  Instruction *Shuffle = nullptr;
  Instruction *Induction = nullptr;

  TripCountPattern(Instruction *P, Value *TC, VectorType *VT)
      : Predicate(P), TripCount(TC), VecTy(VT){};
};

class MVETailPredication : public LoopPass {
  SmallVector<IntrinsicInst*, 4> MaskedInsts;
  Loop *L = nullptr;
@@ -85,7 +106,6 @@ public:
  bool runOnLoop(Loop *L, LPPassManager&) override;

private:

  /// Perform the relevant checks on the loop and convert if possible.
  bool TryConvert(Value *TripCount);

@@ -94,18 +114,16 @@ private:
  bool IsPredicatedVectorLoop();

  /// Compute a value for the total number of elements that the predicated
  /// loop will process.
  Value *ComputeElements(Value *TripCount, VectorType *VecTy);
  /// loop will process if it is a runtime value.
  bool ComputeRuntimeElements(TripCountPattern &TCP);

  /// Is the icmp that generates an i1 vector, based upon a loop counter
  /// and a limit that is defined outside the loop.
  bool isTailPredicate(Instruction *Predicate, Value *NumElements);
  bool isTailPredicate(TripCountPattern &TCP);

  /// Insert the intrinsic to represent the effect of tail predication.
  void InsertVCTPIntrinsic(Instruction *Predicate,
                           DenseMap<Instruction*, Instruction*> &NewPredicates,
                           VectorType *VecTy,
                           Value *NumElements);
  void InsertVCTPIntrinsic(TripCountPattern &TCP,
                           DenseMap<Instruction *, Instruction *> &NewPredicates);

  /// Rematerialize the iteration count in exit blocks, which enables
  /// ARMLowOverheadLoops to better optimise away loop update statements inside
@@ -213,6 +231,7 @@ bool MVETailPredication::runOnLoop(Loop *L, LPPassManager&) {
  if (!Decrement)
    return false;

  ClonedVCTPInExitBlock = false;
  LLVM_DEBUG(dbgs() << "ARM TP: Running on Loop: " << *L << *Setup << "\n"
             << *Decrement << "\n");

@@ -225,17 +244,17 @@ bool MVETailPredication::runOnLoop(Loop *L, LPPassManager&) {
  return false;
}

bool MVETailPredication::isTailPredicate(Instruction *I, Value *NumElements) {
  // Look for the following:
// Pattern match predicates/masks and determine if they use the loop induction
// variable to control the number of elements processed by the loop. If so,
// the loop is a candidate for tail-predication.
bool MVETailPredication::isTailPredicate(TripCountPattern &TCP) {
  using namespace PatternMatch;

  // %trip.count.minus.1 = add i32 %N, -1
  // %broadcast.splatinsert10 = insertelement <4 x i32> undef,
  //                                          i32 %trip.count.minus.1, i32 0
  // %broadcast.splat11 = shufflevector <4 x i32> %broadcast.splatinsert10,
  //                                    <4 x i32> undef,
  //                                    <4 x i32> zeroinitializer
  // ...
  // ...
  // Pattern match the loop body and find the add with takes the index iv
  // and adds a constant vector to it:
  //
  // vector.body:
  // ..
  // %index = phi i32
  // %broadcast.splatinsert = insertelement <4 x i32> undef, i32 %index, i32 0
  // %broadcast.splat = shufflevector <4 x i32> %broadcast.splatinsert,
@@ -244,48 +263,10 @@ bool MVETailPredication::isTailPredicate(Instruction *I, Value *NumElements) {
  // %induction = add <4 x i32> %broadcast.splat, <i32 0, i32 1, i32 2, i32 3>
  // %pred = icmp ule <4 x i32> %induction, %broadcast.splat11

  // And return whether V == %pred.

  using namespace PatternMatch;

  CmpInst::Predicate Pred;
  Instruction *Shuffle = nullptr;
  Instruction *Induction = nullptr;

  // The vector icmp
  if (!match(I, m_ICmp(Pred, m_Instruction(Induction),
                       m_Instruction(Shuffle))) ||
      Pred != ICmpInst::ICMP_ULE)
    return false;

  // First find the stuff outside the loop which is setting up the limit
  // vector....
  // The invariant shuffle that broadcast the limit into a vector.
  Instruction *Insert = nullptr;
  if (!match(Shuffle, m_ShuffleVector(m_Instruction(Insert), m_Undef(),
                                      m_Zero())))
    return false;

  // Insert the limit into a vector.
  Instruction *BECount = nullptr;
  if (!match(Insert, m_InsertElement(m_Undef(), m_Instruction(BECount),
                                     m_Zero())))
    return false;

  // The limit calculation, backedge count.
  Value *TripCount = nullptr;
  if (!match(BECount, m_Add(m_Value(TripCount), m_AllOnes())))
    return false;

  if (TripCount != NumElements || !L->isLoopInvariant(BECount))
    return false;

  // Now back to searching inside the loop body...
  // Find the add with takes the index iv and adds a constant vector to it.
  Instruction *BroadcastSplat = nullptr;
  Constant *Const = nullptr;
  if (!match(Induction, m_Add(m_Instruction(BroadcastSplat),
                              m_Constant(Const))))
  if (!match(TCP.Induction,
             m_Add(m_Instruction(BroadcastSplat), m_Constant(Const))))
    return false;

  // Check that we're adding <0, 1, 2, 3...
@@ -297,9 +278,10 @@ bool MVETailPredication::isTailPredicate(Instruction *I, Value *NumElements) {
  } else
    return false;

  Instruction *Insert = nullptr;
  // The shuffle which broadcasts the index iv into a vector.
  if (!match(BroadcastSplat, m_ShuffleVector(m_Instruction(Insert), m_Undef(),
                                             m_Zero())))
  if (!match(BroadcastSplat,
             m_ShuffleVector(m_Instruction(Insert), m_Undef(), m_Zero())))
    return false;

  // The insert element which initialises a vector with the index iv.
@@ -361,16 +343,107 @@ bool MVETailPredication::IsPredicatedVectorLoop() {
  return !MaskedInsts.empty();
}

Value* MVETailPredication::ComputeElements(Value *TripCount,
                                           VectorType *VecTy) {
  const SCEV *TripCountSE = SE->getSCEV(TripCount);
  ConstantInt *VF = ConstantInt::get(cast<IntegerType>(TripCount->getType()),
                                     VecTy->getNumElements());
// Pattern match the predicate, which is an icmp with a constant vector of this
// form:
//
//   icmp ult <4 x i32> %induction, <i32 32002, i32 32002, i32 32002, i32 32002>
//
// and return the constant, i.e. 32002 in this example. This is assumed to be
// the scalar loop iteration count: the number of loop elements by the
// the vector loop. Further checks are performed in function isTailPredicate(),
// to verify 'induction' behaves as an induction variable.
//
static bool ComputeConstElements(TripCountPattern &TCP) {
  if (!dyn_cast<ConstantInt>(TCP.TripCount))
    return false;

  ConstantInt *VF = ConstantInt::get(
      cast<IntegerType>(TCP.TripCount->getType()), TCP.VecTy->getNumElements());
  using namespace PatternMatch;
  CmpInst::Predicate CC;

  if (!match(TCP.Predicate, m_ICmp(CC, m_Instruction(TCP.Induction),
                                   m_AnyIntegralConstant())) ||
      CC != ICmpInst::ICMP_ULT)
    return false;

  LLVM_DEBUG(dbgs() << "ARM TP: icmp with constants: "; TCP.Predicate->dump(););
  Value *ConstVec = TCP.Predicate->getOperand(1);

  auto *CDS = dyn_cast<ConstantDataSequential>(ConstVec);
  if (!CDS || CDS->getNumElements() != VF->getSExtValue())
    return false;

  if ((TCP.NumElements = CDS->getSplatValue())) {
    assert(dyn_cast<ConstantInt>(TCP.NumElements)->getSExtValue() %
                   VF->getSExtValue() !=
               0 &&
           "tail-predication: trip count should not be a multiple of the VF");
    LLVM_DEBUG(dbgs() << "ARM TP: Found const elem count: " << *TCP.NumElements
                      << "\n");
    return true;
  }
  return false;
}

// Pattern match the loop iteration count setup:
//
// %trip.count.minus.1 = add i32 %N, -1
// %broadcast.splatinsert10 = insertelement <4 x i32> undef,
//                                          i32 %trip.count.minus.1, i32 0
// %broadcast.splat11 = shufflevector <4 x i32> %broadcast.splatinsert10,
//                                    <4 x i32> undef,
//                                    <4 x i32> zeroinitializer
// ..
// vector.body:
// ..
//
static bool MatchElemCountLoopSetup(Loop *L, Instruction *Shuffle,
                                    Value *NumElements) {
  using namespace PatternMatch;
  Instruction *Insert = nullptr;

  if (!match(Shuffle,
             m_ShuffleVector(m_Instruction(Insert), m_Undef(), m_Zero())))
    return false;

  // Insert the limit into a vector.
  Instruction *BECount = nullptr;
  if (!match(Insert,
             m_InsertElement(m_Undef(), m_Instruction(BECount), m_Zero())))
    return false;

  // The limit calculation, backedge count.
  Value *TripCount = nullptr;
  if (!match(BECount, m_Add(m_Value(TripCount), m_AllOnes())))
    return false;

  if (TripCount != NumElements || !L->isLoopInvariant(BECount))
    return false;

  return true;
}

bool MVETailPredication::ComputeRuntimeElements(TripCountPattern &TCP) {
  using namespace PatternMatch;
  const SCEV *TripCountSE = SE->getSCEV(TCP.TripCount);
  ConstantInt *VF = ConstantInt::get(
      cast<IntegerType>(TCP.TripCount->getType()), TCP.VecTy->getNumElements());

  if (VF->equalsInt(1))
    return nullptr;
    return false;

  CmpInst::Predicate Pred;
  if (!match(TCP.Predicate, m_ICmp(Pred, m_Instruction(TCP.Induction),
                                   m_Instruction(TCP.Shuffle))) ||
      Pred != ICmpInst::ICMP_ULE)
    return false;

  // TODO: Support constant trip counts.
  LLVM_DEBUG(dbgs() << "Computing number of elements for vector trip count: ";
             TCP.TripCount->dump());

  // Otherwise, continue and try to pattern match the vector iteration
  // count expression
  auto VisitAdd = [&](const SCEVAddExpr *S) -> const SCEVMulExpr * {
    if (auto *Const = dyn_cast<SCEVConstant>(S->getOperand(0))) {
      if (Const->getAPInt() != -VF->getValue())
@@ -426,15 +499,20 @@ Value* MVETailPredication::ComputeElements(Value *TripCount,
              Elems = Res;

  if (!Elems)
    return nullptr;
    return false;

  Instruction *InsertPt = L->getLoopPreheader()->getTerminator();
  if (!isSafeToExpandAt(Elems, InsertPt, *SE))
    return nullptr;
    return false;

  auto DL = L->getHeader()->getModule()->getDataLayout();
  SCEVExpander Expander(*SE, DL, "elements");
  return Expander.expandCodeFor(Elems, Elems->getType(), InsertPt);
  TCP.NumElements = Expander.expandCodeFor(Elems, Elems->getType(), InsertPt);

  if (!MatchElemCountLoopSetup(L, TCP.Shuffle, TCP.NumElements))
    return false;

  return true;
}

// Look through the exit block to see whether there's a duplicate predicate
@@ -499,24 +577,23 @@ static bool Cleanup(DenseMap<Instruction*, Instruction*> &NewPredicates,
  return ClonedVCTPInExitBlock;
}

void MVETailPredication::InsertVCTPIntrinsic(Instruction *Predicate,
    DenseMap<Instruction*, Instruction*> &NewPredicates,
    VectorType *VecTy, Value *NumElements) {
void MVETailPredication::InsertVCTPIntrinsic(TripCountPattern &TCP,
    DenseMap<Instruction*, Instruction*> &NewPredicates) {
  IRBuilder<> Builder(L->getHeader()->getFirstNonPHI());
  Module *M = L->getHeader()->getModule();
  Type *Ty = IntegerType::get(M->getContext(), 32);

  // Insert a phi to count the number of elements processed by the loop.
  PHINode *Processed = Builder.CreatePHI(Ty, 2);
  Processed->addIncoming(NumElements, L->getLoopPreheader());
  Processed->addIncoming(TCP.NumElements, L->getLoopPreheader());

  // Insert the intrinsic to represent the effect of tail predication.
  Builder.SetInsertPoint(cast<Instruction>(Predicate));
  Builder.SetInsertPoint(cast<Instruction>(TCP.Predicate));
  ConstantInt *Factor =
    ConstantInt::get(cast<IntegerType>(Ty), VecTy->getNumElements());
    ConstantInt::get(cast<IntegerType>(Ty), TCP.VecTy->getNumElements());

  Intrinsic::ID VCTPID;
  switch (VecTy->getNumElements()) {
  switch (TCP.VecTy->getNumElements()) {
  default:
    llvm_unreachable("unexpected number of lanes");
  case 4:  VCTPID = Intrinsic::arm_mve_vctp32; break;
@@ -531,8 +608,8 @@ void MVETailPredication::InsertVCTPIntrinsic(Instruction *Predicate,
  }
  Function *VCTP = Intrinsic::getDeclaration(M, VCTPID);
  Value *TailPredicate = Builder.CreateCall(VCTP, Processed);
  Predicate->replaceAllUsesWith(TailPredicate);
  NewPredicates[Predicate] = cast<Instruction>(TailPredicate);
  TCP.Predicate->replaceAllUsesWith(TailPredicate);
  NewPredicates[TCP.Predicate] = cast<Instruction>(TailPredicate);

  // Add the incoming value to the new phi.
  // TODO: This add likely already exists in the loop.
@@ -545,7 +622,7 @@ void MVETailPredication::InsertVCTPIntrinsic(Instruction *Predicate,

bool MVETailPredication::TryConvert(Value *TripCount) {
  if (!IsPredicatedVectorLoop()) {
    LLVM_DEBUG(dbgs() << "ARM TP: no masked instructions in loop");
    LLVM_DEBUG(dbgs() << "ARM TP: no masked instructions in loop.\n");
    return false;
  }

@@ -563,22 +640,24 @@ bool MVETailPredication::TryConvert(Value *TripCount) {
    if (!Predicate || Predicates.count(Predicate))
      continue;

    VectorType *VecTy = getVectorType(I);
    Value *NumElements = ComputeElements(TripCount, VecTy);
    if (!NumElements)
    TripCountPattern TCP(Predicate, TripCount, getVectorType(I));

    if (!(ComputeConstElements(TCP) || ComputeRuntimeElements(TCP)))
      continue;

    if (!isTailPredicate(Predicate, NumElements)) {
    if (!isTailPredicate(TCP)) {
      LLVM_DEBUG(dbgs() << "ARM TP: Not tail predicate: " << *Predicate << "\n");
      continue;
    }

    LLVM_DEBUG(dbgs() << "ARM TP: Found tail predicate: " << *Predicate << "\n");
    Predicates.insert(Predicate);

    InsertVCTPIntrinsic(Predicate, NewPredicates, VecTy, NumElements);
    InsertVCTPIntrinsic(TCP, NewPredicates);
  }

  if (!NewPredicates.size())
    return false;

  // Now clean up.
  ClonedVCTPInExitBlock = Cleanup(NewPredicates, Predicates, L);
  return true;
+329 −0

File added.

Preview size limit exceeded, changes collapsed.