Commit 889f6606 authored by Christopher Tetreault's avatar Christopher Tetreault
Browse files

Clean up usages of asserting vector getters in Type

Summary:
Remove usages of asserting vector getters in Type in preparation for the
VectorType refactor. The existence of these functions complicates the
refactor while adding little value.

Reviewers: stoklund, sdesmalen, efriedma

Reviewed By: sdesmalen

Subscribers: hiraditya, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D77272
parent 33f76e23
......@@ -82,16 +82,16 @@ private:
/// Estimate a cost of Broadcast as an extract and sequence of insert
/// operations.
unsigned getBroadcastShuffleOverhead(Type *Ty) {
assert(Ty->isVectorTy() && "Can only shuffle vectors");
auto *VTy = cast<VectorType>(Ty);
unsigned Cost = 0;
// Broadcast cost is equal to the cost of extracting the zero'th element
// plus the cost of inserting it into every element of the result vector.
Cost += static_cast<T *>(this)->getVectorInstrCost(
Instruction::ExtractElement, Ty, 0);
Instruction::ExtractElement, VTy, 0);
for (int i = 0, e = Ty->getVectorNumElements(); i < e; ++i) {
for (int i = 0, e = VTy->getNumElements(); i < e; ++i) {
Cost += static_cast<T *>(this)->getVectorInstrCost(
Instruction::InsertElement, Ty, i);
Instruction::InsertElement, VTy, i);
}
return Cost;
}
......@@ -99,7 +99,7 @@ private:
/// Estimate a cost of shuffle as a sequence of extract and insert
/// operations.
unsigned getPermuteShuffleOverhead(Type *Ty) {
assert(Ty->isVectorTy() && "Can only shuffle vectors");
auto *VTy = cast<VectorType>(Ty);
unsigned Cost = 0;
// Shuffle cost is equal to the cost of extracting element from its argument
// plus the cost of inserting them onto the result vector.
......@@ -108,11 +108,11 @@ private:
// index 0 of first vector, index 1 of second vector,index 2 of first
// vector and finally index 3 of second vector and insert them at index
// <0,1,2,3> of result vector.
for (int i = 0, e = Ty->getVectorNumElements(); i < e; ++i) {
Cost += static_cast<T *>(this)
->getVectorInstrCost(Instruction::InsertElement, Ty, i);
Cost += static_cast<T *>(this)
->getVectorInstrCost(Instruction::ExtractElement, Ty, i);
for (int i = 0, e = VTy->getNumElements(); i < e; ++i) {
Cost += static_cast<T *>(this)->getVectorInstrCost(
Instruction::InsertElement, VTy, i);
Cost += static_cast<T *>(this)->getVectorInstrCost(
Instruction::ExtractElement, VTy, i);
}
return Cost;
}
......@@ -122,8 +122,10 @@ private:
unsigned getExtractSubvectorOverhead(Type *Ty, int Index, Type *SubTy) {
assert(Ty && Ty->isVectorTy() && SubTy && SubTy->isVectorTy() &&
"Can only extract subvectors from vectors");
int NumSubElts = SubTy->getVectorNumElements();
assert((Index + NumSubElts) <= (int)Ty->getVectorNumElements() &&
auto *VTy = cast<VectorType>(Ty);
auto *SubVTy = cast<VectorType>(SubTy);
int NumSubElts = SubVTy->getNumElements();
assert((Index + NumSubElts) <= (int)VTy->getNumElements() &&
"SK_ExtractSubvector index out of range");
unsigned Cost = 0;
......@@ -132,9 +134,9 @@ private:
// type.
for (int i = 0; i != NumSubElts; ++i) {
Cost += static_cast<T *>(this)->getVectorInstrCost(
Instruction::ExtractElement, Ty, i + Index);
Instruction::ExtractElement, VTy, i + Index);
Cost += static_cast<T *>(this)->getVectorInstrCost(
Instruction::InsertElement, SubTy, i);
Instruction::InsertElement, SubVTy, i);
}
return Cost;
}
......@@ -144,8 +146,10 @@ private:
unsigned getInsertSubvectorOverhead(Type *Ty, int Index, Type *SubTy) {
assert(Ty && Ty->isVectorTy() && SubTy && SubTy->isVectorTy() &&
"Can only insert subvectors into vectors");
int NumSubElts = SubTy->getVectorNumElements();
assert((Index + NumSubElts) <= (int)Ty->getVectorNumElements() &&
auto *VTy = cast<VectorType>(Ty);
auto *SubVTy = cast<VectorType>(SubTy);
int NumSubElts = SubVTy->getNumElements();
assert((Index + NumSubElts) <= (int)VTy->getNumElements() &&
"SK_InsertSubvector index out of range");
unsigned Cost = 0;
......@@ -154,9 +158,9 @@ private:
// type.
for (int i = 0; i != NumSubElts; ++i) {
Cost += static_cast<T *>(this)->getVectorInstrCost(
Instruction::ExtractElement, SubTy, i);
Instruction::ExtractElement, SubVTy, i);
Cost += static_cast<T *>(this)->getVectorInstrCost(
Instruction::InsertElement, Ty, i + Index);
Instruction::InsertElement, VTy, i + Index);
}
return Cost;
}
......@@ -577,16 +581,16 @@ public:
/// Estimate the overhead of scalarizing an instruction. Insert and Extract
/// are set if the result needs to be inserted and/or extracted from vectors.
unsigned getScalarizationOverhead(Type *Ty, bool Insert, bool Extract) {
assert(Ty->isVectorTy() && "Can only scalarize vectors");
auto *VTy = cast<VectorType>(Ty);
unsigned Cost = 0;
for (int i = 0, e = Ty->getVectorNumElements(); i < e; ++i) {
for (int i = 0, e = VTy->getNumElements(); i < e; ++i) {
if (Insert)
Cost += static_cast<T *>(this)
->getVectorInstrCost(Instruction::InsertElement, Ty, i);
Cost += static_cast<T *>(this)->getVectorInstrCost(
Instruction::InsertElement, VTy, i);
if (Extract)
Cost += static_cast<T *>(this)
->getVectorInstrCost(Instruction::ExtractElement, Ty, i);
Cost += static_cast<T *>(this)->getVectorInstrCost(
Instruction::ExtractElement, VTy, i);
}
return Cost;
......@@ -605,7 +609,7 @@ public:
if (A->getType()->isVectorTy()) {
VecTy = A->getType();
// If A is a vector operand, VF should be 1 or correspond to A.
assert((VF == 1 || VF == VecTy->getVectorNumElements()) &&
assert((VF == 1 || VF == cast<VectorType>(VecTy)->getNumElements()) &&
"Vector argument does not match VF");
}
else
......@@ -619,18 +623,16 @@ public:
}
unsigned getScalarizationOverhead(Type *VecTy, ArrayRef<const Value *> Args) {
assert(VecTy->isVectorTy());
unsigned Cost = 0;
auto *VecVTy = cast<VectorType>(VecTy);
Cost += getScalarizationOverhead(VecTy, true, false);
Cost += getScalarizationOverhead(VecVTy, true, false);
if (!Args.empty())
Cost += getOperandsScalarizationOverhead(Args,
VecTy->getVectorNumElements());
Cost += getOperandsScalarizationOverhead(Args, VecVTy->getNumElements());
else
// When no information on arguments is provided, we add the cost
// associated with one argument as a heuristic.
Cost += getScalarizationOverhead(VecTy, false, true);
Cost += getScalarizationOverhead(VecVTy, false, true);
return Cost;
}
......@@ -672,13 +674,13 @@ public:
// Else, assume that we need to scalarize this op.
// TODO: If one of the types get legalized by splitting, handle this
// similarly to what getCastInstrCost() does.
if (Ty->isVectorTy()) {
unsigned Num = Ty->getVectorNumElements();
unsigned Cost = static_cast<T *>(this)
->getArithmeticInstrCost(Opcode, Ty->getScalarType());
if (auto *VTy = dyn_cast<VectorType>(Ty)) {
unsigned Num = VTy->getNumElements();
unsigned Cost = static_cast<T *>(this)->getArithmeticInstrCost(
Opcode, VTy->getScalarType());
// Return the cost of multiple scalar invocation plus the cost of
// inserting and extracting the values.
return getScalarizationOverhead(Ty, Args) + Num * Cost;
return getScalarizationOverhead(VTy, Args) + Num * Cost;
}
// We don't know anything about this scalar instruction.
......@@ -773,6 +775,8 @@ public:
// Check vector-to-vector casts.
if (Dst->isVectorTy() && Src->isVectorTy()) {
auto *SrcVTy = cast<VectorType>(Src);
auto *DstVTy = cast<VectorType>(Dst);
// If the cast is between same-sized registers, then the check is simple.
if (SrcLT.first == DstLT.first &&
SrcLT.second.getSizeInBits() == DstLT.second.getSizeInBits()) {
......@@ -800,11 +804,11 @@ public:
TargetLowering::TypeSplitVector ||
TLI->getTypeAction(Dst->getContext(), TLI->getValueType(DL, Dst)) ==
TargetLowering::TypeSplitVector) &&
Src->getVectorNumElements() > 1 && Dst->getVectorNumElements() > 1) {
Type *SplitDst = VectorType::get(Dst->getVectorElementType(),
Dst->getVectorNumElements() / 2);
Type *SplitSrc = VectorType::get(Src->getVectorElementType(),
Src->getVectorNumElements() / 2);
SrcVTy->getNumElements() > 1 && DstVTy->getNumElements() > 1) {
Type *SplitDst = VectorType::get(DstVTy->getElementType(),
DstVTy->getNumElements() / 2);
Type *SplitSrc = VectorType::get(SrcVTy->getElementType(),
SrcVTy->getNumElements() / 2);
T *TTI = static_cast<T *>(this);
return TTI->getVectorSplitCost() +
(2 * TTI->getCastInstrCost(Opcode, SplitDst, SplitSrc, I));
......@@ -812,7 +816,7 @@ public:
// In other cases where the source or destination are illegal, assume
// the operation will get scalarized.
unsigned Num = Dst->getVectorNumElements();
unsigned Num = DstVTy->getNumElements();
unsigned Cost = static_cast<T *>(this)->getCastInstrCost(
Opcode, Dst->getScalarType(), Src->getScalarType(), I);
......@@ -872,16 +876,16 @@ public:
// Otherwise, assume that the cast is scalarized.
// TODO: If one of the types get legalized by splitting, handle this
// similarly to what getCastInstrCost() does.
if (ValTy->isVectorTy()) {
unsigned Num = ValTy->getVectorNumElements();
if (auto *ValVTy = dyn_cast<VectorType>(ValTy)) {
unsigned Num = ValVTy->getNumElements();
if (CondTy)
CondTy = CondTy->getScalarType();
unsigned Cost = static_cast<T *>(this)->getCmpSelInstrCost(
Opcode, ValTy->getScalarType(), CondTy, I);
Opcode, ValVTy->getScalarType(), CondTy, I);
// Return the cost of multiple scalar invocation plus the cost of
// inserting and extracting the values.
return getScalarizationOverhead(ValTy, true, false) + Num * Cost;
return getScalarizationOverhead(ValVTy, true, false) + Num * Cost;
}
// Unknown scalar opcode.
......@@ -933,8 +937,7 @@ public:
unsigned Alignment, unsigned AddressSpace,
bool UseMaskForCond = false,
bool UseMaskForGaps = false) {
VectorType *VT = dyn_cast<VectorType>(VecTy);
assert(VT && "Expect a vector type for interleaved memory op");
auto *VT = cast<VectorType>(VecTy);
unsigned NumElts = VT->getNumElements();
assert(Factor > 1 && NumElts % Factor == 0 && "Invalid interleave factor");
......@@ -1087,7 +1090,8 @@ public:
ArrayRef<Value *> Args, FastMathFlags FMF,
unsigned VF = 1,
const Instruction *I = nullptr) {
unsigned RetVF = (RetTy->isVectorTy() ? RetTy->getVectorNumElements() : 1);
unsigned RetVF =
(RetTy->isVectorTy() ? cast<VectorType>(RetTy)->getNumElements() : 1);
assert((RetVF == 1 || VF == 1) && "VF > 1 and RetVF is a vector type");
auto *ConcreteTTI = static_cast<T *>(this);
......@@ -1210,7 +1214,8 @@ public:
if (RetTy->isVectorTy()) {
if (ScalarizationCostPassed == std::numeric_limits<unsigned>::max())
ScalarizationCost = getScalarizationOverhead(RetTy, true, false);
ScalarCalls = std::max(ScalarCalls, RetTy->getVectorNumElements());
ScalarCalls = std::max(
ScalarCalls, (unsigned)cast<VectorType>(RetTy)->getNumElements());
ScalarRetTy = RetTy->getScalarType();
}
SmallVector<Type *, 4> ScalarTys;
......@@ -1219,7 +1224,8 @@ public:
if (Ty->isVectorTy()) {
if (ScalarizationCostPassed == std::numeric_limits<unsigned>::max())
ScalarizationCost += getScalarizationOverhead(Ty, false, true);
ScalarCalls = std::max(ScalarCalls, Ty->getVectorNumElements());
ScalarCalls = std::max(
ScalarCalls, (unsigned)cast<VectorType>(Ty)->getNumElements());
Ty = Ty->getScalarType();
}
ScalarTys.push_back(Ty);
......@@ -1551,7 +1557,7 @@ public:
((ScalarizationCostPassed != std::numeric_limits<unsigned>::max())
? ScalarizationCostPassed
: getScalarizationOverhead(RetTy, true, false));
unsigned ScalarCalls = RetTy->getVectorNumElements();
unsigned ScalarCalls = cast<VectorType>(RetTy)->getNumElements();
SmallVector<Type *, 4> ScalarTys;
for (unsigned i = 0, ie = Tys.size(); i != ie; ++i) {
Type *Ty = Tys[i];
......@@ -1565,7 +1571,9 @@ public:
if (Tys[i]->isVectorTy()) {
if (ScalarizationCostPassed == std::numeric_limits<unsigned>::max())
ScalarizationCost += getScalarizationOverhead(Tys[i], false, true);
ScalarCalls = std::max(ScalarCalls, Tys[i]->getVectorNumElements());
ScalarCalls =
std::max(ScalarCalls,
(unsigned)cast<VectorType>(Tys[i])->getNumElements());
}
}
......@@ -1639,8 +1647,8 @@ public:
unsigned getArithmeticReductionCost(unsigned Opcode, Type *Ty,
bool IsPairwise) {
assert(Ty->isVectorTy() && "Expect a vector type");
Type *ScalarTy = Ty->getVectorElementType();
unsigned NumVecElts = Ty->getVectorNumElements();
Type *ScalarTy = cast<VectorType>(Ty)->getElementType();
unsigned NumVecElts = cast<VectorType>(Ty)->getNumElements();
unsigned NumReduxLevels = Log2_32(NumVecElts);
unsigned ArithCost = 0;
unsigned ShuffleCost = 0;
......@@ -1689,9 +1697,9 @@ public:
unsigned getMinMaxReductionCost(Type *Ty, Type *CondTy, bool IsPairwise,
bool) {
assert(Ty->isVectorTy() && "Expect a vector type");
Type *ScalarTy = Ty->getVectorElementType();
Type *ScalarCondTy = CondTy->getVectorElementType();
unsigned NumVecElts = Ty->getVectorNumElements();
Type *ScalarTy = cast<VectorType>(Ty)->getElementType();
Type *ScalarCondTy = cast<VectorType>(CondTy)->getElementType();
unsigned NumVecElts = cast<VectorType>(Ty)->getNumElements();
unsigned NumReduxLevels = Log2_32(NumVecElts);
unsigned CmpOpcode;
if (Ty->isFPOrFPVectorTy()) {
......
......@@ -6577,7 +6577,7 @@ class VectorPromoteHelper {
UseSplat = true;
}
ElementCount EC = getTransitionType()->getVectorElementCount();
ElementCount EC = cast<VectorType>(getTransitionType())->getElementCount();
if (UseSplat)
return ConstantVector::getSplat(EC, Val);
......@@ -6840,7 +6840,7 @@ static bool splitMergedValStore(StoreInst &SI, const DataLayout &DL,
// whereas scalable vectors would have to be shifted by
// <2log(vscale) + number of bits> in order to store the
// low/high parts. Bailing out for now.
if (StoreType->isVectorTy() && StoreType->getVectorIsScalable())
if (StoreType->isVectorTy() && cast<VectorType>(StoreType)->isScalable())
return false;
if (!DL.typeSizeEqualsStoreSize(StoreType) ||
......
......@@ -125,7 +125,7 @@ bool expandReductions(Function &F, const TargetTransformInfo *TTI) {
if (!FMF.allowReassoc())
Rdx = getOrderedReduction(Builder, Acc, Vec, getOpcode(ID), MRK);
else {
if (!isPowerOf2_32(Vec->getType()->getVectorNumElements()))
if (!isPowerOf2_32(cast<VectorType>(Vec->getType())->getNumElements()))
continue;
Rdx = getShuffleReduction(Builder, Vec, getOpcode(ID), MRK);
......@@ -146,7 +146,7 @@ bool expandReductions(Function &F, const TargetTransformInfo *TTI) {
case Intrinsic::experimental_vector_reduce_fmax:
case Intrinsic::experimental_vector_reduce_fmin: {
Value *Vec = II->getArgOperand(0);
if (!isPowerOf2_32(Vec->getType()->getVectorNumElements()))
if (!isPowerOf2_32(cast<VectorType>(Vec->getType())->getNumElements()))
continue;
Rdx = getShuffleReduction(Builder, Vec, getOpcode(ID), MRK);
......
......@@ -1909,7 +1909,7 @@ bool IRTranslator::translateInsertElement(const User &U,
MachineIRBuilder &MIRBuilder) {
// If it is a <1 x Ty> vector, use the scalar as it is
// not a legal vector type in LLT.
if (U.getType()->getVectorNumElements() == 1) {
if (cast<VectorType>(U.getType())->getNumElements() == 1) {
Register Elt = getOrCreateVReg(*U.getOperand(1));
auto &Regs = *VMap.getVRegs(U);
if (Regs.empty()) {
......@@ -1933,7 +1933,7 @@ bool IRTranslator::translateExtractElement(const User &U,
MachineIRBuilder &MIRBuilder) {
// If it is a <1 x Ty> vector, use the scalar as it is
// not a legal vector type in LLT.
if (U.getOperand(0)->getType()->getVectorNumElements() == 1) {
if (cast<VectorType>(U.getOperand(0)->getType())->getNumElements() == 1) {
Register Elt = getOrCreateVReg(*U.getOperand(0));
auto &Regs = *VMap.getVRegs(U);
if (Regs.empty()) {
......
......@@ -308,7 +308,7 @@ bool InterleavedAccess::lowerInterleavedLoad(
unsigned Factor, Index;
unsigned NumLoadElements = LI->getType()->getVectorNumElements();
unsigned NumLoadElements = cast<VectorType>(LI->getType())->getNumElements();
// Check if the first shufflevector is DE-interleave shuffle.
if (!isDeInterleaveMask(Shuffles[0]->getShuffleMask(), Factor, Index,
MaxFactor, NumLoadElements))
......@@ -426,7 +426,8 @@ bool InterleavedAccess::lowerInterleavedStore(
// Check if the shufflevector is RE-interleave shuffle.
unsigned Factor;
unsigned OpNumElts = SVI->getOperand(0)->getType()->getVectorNumElements();
unsigned OpNumElts =
cast<VectorType>(SVI->getOperand(0)->getType())->getNumElements();
if (!isReInterleaveMask(SVI->getShuffleMask(), Factor, MaxFactor, OpNumElts))
return false;
......
......@@ -82,7 +82,7 @@ static bool isConstantIntVector(Value *Mask) {
if (!C)
return false;
unsigned NumElts = Mask->getType()->getVectorNumElements();
unsigned NumElts = cast<VectorType>(Mask->getType())->getNumElements();
for (unsigned i = 0; i != NumElts; ++i) {
Constant *CElt = C->getAggregateElement(i);
if (!CElt || !isa<ConstantInt>(CElt))
......@@ -521,9 +521,10 @@ static void scalarizeMaskedScatter(CallInst *CI, bool &ModifiedDT) {
assert(isa<VectorType>(Src->getType()) &&
"Unexpected data type in masked scatter intrinsic");
assert(isa<VectorType>(Ptrs->getType()) &&
isa<PointerType>(Ptrs->getType()->getVectorElementType()) &&
"Vector of pointers is expected in masked scatter intrinsic");
assert(
isa<VectorType>(Ptrs->getType()) &&
isa<PointerType>(cast<VectorType>(Ptrs->getType())->getElementType()) &&
"Vector of pointers is expected in masked scatter intrinsic");
IRBuilder<> Builder(CI->getContext());
Instruction *InsertPt = CI;
......@@ -532,7 +533,7 @@ static void scalarizeMaskedScatter(CallInst *CI, bool &ModifiedDT) {
Builder.SetCurrentDebugLocation(CI->getDebugLoc());
MaybeAlign AlignVal(cast<ConstantInt>(Alignment)->getZExtValue());
unsigned VectorWidth = Src->getType()->getVectorNumElements();
unsigned VectorWidth = cast<VectorType>(Src->getType())->getNumElements();
// Shorten the way if the mask is a vector of constants.
if (isConstantIntVector(Mask)) {
......@@ -725,7 +726,7 @@ static void scalarizeMaskedCompressStore(CallInst *CI, bool &ModifiedDT) {
Builder.SetInsertPoint(InsertPt);
Builder.SetCurrentDebugLocation(CI->getDebugLoc());
Type *EltTy = VecType->getVectorElementType();
Type *EltTy = VecType->getElementType();
unsigned VectorWidth = VecType->getNumElements();
......
......@@ -161,7 +161,7 @@ void FunctionLoweringInfo::set(const Function &fn, MachineFunction &mf,
// Scalable vectors may need a special StackID to distinguish
// them from other (fixed size) stack objects.
if (Ty->isVectorTy() && Ty->getVectorIsScalable())
if (Ty->isVectorTy() && cast<VectorType>(Ty)->isScalable())
MF->getFrameInfo().setStackID(FrameIndex,
TFI->getStackIDForScalableVectors());
......
......@@ -3751,8 +3751,9 @@ void SelectionDAGBuilder::visitGetElementPtr(const User &I) {
// Normalize Vector GEP - all scalar operands should be converted to the
// splat vector.
bool IsVectorGEP = I.getType()->isVectorTy();
ElementCount VectorElementCount = IsVectorGEP ?
I.getType()->getVectorElementCount() : ElementCount(0, false);
ElementCount VectorElementCount =
IsVectorGEP ? cast<VectorType>(I.getType())->getElementCount()
: ElementCount(0, false);
if (IsVectorGEP && !N.getValueType().isVector()) {
LLVMContext &Context = *DAG.getContext();
......@@ -4312,7 +4313,7 @@ static bool getUniformBase(const Value *Ptr, SDValue &Base, SDValue &Index,
IndexType = ISD::SIGNED_SCALED;
if (STy || !Index.getValueType().isVector()) {
unsigned GEPWidth = GEP->getType()->getVectorNumElements();
unsigned GEPWidth = cast<VectorType>(GEP->getType())->getNumElements();
EVT VT = EVT::getVectorVT(Context, Index.getValueType(), GEPWidth);
Index = DAG.getSplatBuildVector(VT, SDLoc(Index), Index);
}
......
......@@ -1681,8 +1681,8 @@ static std::string scalarConstantToHexString(const Constant *C) {
return APIntToHexString(CI->getValue());
} else {
unsigned NumElements;
if (isa<VectorType>(Ty))
NumElements = Ty->getVectorNumElements();
if (auto *VTy = dyn_cast<VectorType>(Ty))
NumElements = VTy->getNumElements();
else
NumElements = Ty->getArrayNumElements();
std::string HexString;
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment