Commit 0676c6d9 authored by Yeting Kuo's avatar Yeting Kuo
Browse files

[RISCV] Support vector type strict_fma.

Like D145900, the patch also supports fixed vector strict_fma nodes in RISC-V by
customized lowering them to riscv_strict_vfmadd_vl nodes. riscv_strict_vfmadd_vl
is created to avoid some riscv_vfmadd_vl optimizations happening to original
strict_fma nodes. The patch also adds combine patterns for riscv_strict_fmadd_vl
nodes with negation operands.

Reviewed By: craig.topper

Differential Revision: https://reviews.llvm.org/D146939
parent 9382bbad
Loading
Loading
Loading
Loading
+31 −8
Original line number Diff line number Diff line
@@ -806,7 +806,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
      setOperationAction(ISD::STRICT_FP_EXTEND, VT, Custom);
      setOperationAction({ISD::STRICT_FADD, ISD::STRICT_FSUB, ISD::STRICT_FMUL,
                          ISD::STRICT_FDIV, ISD::STRICT_FSQRT},
                          ISD::STRICT_FDIV, ISD::STRICT_FSQRT, ISD::STRICT_FMA},
                         VT, Legal);
    };
@@ -1024,7 +1024,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
        setOperationAction(ISD::STRICT_FP_EXTEND, VT, Custom);
        setOperationAction({ISD::STRICT_FADD, ISD::STRICT_FSUB,
                            ISD::STRICT_FMUL, ISD::STRICT_FDIV,
                            ISD::STRICT_FSQRT},
                            ISD::STRICT_FSQRT, ISD::STRICT_FMA},
                           VT, Custom);
      }
@@ -4506,6 +4506,8 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
                             /*HasMergeOp*/ true);
  case ISD::STRICT_FSQRT:
    return lowerToScalableOp(Op, DAG, RISCVISD::STRICT_FSQRT_VL);
  case ISD::STRICT_FMA:
    return lowerToScalableOp(Op, DAG, RISCVISD::STRICT_VFMADD_VL);
  case ISD::MGATHER:
  case ISD::VP_GATHER:
    return lowerMaskedGather(Op, DAG);
@@ -10386,6 +10388,10 @@ static unsigned negateFMAOpcode(unsigned Opcode, bool NegMul, bool NegAcc) {
    case RISCVISD::VFNMSUB_VL: Opcode = RISCVISD::VFMADD_VL;  break;
    case RISCVISD::VFNMADD_VL: Opcode = RISCVISD::VFMSUB_VL;  break;
    case RISCVISD::VFMSUB_VL:  Opcode = RISCVISD::VFNMADD_VL; break;
    case RISCVISD::STRICT_VFMADD_VL:  Opcode = RISCVISD::STRICT_VFNMSUB_VL; break;
    case RISCVISD::STRICT_VFNMSUB_VL: Opcode = RISCVISD::STRICT_VFMADD_VL;  break;
    case RISCVISD::STRICT_VFNMADD_VL: Opcode = RISCVISD::STRICT_VFMSUB_VL;  break;
    case RISCVISD::STRICT_VFMSUB_VL:  Opcode = RISCVISD::STRICT_VFNMADD_VL; break;
    }
    // clang-format on
  }
@@ -10399,6 +10405,10 @@ static unsigned negateFMAOpcode(unsigned Opcode, bool NegMul, bool NegAcc) {
    case RISCVISD::VFMSUB_VL:  Opcode = RISCVISD::VFMADD_VL;  break;
    case RISCVISD::VFNMADD_VL: Opcode = RISCVISD::VFNMSUB_VL; break;
    case RISCVISD::VFNMSUB_VL: Opcode = RISCVISD::VFNMADD_VL; break;
    case RISCVISD::STRICT_VFMADD_VL:  Opcode = RISCVISD::STRICT_VFMSUB_VL;  break;
    case RISCVISD::STRICT_VFMSUB_VL:  Opcode = RISCVISD::STRICT_VFMADD_VL;  break;
    case RISCVISD::STRICT_VFNMADD_VL: Opcode = RISCVISD::STRICT_VFNMSUB_VL; break;
    case RISCVISD::STRICT_VFNMSUB_VL: Opcode = RISCVISD::STRICT_VFNMADD_VL; break;
    }
    // clang-format on
  }
@@ -11146,13 +11156,19 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
  case RISCVISD::VFMADD_VL:
  case RISCVISD::VFNMADD_VL:
  case RISCVISD::VFMSUB_VL:
  case RISCVISD::VFNMSUB_VL: {
  case RISCVISD::VFNMSUB_VL:
  case RISCVISD::STRICT_VFMADD_VL:
  case RISCVISD::STRICT_VFNMADD_VL:
  case RISCVISD::STRICT_VFMSUB_VL:
  case RISCVISD::STRICT_VFNMSUB_VL: {
    // Fold FNEG_VL into FMA opcodes.
    SDValue A = N->getOperand(0);
    SDValue B = N->getOperand(1);
    SDValue C = N->getOperand(2);
    SDValue Mask = N->getOperand(3);
    SDValue VL = N->getOperand(4);
    // The first operand of strict-fp is chain.
    unsigned Offset = N->isTargetStrictFPOpcode();
    SDValue A = N->getOperand(0 + Offset);
    SDValue B = N->getOperand(1 + Offset);
    SDValue C = N->getOperand(2 + Offset);
    SDValue Mask = N->getOperand(3 + Offset);
    SDValue VL = N->getOperand(4 + Offset);
    auto invertIfNegative = [&Mask, &VL](SDValue &V) {
      if (V.getOpcode() == RISCVISD::FNEG_VL && V.getOperand(1) == Mask &&
@@ -11174,6 +11190,9 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
      return SDValue();
    unsigned NewOpcode = negateFMAOpcode(N->getOpcode(), NegA != NegB, NegC);
    if (Offset > 0)
      return DAG.getNode(NewOpcode, SDLoc(N), N->getVTList(),
                         {N->getOperand(0), A, B, C, Mask, VL});
    return DAG.getNode(NewOpcode, SDLoc(N), N->getValueType(0), A, B, C, Mask,
                       VL);
  }
@@ -14102,6 +14121,10 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
  NODE_NAME_CASE(STRICT_FMUL_VL)
  NODE_NAME_CASE(STRICT_FDIV_VL)
  NODE_NAME_CASE(STRICT_FSQRT_VL)
  NODE_NAME_CASE(STRICT_VFMADD_VL)
  NODE_NAME_CASE(STRICT_VFNMADD_VL)
  NODE_NAME_CASE(STRICT_VFMSUB_VL)
  NODE_NAME_CASE(STRICT_VFNMSUB_VL)
  NODE_NAME_CASE(STRICT_FP_EXTEND_VL)
  NODE_NAME_CASE(VWMUL_VL)
  NODE_NAME_CASE(VWMULU_VL)
+4 −0
Original line number Diff line number Diff line
@@ -335,6 +335,10 @@ enum NodeType : unsigned {
  STRICT_FMUL_VL,
  STRICT_FDIV_VL,
  STRICT_FSQRT_VL,
  STRICT_VFMADD_VL,
  STRICT_VFNMADD_VL,
  STRICT_VFMSUB_VL,
  STRICT_VFNMSUB_VL,
  STRICT_FP_EXTEND_VL,

  // WARNING: Do not add anything in the end unless you want the node to
+20 −20
Original line number Diff line number Diff line
@@ -939,22 +939,22 @@ foreach fvti = AllFloatVectors in {
  // NOTE: We choose VFMADD because it has the most commuting freedom. So it
  // works best with how TwoAddressInstructionPass tries commuting.
  defvar suffix = fvti.LMul.MX;
  def : Pat<(fvti.Vector (fma fvti.RegClass:$rs1, fvti.RegClass:$rd,
  def : Pat<(fvti.Vector (any_fma fvti.RegClass:$rs1, fvti.RegClass:$rd,
                                  fvti.RegClass:$rs2)),
            (!cast<Instruction>("PseudoVFMADD_VV_"# suffix)
                 fvti.RegClass:$rd, fvti.RegClass:$rs1, fvti.RegClass:$rs2,
                 fvti.AVL, fvti.Log2SEW, TAIL_AGNOSTIC)>;
  def : Pat<(fvti.Vector (fma fvti.RegClass:$rs1, fvti.RegClass:$rd,
  def : Pat<(fvti.Vector (any_fma fvti.RegClass:$rs1, fvti.RegClass:$rd,
                                  (fneg fvti.RegClass:$rs2))),
            (!cast<Instruction>("PseudoVFMSUB_VV_"# suffix)
                 fvti.RegClass:$rd, fvti.RegClass:$rs1, fvti.RegClass:$rs2,
                 fvti.AVL, fvti.Log2SEW, TAIL_AGNOSTIC)>;
  def : Pat<(fvti.Vector (fma (fneg fvti.RegClass:$rs1), fvti.RegClass:$rd,
  def : Pat<(fvti.Vector (any_fma (fneg fvti.RegClass:$rs1), fvti.RegClass:$rd,
                                  (fneg fvti.RegClass:$rs2))),
            (!cast<Instruction>("PseudoVFNMADD_VV_"# suffix)
                 fvti.RegClass:$rd, fvti.RegClass:$rs1, fvti.RegClass:$rs2,
                 fvti.AVL, fvti.Log2SEW, TAIL_AGNOSTIC)>;
  def : Pat<(fvti.Vector (fma (fneg fvti.RegClass:$rs1), fvti.RegClass:$rd,
  def : Pat<(fvti.Vector (any_fma (fneg fvti.RegClass:$rs1), fvti.RegClass:$rd,
                                  fvti.RegClass:$rs2)),
            (!cast<Instruction>("PseudoVFNMSUB_VV_"# suffix)
                 fvti.RegClass:$rd, fvti.RegClass:$rs1, fvti.RegClass:$rs2,
@@ -962,35 +962,35 @@ foreach fvti = AllFloatVectors in {

  // The choice of VFMADD here is arbitrary, vfmadd.vf and vfmacc.vf are equally
  // commutable.
  def : Pat<(fvti.Vector (fma (SplatFPOp fvti.ScalarRegClass:$rs1),
  def : Pat<(fvti.Vector (any_fma (SplatFPOp fvti.ScalarRegClass:$rs1),
                                  fvti.RegClass:$rd, fvti.RegClass:$rs2)),
            (!cast<Instruction>("PseudoVFMADD_V" # fvti.ScalarSuffix # "_" # suffix)
                 fvti.RegClass:$rd, fvti.ScalarRegClass:$rs1, fvti.RegClass:$rs2,
                 fvti.AVL, fvti.Log2SEW, TAIL_AGNOSTIC)>;
  def : Pat<(fvti.Vector (fma (SplatFPOp fvti.ScalarRegClass:$rs1),
  def : Pat<(fvti.Vector (any_fma (SplatFPOp fvti.ScalarRegClass:$rs1),
                                  fvti.RegClass:$rd, (fneg fvti.RegClass:$rs2))),
            (!cast<Instruction>("PseudoVFMSUB_V" # fvti.ScalarSuffix # "_" # suffix)
                 fvti.RegClass:$rd, fvti.ScalarRegClass:$rs1, fvti.RegClass:$rs2,
                 fvti.AVL, fvti.Log2SEW, TAIL_AGNOSTIC)>;

  def : Pat<(fvti.Vector (fma (SplatFPOp fvti.ScalarRegClass:$rs1),
  def : Pat<(fvti.Vector (any_fma (SplatFPOp fvti.ScalarRegClass:$rs1),
                                  (fneg fvti.RegClass:$rd), (fneg fvti.RegClass:$rs2))),
            (!cast<Instruction>("PseudoVFNMADD_V" # fvti.ScalarSuffix # "_" # suffix)
                 fvti.RegClass:$rd, fvti.ScalarRegClass:$rs1, fvti.RegClass:$rs2,
                 fvti.AVL, fvti.Log2SEW, TAIL_AGNOSTIC)>;
  def : Pat<(fvti.Vector (fma (SplatFPOp fvti.ScalarRegClass:$rs1),
  def : Pat<(fvti.Vector (any_fma (SplatFPOp fvti.ScalarRegClass:$rs1),
                                  (fneg fvti.RegClass:$rd), fvti.RegClass:$rs2)),
            (!cast<Instruction>("PseudoVFNMSUB_V" # fvti.ScalarSuffix # "_" # suffix)
                 fvti.RegClass:$rd, fvti.ScalarRegClass:$rs1, fvti.RegClass:$rs2,
                 fvti.AVL, fvti.Log2SEW, TAIL_AGNOSTIC)>;

  // The splat might be negated.
  def : Pat<(fvti.Vector (fma (fneg (SplatFPOp fvti.ScalarRegClass:$rs1)),
  def : Pat<(fvti.Vector (any_fma (fneg (SplatFPOp fvti.ScalarRegClass:$rs1)),
                                  fvti.RegClass:$rd, (fneg fvti.RegClass:$rs2))),
            (!cast<Instruction>("PseudoVFNMADD_V" # fvti.ScalarSuffix # "_" # suffix)
                 fvti.RegClass:$rd, fvti.ScalarRegClass:$rs1, fvti.RegClass:$rs2,
                 fvti.AVL, fvti.Log2SEW, TAIL_AGNOSTIC)>;
  def : Pat<(fvti.Vector (fma (fneg (SplatFPOp fvti.ScalarRegClass:$rs1)),
  def : Pat<(fvti.Vector (any_fma (fneg (SplatFPOp fvti.ScalarRegClass:$rs1)),
                                  fvti.RegClass:$rd, fvti.RegClass:$rs2)),
            (!cast<Instruction>("PseudoVFNMSUB_V" # fvti.ScalarSuffix # "_" # suffix)
                 fvti.RegClass:$rd, fvti.ScalarRegClass:$rs1, fvti.RegClass:$rs2,
+23 −5
Original line number Diff line number Diff line
@@ -141,6 +141,24 @@ def riscv_vfnmadd_vl : SDNode<"RISCVISD::VFNMADD_VL", SDT_RISCVVecFMA_VL, [SDNPC
def riscv_vfmsub_vl : SDNode<"RISCVISD::VFMSUB_VL", SDT_RISCVVecFMA_VL, [SDNPCommutative]>;
def riscv_vfnmsub_vl : SDNode<"RISCVISD::VFNMSUB_VL", SDT_RISCVVecFMA_VL, [SDNPCommutative]>;

def riscv_strict_vfmadd_vl : SDNode<"RISCVISD::STRICT_VFMADD_VL", SDT_RISCVVecFMA_VL, [SDNPCommutative, SDNPHasChain]>;
def riscv_strict_vfnmadd_vl : SDNode<"RISCVISD::STRICT_VFNMADD_VL", SDT_RISCVVecFMA_VL, [SDNPCommutative, SDNPHasChain]>;
def riscv_strict_vfmsub_vl : SDNode<"RISCVISD::STRICT_VFMSUB_VL", SDT_RISCVVecFMA_VL, [SDNPCommutative, SDNPHasChain]>;
def riscv_strict_vfnmsub_vl : SDNode<"RISCVISD::STRICT_VFNMSUB_VL", SDT_RISCVVecFMA_VL, [SDNPCommutative, SDNPHasChain]>;

def any_riscv_vfmadd_vl : PatFrags<(ops node:$rs1, node:$rs2, node:$rs3, node:$mask, node:$vl),
                        [(riscv_vfmadd_vl node:$rs1, node:$rs2, node:$rs3, node:$mask, node:$vl),
                         (riscv_strict_vfmadd_vl node:$rs1, node:$rs2, node:$rs3, node:$mask, node:$vl)]>;
def any_riscv_vfnmadd_vl : PatFrags<(ops node:$rs1, node:$rs2, node:$rs3, node:$mask, node:$vl),
                        [(riscv_vfnmadd_vl node:$rs1, node:$rs2, node:$rs3, node:$mask, node:$vl),
                         (riscv_strict_vfnmadd_vl node:$rs1, node:$rs2, node:$rs3, node:$mask, node:$vl)]>;
def any_riscv_vfmsub_vl : PatFrags<(ops node:$rs1, node:$rs2, node:$rs3, node:$mask, node:$vl),
                        [(riscv_vfmsub_vl node:$rs1, node:$rs2, node:$rs3, node:$mask, node:$vl),
                         (riscv_strict_vfmsub_vl node:$rs1, node:$rs2, node:$rs3, node:$mask, node:$vl)]>;
def any_riscv_vfnmsub_vl : PatFrags<(ops node:$rs1, node:$rs2, node:$rs3, node:$mask, node:$vl),
                        [(riscv_vfnmsub_vl node:$rs1, node:$rs2, node:$rs3, node:$mask, node:$vl),
                         (riscv_strict_vfnmsub_vl node:$rs1, node:$rs2, node:$rs3, node:$mask, node:$vl)]>;

def SDT_RISCVFPRoundOp_VL  : SDTypeProfile<1, 3, [
  SDTCisFP<0>, SDTCisFP<1>, SDTCisOpSmallerThanOp<0, 1>, SDTCisSameNumEltsAs<0, 1>,
  SDTCVecEltisVT<2, i1>, SDTCisSameNumEltsAs<1, 2>, SDTCisVT<3, XLenVT>
@@ -1395,7 +1413,7 @@ multiclass VPatNarrowShiftSplat_WX_WI<SDNode op, string instruction_name> {
  }
}

multiclass VPatFPMulAddVL_VV_VF<SDNode vop, string instruction_name> {
multiclass VPatFPMulAddVL_VV_VF<SDPatternOperator vop, string instruction_name> {
  foreach vti = AllFloatVectors in {
  defvar suffix = vti.LMul.MX;
  def : Pat<(vti.Vector (vop vti.RegClass:$rs1, vti.RegClass:$rd,
@@ -1783,10 +1801,10 @@ defm : VPatBinaryFPVL_R_VF_E<any_riscv_fdiv_vl, "PseudoVFRDIV">;
defm : VPatWidenBinaryFPVL_VV_VF<riscv_fmul_vl, riscv_fpextend_vl_oneuse, "PseudoVFWMUL">;

// 13.6 Vector Single-Width Floating-Point Fused Multiply-Add Instructions.
defm : VPatFPMulAddVL_VV_VF<riscv_vfmadd_vl,  "PseudoVFMADD">;
defm : VPatFPMulAddVL_VV_VF<riscv_vfmsub_vl,  "PseudoVFMSUB">;
defm : VPatFPMulAddVL_VV_VF<riscv_vfnmadd_vl, "PseudoVFNMADD">;
defm : VPatFPMulAddVL_VV_VF<riscv_vfnmsub_vl, "PseudoVFNMSUB">;
defm : VPatFPMulAddVL_VV_VF<any_riscv_vfmadd_vl,  "PseudoVFMADD">;
defm : VPatFPMulAddVL_VV_VF<any_riscv_vfmsub_vl,  "PseudoVFMSUB">;
defm : VPatFPMulAddVL_VV_VF<any_riscv_vfnmadd_vl, "PseudoVFNMADD">;
defm : VPatFPMulAddVL_VV_VF<any_riscv_vfnmsub_vl, "PseudoVFNMSUB">;
defm : VPatFPMulAccVL_VV_VF<riscv_vfmadd_vl_oneuse,  "PseudoVFMACC">;
defm : VPatFPMulAccVL_VV_VF<riscv_vfmsub_vl_oneuse,  "PseudoVFMSAC">;
defm : VPatFPMulAccVL_VV_VF<riscv_vfnmadd_vl_oneuse, "PseudoVFNMACC">;
+298 −0

File added.

Preview size limit exceeded, changes collapsed.

Loading