Commit 46b9f14d authored by Andrzej Warzynski's avatar Andrzej Warzynski
Browse files

[AArch64][SVE] Add intrinsics for non-temporal scatters/gathers

Summary:
This patch adds the following intrinsics for non-temporal gather loads
and scatter stores:
  * aarch64_sve_ldnt1_gather_index
  * aarch64_sve_stnt1_scatter_index
These intrinsics implement the "scalar + vector of indices" addressing
mode.

As opposed to regular and first-faulting gathers/scatters, there's no
instruction that would take indices and then scale them. Instead, the
indices for non-temporal gathers/scatters are scaled before the
intrinsics are lowered to `ldnt1` instructions.

The new ISD nodes, GLDNT1_INDEX and SSTNT1_INDEX, are only used as
placeholders so that we can easily identify the cases implemented in
this patch in performGatherLoadCombine and performScatterStoreCombined.
Once encountered, they are replaced with:
  * GLDNT1_INDEX -> SPLAT_VECTOR + SHL + GLDNT1
  * SSTNT1_INDEX -> SPLAT_VECTOR + SHL + SSTNT1

The patterns for lowering ISD::SHL for scalable vectors (required by
this patch) were missing, so these are added too.

Reviewed By: sdesmalen

Differential Revision: https://reviews.llvm.org/D75601
parent a66dc755
Loading
Loading
Loading
Loading
+7 −0
Original line number Diff line number Diff line
@@ -1782,6 +1782,9 @@ def int_aarch64_sve_ldff1_gather_scalar_offset : AdvSIMD_GatherLoad_VS_Intrinsic
// 64 bit unscaled offsets
def int_aarch64_sve_ldnt1_gather : AdvSIMD_GatherLoad_SV_64b_Offsets_Intrinsic;

// 64 bit indices
def int_aarch64_sve_ldnt1_gather_index : AdvSIMD_GatherLoad_SV_64b_Offsets_Intrinsic;

// 32 bit unscaled offsets, zero (zxtw) extended to 64 bits
def int_aarch64_sve_ldnt1_gather_uxtw : AdvSIMD_GatherLoad_SV_32b_Offsets_Intrinsic;

@@ -1829,6 +1832,10 @@ def int_aarch64_sve_st1_scatter_scalar_offset : AdvSIMD_ScatterStore_VS_Intrinsi
// 64 bit unscaled offsets
def int_aarch64_sve_stnt1_scatter : AdvSIMD_ScatterStore_SV_64b_Offsets_Intrinsic;

// 64 bit indices
def int_aarch64_sve_stnt1_scatter_index
    : AdvSIMD_ScatterStore_SV_64b_Offsets_Intrinsic;

// 32 bit unscaled offsets, zero (zxtw) extended to 64 bits
def int_aarch64_sve_stnt1_scatter_uxtw : AdvSIMD_ScatterStore_SV_32b_Offsets_Intrinsic;

+2 −1
Original line number Diff line number Diff line
@@ -5262,7 +5262,8 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
    // amounts.  This catches things like trying to shift an i1024 value by an
    // i8, which is easy to fall into in generic code that uses
    // TLI.getShiftAmount().
    assert(N2.getValueSizeInBits() >= Log2_32_Ceil(N1.getValueSizeInBits()) &&
    assert(N2.getValueType().getScalarSizeInBits().getFixedSize() >=
               Log2_32_Ceil(VT.getScalarSizeInBits().getFixedSize()) &&
           "Invalid use of small shift amount with oversized value!");

    // Always fold shifts of i1 values so the code generator doesn't need to
+25 −0
Original line number Diff line number Diff line
@@ -190,6 +190,11 @@ public:
    return SelectSVELogicalImm(N, VT, Imm);
  }

  template <unsigned Low, unsigned High>
  bool SelectSVEShiftImm64(SDValue N, SDValue &Imm) {
    return SelectSVEShiftImm64(N, Low, High, Imm);
  }

  // Returns a suitable CNT/INC/DEC/RDVL multiplier to calculate VSCALE*N.
  template<signed Min, signed Max, signed Scale, bool Shift>
  bool SelectCntImm(SDValue N, SDValue &Imm) {
@@ -307,6 +312,8 @@ private:
  bool SelectSVELogicalImm(SDValue N, MVT VT, SDValue &Imm);

  bool SelectSVESignedArithImm(SDValue N, SDValue &Imm);
  bool SelectSVEShiftImm64(SDValue N, uint64_t Low, uint64_t High,
                           SDValue &Imm);

  bool SelectSVEArithImm(SDValue N, SDValue &Imm);
  bool SelectSVERegRegAddrMode(SDValue N, unsigned Scale, SDValue &Base,
@@ -3072,6 +3079,24 @@ bool AArch64DAGToDAGISel::SelectSVELogicalImm(SDValue N, MVT VT, SDValue &Imm) {
  return false;
}

// This method is only needed to "cast" i64s into i32s when the value
// is a valid shift which has been splatted into a vector with i64 elements.
// Every other type is fine in tablegen.
bool AArch64DAGToDAGISel::SelectSVEShiftImm64(SDValue N, uint64_t Low,
                                              uint64_t High, SDValue &Imm) {
  if (auto *CN = dyn_cast<ConstantSDNode>(N)) {
    uint64_t ImmVal = CN->getZExtValue();
    SDLoc DL(N);

    if (ImmVal >= Low && ImmVal <= High) {
      Imm = CurDAG->getTargetConstant(ImmVal, DL, MVT::i32);
      return true;
    }
  }

  return false;
}

bool AArch64DAGToDAGISel::trySelectStackSlotTagP(SDNode *N) {
  // tagp(FrameIndex, IRGstack, tag_offset):
  // since the offset between FrameIndex and IRGstack is a compile-time
+37 −0
Original line number Diff line number Diff line
@@ -1440,6 +1440,7 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
  case AArch64ISD::GLDFF1S_IMM:       return "AArch64ISD::GLDFF1S_IMM";
  case AArch64ISD::GLDNT1:            return "AArch64ISD::GLDNT1";
  case AArch64ISD::GLDNT1_INDEX:      return "AArch64ISD::GLDNT1_INDEX";
  case AArch64ISD::GLDNT1S:           return "AArch64ISD::GLDNT1S";
  case AArch64ISD::SST1:              return "AArch64ISD::SST1";
@@ -1451,6 +1452,7 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
  case AArch64ISD::SST1_IMM:          return "AArch64ISD::SST1_IMM";
  case AArch64ISD::SSTNT1:            return "AArch64ISD::SSTNT1";
  case AArch64ISD::SSTNT1_INDEX:      return "AArch64ISD::SSTNT1_INDEX";
  case AArch64ISD::LDP:               return "AArch64ISD::LDP";
  case AArch64ISD::STP:               return "AArch64ISD::STP";
@@ -12628,6 +12630,19 @@ static SDValue performGlobalAddressCombine(SDNode *N, SelectionDAG &DAG,
                     DAG.getConstant(MinOffset, DL, MVT::i64));
}
// Turns the vector of indices into a vector of byte offstes by scaling Offset
// by (BitWidth / 8).
static SDValue getScaledOffsetForBitWidth(SelectionDAG &DAG, SDValue Offset,
                                          SDLoc DL, unsigned BitWidth) {
  assert(Offset.getValueType().isScalableVector() &&
         "This method is only for scalable vectors of offsets");
  SDValue Shift = DAG.getConstant(Log2_32(BitWidth / 8), DL, MVT::i64);
  SDValue SplatShift = DAG.getNode(ISD::SPLAT_VECTOR, DL, MVT::nxv2i64, Shift);
  return DAG.getNode(ISD::SHL, DL, MVT::nxv2i64, Offset, SplatShift);
}
static SDValue performScatterStoreCombine(SDNode *N, SelectionDAG &DAG,
                                          unsigned Opcode,
                                          bool OnlyPackedOffsets = true) {
@@ -12655,6 +12670,15 @@ static SDValue performScatterStoreCombine(SDNode *N, SelectionDAG &DAG,
  // vector of offsets  (that fits into one register)
  SDValue Offset = N->getOperand(5);
  // For "scalar + vector of indices", just scale the indices. This only
  // applies to non-temporal scatters because there's no instruction that takes
  // indicies.
  if (Opcode == AArch64ISD::SSTNT1_INDEX) {
    Offset =
        getScaledOffsetForBitWidth(DAG, Offset, DL, SrcElVT.getSizeInBits());
    Opcode = AArch64ISD::SSTNT1;
  }
  // In the case of non-temporal gather loads there's only one SVE instruction
  // per data-size: "scalar + vector", i.e.
  //    * stnt1{b|h|w|d} { z0.s }, p0/z, [z0.s, x0]
@@ -12749,6 +12773,15 @@ static SDValue performGatherLoadCombine(SDNode *N, SelectionDAG &DAG,
  // vector of offsets  (that fits into one register)
  SDValue Offset = N->getOperand(4);
  // For "scalar + vector of indices", just scale the indices. This only
  // applies to non-temporal gathers because there's no instruction that takes
  // indicies.
  if (Opcode == AArch64ISD::GLDNT1_INDEX) {
    Offset =
        getScaledOffsetForBitWidth(DAG, Offset, DL, RetElVT.getSizeInBits());
    Opcode = AArch64ISD::GLDNT1;
  }
  // In the case of non-temporal gather loads there's only one SVE instruction
  // per data-size: "scalar + vector", i.e.
  //    * ldnt1{b|h|w|d} { z0.s }, p0/z, [z0.s, x0]
@@ -13006,6 +13039,8 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
      return performGatherLoadCombine(N, DAG, AArch64ISD::GLDNT1);
    case Intrinsic::aarch64_sve_ldnt1_gather:
      return performGatherLoadCombine(N, DAG, AArch64ISD::GLDNT1);
    case Intrinsic::aarch64_sve_ldnt1_gather_index:
      return performGatherLoadCombine(N, DAG, AArch64ISD::GLDNT1_INDEX);
    case Intrinsic::aarch64_sve_ldnt1_gather_uxtw:
      return performGatherLoadCombine(N, DAG, AArch64ISD::GLDNT1);
    case Intrinsic::aarch64_sve_ldnf1:
@@ -13020,6 +13055,8 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
      return performScatterStoreCombine(N, DAG, AArch64ISD::SSTNT1);
    case Intrinsic::aarch64_sve_stnt1_scatter:
      return performScatterStoreCombine(N, DAG, AArch64ISD::SSTNT1);
    case Intrinsic::aarch64_sve_stnt1_scatter_index:
      return performScatterStoreCombine(N, DAG, AArch64ISD::SSTNT1_INDEX);
    case Intrinsic::aarch64_sve_ld1_gather:
      return performGatherLoadCombine(N, DAG, AArch64ISD::GLD1);
    case Intrinsic::aarch64_sve_ld1_gather_index:
+2 −0
Original line number Diff line number Diff line
@@ -263,6 +263,7 @@ enum NodeType : unsigned {

  // Non-temporal gather loads
  GLDNT1,
  GLDNT1_INDEX,
  GLDNT1S,

  // Scatter store
@@ -276,6 +277,7 @@ enum NodeType : unsigned {

  // Non-temporal scatter store
  SSTNT1,
  SSTNT1_INDEX,

  // Strict (exception-raising) floating point comparison
  STRICT_FCMP = ISD::FIRST_TARGET_STRICTFP_OPCODE,
Loading