Commit 49429783 authored by Craig Topper's avatar Craig Topper
Browse files

[RISCV] Add lowering for scalar fmaximum/fminimum.

Unlike fmaxnum and fminnum, these operations propagate nan and
consider -0.0 to be less than +0.0.

Without Zfa, we don't have a single instruction for this. The
lowering I've used forces the other input to nan if one input
is a nan. If both inputs are nan, they get swapped. Then use
the fmax or fmin instruction.

New ISD nodes are needed because fmaxnum/fminnum to not define
the order of -0.0 and +0.0.

This lowering ensures the snans are quieted though that is probably not
required in default environment). Also ensures non-canonical nans
are canonicalized, though I'm also not sure that's needed.

Another option could be to use fmax/fmin and then overwrite the
result based on the inputs being nan, but I'm not sure we can do
that with any less code.

Future work will handle nonans FMF, and handling the case where
we can prove the input isn't nan.

This does fix the crash in #64022, but we need to do more work
to avoid scalarization.

Reviewed By: fakepaper56

Differential Revision: https://reviews.llvm.org/D156069
parent 6c48f57c
Loading
Loading
Loading
Loading
+40 −0
Original line number Diff line number Diff line
@@ -423,6 +423,9 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
    // We need to custom promote this.
    if (Subtarget.is64Bit())
      setOperationAction(ISD::FPOWI, MVT::i32, Custom);
    if (!Subtarget.hasStdExtZfa())
      setOperationAction({ISD::FMAXIMUM, ISD::FMINIMUM}, MVT::f16, Custom);
  }
  if (Subtarget.hasStdExtFOrZfinx()) {
@@ -445,6 +448,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
    if (Subtarget.hasStdExtZfa())
      setOperationAction(ISD::FNEARBYINT, MVT::f32, Legal);
    else
      setOperationAction({ISD::FMAXIMUM, ISD::FMINIMUM}, MVT::f32, Custom);
  }
  if (Subtarget.hasStdExtFOrZfinx() && Subtarget.is64Bit())
@@ -461,6 +466,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
    } else {
      if (Subtarget.is64Bit())
        setOperationAction(FPRndMode, MVT::f64, Custom);
      setOperationAction({ISD::FMAXIMUM, ISD::FMINIMUM}, MVT::f64, Custom);
    }
    setOperationAction(ISD::STRICT_FP_ROUND, MVT::f32, Legal);
@@ -4624,6 +4631,34 @@ SDValue RISCVTargetLowering::LowerIS_FPCLASS(SDValue Op,
                      ISD::CondCode::SETNE);
}
// Lower fmaximum and fminimum. Unlike our fmax and fmin instructions, these
// operations propagate nans.
static SDValue lowerFMAXIMUM_FMINIMUM(SDValue Op, SelectionDAG &DAG,
                                      const RISCVSubtarget &Subtarget) {
  SDLoc DL(Op);
  EVT VT = Op.getValueType();
  SDValue X = Op.getOperand(0);
  SDValue Y = Op.getOperand(1);
  MVT XLenVT = Subtarget.getXLenVT();
  // If X is a nan, replace Y with X. If Y is a nan, replace X with Y. This
  // ensures that when one input is a nan, the other will also be a nan allowing
  // the nan to propagate. If both inputs are nan, this will swap the inputs
  // which is harmless.
  // FIXME: Handle nonans FMF and use isKnownNeverNaN.
  SDValue XIsNonNan = DAG.getSetCC(DL, XLenVT, X, X, ISD::SETOEQ);
  SDValue NewY = DAG.getSelect(DL, VT, XIsNonNan, Y, X);
  SDValue YIsNonNan = DAG.getSetCC(DL, XLenVT, Y, Y, ISD::SETOEQ);
  SDValue NewX = DAG.getSelect(DL, VT, YIsNonNan, X, Y);
  unsigned Opc =
      Op.getOpcode() == ISD::FMAXIMUM ? RISCVISD::FMAX : RISCVISD::FMIN;
  return DAG.getNode(Opc, DL, VT, NewX, NewY);
}
/// Get a RISCV target specified VL op for a given SDNode.
static unsigned getRISCVVLOp(SDValue Op) {
#define OP_CASE(NODE)                                                          \
@@ -4948,6 +4983,9 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
    }
    return SDValue();
  }
  case ISD::FMAXIMUM:
  case ISD::FMINIMUM:
    return lowerFMAXIMUM_FMINIMUM(Op, DAG, Subtarget);
  case ISD::FP_EXTEND: {
    SDLoc DL(Op);
    EVT VT = Op.getValueType();
@@ -16054,6 +16092,8 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
  NODE_NAME_CASE(FP_EXTEND_BF16)
  NODE_NAME_CASE(FROUND)
  NODE_NAME_CASE(FPCLASS)
  NODE_NAME_CASE(FMAX)
  NODE_NAME_CASE(FMIN)
  NODE_NAME_CASE(READ_CYCLE_WIDE)
  NODE_NAME_CASE(BREV8)
  NODE_NAME_CASE(ORC_B)
+4 −0
Original line number Diff line number Diff line
@@ -122,6 +122,10 @@ enum NodeType : unsigned {
  FROUND,

  FPCLASS,

  // Floating point fmax and fmin matching the RISC-V instruction semantics.
  FMAX, FMIN,

  // READ_CYCLE_WIDE - A read of the 64-bit cycle CSR on a 32-bit target
  // (returns (Lo, Hi)). It takes a chain operand.
  READ_CYCLE_WIDE,
+2 −0
Original line number Diff line number Diff line
@@ -386,6 +386,8 @@ def : Pat<(fneg (any_fma_nsz FPR64IN32X:$rs1, FPR64IN32X:$rs2, FPR64IN32X:$rs3))
foreach Ext = DExts in {
  defm : PatFprFpr_m<fminnum, FMIN_D, Ext>;
  defm : PatFprFpr_m<fmaxnum, FMAX_D, Ext>;
  defm : PatFprFpr_m<riscv_fmin, FMIN_D, Ext>;
  defm : PatFprFpr_m<riscv_fmax, FMAX_D, Ext>;
}

/// Setcc
+5 −0
Original line number Diff line number Diff line
@@ -51,6 +51,9 @@ def riscv_fcvt_x
def riscv_fcvt_xu
    : SDNode<"RISCVISD::FCVT_XU", SDT_RISCVFCVT_X>;

def riscv_fmin : SDNode<"RISCVISD::FMIN", SDTFPBinOp>;
def riscv_fmax : SDNode<"RISCVISD::FMAX", SDTFPBinOp>;

def riscv_strict_fcvt_w_rv64
    : SDNode<"RISCVISD::STRICT_FCVT_W_RV64", SDT_RISCVFCVT_W_RV64,
             [SDNPHasChain]>;
@@ -555,6 +558,8 @@ def : Pat<(fneg (any_fma_nsz FPR32INX:$rs1, FPR32INX:$rs2, FPR32INX:$rs3)),
foreach Ext = FExts in {
  defm : PatFprFpr_m<fminnum, FMIN_S, Ext>;
  defm : PatFprFpr_m<fmaxnum, FMAX_S, Ext>;
  defm : PatFprFpr_m<riscv_fmin, FMIN_S, Ext>;
  defm : PatFprFpr_m<riscv_fmax, FMAX_S, Ext>;
}

/// Setcc
+2 −0
Original line number Diff line number Diff line
@@ -348,6 +348,8 @@ def : Pat<(fneg (any_fma_nsz FPR16INX:$rs1, FPR16INX:$rs2, FPR16INX:$rs3)),
foreach Ext = ZfhExts in {
  defm : PatFprFpr_m<fminnum, FMIN_H, Ext>;
  defm : PatFprFpr_m<fmaxnum, FMAX_H, Ext>;
  defm : PatFprFpr_m<riscv_fmin, FMIN_H, Ext>;
  defm : PatFprFpr_m<riscv_fmax, FMAX_H, Ext>;
}

/// Setcc
Loading