Commit 57d96ab5 authored by David Green's avatar David Green
Browse files

[ARM] Add some VCMP folding and canonicalisation

The VCMP instructions in MVE can accept a register or ZR, but only as
the right hand operator. Most of the time this will already be correct
because the icmp will have been canonicalised that way already. There
are some cases in the lowering of float conditions that this will not
apply to though. This code should fix up those cases.

Differential Revision: https://reviews.llvm.org/D70822
parent 63aff5cd
Loading
Loading
Loading
Loading
+0 −19
Original line number Diff line number Diff line
@@ -2723,25 +2723,6 @@ static bool isSuitableForMask(MachineInstr *&MI, unsigned SrcReg,
  return false;
}

/// getSwappedCondition - assume the flags are set by MI(a,b), return
/// the condition code if we modify the instructions such that flags are
/// set by MI(b,a).
inline static ARMCC::CondCodes getSwappedCondition(ARMCC::CondCodes CC) {
  switch (CC) {
  default: return ARMCC::AL;
  case ARMCC::EQ: return ARMCC::EQ;
  case ARMCC::NE: return ARMCC::NE;
  case ARMCC::HS: return ARMCC::LS;
  case ARMCC::LO: return ARMCC::HI;
  case ARMCC::HI: return ARMCC::LO;
  case ARMCC::LS: return ARMCC::HS;
  case ARMCC::GE: return ARMCC::LE;
  case ARMCC::LT: return ARMCC::GT;
  case ARMCC::GT: return ARMCC::LT;
  case ARMCC::LE: return ARMCC::GE;
  }
}

