Unverified Commit f4cc934d authored by Sander de Smalen's avatar Sander de Smalen Committed by GitHub
Browse files

[LV] NFCI: Create VPExpressions in transformToPartialReductions. (#182863)

With this change, all logic to generate partial reductions and
recognising them as VPExpressions is contained in
`transformToPartialReductions`, without the need for a second transform
pass.
The PR intends to be a non-functional change.
parent 9435160a
Loading
Loading
Loading
Loading
+68 −52
Original line number Diff line number Diff line
@@ -4458,11 +4458,8 @@ tryToMatchAndCreateExtendedReduction(VPReductionRecipe *Red, VPCostContext &Ctx,
  Type *RedTy = Ctx.Types.inferScalarType(Red);
  VPValue *VecOp = Red->getVecOp();

  // For partial reductions, the decision has already been made at the point of
  // transforming reductions -> partial reductions for a given plan, based on
  // the cost-model.
  if (Red->isPartialReduction())
    return new VPExpressionRecipe(cast<VPWidenCastRecipe>(VecOp), Red);
  assert(!Red->isPartialReduction() &&
         "This path does not support partial reductions");

  // Clamp the range if using extended-reduction is profitable.
  auto IsExtendedRedValidAndClampRange =
@@ -4477,10 +4474,8 @@ tryToMatchAndCreateExtendedReduction(VPReductionRecipe *Red, VPCostContext &Ctx,
              cast<VPWidenCastRecipe>(VecOp)->computeCost(VF, Ctx);
          InstructionCost RedCost = Red->computeCost(VF, Ctx);

          // TTI::getExtendedReductionCost for in-loop reductions
          // only supports integer types.
          if (RedTy->isFloatingPointTy())
            return false;
          assert(!RedTy->isFloatingPointTy() &&
                 "getExtendedReductionCost only supports integer types");
          ExtRedCost = Ctx.TTI.getExtendedReductionCost(
              Opcode, ExtOpc == Instruction::CastOps::ZExt, RedTy, SrcVecTy,
              Red->getFastMathFlags(), CostKind);
@@ -4491,8 +4486,7 @@ tryToMatchAndCreateExtendedReduction(VPReductionRecipe *Red, VPCostContext &Ctx,

  VPValue *A;
  // Match reduce(ext)).
  if (match(VecOp, m_Isa<VPWidenCastRecipe>(m_CombineOr(
                       m_ZExtOrSExt(m_VPValue(A)), m_FPExt(m_VPValue(A))))) &&
  if (match(VecOp, m_Isa<VPWidenCastRecipe>(m_ZExtOrSExt(m_VPValue(A)))) &&
      IsExtendedRedValidAndClampRange(
          RecurrenceDescriptor::getOpcode(Red->getRecurrenceKind()),
          cast<VPWidenCastRecipe>(VecOp)->getOpcode(),
@@ -4519,6 +4513,8 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
      Opcode != Instruction::FAdd)
    return nullptr;

  assert(!Red->isPartialReduction() &&
         "This path does not support partial reductions");
  Type *RedTy = Ctx.Types.inferScalarType(Red);

  // Clamp the range if using multiply-accumulate-reduction is profitable.
@@ -4527,19 +4523,13 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
          VPWidenCastRecipe *OuterExt) -> bool {
    return LoopVectorizationPlanner::getDecisionAndClampRange(
        [&](ElementCount VF) {
          // For partial reductions, the decision has already been made at the
          // point of transforming reductions -> partial reductions for a given
          // plan, based on the cost-model.
          if (Red->isPartialReduction())
            return true;

          TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
          Type *SrcTy =
              Ext0 ? Ctx.Types.inferScalarType(Ext0->getOperand(0)) : RedTy;
          InstructionCost MulAccCost;

          // Only partial reductions support mixed or floating-point extends at
          // the moment.
          // getMulAccReductionCost for in-loop reductions does not support
          // mixed or floating-point extends.
          if (Ext0 && Ext1 &&
              (Ext0->getOpcode() != Ext1->getOpcode() ||
               Ext0->getOpcode() == Instruction::CastOps::FPExt))
@@ -4572,23 +4562,6 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
  VPValue *A, *B;
  VPValue *Tmp = nullptr;

  // Try to match reduce.fadd(fmul(fpext(...), fpext(...))).
  if (match(VecOp, m_FMul(m_FPExt(m_VPValue()), m_FPExt(m_VPValue())))) {
    assert(Opcode == Instruction::FAdd &&
           "MulAccumulateReduction from an FMul must accumulate into an FAdd "
           "instruction");
    auto *FMul = dyn_cast<VPWidenRecipe>(VecOp);
    if (!FMul)
      return nullptr;

    auto *RecipeA = dyn_cast<VPWidenCastRecipe>(FMul->getOperand(0));
    auto *RecipeB = dyn_cast<VPWidenCastRecipe>(FMul->getOperand(1));

    if (RecipeA && RecipeB &&
        IsMulAccValidAndClampRange(FMul, RecipeA, RecipeB, nullptr)) {
      return new VPExpressionRecipe(RecipeA, RecipeB, FMul, Red);
    }
  }
  if (RedTy->isFloatingPointTy())
    return nullptr;

@@ -4603,11 +4576,10 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
  // creates two uniform extends that can more easily be matched by the rest of
  // the bundling code. The ExtB reference, ValB and operand 1 of Mul are all
  // replaced with the new extend of the constant.
  auto ExtendAndReplaceConstantOp = [&Ctx, &Red](VPWidenCastRecipe *ExtA,
  auto ExtendAndReplaceConstantOp = [&Ctx](VPWidenCastRecipe *ExtA,
                                           VPWidenCastRecipe *&ExtB,
                                                 VPValue *&ValB,
                                                 VPWidenRecipe *Mul) {
    if (!ExtA || ExtB || !isa<VPIRValue>(ValB) || Red->isPartialReduction())
                                           VPValue *&ValB, VPWidenRecipe *Mul) {
    if (!ExtA || ExtB || !isa<VPIRValue>(ValB))
      return;
    Type *NarrowTy = Ctx.Types.inferScalarType(ExtA->getOperand(0));
    Instruction::CastOps ExtOpc = ExtA->getOpcode();
@@ -4657,8 +4629,7 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
    return nullptr;

  // Match reduce.add(ext(mul(A, B))).
  if (!Red->isPartialReduction() &&
      match(VecOp, m_ZExtOrSExt(m_Mul(m_VPValue(A), m_VPValue(B))))) {
  if (match(VecOp, m_ZExtOrSExt(m_Mul(m_VPValue(A), m_VPValue(B))))) {
    auto *Ext = cast<VPWidenCastRecipe>(VecOp);
    auto *Mul = cast<VPWidenRecipe>(Ext->getOperand(0));
    auto *Ext0 = dyn_cast<VPWidenCastRecipe>(A);
@@ -4704,6 +4675,11 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
static void tryToCreateAbstractReductionRecipe(VPReductionRecipe *Red,
                                               VPCostContext &Ctx,
                                               VFRange &Range) {
  // Creation of VPExpressions for partial reductions is entirely handled in
  // transformToPartialReduction.
  assert(!Red->isPartialReduction() &&
         "This path does not support partial reductions");

  VPExpressionRecipe *AbstractR = nullptr;
  auto IP = std::next(Red->getIterator());
  auto *VPBB = Red->getParent();
@@ -6136,6 +6112,40 @@ optimizeExtendsForPartialReduction(VPSingleDefRecipe *Op,
  return Op;
}

static VPExpressionRecipe *
createPartialReductionExpression(VPReductionRecipe *Red) {
  VPValue *VecOp = Red->getVecOp();

  // reduce.[f]add(ext(op))
  //  -> VPExpressionRecipe(op, red)
  if (match(VecOp, m_WidenAnyExtend(m_VPValue())))
    return new VPExpressionRecipe(cast<VPWidenCastRecipe>(VecOp), Red);

  // reduce.[f]add([f]mul(ext(a), ext(b)))
  //  -> VPExpressionRecipe(a, b, mul, red)
  if (match(VecOp, m_FMul(m_FPExt(m_VPValue()), m_FPExt(m_VPValue()))) ||
      match(VecOp,
            m_Mul(m_ZExtOrSExt(m_VPValue()), m_ZExtOrSExt(m_VPValue())))) {
    auto *Mul = cast<VPWidenRecipe>(VecOp);
    auto *ExtA = cast<VPWidenCastRecipe>(Mul->getOperand(0));
    auto *ExtB = cast<VPWidenCastRecipe>(Mul->getOperand(1));
    return new VPExpressionRecipe(ExtA, ExtB, Mul, Red);
  }

  // reduce.add(neg(mul(ext(a), ext(b))))
  //  -> VPExpressionRecipe(a, b, mul, sub, red)
  if (match(VecOp, m_Sub(m_ZeroInt(), m_Mul(m_ZExtOrSExt(m_VPValue()),
                                            m_ZExtOrSExt(m_VPValue()))))) {
    auto *Sub = cast<VPWidenRecipe>(VecOp);
    auto *Mul = cast<VPWidenRecipe>(Sub->getOperand(1));
    auto *ExtA = cast<VPWidenCastRecipe>(Mul->getOperand(0));
    auto *ExtB = cast<VPWidenCastRecipe>(Mul->getOperand(1));
    return new VPExpressionRecipe(ExtA, ExtB, Mul, Sub, Red);
  }

  llvm_unreachable("Unsupported expression");
}

// Helper to transform a partial reduction chain into a partial reduction
// recipe. Assumes profitability has been checked.
static void transformToPartialReduction(const VPPartialReductionChain &Chain,
@@ -6203,6 +6213,11 @@ static void transformToPartialReduction(const VPPartialReductionChain &Chain,
    ExitValue->replaceAllUsesWith(PartialRed);
  WidenRecipe->replaceAllUsesWith(PartialRed);

  // For cost-model purposes, fold this into a VPExpression.
  VPExpressionRecipe *E = createPartialReductionExpression(PartialRed);
  E->insertBefore(WidenRecipe);
  PartialRed->replaceAllUsesWith(E);

  // We only need to update the PHI node once, which is when we find the
  // last reduction in the chain.
  if (!IsLastInChain)
@@ -6275,10 +6290,10 @@ static ExtendKind getPartialReductionExtendKind(VPWidenCastRecipe *Cast) {
///
/// Possible forms matched by this function:
///  - UpdateR(PrevValue, ext(...))
///  - UpdateR(PrevValue, BinOp(ext(...), ext(...)))
///  - UpdateR(PrevValue, BinOp(ext(...), Constant))
///  - UpdateR(PrevValue, neg(BinOp(ext(...), ext(...))))
///  - UpdateR(PrevValue, neg(BinOp(ext(...), Constant)))
///  - UpdateR(PrevValue, mul(ext(...), ext(...)))
///  - UpdateR(PrevValue, mul(ext(...), Constant))
///  - UpdateR(PrevValue, neg(mul(ext(...), ext(...))))
///  - UpdateR(PrevValue, neg(mul(ext(...), Constant)))
///  - UpdateR(PrevValue, ext(mul(ext(...), ext(...))))
///  - UpdateR(PrevValue, ext(mul(ext(...), Constant)))
///  - UpdateR(PrevValue, abs(sub(ext(...), ext(...)))
@@ -6345,15 +6360,16 @@ matchExtendedReductionOperand(VPWidenRecipe *UpdateR, VPValue *Op,
  if (!Op->hasOneUse())
    return std::nullopt;

  VPWidenRecipe *BinOp = dyn_cast<VPWidenRecipe>(Op);
  if (!BinOp || !Instruction::isBinaryOp(BinOp->getOpcode()))
  VPWidenRecipe *MulOp = dyn_cast<VPWidenRecipe>(Op);
  if (!MulOp ||
      !is_contained({Instruction::Mul, Instruction::FMul}, MulOp->getOpcode()))
    return std::nullopt;

  // The rest of the matching assumes `Op` is a (possibly extended/negated)
  // binary operation.

  VPValue *LHS = BinOp->getOperand(0);
  VPValue *RHS = BinOp->getOperand(1);
  VPValue *LHS = MulOp->getOperand(0);
  VPValue *RHS = MulOp->getOperand(1);

  // The LHS of the operation must always be an extend.
  if (!match(LHS, m_WidenAnyExtend(m_VPValue())))
@@ -6386,7 +6402,7 @@ matchExtendedReductionOperand(VPWidenRecipe *UpdateR, VPValue *Op,
  }

  return ExtendedReductionOperand{
      BinOp, {LHSInputType, LHSExtendKind}, {RHSInputType, RHSExtendKind}};
      MulOp, {LHSInputType, LHSExtendKind}, {RHSInputType, RHSExtendKind}};
}

/// Examines each operation in the reduction chain corresponding to \p RedPhiR,