Commit 8af492ad authored by Liu, Chen3's avatar Liu, Chen3
Browse files

add strict float for round operation

Differential Revision: https://reviews.llvm.org/D72026
parent d2bb8c16
Loading
Loading
Loading
Loading
+28 −5
Original line number Diff line number Diff line
@@ -897,26 +897,49 @@ void X86DAGToDAGISel::PreprocessISelDAG() {
      continue;
    }
    case ISD::FCEIL:
    case ISD::STRICT_FCEIL:
    case ISD::FFLOOR:
    case ISD::STRICT_FFLOOR:
    case ISD::FTRUNC:
    case ISD::STRICT_FTRUNC:
    case ISD::FNEARBYINT:
    case ISD::FRINT: {
    case ISD::STRICT_FNEARBYINT:
    case ISD::FRINT:
    case ISD::STRICT_FRINT: {
      // Replace fp rounding with their X86 specific equivalent so we don't
      // need 2 sets of patterns.
      unsigned Imm;
      switch (N->getOpcode()) {
      default: llvm_unreachable("Unexpected opcode!");
      case ISD::STRICT_FCEIL:
      case ISD::FCEIL:      Imm = 0xA; break;
      case ISD::STRICT_FFLOOR:
      case ISD::FFLOOR:     Imm = 0x9; break;
      case ISD::STRICT_FTRUNC:
      case ISD::FTRUNC:     Imm = 0xB; break;
      case ISD::STRICT_FNEARBYINT:
      case ISD::FNEARBYINT: Imm = 0xC; break;
      case ISD::STRICT_FRINT:
      case ISD::FRINT:      Imm = 0x4; break;
      }
      SDLoc dl(N);
      SDValue Res = CurDAG->getNode(
          X86ISD::VRNDSCALE, dl, N->getValueType(0), N->getOperand(0),
      bool IsStrict = N->isStrictFPOpcode();
      SDValue Res;
      if (IsStrict)
        Res = CurDAG->getNode(X86ISD::STRICT_VRNDSCALE, dl,
                              {N->getValueType(0), MVT::Other},
                              {N->getOperand(0), N->getOperand(1),
                               CurDAG->getTargetConstant(Imm, dl, MVT::i8)});
      else
        Res = CurDAG->getNode(X86ISD::VRNDSCALE, dl, N->getValueType(0),
                              N->getOperand(0),
                              CurDAG->getTargetConstant(Imm, dl, MVT::i8));
      --I;
      if (IsStrict) {
        SDValue From[] = {SDValue(N, 0), SDValue(N, 1)};
        SDValue To[] = {Res.getValue(0), Res.getValue(1)};
        CurDAG->ReplaceAllUsesOfValuesWith(From, To, 2);
      } else
        CurDAG->ReplaceAllUsesOfValueWith(SDValue(N, 0), Res);
      ++I;
      CurDAG->DeleteNode(N);
+34 −18
Original line number Diff line number Diff line
@@ -1069,10 +1069,15 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
  if (!Subtarget.useSoftFloat() && Subtarget.hasSSE41()) {
    for (MVT RoundedTy : {MVT::f32, MVT::f64, MVT::v4f32, MVT::v2f64}) {
      setOperationAction(ISD::FFLOOR,            RoundedTy,  Legal);
      setOperationAction(ISD::STRICT_FFLOOR,     RoundedTy,  Legal);
      setOperationAction(ISD::FCEIL,             RoundedTy,  Legal);
      setOperationAction(ISD::STRICT_FCEIL,      RoundedTy,  Legal);
      setOperationAction(ISD::FTRUNC,            RoundedTy,  Legal);
      setOperationAction(ISD::STRICT_FTRUNC,     RoundedTy,  Legal);
      setOperationAction(ISD::FRINT,             RoundedTy,  Legal);
      setOperationAction(ISD::STRICT_FRINT,      RoundedTy,  Legal);
      setOperationAction(ISD::FNEARBYINT,        RoundedTy,  Legal);
      setOperationAction(ISD::STRICT_FNEARBYINT, RoundedTy,  Legal);
    }
    setOperationAction(ISD::SMAX,               MVT::v16i8, Legal);
@@ -1145,10 +1150,15 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
    for (auto VT : { MVT::v8f32, MVT::v4f64 }) {
      setOperationAction(ISD::FFLOOR,            VT, Legal);
      setOperationAction(ISD::STRICT_FFLOOR,     VT, Legal);
      setOperationAction(ISD::FCEIL,             VT, Legal);
      setOperationAction(ISD::STRICT_FCEIL,      VT, Legal);
      setOperationAction(ISD::FTRUNC,            VT, Legal);
      setOperationAction(ISD::STRICT_FTRUNC,     VT, Legal);
      setOperationAction(ISD::FRINT,             VT, Legal);
      setOperationAction(ISD::STRICT_FRINT,      VT, Legal);
      setOperationAction(ISD::FNEARBYINT,        VT, Legal);
      setOperationAction(ISD::STRICT_FNEARBYINT, VT, Legal);
      setOperationAction(ISD::FNEG,              VT, Custom);
      setOperationAction(ISD::FABS,              VT, Custom);
      setOperationAction(ISD::FCOPYSIGN,         VT, Custom);
@@ -1504,10 +1514,15 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
    for (auto VT : { MVT::v16f32, MVT::v8f64 }) {
      setOperationAction(ISD::FFLOOR,            VT, Legal);
      setOperationAction(ISD::STRICT_FFLOOR,     VT, Legal);
      setOperationAction(ISD::FCEIL,             VT, Legal);
      setOperationAction(ISD::STRICT_FCEIL,      VT, Legal);
      setOperationAction(ISD::FTRUNC,            VT, Legal);
      setOperationAction(ISD::STRICT_FTRUNC,     VT, Legal);
      setOperationAction(ISD::FRINT,             VT, Legal);
      setOperationAction(ISD::STRICT_FRINT,      VT, Legal);
      setOperationAction(ISD::FNEARBYINT,        VT, Legal);
      setOperationAction(ISD::STRICT_FNEARBYINT, VT, Legal);
      setOperationAction(ISD::SELECT,           VT, Custom);
    }
@@ -29650,6 +29665,7 @@ const char *X86TargetLowering::getTargetNodeName(unsigned Opcode) const {
  case X86ISD::VPMADD52H:          return "X86ISD::VPMADD52H";
  case X86ISD::VPMADD52L:          return "X86ISD::VPMADD52L";
  case X86ISD::VRNDSCALE:          return "X86ISD::VRNDSCALE";
  case X86ISD::STRICT_VRNDSCALE:   return "X86ISD::STRICT_VRNDSCALE";
  case X86ISD::VRNDSCALE_SAE:      return "X86ISD::VRNDSCALE_SAE";
  case X86ISD::VRNDSCALES:         return "X86ISD::VRNDSCALES";
  case X86ISD::VRNDSCALES_SAE:     return "X86ISD::VRNDSCALES_SAE";
+1 −1
Original line number Diff line number Diff line
@@ -424,7 +424,7 @@ namespace llvm {
      // RndScale - Round FP Values To Include A Given Number Of Fraction Bits.
      // Also used by the legacy (V)ROUND intrinsics where we mask out the
      // scaling part of the immediate.
      VRNDSCALE, VRNDSCALE_SAE, VRNDSCALES, VRNDSCALES_SAE,
      VRNDSCALE, VRNDSCALE_SAE, VRNDSCALES, VRNDSCALES_SAE, STRICT_VRNDSCALE,
      // Tests Types Of a FP Values for packed types.
      VFPCLASS,
      // Tests Types Of a FP Values for scalar types.
+3 −3
Original line number Diff line number Diff line
@@ -9019,13 +9019,13 @@ multiclass avx512_rndscale_scalar<bits<8> opc, string OpcodeStr,
  }
  let Predicates = [HasAVX512] in {
    def : Pat<(X86VRndScale _.FRC:$src1, timm:$src2),
    def : Pat<(X86any_VRndScale _.FRC:$src1, timm:$src2),
              (_.EltVT (!cast<Instruction>(NAME##r) (_.EltVT (IMPLICIT_DEF)),
               _.FRC:$src1, timm:$src2))>;
  }
  let Predicates = [HasAVX512, OptForSize] in {
    def : Pat<(X86VRndScale (_.ScalarLdFrag addr:$src1), timm:$src2),
    def : Pat<(X86any_VRndScale (_.ScalarLdFrag addr:$src1), timm:$src2),
              (_.EltVT (!cast<Instruction>(NAME##m) (_.EltVT (IMPLICIT_DEF)),
               addr:$src1, timm:$src2))>;
  }
@@ -10290,7 +10290,7 @@ defm VREDUCE : avx512_common_unary_fp_sae_packed_imm_all<"vreduce", 0x56, 0x56
                              X86VReduce, X86VReduceSAE, SchedWriteFRnd, HasDQI>,
                              AVX512AIi8Base, EVEX;
defm VRNDSCALE : avx512_common_unary_fp_sae_packed_imm_all<"vrndscale", 0x08, 0x09,
                              X86VRndScale, X86VRndScaleSAE, SchedWriteFRnd, HasAVX512>,
                              X86any_VRndScale, X86VRndScaleSAE, SchedWriteFRnd, HasAVX512>,
                              AVX512AIi8Base, EVEX;
defm VGETMANT : avx512_common_unary_fp_sae_packed_imm_all<"vgetmant", 0x26, 0x26,
                              X86VGetMant, X86VGetMantSAE, SchedWriteFRnd, HasAVX512>,
+6 −0
Original line number Diff line number Diff line
@@ -466,6 +466,12 @@ def X86VRangeSAE : SDNode<"X86ISD::VRANGE_SAE", SDTFPBinOpImm>;
def X86VReduce     : SDNode<"X86ISD::VREDUCE",       SDTFPUnaryOpImm>;
def X86VReduceSAE  : SDNode<"X86ISD::VREDUCE_SAE",   SDTFPUnaryOpImm>;
def X86VRndScale   : SDNode<"X86ISD::VRNDSCALE",     SDTFPUnaryOpImm>;
def X86strict_VRndScale : SDNode<"X86ISD::STRICT_VRNDSCALE", SDTFPUnaryOpImm,
                                  [SDNPHasChain]>;
def X86any_VRndScale    : PatFrags<(ops node:$src1, node:$src2),
                                    [(X86strict_VRndScale node:$src1, node:$src2),
                                    (X86VRndScale node:$src1, node:$src2)]>;

def X86VRndScaleSAE: SDNode<"X86ISD::VRNDSCALE_SAE", SDTFPUnaryOpImm>;
def X86VGetMant    : SDNode<"X86ISD::VGETMANT",      SDTFPUnaryOpImm>;
def X86VGetMantSAE : SDNode<"X86ISD::VGETMANT_SAE",  SDTFPUnaryOpImm>;
Loading