Commit 216f546c authored by Paul Walker's avatar Paul Walker
Browse files

[SVE] Refactor lowering for fixed length MGATHER/MSCATTER.

Lower fixed length MGATHER/MSCATTER operations to scalable vector
equivalents, which are then lowered to SVE specific nodes. This
two stage process is in preparation for making scalable vector
MGATHER/MSCATTER operations legal.

Differential Revision: https://reviews.llvm.org/D125192
parent 86fd1c13
Loading
Loading
Loading
Loading
+98 −69
Original line number Diff line number Diff line
@@ -4696,6 +4696,56 @@ SDValue AArch64TargetLowering::LowerMGATHER(SDValue Op,
                               MGT->getMemOperand(), IndexType, ExtType);
  }
  // Lower fixed length gather to a scalable equivalent.
  if (VT.isFixedLengthVector()) {
    assert(Subtarget->useSVEForFixedLengthVectors() &&
           "Cannot lower when not using SVE for fixed vectors!");
    // NOTE: Handle floating-point as if integer then bitcast the result.
    EVT DataVT = VT.changeVectorElementTypeToInteger();
    MemVT = MemVT.changeVectorElementTypeToInteger();
    // Find the smallest integer fixed length vector we can use for the gather.
    EVT PromotedVT = VT.changeVectorElementType(MVT::i32);
    if (DataVT.getVectorElementType() == MVT::i64 ||
        Index.getValueType().getVectorElementType() == MVT::i64 ||
        Mask.getValueType().getVectorElementType() == MVT::i64)
      PromotedVT = VT.changeVectorElementType(MVT::i64);
    // Promote vector operands except for passthrough, which we know is either
    // undef or zero, and thus best constructed directly.
    unsigned ExtOpcode = IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
    Index = DAG.getNode(ExtOpcode, DL, PromotedVT, Index);
    Mask = DAG.getNode(ISD::SIGN_EXTEND, DL, PromotedVT, Mask);
    // A promoted result type forces the need for an extending load.
    if (PromotedVT != DataVT && ExtType == ISD::NON_EXTLOAD)
      ExtType = ISD::EXTLOAD;
    EVT ContainerVT = getContainerForFixedLengthVector(DAG, PromotedVT);
    // Convert fixed length vector operands to scalable.
    MemVT = ContainerVT.changeVectorElementType(MemVT.getVectorElementType());
    Index = convertToScalableVector(DAG, ContainerVT, Index);
    Mask = convertFixedMaskToScalableVector(Mask, DAG);
    PassThru = PassThru->isUndef() ? DAG.getUNDEF(ContainerVT)
                                   : DAG.getConstant(0, DL, ContainerVT);
    // Emit equivalent scalable vector gather.
    SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale};
    SDValue Load =
        DAG.getMaskedGather(DAG.getVTList(ContainerVT, MVT::Other), MemVT, DL,
                            Ops, MGT->getMemOperand(), IndexType, ExtType);
    // Extract fixed length data then convert to the required result type.
    SDValue Result = convertFromScalableVector(DAG, PromotedVT, Load);
    Result = DAG.getNode(ISD::TRUNCATE, DL, DataVT, Result);
    if (VT.isFloatingPoint())
      Result = DAG.getNode(ISD::BITCAST, DL, VT, Result);
    return DAG.getMergeValues({Result, Load.getValue(1)}, DL);
  }
  bool IdxNeedsExtend =
      getGatherScatterIndexIsExtended(Index) ||
      Index.getSimpleValueType().getVectorElementType() == MVT::i32;
@@ -4703,26 +4753,8 @@ SDValue AArch64TargetLowering::LowerMGATHER(SDValue Op,
  EVT IndexVT = Index.getSimpleValueType();
  SDValue InputVT = DAG.getValueType(MemVT);
  bool IsFixedLength = MGT->getMemoryVT().isFixedLengthVector();
  if (IsFixedLength) {
    assert(Subtarget->useSVEForFixedLengthVectors() &&
           "Cannot lower when not using SVE for fixed vectors");
    if (MemVT.getScalarSizeInBits() <= IndexVT.getScalarSizeInBits()) {
      IndexVT = getContainerForFixedLengthVector(DAG, IndexVT);
      MemVT = IndexVT.changeVectorElementType(MemVT.getVectorElementType());
    } else {
      MemVT = getContainerForFixedLengthVector(DAG, MemVT);
      IndexVT = MemVT.changeTypeToInteger();
    }
    InputVT = DAG.getValueType(MemVT.changeTypeToInteger());
    Mask = DAG.getNode(
        ISD::SIGN_EXTEND, DL,
        VT.changeVectorElementType(IndexVT.getVectorElementType()), Mask);
  }
  // Handle FP data by using an integer gather and casting the result.
  if (VT.isFloatingPoint() && !IsFixedLength)
  if (VT.isFloatingPoint())
    InputVT = DAG.getValueType(MemVT.changeVectorElementTypeToInteger());
  SDVTList VTs = DAG.getVTList(IndexVT, MVT::Other);
@@ -4737,25 +4769,11 @@ SDValue AArch64TargetLowering::LowerMGATHER(SDValue Op,
  if (ExtType == ISD::SEXTLOAD)
    Opcode = getSignExtendedGatherOpcode(Opcode);
  if (IsFixedLength) {
    if (Index.getSimpleValueType().isFixedLengthVector())
      Index = convertToScalableVector(DAG, IndexVT, Index);
    if (BasePtr.getSimpleValueType().isFixedLengthVector())
      BasePtr = convertToScalableVector(DAG, IndexVT, BasePtr);
    Mask = convertFixedMaskToScalableVector(Mask, DAG);
  }
  SDValue Ops[] = {Chain, Mask, BasePtr, Index, InputVT};
  SDValue Result = DAG.getNode(Opcode, DL, VTs, Ops);
  Chain = Result.getValue(1);
  if (IsFixedLength) {
    Result = convertFromScalableVector(
        DAG, VT.changeVectorElementType(IndexVT.getVectorElementType()),
        Result);
    Result = DAG.getNode(ISD::TRUNCATE, DL, VT.changeTypeToInteger(), Result);
    Result = DAG.getNode(ISD::BITCAST, DL, VT, Result);
  } else if (VT.isFloatingPoint())
  if (VT.isFloatingPoint())
    Result = getSVESafeBitCast(VT, Result, DAG);
  return DAG.getMergeValues({Result, Chain}, DL);
@@ -4775,6 +4793,7 @@ SDValue AArch64TargetLowering::LowerMSCATTER(SDValue Op,
  EVT VT = StoreVal.getValueType();
  EVT MemVT = MSC->getMemoryVT();
  ISD::MemIndexType IndexType = MSC->getIndexType();
  bool Truncating = MSC->isTruncatingStore();
  bool IsScaled = MSC->isIndexScaled();
  bool IsSigned = MSC->isIndexSigned();
@@ -4791,42 +4810,60 @@ SDValue AArch64TargetLowering::LowerMSCATTER(SDValue Op,
    SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, Scale};
    return DAG.getMaskedScatter(MSC->getVTList(), MemVT, DL, Ops,
                                MSC->getMemOperand(), IndexType,
                                MSC->isTruncatingStore());
                                MSC->getMemOperand(), IndexType, Truncating);
  }
  // Lower fixed length scatter to a scalable equivalent.
  if (VT.isFixedLengthVector()) {
    assert(Subtarget->useSVEForFixedLengthVectors() &&
           "Cannot lower when not using SVE for fixed vectors!");
    // Once bitcast we treat floating-point scatters as if integer.
    if (VT.isFloatingPoint()) {
      VT = VT.changeVectorElementTypeToInteger();
      MemVT = MemVT.changeVectorElementTypeToInteger();
      StoreVal = DAG.getNode(ISD::BITCAST, DL, VT, StoreVal);
    }
    // Find the smallest integer fixed length vector we can use for the scatter.
    EVT PromotedVT = VT.changeVectorElementType(MVT::i32);
    if (VT.getVectorElementType() == MVT::i64 ||
        Index.getValueType().getVectorElementType() == MVT::i64 ||
        Mask.getValueType().getVectorElementType() == MVT::i64)
      PromotedVT = VT.changeVectorElementType(MVT::i64);
    // Promote vector operands.
    unsigned ExtOpcode = IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
    Index = DAG.getNode(ExtOpcode, DL, PromotedVT, Index);
    Mask = DAG.getNode(ISD::SIGN_EXTEND, DL, PromotedVT, Mask);
    StoreVal = DAG.getNode(ISD::ANY_EXTEND, DL, PromotedVT, StoreVal);
    // A promoted value type forces the need for a truncating store.
    if (PromotedVT != VT)
      Truncating = true;
    EVT ContainerVT = getContainerForFixedLengthVector(DAG, PromotedVT);
    // Convert fixed length vector operands to scalable.
    MemVT = ContainerVT.changeVectorElementType(MemVT.getVectorElementType());
    Index = convertToScalableVector(DAG, ContainerVT, Index);
    Mask = convertFixedMaskToScalableVector(Mask, DAG);
    StoreVal = convertToScalableVector(DAG, ContainerVT, StoreVal);
    // Emit equivalent scalable vector scatter.
    SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, Scale};
    return DAG.getMaskedScatter(MSC->getVTList(), MemVT, DL, Ops,
                                MSC->getMemOperand(), IndexType, Truncating);
  }
  bool NeedsExtend =
      getGatherScatterIndexIsExtended(Index) ||
      Index.getSimpleValueType().getVectorElementType() == MVT::i32;
  EVT IndexVT = Index.getSimpleValueType();
  SDVTList VTs = DAG.getVTList(MVT::Other);
  SDValue InputVT = DAG.getValueType(MemVT);
  bool IsFixedLength = MSC->getMemoryVT().isFixedLengthVector();
  if (IsFixedLength) {
    assert(Subtarget->useSVEForFixedLengthVectors() &&
           "Cannot lower when not using SVE for fixed vectors");
    if (MemVT.getScalarSizeInBits() <= IndexVT.getScalarSizeInBits()) {
      IndexVT = getContainerForFixedLengthVector(DAG, IndexVT);
      MemVT = IndexVT.changeVectorElementType(MemVT.getVectorElementType());
    } else {
      MemVT = getContainerForFixedLengthVector(DAG, MemVT);
      IndexVT = MemVT.changeTypeToInteger();
    }
    InputVT = DAG.getValueType(MemVT.changeTypeToInteger());
    StoreVal =
        DAG.getNode(ISD::BITCAST, DL, VT.changeTypeToInteger(), StoreVal);
    StoreVal = DAG.getNode(
        ISD::ANY_EXTEND, DL,
        VT.changeVectorElementType(IndexVT.getVectorElementType()), StoreVal);
    StoreVal = convertToScalableVector(DAG, IndexVT, StoreVal);
    Mask = DAG.getNode(
        ISD::SIGN_EXTEND, DL,
        VT.changeVectorElementType(IndexVT.getVectorElementType()), Mask);
  } else if (VT.isFloatingPoint()) {
  if (VT.isFloatingPoint()) {
    // Handle FP data by casting the data so an integer scatter can be used.
    EVT StoreValVT = getPackedSVEVectorVT(VT.getVectorElementCount());
    StoreVal = getSVESafeBitCast(StoreValVT, StoreVal, DAG);
@@ -4840,14 +4877,6 @@ SDValue AArch64TargetLowering::LowerMSCATTER(SDValue Op,
  selectGatherScatterAddrMode(BasePtr, Index, IsScaled, MemVT, Opcode,
                              /*isGather=*/false, DAG);
  if (IsFixedLength) {
    if (Index.getSimpleValueType().isFixedLengthVector())
      Index = convertToScalableVector(DAG, IndexVT, Index);
    if (BasePtr.getSimpleValueType().isFixedLengthVector())
      BasePtr = convertToScalableVector(DAG, IndexVT, BasePtr);
    Mask = convertFixedMaskToScalableVector(Mask, DAG);
  }
  SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, InputVT};
  return DAG.getNode(Opcode, DL, VTs, Ops);
}