/// getCmpToAddCondition - assume the flags are set by CMP(a,b), return
/// the condition code if we modify the instructions such that flags are
/// set by ADD(a,b,X).
+43 −8
Original line number Diff line number Diff line
@@ -8993,6 +8993,12 @@ static SDValue LowerPredicateStore(SDValue Op, SelectionDAG &DAG) {
      ST->getMemOperand());
}
static bool isZeroVector(SDValue N) {
  return (ISD::isBuildVectorAllZeros(N.getNode()) ||
          (N->getOpcode() == ARMISD::VMOVIMM &&
           isNullConstant(N->getOperand(0))));
}
static SDValue LowerMLOAD(SDValue Op, SelectionDAG &DAG) {
  MaskedLoadSDNode *N = cast<MaskedLoadSDNode>(Op.getNode());
  MVT VT = Op.getSimpleValueType();
@@ -9000,13 +9006,7 @@ static SDValue LowerMLOAD(SDValue Op, SelectionDAG &DAG) {
  SDValue PassThru = N->getPassThru();
  SDLoc dl(Op);
  auto IsZero = [](SDValue PassThru) {
    return (ISD::isBuildVectorAllZeros(PassThru.getNode()) ||
      (PassThru->getOpcode() == ARMISD::VMOVIMM &&
       isNullConstant(PassThru->getOperand(0))));
  };
  if (IsZero(PassThru))
  if (isZeroVector(PassThru))
    return Op;
  // MVE Masked loads use zero as the passthru value. Here we convert undef to
@@ -9020,7 +9020,7 @@ static SDValue LowerMLOAD(SDValue Op, SelectionDAG &DAG) {
  SDValue Combo = NewLoad;
  if (!PassThru.isUndef() &&
      (PassThru.getOpcode() != ISD::BITCAST ||
       !IsZero(PassThru->getOperand(0))))
       !isZeroVector(PassThru->getOperand(0))))
    Combo = DAG.getNode(ISD::VSELECT, dl, VT, Mask, NewLoad, PassThru);
  return DAG.getMergeValues({Combo, NewLoad.getValue(1)}, dl);
}
@@ -12743,6 +12743,39 @@ PerformPREDICATE_CASTCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
  return SDValue();
}
static SDValue PerformVCMPCombine(SDNode *N,
                                  TargetLowering::DAGCombinerInfo &DCI,
                                  const ARMSubtarget *Subtarget) {
  if (!Subtarget->hasMVEIntegerOps())
    return SDValue();
  EVT VT = N->getValueType(0);
  SDValue Op0 = N->getOperand(0);
  SDValue Op1 = N->getOperand(1);
  ARMCC::CondCodes Cond =
      (ARMCC::CondCodes)cast<ConstantSDNode>(N->getOperand(2))->getZExtValue();
  SDLoc dl(N);
  // vcmp X, 0, cc -> vcmpz X, cc
  if (isZeroVector(Op1))
    return DCI.DAG.getNode(ARMISD::VCMPZ, dl, VT, Op0,
                           N->getOperand(2));
  unsigned SwappedCond = getSwappedCondition(Cond);
  if (isValidMVECond(SwappedCond, VT.isFloatingPoint())) {
    // vcmp 0, X, cc -> vcmpz X, reversed(cc)
    if (isZeroVector(Op0))
      return DCI.DAG.getNode(ARMISD::VCMPZ, dl, VT, Op1,
                             DCI.DAG.getConstant(SwappedCond, dl, MVT::i32));
    // vcmp vdup(Y), X, cc -> vcmp X, vdup(Y), reversed(cc)
    if (Op0->getOpcode() == ARMISD::VDUP && Op1->getOpcode() != ARMISD::VDUP)
      return DCI.DAG.getNode(ARMISD::VCMP, dl, VT, Op1, Op0,
                             DCI.DAG.getConstant(SwappedCond, dl, MVT::i32));
  }
  return SDValue();
}
/// PerformInsertEltCombine - Target-specific dag combine xforms for
/// ISD::INSERT_VECTOR_ELT.
static SDValue PerformInsertEltCombine(SDNode *N,
@@ -14423,6 +14456,8 @@ SDValue ARMTargetLowering::PerformDAGCombine(SDNode *N,
    return PerformARMBUILD_VECTORCombine(N, DCI);
  case ARMISD::PREDICATE_CAST:
    return PerformPREDICATE_CASTCombine(N, DCI);
  case ARMISD::VCMP:
    return PerformVCMPCombine(N, DCI, Subtarget);
  case ARMISD::SMULWB: {
    unsigned BitWidth = N->getValueType(0).getSizeInBits();
    APInt DemandedMask = APInt::getLowBitsSet(BitWidth, 16);
+19 −0
Original line number Diff line number Diff line
@@ -64,6 +64,25 @@ inline static CondCodes getOppositeCondition(CondCodes CC) {
  case LE: return GT;
  }
}

/// getSwappedCondition - assume the flags are set by MI(a,b), return
/// the condition code if we modify the instructions such that flags are
/// set by MI(b,a).
inline static ARMCC::CondCodes getSwappedCondition(ARMCC::CondCodes CC) {
  switch (CC) {
  default: return ARMCC::AL;
  case ARMCC::EQ: return ARMCC::EQ;
  case ARMCC::NE: return ARMCC::NE;
  case ARMCC::HS: return ARMCC::LS;
  case ARMCC::LO: return ARMCC::HI;
  case ARMCC::HI: return ARMCC::LO;
  case ARMCC::LS: return ARMCC::HS;
  case ARMCC::GE: return ARMCC::LE;
  case ARMCC::LT: return ARMCC::GT;
  case ARMCC::GT: return ARMCC::LT;
  case ARMCC::LE: return ARMCC::GE;
  }
}
} // end namespace ARMCC

namespace ARMVCC {
+32 −48
Original line number Diff line number Diff line
@@ -107,9 +107,8 @@ define arm_aapcs_vfpcc <4 x float> @vcmp_one_v4f32(<4 x float> %src, <4 x float>
;
; CHECK-MVEFP-LABEL: vcmp_one_v4f32:
; CHECK-MVEFP:       @ %bb.0: @ %entry
; CHECK-MVEFP-NEXT:    vmov.i32 q3, #0x0
; CHECK-MVEFP-NEXT:    vpt.f32 le, q3, q0
; CHECK-MVEFP-NEXT:    vcmpt.f32 le, q0, q3
; CHECK-MVEFP-NEXT:    vpt.f32 ge, q0, zr
; CHECK-MVEFP-NEXT:    vcmpt.f32 le, q0, zr
; CHECK-MVEFP-NEXT:    vpnot
; CHECK-MVEFP-NEXT:    vpsel q0, q1, q2
; CHECK-MVEFP-NEXT:    bx lr
@@ -380,9 +379,8 @@ define arm_aapcs_vfpcc <4 x float> @vcmp_ueq_v4f32(<4 x float> %src, <4 x float>
;
; CHECK-MVEFP-LABEL: vcmp_ueq_v4f32:
; CHECK-MVEFP:       @ %bb.0: @ %entry
; CHECK-MVEFP-NEXT:    vmov.i32 q3, #0x0
; CHECK-MVEFP-NEXT:    vpt.f32 le, q3, q0
; CHECK-MVEFP-NEXT:    vcmpt.f32 le, q0, q3
; CHECK-MVEFP-NEXT:    vpt.f32 ge, q0, zr
; CHECK-MVEFP-NEXT:    vcmpt.f32 le, q0, zr
; CHECK-MVEFP-NEXT:    vpsel q0, q1, q2
; CHECK-MVEFP-NEXT:    bx lr
entry:
@@ -698,9 +696,8 @@ define arm_aapcs_vfpcc <4 x float> @vcmp_ord_v4f32(<4 x float> %src, <4 x float>
;
; CHECK-MVEFP-LABEL: vcmp_ord_v4f32:
; CHECK-MVEFP:       @ %bb.0: @ %entry
; CHECK-MVEFP-NEXT:    vmov.i32 q3, #0x0
; CHECK-MVEFP-NEXT:    vpt.f32 le, q3, q0
; CHECK-MVEFP-NEXT:    vcmpt.f32 lt, q0, q3
; CHECK-MVEFP-NEXT:    vpt.f32 ge, q0, zr
; CHECK-MVEFP-NEXT:    vcmpt.f32 lt, q0, zr
; CHECK-MVEFP-NEXT:    vpnot
; CHECK-MVEFP-NEXT:    vpsel q0, q1, q2
; CHECK-MVEFP-NEXT:    bx lr
@@ -753,9 +750,8 @@ define arm_aapcs_vfpcc <4 x float> @vcmp_uno_v4f32(<4 x float> %src, <4 x float>
;
; CHECK-MVEFP-LABEL: vcmp_uno_v4f32:
; CHECK-MVEFP:       @ %bb.0: @ %entry
; CHECK-MVEFP-NEXT:    vmov.i32 q3, #0x0
; CHECK-MVEFP-NEXT:    vpt.f32 le, q3, q0
; CHECK-MVEFP-NEXT:    vcmpt.f32 lt, q0, q3
; CHECK-MVEFP-NEXT:    vpt.f32 ge, q0, zr
; CHECK-MVEFP-NEXT:    vcmpt.f32 lt, q0, zr
; CHECK-MVEFP-NEXT:    vpsel q0, q1, q2
; CHECK-MVEFP-NEXT:    bx lr
entry:
@@ -1013,9 +1009,8 @@ define arm_aapcs_vfpcc <8 x half> @vcmp_one_v8f16(<8 x half> %src, <8 x half> %a
;
; CHECK-MVEFP-LABEL: vcmp_one_v8f16:
; CHECK-MVEFP:       @ %bb.0: @ %entry
; CHECK-MVEFP-NEXT:    vmov.i32 q3, #0x0
; CHECK-MVEFP-NEXT:    vpt.f16 le, q3, q0
; CHECK-MVEFP-NEXT:    vcmpt.f16 le, q0, q3
; CHECK-MVEFP-NEXT:    vpt.f16 ge, q0, zr
; CHECK-MVEFP-NEXT:    vcmpt.f32 le, q0, zr
; CHECK-MVEFP-NEXT:    vpnot
; CHECK-MVEFP-NEXT:    vpsel q0, q1, q2
; CHECK-MVEFP-NEXT:    bx lr
@@ -1632,9 +1627,8 @@ define arm_aapcs_vfpcc <8 x half> @vcmp_ueq_v8f16(<8 x half> %src, <8 x half> %a
;
; CHECK-MVEFP-LABEL: vcmp_ueq_v8f16:
; CHECK-MVEFP:       @ %bb.0: @ %entry
; CHECK-MVEFP-NEXT:    vmov.i32 q3, #0x0
; CHECK-MVEFP-NEXT:    vpt.f16 le, q3, q0
; CHECK-MVEFP-NEXT:    vcmpt.f16 le, q0, q3
; CHECK-MVEFP-NEXT:    vpt.f16 ge, q0, zr
; CHECK-MVEFP-NEXT:    vcmpt.f32 le, q0, zr
; CHECK-MVEFP-NEXT:    vpsel q0, q1, q2
; CHECK-MVEFP-NEXT:    bx lr
entry:
@@ -2358,9 +2352,8 @@ define arm_aapcs_vfpcc <8 x half> @vcmp_ord_v8f16(<8 x half> %src, <8 x half> %a
;
; CHECK-MVEFP-LABEL: vcmp_ord_v8f16:
; CHECK-MVEFP:       @ %bb.0: @ %entry
; CHECK-MVEFP-NEXT:    vmov.i32 q3, #0x0
; CHECK-MVEFP-NEXT:    vpt.f16 le, q3, q0
; CHECK-MVEFP-NEXT:    vcmpt.f16 lt, q0, q3
; CHECK-MVEFP-NEXT:    vpt.f16 ge, q0, zr
; CHECK-MVEFP-NEXT:    vcmpt.f32 lt, q0, zr
; CHECK-MVEFP-NEXT:    vpnot
; CHECK-MVEFP-NEXT:    vpsel q0, q1, q2
; CHECK-MVEFP-NEXT:    bx lr
@@ -2481,9 +2474,8 @@ define arm_aapcs_vfpcc <8 x half> @vcmp_uno_v8f16(<8 x half> %src, <8 x half> %a
;
; CHECK-MVEFP-LABEL: vcmp_uno_v8f16:
; CHECK-MVEFP:       @ %bb.0: @ %entry
; CHECK-MVEFP-NEXT:    vmov.i32 q3, #0x0
; CHECK-MVEFP-NEXT:    vpt.f16 le, q3, q0
; CHECK-MVEFP-NEXT:    vcmpt.f16 lt, q0, q3
; CHECK-MVEFP-NEXT:    vpt.f16 ge, q0, zr
; CHECK-MVEFP-NEXT:    vcmpt.f32 lt, q0, zr
; CHECK-MVEFP-NEXT:    vpsel q0, q1, q2
; CHECK-MVEFP-NEXT:    bx lr
entry:
@@ -2600,9 +2592,8 @@ define arm_aapcs_vfpcc <4 x float> @vcmp_r_one_v4f32(<4 x float> %src, <4 x floa
;
; CHECK-MVEFP-LABEL: vcmp_r_one_v4f32:
; CHECK-MVEFP:       @ %bb.0: @ %entry
; CHECK-MVEFP-NEXT:    vmov.i32 q3, #0x0
; CHECK-MVEFP-NEXT:    vpt.f32 le, q0, q3
; CHECK-MVEFP-NEXT:    vcmpt.f32 le, q3, q0
; CHECK-MVEFP-NEXT:    vpt.f32 le, q0, zr
; CHECK-MVEFP-NEXT:    vcmpt.f32 ge, q0, zr
; CHECK-MVEFP-NEXT:    vpnot
; CHECK-MVEFP-NEXT:    vpsel q0, q1, q2
; CHECK-MVEFP-NEXT:    bx lr
@@ -2873,9 +2864,8 @@ define arm_aapcs_vfpcc <4 x float> @vcmp_r_ueq_v4f32(<4 x float> %src, <4 x floa
;
; CHECK-MVEFP-LABEL: vcmp_r_ueq_v4f32:
; CHECK-MVEFP:       @ %bb.0: @ %entry
; CHECK-MVEFP-NEXT:    vmov.i32 q3, #0x0
; CHECK-MVEFP-NEXT:    vpt.f32 le, q0, q3
; CHECK-MVEFP-NEXT:    vcmpt.f32 le, q3, q0
; CHECK-MVEFP-NEXT:    vpt.f32 le, q0, zr
; CHECK-MVEFP-NEXT:    vcmpt.f32 ge, q0, zr
; CHECK-MVEFP-NEXT:    vpsel q0, q1, q2
; CHECK-MVEFP-NEXT:    bx lr
entry:
@@ -3191,9 +3181,8 @@ define arm_aapcs_vfpcc <4 x float> @vcmp_r_ord_v4f32(<4 x float> %src, <4 x floa
;
; CHECK-MVEFP-LABEL: vcmp_r_ord_v4f32:
; CHECK-MVEFP:       @ %bb.0: @ %entry
; CHECK-MVEFP-NEXT:    vmov.i32 q3, #0x0
; CHECK-MVEFP-NEXT:    vpt.f32 le, q0, q3
; CHECK-MVEFP-NEXT:    vcmpt.f32 lt, q3, q0
; CHECK-MVEFP-NEXT:    vpt.f32 le, q0, zr
; CHECK-MVEFP-NEXT:    vcmpt.f32 gt, q0, zr
; CHECK-MVEFP-NEXT:    vpnot
; CHECK-MVEFP-NEXT:    vpsel q0, q1, q2
; CHECK-MVEFP-NEXT:    bx lr
@@ -3246,9 +3235,8 @@ define arm_aapcs_vfpcc <4 x float> @vcmp_r_uno_v4f32(<4 x float> %src, <4 x floa
;
; CHECK-MVEFP-LABEL: vcmp_r_uno_v4f32:
; CHECK-MVEFP:       @ %bb.0: @ %entry
; CHECK-MVEFP-NEXT:    vmov.i32 q3, #0x0
; CHECK-MVEFP-NEXT:    vpt.f32 le, q0, q3
; CHECK-MVEFP-NEXT:    vcmpt.f32 lt, q3, q0
; CHECK-MVEFP-NEXT:    vpt.f32 le, q0, zr
; CHECK-MVEFP-NEXT:    vcmpt.f32 gt, q0, zr
; CHECK-MVEFP-NEXT:    vpsel q0, q1, q2
; CHECK-MVEFP-NEXT:    bx lr
entry:
@@ -3506,9 +3494,8 @@ define arm_aapcs_vfpcc <8 x half> @vcmp_r_one_v8f16(<8 x half> %src, <8 x half>
;
; CHECK-MVEFP-LABEL: vcmp_r_one_v8f16:
; CHECK-MVEFP:       @ %bb.0: @ %entry
; CHECK-MVEFP-NEXT:    vmov.i32 q3, #0x0
; CHECK-MVEFP-NEXT:    vpt.f16 le, q0, q3
; CHECK-MVEFP-NEXT:    vcmpt.f16 le, q3, q0
; CHECK-MVEFP-NEXT:    vpt.f16 le, q0, zr
; CHECK-MVEFP-NEXT:    vcmpt.f32 ge, q0, zr
; CHECK-MVEFP-NEXT:    vpnot
; CHECK-MVEFP-NEXT:    vpsel q0, q1, q2
; CHECK-MVEFP-NEXT:    bx lr
@@ -4125,9 +4112,8 @@ define arm_aapcs_vfpcc <8 x half> @vcmp_r_ueq_v8f16(<8 x half> %src, <8 x half>
;
; CHECK-MVEFP-LABEL: vcmp_r_ueq_v8f16:
; CHECK-MVEFP:       @ %bb.0: @ %entry
; CHECK-MVEFP-NEXT:    vmov.i32 q3, #0x0
; CHECK-MVEFP-NEXT:    vpt.f16 le, q0, q3
; CHECK-MVEFP-NEXT:    vcmpt.f16 le, q3, q0
; CHECK-MVEFP-NEXT:    vpt.f16 le, q0, zr
; CHECK-MVEFP-NEXT:    vcmpt.f32 ge, q0, zr
; CHECK-MVEFP-NEXT:    vpsel q0, q1, q2
; CHECK-MVEFP-NEXT:    bx lr
entry:
@@ -4851,9 +4837,8 @@ define arm_aapcs_vfpcc <8 x half> @vcmp_r_ord_v8f16(<8 x half> %src, <8 x half>
;
; CHECK-MVEFP-LABEL: vcmp_r_ord_v8f16:
; CHECK-MVEFP:       @ %bb.0: @ %entry
; CHECK-MVEFP-NEXT:    vmov.i32 q3, #0x0
; CHECK-MVEFP-NEXT:    vpt.f16 le, q0, q3
; CHECK-MVEFP-NEXT:    vcmpt.f16 lt, q3, q0
; CHECK-MVEFP-NEXT:    vpt.f16 le, q0, zr
; CHECK-MVEFP-NEXT:    vcmpt.f32 gt, q0, zr
; CHECK-MVEFP-NEXT:    vpnot
; CHECK-MVEFP-NEXT:    vpsel q0, q1, q2
; CHECK-MVEFP-NEXT:    bx lr
@@ -4974,9 +4959,8 @@ define arm_aapcs_vfpcc <8 x half> @vcmp_r_uno_v8f16(<8 x half> %src, <8 x half>
;
; CHECK-MVEFP-LABEL: vcmp_r_uno_v8f16:
; CHECK-MVEFP:       @ %bb.0: @ %entry
; CHECK-MVEFP-NEXT:    vmov.i32 q3, #0x0
; CHECK-MVEFP-NEXT:    vpt.f16 le, q0, q3
; CHECK-MVEFP-NEXT:    vcmpt.f16 lt, q3, q0
; CHECK-MVEFP-NEXT:    vpt.f16 le, q0, zr
; CHECK-MVEFP-NEXT:    vcmpt.f32 gt, q0, zr
; CHECK-MVEFP-NEXT:    vpsel q0, q1, q2
; CHECK-MVEFP-NEXT:    bx lr
entry: