Commit 4e3c0055 authored by Simon Pilgrim's avatar Simon Pilgrim
Browse files

[TTI] getScalarizationOverhead - use explicit VectorType operand

getScalarizationOverhead is only ever called with vectors (and we already had a load of cast<VectorType> calls immediately inside the functions).

Followup to D78357

Reviewed By: @samparker

Differential Revision: https://reviews.llvm.org/D79341
parent e78ef938
......@@ -620,7 +620,7 @@ public:
/// Estimate the overhead of scalarizing an instruction. Insert and Extract
/// are set if the demanded result elements need to be inserted and/or
/// extracted from vectors.
unsigned getScalarizationOverhead(Type *Ty, const APInt &DemandedElts,
unsigned getScalarizationOverhead(VectorType *Ty, const APInt &DemandedElts,
bool Insert, bool Extract) const;
/// Estimate the overhead of scalarizing an instructions unique
......@@ -1261,7 +1261,8 @@ public:
virtual bool shouldBuildLookupTables() = 0;
virtual bool shouldBuildLookupTablesForConstant(Constant *C) = 0;
virtual bool useColdCCForColdCall(Function &F) = 0;
virtual unsigned getScalarizationOverhead(Type *Ty, const APInt &DemandedElts,
virtual unsigned getScalarizationOverhead(VectorType *Ty,
const APInt &DemandedElts,
bool Insert, bool Extract) = 0;
virtual unsigned
getOperandsScalarizationOverhead(ArrayRef<const Value *> Args,
......@@ -1609,7 +1610,7 @@ public:
return Impl.useColdCCForColdCall(F);
}
unsigned getScalarizationOverhead(Type *Ty, const APInt &DemandedElts,
unsigned getScalarizationOverhead(VectorType *Ty, const APInt &DemandedElts,
bool Insert, bool Extract) override {
return Impl.getScalarizationOverhead(Ty, DemandedElts, Insert, Extract);
}
......
......@@ -240,7 +240,7 @@ public:
bool useColdCCForColdCall(Function &F) { return false; }
unsigned getScalarizationOverhead(Type *Ty, const APInt &DemandedElts,
unsigned getScalarizationOverhead(VectorType *Ty, const APInt &DemandedElts,
bool Insert, bool Extract) {
return 0;
}
......
......@@ -552,32 +552,30 @@ public:
/// Estimate the overhead of scalarizing an instruction. Insert and Extract
/// are set if the demanded result elements need to be inserted and/or
/// extracted from vectors.
unsigned getScalarizationOverhead(Type *Ty, const APInt &DemandedElts,
unsigned getScalarizationOverhead(VectorType *Ty, const APInt &DemandedElts,
bool Insert, bool Extract) {
auto *VTy = cast<VectorType>(Ty);
assert(DemandedElts.getBitWidth() == VTy->getNumElements() &&
assert(DemandedElts.getBitWidth() == Ty->getNumElements() &&
"Vector size mismatch");
unsigned Cost = 0;
for (int i = 0, e = VTy->getNumElements(); i < e; ++i) {
for (int i = 0, e = Ty->getNumElements(); i < e; ++i) {
if (!DemandedElts[i])
continue;
if (Insert)
Cost += static_cast<T *>(this)->getVectorInstrCost(
Instruction::InsertElement, VTy, i);
Instruction::InsertElement, Ty, i);
if (Extract)
Cost += static_cast<T *>(this)->getVectorInstrCost(
Instruction::ExtractElement, VTy, i);
Instruction::ExtractElement, Ty, i);
}
return Cost;
}
/// Helper wrapper for the DemandedElts variant of getScalarizationOverhead.
unsigned getScalarizationOverhead(Type *Ty, bool Insert, bool Extract) {
auto *VTy = cast<VectorType>(Ty);
APInt DemandedElts = APInt::getAllOnesValue(VTy->getNumElements());
unsigned getScalarizationOverhead(VectorType *Ty, bool Insert, bool Extract) {
APInt DemandedElts = APInt::getAllOnesValue(Ty->getNumElements());
return static_cast<T *>(this)->getScalarizationOverhead(Ty, DemandedElts,
Insert, Extract);
}
......@@ -591,11 +589,10 @@ public:
SmallPtrSet<const Value*, 4> UniqueOperands;
for (const Value *A : Args) {
if (!isa<Constant>(A) && UniqueOperands.insert(A).second) {
Type *VecTy = nullptr;
if (A->getType()->isVectorTy()) {
VecTy = A->getType();
auto *VecTy = dyn_cast<VectorType>(A->getType());
if (VecTy) {
// If A is a vector operand, VF should be 1 or correspond to A.
assert((VF == 1 || VF == cast<VectorType>(VecTy)->getNumElements()) &&
assert((VF == 1 || VF == VecTy->getNumElements()) &&
"Vector argument does not match VF");
}
else
......@@ -608,17 +605,16 @@ public:
return Cost;
}
unsigned getScalarizationOverhead(Type *VecTy, ArrayRef<const Value *> Args) {
unsigned getScalarizationOverhead(VectorType *Ty, ArrayRef<const Value *> Args) {
unsigned Cost = 0;
auto *VecVTy = cast<VectorType>(VecTy);
Cost += getScalarizationOverhead(VecVTy, true, false);
Cost += getScalarizationOverhead(Ty, true, false);
if (!Args.empty())
Cost += getOperandsScalarizationOverhead(Args, VecVTy->getNumElements());
Cost += getOperandsScalarizationOverhead(Args, Ty->getNumElements());
else
// When no information on arguments is provided, we add the cost
// associated with one argument as a heuristic.
Cost += getScalarizationOverhead(VecVTy, false, true);
Cost += getScalarizationOverhead(Ty, false, true);
return Cost;
}
......@@ -742,13 +738,16 @@ public:
break;
}
auto *SrcVTy = dyn_cast<VectorType>(Src);
auto *DstVTy = dyn_cast<VectorType>(Dst);
// If the cast is marked as legal (or promote) then assume low cost.
if (SrcLT.first == DstLT.first &&
TLI->isOperationLegalOrPromote(ISD, DstLT.second))
return SrcLT.first;
// Handle scalar conversions.
if (!Src->isVectorTy() && !Dst->isVectorTy()) {
if (!SrcVTy && !DstVTy) {
// Scalar bitcasts are usually free.
if (Opcode == Instruction::BitCast)
return 0;
......@@ -763,9 +762,7 @@ public:
}
// Check vector-to-vector casts.
if (Dst->isVectorTy() && Src->isVectorTy()) {
auto *SrcVTy = cast<VectorType>(Src);
auto *DstVTy = cast<VectorType>(Dst);
if (DstVTy && SrcVTy) {
// If the cast is between same-sized registers, then the check is simple.
if (SrcLT.first == DstLT.first &&
SrcLT.second.getSizeInBits() == DstLT.second.getSizeInBits()) {
......@@ -819,19 +816,18 @@ public:
// Return the cost of multiple scalar invocation plus the cost of
// inserting and extracting the values.
return getScalarizationOverhead(Dst, true, true) + Num * Cost;
return getScalarizationOverhead(DstVTy, true, true) + Num * Cost;
}
// We already handled vector-to-vector and scalar-to-scalar conversions.
// This
// is where we handle bitcast between vectors and scalars. We need to assume
// that the conversion is scalarized in one way or another.
if (Opcode == Instruction::BitCast)
if (Opcode == Instruction::BitCast) {
// Illegal bitcasts are done by storing and loading from a stack slot.
return (Src->isVectorTy() ? getScalarizationOverhead(Src, false, true)
: 0) +
(Dst->isVectorTy() ? getScalarizationOverhead(Dst, true, false)
: 0);
return (SrcVTy ? getScalarizationOverhead(SrcVTy, false, true) : 0) +
(DstVTy ? getScalarizationOverhead(DstVTy, true, false) : 0);
}
llvm_unreachable("Unhandled cast");
}
......@@ -923,7 +919,8 @@ public:
if (LA != TargetLowering::Legal && LA != TargetLowering::Custom) {
// This is a vector load/store for some illegal type that is scalarized.
// We must account for the cost of building or decomposing the vector.
Cost += getScalarizationOverhead(Src, Opcode != Instruction::Store,
Cost += getScalarizationOverhead(cast<VectorType>(Src),
Opcode != Instruction::Store,
Opcode == Instruction::Store);
}
}
......@@ -1118,7 +1115,8 @@ public:
if (RetVF > 1 || VF > 1) {
ScalarizationCost = 0;
if (!RetTy->isVoidTy())
ScalarizationCost += getScalarizationOverhead(RetTy, true, false);
ScalarizationCost +=
getScalarizationOverhead(cast<VectorType>(RetTy), true, false);
ScalarizationCost += getOperandsScalarizationOverhead(Args, VF);
}
......@@ -1224,21 +1222,19 @@ public:
unsigned ScalarizationCost = ScalarizationCostPassed;
unsigned ScalarCalls = 1;
Type *ScalarRetTy = RetTy;
if (RetTy->isVectorTy()) {
if (auto *RetVTy = dyn_cast<VectorType>(RetTy)) {
if (ScalarizationCostPassed == std::numeric_limits<unsigned>::max())
ScalarizationCost = getScalarizationOverhead(RetTy, true, false);
ScalarCalls =
std::max(ScalarCalls, cast<VectorType>(RetTy)->getNumElements());
ScalarizationCost = getScalarizationOverhead(RetVTy, true, false);
ScalarCalls = std::max(ScalarCalls, RetVTy->getNumElements());
ScalarRetTy = RetTy->getScalarType();
}
SmallVector<Type *, 4> ScalarTys;
for (unsigned i = 0, ie = Tys.size(); i != ie; ++i) {
Type *Ty = Tys[i];
if (Ty->isVectorTy()) {
if (auto *VTy = dyn_cast<VectorType>(Ty)) {
if (ScalarizationCostPassed == std::numeric_limits<unsigned>::max())
ScalarizationCost += getScalarizationOverhead(Ty, false, true);
ScalarCalls =
std::max(ScalarCalls, cast<VectorType>(Ty)->getNumElements());
ScalarizationCost += getScalarizationOverhead(VTy, false, true);
ScalarCalls = std::max(ScalarCalls, VTy->getNumElements());
Ty = Ty->getScalarType();
}
ScalarTys.push_back(Ty);
......@@ -1588,12 +1584,12 @@ public:
// Else, assume that we need to scalarize this intrinsic. For math builtins
// this will emit a costly libcall, adding call overhead and spills. Make it
// very expensive.
if (RetTy->isVectorTy()) {
if (auto *RetVTy = dyn_cast<VectorType>(RetTy)) {
unsigned ScalarizationCost =
((ScalarizationCostPassed != std::numeric_limits<unsigned>::max())
? ScalarizationCostPassed
: getScalarizationOverhead(RetTy, true, false));
unsigned ScalarCalls = cast<VectorType>(RetTy)->getNumElements();
: getScalarizationOverhead(RetVTy, true, false));
unsigned ScalarCalls = RetVTy->getNumElements();
SmallVector<Type *, 4> ScalarTys;
for (unsigned i = 0, ie = Tys.size(); i != ie; ++i) {
Type *Ty = Tys[i];
......@@ -1604,14 +1600,12 @@ public:
unsigned ScalarCost = ConcreteTTI->getIntrinsicInstrCost(
IID, RetTy->getScalarType(), ScalarTys, FMF, CostKind);
for (unsigned i = 0, ie = Tys.size(); i != ie; ++i) {
if (Tys[i]->isVectorTy()) {
if (auto *VTy = dyn_cast<VectorType>(Tys[i])) {
if (ScalarizationCostPassed == std::numeric_limits<unsigned>::max())
ScalarizationCost += getScalarizationOverhead(Tys[i], false, true);
ScalarCalls =
std::max(ScalarCalls, cast<VectorType>(Tys[i])->getNumElements());
ScalarizationCost += getScalarizationOverhead(VTy, false, true);
ScalarCalls = std::max(ScalarCalls, VTy->getNumElements());
}
}
return ScalarCalls * ScalarCost + ScalarizationCost;
}
......
......@@ -370,8 +370,10 @@ bool TargetTransformInfo::useColdCCForColdCall(Function &F) const {
return TTIImpl->useColdCCForColdCall(F);
}
unsigned TargetTransformInfo::getScalarizationOverhead(
Type *Ty, const APInt &DemandedElts, bool Insert, bool Extract) const {
unsigned
TargetTransformInfo::getScalarizationOverhead(VectorType *Ty,
const APInt &DemandedElts,
bool Insert, bool Extract) const {
return TTIImpl->getScalarizationOverhead(Ty, DemandedElts, Insert, Extract);
}
......
......@@ -807,7 +807,7 @@ int ARMTTIImpl::getArithmeticInstrCost(unsigned Opcode, Type *Ty,
CostKind);
// Return the cost of multiple scalar invocation plus the cost of
// inserting and extracting the values.
return BaseT::getScalarizationOverhead(Ty, Args) + Num * Cost;
return BaseT::getScalarizationOverhead(VTy, Args) + Num * Cost;
}
return BaseCost;
......@@ -899,7 +899,7 @@ unsigned ARMTTIImpl::getGatherScatterOpCost(unsigned Opcode, Type *DataTy,
// The scalarization cost should be a lot higher. We use the number of vector
// elements plus the scalarization overhead.
unsigned ScalarCost =
NumElems * LT.first + BaseT::getScalarizationOverhead(DataTy, {});
NumElems * LT.first + BaseT::getScalarizationOverhead(VTy, {});
if (Alignment < EltSize / 8)
return ScalarCost;
......
......@@ -115,7 +115,7 @@ unsigned HexagonTTIImpl::getMinimumVF(unsigned ElemWidth) const {
return (8 * ST.getVectorLength()) / ElemWidth;
}
unsigned HexagonTTIImpl::getScalarizationOverhead(Type *Ty,
unsigned HexagonTTIImpl::getScalarizationOverhead(VectorType *Ty,
const APInt &DemandedElts,
bool Insert, bool Extract) {
return BaseT::getScalarizationOverhead(Ty, DemandedElts, Insert, Extract);
......
......@@ -101,7 +101,7 @@ public:
return true;
}
unsigned getScalarizationOverhead(Type *Ty, const APInt &DemandedElts,
unsigned getScalarizationOverhead(VectorType *Ty, const APInt &DemandedElts,
bool Insert, bool Extract);
unsigned getOperandsScalarizationOverhead(ArrayRef<const Value *> Args,
unsigned VF);
......
......@@ -464,7 +464,8 @@ int SystemZTTIImpl::getArithmeticInstrCost(
return DivInstrCost;
}
else if (ST->hasVector()) {
unsigned VF = cast<VectorType>(Ty)->getNumElements();
auto *VTy = cast<VectorType>(Ty);
unsigned VF = VTy->getNumElements();
unsigned NumVectors = getNumVectorRegs(Ty);
// These vector operations are custom handled, but are still supported
......@@ -477,7 +478,7 @@ int SystemZTTIImpl::getArithmeticInstrCost(
if (DivRemConstPow2)
return (NumVectors * (SignedDivRem ? SDivPow2Cost : 1));
if (DivRemConst)
return VF * DivMulSeqCost + getScalarizationOverhead(Ty, Args);
return VF * DivMulSeqCost + getScalarizationOverhead(VTy, Args);
if ((SignedDivRem || UnsignedDivRem) && VF > 4)
// Temporary hack: disable high vectorization factors with integer
// division/remainder, which will get scalarized and handled with
......@@ -500,7 +501,7 @@ int SystemZTTIImpl::getArithmeticInstrCost(
// inserting and extracting the values.
unsigned ScalarCost =
getArithmeticInstrCost(Opcode, Ty->getScalarType(), CostKind);
unsigned Cost = (VF * ScalarCost) + getScalarizationOverhead(Ty, Args);
unsigned Cost = (VF * ScalarCost) + getScalarizationOverhead(VTy, Args);
// FIXME: VF 2 for these FP operations are currently just as
// expensive as for VF 4.
if (VF == 2)
......@@ -517,7 +518,7 @@ int SystemZTTIImpl::getArithmeticInstrCost(
// There is no native support for FRem.
if (Opcode == Instruction::FRem) {
unsigned Cost = (VF * LIBCALL_COST) + getScalarizationOverhead(Ty, Args);
unsigned Cost = (VF * LIBCALL_COST) + getScalarizationOverhead(VTy, Args);
// FIXME: VF 2 for float is currently just as expensive as for VF 4.
if (VF == 2 && ScalarBits == 32)
Cost *= 2;
......@@ -724,8 +725,9 @@ int SystemZTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
}
}
else if (ST->hasVector()) {
assert (Dst->isVectorTy());
unsigned VF = cast<VectorType>(Src)->getNumElements();
auto *SrcVecTy = cast<VectorType>(Src);
auto *DstVecTy = cast<VectorType>(Dst);
unsigned VF = SrcVecTy->getNumElements();
unsigned NumDstVectors = getNumVectorRegs(Dst);
unsigned NumSrcVectors = getNumVectorRegs(Src);
......@@ -781,8 +783,8 @@ int SystemZTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
(Opcode == Instruction::FPToSI || Opcode == Instruction::FPToUI))
NeedsExtracts = false;
TotCost += getScalarizationOverhead(Src, false, NeedsExtracts);
TotCost += getScalarizationOverhead(Dst, NeedsInserts, false);
TotCost += getScalarizationOverhead(SrcVecTy, false, NeedsExtracts);
TotCost += getScalarizationOverhead(DstVecTy, NeedsInserts, false);
// FIXME: VF 2 for float<->i32 is currently just as expensive as for VF 4.
if (VF == 2 && SrcScalarBits == 32 && DstScalarBits == 32)
......@@ -793,7 +795,8 @@ int SystemZTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
if (Opcode == Instruction::FPTrunc) {
if (SrcScalarBits == 128) // fp128 -> double/float + inserts of elements.
return VF /*ldxbr/lexbr*/ + getScalarizationOverhead(Dst, true, false);
return VF /*ldxbr/lexbr*/ +
getScalarizationOverhead(DstVecTy, true, false);
else // double -> float
return VF / 2 /*vledb*/ + std::max(1U, VF / 4 /*vperm*/);
}
......@@ -806,7 +809,7 @@ int SystemZTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
return VF * 2;
}
// -> fp128. VF * lxdb/lxeb + extraction of elements.
return VF + getScalarizationOverhead(Src, false, true);
return VF + getScalarizationOverhead(SrcVecTy, false, true);
}
}
......
......@@ -2888,10 +2888,9 @@ int X86TTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val, unsigned Index) {
return BaseT::getVectorInstrCost(Opcode, Val, Index) + RegisterFileMoveCost;
}
unsigned X86TTIImpl::getScalarizationOverhead(Type *Ty,
unsigned X86TTIImpl::getScalarizationOverhead(VectorType *Ty,
const APInt &DemandedElts,
bool Insert, bool Extract) {
auto* VecTy = cast<VectorType>(Ty);
unsigned Cost = 0;
// For insertions, a ISD::BUILD_VECTOR style vector initialization can be much
......@@ -2917,7 +2916,7 @@ unsigned X86TTIImpl::getScalarizationOverhead(Type *Ty,
// 128-bit vector is free.
// NOTE: This assumes legalization widens vXf32 vectors.
if (MScalarTy == MVT::f32)
for (unsigned i = 0, e = VecTy->getNumElements(); i < e; i += 4)
for (unsigned i = 0, e = Ty->getNumElements(); i < e; i += 4)
if (DemandedElts[i])
Cost--;
}
......@@ -2933,7 +2932,7 @@ unsigned X86TTIImpl::getScalarizationOverhead(Type *Ty,
// vector elements, which represents the number of unpacks we'll end up
// performing.
unsigned NumElts = LT.second.getVectorNumElements();
unsigned Pow2Elts = PowerOf2Ceil(VecTy->getNumElements());
unsigned Pow2Elts = PowerOf2Ceil(Ty->getNumElements());
Cost += (std::min<unsigned>(NumElts, Pow2Elts) - 1) * LT.first;
}
}
......@@ -2970,7 +2969,7 @@ int X86TTIImpl::getMemoryOpCost(unsigned Opcode, Type *Src,
APInt DemandedElts = APInt::getAllOnesValue(NumElem);
int Cost = BaseT::getMemoryOpCost(Opcode, VTy->getScalarType(), Alignment,
AddressSpace, CostKind);
int SplitCost = getScalarizationOverhead(Src, DemandedElts,
int SplitCost = getScalarizationOverhead(VTy, DemandedElts,
Opcode == Instruction::Load,
Opcode == Instruction::Store);
return NumElem * Cost + SplitCost;
......
......@@ -135,7 +135,7 @@ public:
TTI::TargetCostKind CostKind,
const Instruction *I = nullptr);
int getVectorInstrCost(unsigned Opcode, Type *Val, unsigned Index);
unsigned getScalarizationOverhead(Type *Ty, const APInt &DemandedElts,
unsigned getScalarizationOverhead(VectorType *Ty, const APInt &DemandedElts,
bool Insert, bool Extract);
int getMemoryOpCost(unsigned Opcode, Type *Src, MaybeAlign Alignment,
unsigned AddressSpace,
......
......@@ -5702,9 +5702,9 @@ int LoopVectorizationCostModel::computePredInstDiscount(
// Compute the scalarization overhead of needed insertelement instructions
// and phi nodes.
if (isScalarWithPredication(I) && !I->getType()->isVoidTy()) {
ScalarCost +=
TTI.getScalarizationOverhead(ToVectorTy(I->getType(), VF),
APInt::getAllOnesValue(VF), true, false);
ScalarCost += TTI.getScalarizationOverhead(
cast<VectorType>(ToVectorTy(I->getType(), VF)),
APInt::getAllOnesValue(VF), true, false);
ScalarCost += VF * TTI.getCFInstrCost(Instruction::PHI);
}
......@@ -5720,8 +5720,8 @@ int LoopVectorizationCostModel::computePredInstDiscount(
Worklist.push_back(J);
else if (needsExtract(J, VF))
ScalarCost += TTI.getScalarizationOverhead(
ToVectorTy(J->getType(), VF), APInt::getAllOnesValue(VF), false,
true);
cast<VectorType>(ToVectorTy(J->getType(), VF)),
APInt::getAllOnesValue(VF), false, true);
}
// Scale the total scalar cost by block probability.
......@@ -6016,8 +6016,8 @@ unsigned LoopVectorizationCostModel::getScalarizationOverhead(Instruction *I,
Type *RetTy = ToVectorTy(I->getType(), VF);
if (!RetTy->isVoidTy() &&
(!isa<LoadInst>(I) || !TTI.supportsEfficientVectorElementLoadStore()))
Cost += TTI.getScalarizationOverhead(RetTy, APInt::getAllOnesValue(VF),
true, false);
Cost += TTI.getScalarizationOverhead(
cast<VectorType>(RetTy), APInt::getAllOnesValue(VF), true, false);
// Some targets keep addresses scalar.
if (isa<LoadInst>(I) && !TTI.prefersVectorizedAddressing())
......@@ -6222,7 +6222,7 @@ unsigned LoopVectorizationCostModel::getInstructionCost(Instruction *I,
if (ScalarPredicatedBB) {
// Return cost for branches around scalarized and predicated blocks.
Type *Vec_i1Ty =
VectorType *Vec_i1Ty =
VectorType::get(IntegerType::getInt1Ty(RetTy->getContext()), VF);
return (TTI.getScalarizationOverhead(Vec_i1Ty, APInt::getAllOnesValue(VF),
false, true) +
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment