Commit 37289615 authored by Sam Parker's avatar Sam Parker
Browse files

[NFCI][CostModel] Unify getCmpSelInstrCost

Add cases for icmp, fcmp and select into the switch statement of the
generic getUserCost implementation with getInstructionThroughput then
calling into it. The BasicTTI and backend implementations have be set
to return a default value (1) when a cost other than throughput is
being queried.

Differential Revision: https://reviews.llvm.org/D80550
parent 30dfbf03
......@@ -865,6 +865,17 @@ public:
LI->getPointerAddressSpace(),
CostKind, I);
}
case Instruction::Select: {
Type *CondTy = U->getOperand(0)->getType();
return TargetTTI->getCmpSelInstrCost(Opcode, U->getType(), CondTy,
CostKind, I);
}
case Instruction::ICmp:
case Instruction::FCmp: {
Type *ValTy = U->getOperand(0)->getType();
return TargetTTI->getCmpSelInstrCost(Opcode, ValTy, U->getType(),
CostKind, I);
}
}
// By default, just classify everything as 'basic'.
return TTI::TCC_Basic;
......
......@@ -838,6 +838,10 @@ public:
int ISD = TLI->InstructionOpcodeToISD(Opcode);
assert(ISD && "Invalid opcode");
// TODO: Handle other cost kinds.
if (CostKind != TTI::TCK_RecipThroughput)
return BaseT::getCmpSelInstrCost(Opcode, ValTy, CondTy, CostKind, I);
// Selects on vectors are actually vector selects.
if (ISD == ISD::SELECT) {
assert(CondTy && "CondTy must exist");
......
......@@ -1296,18 +1296,10 @@ int TargetTransformInfo::getInstructionThroughput(const Instruction *I) const {
Op1VK, Op2VK,
Op1VP, Op2VP, Operands, I);
}
case Instruction::Select: {
const SelectInst *SI = cast<SelectInst>(I);
Type *CondTy = SI->getCondition()->getType();
return getCmpSelInstrCost(I->getOpcode(), I->getType(), CondTy,
CostKind, I);
}
case Instruction::Select:
case Instruction::ICmp:
case Instruction::FCmp: {
Type *ValTy = I->getOperand(0)->getType();
return getCmpSelInstrCost(I->getOpcode(), ValTy, I->getType(),
CostKind, I);
}
case Instruction::FCmp:
return getUserCost(I, CostKind);
case Instruction::Store:
case Instruction::Load:
return getUserCost(I, CostKind);
......
......@@ -620,6 +620,9 @@ int AArch64TTIImpl::getCmpSelInstrCost(unsigned Opcode, Type *ValTy,
Type *CondTy,
TTI::TargetCostKind CostKind,
const Instruction *I) {
// TODO: Handle other cost kinds.
if (CostKind != TTI::TCK_RecipThroughput)
return BaseT::getCmpSelInstrCost(Opcode, ValTy, CondTy, CostKind, I);
int ISD = TLI->InstructionOpcodeToISD(Opcode);
// We don't lower some vector selects well that are wider than the register
......
......@@ -508,6 +508,10 @@ int ARMTTIImpl::getVectorInstrCost(unsigned Opcode, Type *ValTy,
int ARMTTIImpl::getCmpSelInstrCost(unsigned Opcode, Type *ValTy, Type *CondTy,
TTI::TargetCostKind CostKind,
const Instruction *I) {
// TODO: Handle other cost kinds.
if (CostKind != TTI::TCK_RecipThroughput)
return BaseT::getCmpSelInstrCost(Opcode, ValTy, CondTy, CostKind, I);
int ISD = TLI->InstructionOpcodeToISD(Opcode);
// On NEON a vector select gets lowered to vbsl.
if (ST->hasNEON() && ValTy->isVectorTy() && ISD == ISD::SELECT) {
......
......@@ -236,7 +236,7 @@ unsigned HexagonTTIImpl::getInterleavedMemoryOpCost(unsigned Opcode,
unsigned HexagonTTIImpl::getCmpSelInstrCost(unsigned Opcode, Type *ValTy,
Type *CondTy, TTI::TargetCostKind CostKind, const Instruction *I) {
if (ValTy->isVectorTy()) {
if (ValTy->isVectorTy() && CostKind == TTI::TCK_RecipThroughput) {
std::pair<int, MVT> LT = TLI.getTypeLegalizationCost(DL, ValTy);
if (Opcode == Instruction::FCmp)
return LT.first + FloatFactor * getTypeNumElements(ValTy);
......
......@@ -779,6 +779,9 @@ int PPCTTIImpl::getCmpSelInstrCost(unsigned Opcode, Type *ValTy, Type *CondTy,
TTI::TargetCostKind CostKind,
const Instruction *I) {
int Cost = BaseT::getCmpSelInstrCost(Opcode, ValTy, CondTy, CostKind, I);
// TODO: Handle other cost kinds.
if (CostKind != TTI::TCK_RecipThroughput)
return Cost;
return vectorCostAdjustment(Cost, Opcode, ValTy, nullptr);
}
......
......@@ -837,6 +837,9 @@ int SystemZTTIImpl::getCmpSelInstrCost(unsigned Opcode, Type *ValTy,
Type *CondTy,
TTI::TargetCostKind CostKind,
const Instruction *I) {
if (CostKind != TTI::TCK_RecipThroughput)
return BaseT::getCmpSelInstrCost(Opcode, ValTy, CondTy, CostKind);
if (!ValTy->isVectorTy()) {
switch (Opcode) {
case Instruction::ICmp: {
......
......@@ -2051,6 +2051,10 @@ int X86TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
int X86TTIImpl::getCmpSelInstrCost(unsigned Opcode, Type *ValTy, Type *CondTy,
TTI::TargetCostKind CostKind,
const Instruction *I) {
// TODO: Handle other cost kinds.
if (CostKind != TTI::TCK_RecipThroughput)
return BaseT::getCmpSelInstrCost(Opcode, ValTy, CondTy, CostKind, I);
// Legalize the type.
std::pair<int, MVT> LT = TLI->getTypeLegalizationCost(DL, ValTy);
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment