Commit ffc3e800 authored by QingShan Zhang's avatar QingShan Zhang
Browse files

[NFC] [DAGCombine] Correct the result for sqrt even the iteration is zero

For now, we correct the result for sqrt if iteration > 0. This doesn't make
sense as they are not strict relative.

Reviewed By: dmgreen, spatel, RKSimon

Differential Revision: https://reviews.llvm.org/D94480
parent 89a5147e
Loading
Loading
Loading
Loading
+1 −3
Original line number Diff line number Diff line
@@ -4287,9 +4287,7 @@ public:
  /// comparison may check if the operand is NAN, INF, zero, normal, etc. The
  /// result should be used as the condition operand for a select or branch.
  virtual SDValue getSqrtInputTest(SDValue Operand, SelectionDAG &DAG,
                                   const DenormalMode &Mode) const {
    return SDValue();
  }
                                   const DenormalMode &Mode) const;

  /// Return a target-dependent result if the input operand is not suitable for
  /// use with a square root estimate calculation.
+12 −34
Original line number Diff line number Diff line
@@ -22275,35 +22275,14 @@ SDValue DAGCombiner::buildSqrtEstimateImpl(SDValue Op, SDNodeFlags Flags,
                          Reciprocal)) {
    AddToWorklist(Est.getNode());
    if (Iterations) {
    if (Iterations)
      Est = UseOneConstNR
            ? buildSqrtNROneConst(Op, Est, Iterations, Flags, Reciprocal)
            : buildSqrtNRTwoConst(Op, Est, Iterations, Flags, Reciprocal);
    if (!Reciprocal) {
      SDLoc DL(Op);
        EVT CCVT = getSetCCResultType(VT);
        SDValue FPZero = DAG.getConstantFP(0.0, DL, VT);
        DenormalMode DenormMode = DAG.getDenormalMode(VT);
      // Try the target specific test first.
        SDValue Test = TLI.getSqrtInputTest(Op, DAG, DenormMode);
        if (!Test) {
          // If no test provided by target, testing it with denormal inputs to
          // avoid wrong estimate.
          if (DenormMode.Input == DenormalMode::IEEE) {
            // This is specifically a check for the handling of denormal inputs,
            // not the result.
            // Test = fabs(X) < SmallestNormal
            const fltSemantics &FltSem = DAG.EVTToAPFloatSemantics(VT);
            APFloat SmallestNorm = APFloat::getSmallestNormalized(FltSem);
            SDValue NormC = DAG.getConstantFP(SmallestNorm, DL, VT);
            SDValue Fabs = DAG.getNode(ISD::FABS, DL, VT, Op);
            Test = DAG.getSetCC(DL, CCVT, Fabs, NormC, ISD::SETLT);
          } else
            // Test = X == 0.0
            Test = DAG.getSetCC(DL, CCVT, Op, FPZero, ISD::SETEQ);
        }
      SDValue Test = TLI.getSqrtInputTest(Op, DAG, DAG.getDenormalMode(VT));
      // The estimate is now completely wrong if the input was exactly 0.0 or
      // possibly a denormal. Force the answer to 0.0 or value provided by
@@ -22312,7 +22291,6 @@ SDValue DAGCombiner::buildSqrtEstimateImpl(SDValue Op, SDNodeFlags Flags,
          Test.getValueType().isVector() ? ISD::VSELECT : ISD::SELECT, DL, VT,
          Test, TLI.getSqrtResultForDenormInput(Op, DAG), Est);
    }
    }
    return Est;
  }
+22 −0
Original line number Diff line number Diff line
@@ -5841,6 +5841,28 @@ verifyReturnAddressArgumentIsConstant(SDValue Op, SelectionDAG &DAG) const {
  return false;
}

SDValue TargetLowering::getSqrtInputTest(SDValue Op, SelectionDAG &DAG,
                                         const DenormalMode &Mode) const {
  SDLoc DL(Op);
  EVT VT = Op.getValueType();
  EVT CCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
  SDValue FPZero = DAG.getConstantFP(0.0, DL, VT);
  // Testing it with denormal inputs to avoid wrong estimate.
  if (Mode.Input == DenormalMode::IEEE) {
    // This is specifically a check for the handling of denormal inputs,
    // not the result.

    // Test = fabs(X) < SmallestNormal
    const fltSemantics &FltSem = DAG.EVTToAPFloatSemantics(VT);
    APFloat SmallestNorm = APFloat::getSmallestNormalized(FltSem);
    SDValue NormC = DAG.getConstantFP(SmallestNorm, DL, VT);
    SDValue Fabs = DAG.getNode(ISD::FABS, DL, VT, Op);
    return DAG.getSetCC(DL, CCVT, Fabs, NormC, ISD::SETLT);
  }
  // Test = X == 0.0
  return DAG.getSetCC(DL, CCVT, Op, FPZero, ISD::SETEQ);
}

SDValue TargetLowering::getNegatedExpression(SDValue Op, SelectionDAG &DAG,
                                             bool LegalOps, bool OptForSize,
                                             NegatibleCost &Cost,
+17 −10
Original line number Diff line number Diff line
@@ -7471,6 +7471,22 @@ static SDValue getEstimate(const AArch64Subtarget *ST, unsigned Opcode,
  return SDValue();
}
SDValue
AArch64TargetLowering::getSqrtInputTest(SDValue Op, SelectionDAG &DAG,
                                        const DenormalMode &Mode) const {
  SDLoc DL(Op);
  EVT VT = Op.getValueType();
  EVT CCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
  SDValue FPZero = DAG.getConstantFP(0.0, DL, VT);
  return DAG.getSetCC(DL, CCVT, Op, FPZero, ISD::SETEQ);
}
SDValue
AArch64TargetLowering::getSqrtResultForDenormInput(SDValue Op,
                                                   SelectionDAG &DAG) const {
  return Op;
}
SDValue AArch64TargetLowering::getSqrtEstimate(SDValue Operand,
                                               SelectionDAG &DAG, int Enabled,
                                               int &ExtraSteps,
@@ -7494,17 +7510,8 @@ SDValue AArch64TargetLowering::getSqrtEstimate(SDValue Operand,
        Step = DAG.getNode(AArch64ISD::FRSQRTS, DL, VT, Operand, Step, Flags);
        Estimate = DAG.getNode(ISD::FMUL, DL, VT, Estimate, Step, Flags);
      }
      if (!Reciprocal) {
        EVT CCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(),
                                      VT);
        SDValue FPZero = DAG.getConstantFP(0.0, DL, VT);
        SDValue Eq = DAG.getSetCC(DL, CCVT, Operand, FPZero, ISD::SETEQ);
      if (!Reciprocal)
        Estimate = DAG.getNode(ISD::FMUL, DL, VT, Operand, Estimate, Flags);
        // Correct the result if the operand is 0.0.
        Estimate = DAG.getNode(VT.isVector() ? ISD::VSELECT : ISD::SELECT, DL,
                               VT, Eq, Operand, Estimate);
      }
      ExtraSteps = 0;
      return Estimate;
+4 −0
Original line number Diff line number Diff line
@@ -961,6 +961,10 @@ private:
                          bool Reciprocal) const override;
  SDValue getRecipEstimate(SDValue Operand, SelectionDAG &DAG, int Enabled,
                           int &ExtraSteps) const override;
  SDValue getSqrtInputTest(SDValue Operand, SelectionDAG &DAG,
                           const DenormalMode &Mode) const override;
  SDValue getSqrtResultForDenormInput(SDValue Operand,
                                      SelectionDAG &DAG) const override;
  unsigned combineRepeatedFPDivisors() const override;

  ConstraintType getConstraintType(StringRef Constraint) const override;
Loading