Commit 346f6b54 authored by Anna Welker's avatar Anna Welker
Browse files

[ARM][MVE] Enable masked gathers from vector of pointers

Adds a pass to the ARM backend that takes a v4i32
gather and transforms it into a call to MVE's
masked gather intrinsics.

Differential Revision: https://reviews.llvm.org/D71743
parent 55a51e1c
Loading
Loading
Loading
Loading
+2 −0
Original line number Diff line number Diff line
@@ -53,6 +53,7 @@ FunctionPass *createThumb2SizeReductionPass(
InstructionSelector *
createARMInstructionSelector(const ARMBaseTargetMachine &TM, const ARMSubtarget &STI,
                             const ARMRegisterBankInfo &RBI);
Pass *createMVEGatherScatterLoweringPass();

void LowerARMMachineInstrToMCInst(const MachineInstr *MI, MCInst &OutMI,
                                  ARMAsmPrinter &AP);
@@ -67,6 +68,7 @@ void initializeThumb2ITBlockPass(PassRegistry &);
void initializeMVEVPTBlockPass(PassRegistry &);
void initializeARMLowOverheadLoopsPass(PassRegistry &);
void initializeMVETailPredicationPass(PassRegistry &);
void initializeMVEGatherScatterLoweringPass(PassRegistry &);

} // end namespace llvm

+3 −0
Original line number Diff line number Diff line
@@ -98,6 +98,7 @@ extern "C" void LLVMInitializeARMTarget() {
  initializeMVEVPTBlockPass(Registry);
  initializeMVETailPredicationPass(Registry);
  initializeARMLowOverheadLoopsPass(Registry);
  initializeMVEGatherScatterLoweringPass(Registry);
}

static std::unique_ptr<TargetLoweringObjectFile> createTLOF(const Triple &TT) {
@@ -404,6 +405,8 @@ void ARMPassConfig::addIRPasses() {
          return ST.hasAnyDataBarrier() && !ST.isThumb1Only();
        }));

  addPass(createMVEGatherScatterLoweringPass());

  TargetPassConfig::addIRPasses();

  // Run the parallel DSP pass.
+24 −0
Original line number Diff line number Diff line
@@ -22,6 +22,7 @@
#include "llvm/IR/Instruction.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/PatternMatch.h"
#include "llvm/IR/Type.h"
#include "llvm/MC/SubtargetFeature.h"
#include "llvm/Support/Casting.h"
@@ -46,6 +47,8 @@ static cl::opt<bool> DisableLowOverheadLoops(

extern cl::opt<bool> DisableTailPredication;

extern cl::opt<bool> EnableMaskedGatherScatters;

bool ARMTTIImpl::areInlineCompatible(const Function *Caller,
                                     const Function *Callee) const {
  const TargetMachine &TM = getTLI()->getTargetMachine();
@@ -514,6 +517,27 @@ bool ARMTTIImpl::isLegalMaskedLoad(Type *DataTy, MaybeAlign Alignment) {
         (EltWidth == 8);
}

bool ARMTTIImpl::isLegalMaskedGather(Type *Ty, MaybeAlign Alignment) {
  if (!EnableMaskedGatherScatters || !ST->hasMVEIntegerOps())
    return false;

  // This method is called in 2 places:
  //  - from the vectorizer with a scalar type, in which case we need to get
  //  this as good as we can with the limited info we have (and rely on the cost
  //  model for the rest).
  //  - from the masked intrinsic lowering pass with the actual vector type.
  // For MVE, we have a custom lowering pass that will already have custom
  // legalised any gathers that we can to MVE intrinsics, and want to expand all
  // the rest. The pass runs before the masked intrinsic lowering pass, so if we
  // are here, we know we want to expand.
  if (isa<VectorType>(Ty))
    return false;

  unsigned EltWidth = Ty->getScalarSizeInBits();
  return ((EltWidth == 32 && (!Alignment || Alignment >= 4)) ||
          (EltWidth == 16 && (!Alignment || Alignment >= 2)) || EltWidth == 8);
}

int ARMTTIImpl::getMemcpyCost(const Instruction *I) {
  const MemCpyInst *MI = dyn_cast<MemCpyInst>(I);
  assert(MI && "MemcpyInst expected");
+1 −1
Original line number Diff line number Diff line
@@ -159,7 +159,7 @@ public:
    return isLegalMaskedLoad(DataTy, Alignment);
  }

  bool isLegalMaskedGather(Type *Ty, MaybeAlign Alignment) { return false; }
  bool isLegalMaskedGather(Type *Ty, MaybeAlign Alignment);

  bool isLegalMaskedScatter(Type *Ty, MaybeAlign Alignment) { return false; }

+1 −0
Original line number Diff line number Diff line
@@ -51,6 +51,7 @@ add_llvm_target(ARMCodeGen
  ARMTargetObjectFile.cpp
  ARMTargetTransformInfo.cpp
  MLxExpansionPass.cpp
  MVEGatherScatterLowering.cpp
  MVETailPredication.cpp
  MVEVPTBlockPass.cpp
  Thumb1FrameLowering.cpp
Loading