Commit 05e7d8d6 authored by Matt Arsenault's avatar Matt Arsenault Committed by Matt Arsenault
Browse files

TTI: Add addrspace parameters to memcpy lowering functions

parent 7d382dcd
Loading
Loading
Loading
Loading
+15 −2
Original line number Diff line number Diff line
@@ -1069,6 +1069,7 @@ public:

  /// \returns The type to use in a loop expansion of a memcpy call.
  Type *getMemcpyLoopLoweringType(LLVMContext &Context, Value *Length,
                                  unsigned SrcAddrSpace, unsigned DestAddrSpace,
                                  unsigned SrcAlign, unsigned DestAlign) const;

  /// \param[out] OpsOut The operand types to copy RemainingBytes of memory.
@@ -1080,6 +1081,8 @@ public:
  void getMemcpyLoopResidualLoweringType(SmallVectorImpl<Type *> &OpsOut,
                                         LLVMContext &Context,
                                         unsigned RemainingBytes,
                                         unsigned SrcAddrSpace,
                                         unsigned DestAddrSpace,
                                         unsigned SrcAlign,
                                         unsigned DestAlign) const;

@@ -1382,11 +1385,15 @@ public:
  virtual Value *getOrCreateResultFromMemIntrinsic(IntrinsicInst *Inst,
                                                   Type *ExpectedType) = 0;
  virtual Type *getMemcpyLoopLoweringType(LLVMContext &Context, Value *Length,
                                          unsigned SrcAddrSpace,
                                          unsigned DestAddrSpace,
                                          unsigned SrcAlign,
                                          unsigned DestAlign) const = 0;
  virtual void getMemcpyLoopResidualLoweringType(
      SmallVectorImpl<Type *> &OpsOut, LLVMContext &Context,
      unsigned RemainingBytes, unsigned SrcAlign, unsigned DestAlign) const = 0;
      unsigned RemainingBytes,
      unsigned SrcAddrSpace, unsigned DestAddrSpace,
      unsigned SrcAlign, unsigned DestAlign) const = 0;
  virtual bool areInlineCompatible(const Function *Caller,
                                   const Function *Callee) const = 0;
  virtual bool
@@ -1830,16 +1837,22 @@ public:
    return Impl.getOrCreateResultFromMemIntrinsic(Inst, ExpectedType);
  }
  Type *getMemcpyLoopLoweringType(LLVMContext &Context, Value *Length,
                                  unsigned SrcAddrSpace, unsigned DestAddrSpace,
                                  unsigned SrcAlign,
                                  unsigned DestAlign) const override {
    return Impl.getMemcpyLoopLoweringType(Context, Length, SrcAlign, DestAlign);
    return Impl.getMemcpyLoopLoweringType(Context, Length,
                                          SrcAddrSpace, DestAddrSpace,
                                          SrcAlign, DestAlign);
  }
  void getMemcpyLoopResidualLoweringType(SmallVectorImpl<Type *> &OpsOut,
                                         LLVMContext &Context,
                                         unsigned RemainingBytes,
                                         unsigned SrcAddrSpace,
                                         unsigned DestAddrSpace,
                                         unsigned SrcAlign,
                                         unsigned DestAlign) const override {
    Impl.getMemcpyLoopResidualLoweringType(OpsOut, Context, RemainingBytes,
                                           SrcAddrSpace, DestAddrSpace,
                                           SrcAlign, DestAlign);
  }
  bool areInlineCompatible(const Function *Caller,
+3 −0
Original line number Diff line number Diff line
@@ -543,6 +543,7 @@ public:
  }

  Type *getMemcpyLoopLoweringType(LLVMContext &Context, Value *Length,
                                  unsigned SrcAddrSpace, unsigned DestAddrSpace,
                                  unsigned SrcAlign, unsigned DestAlign) const {
    return Type::getInt8Ty(Context);
  }
@@ -550,6 +551,8 @@ public:
  void getMemcpyLoopResidualLoweringType(SmallVectorImpl<Type *> &OpsOut,
                                         LLVMContext &Context,
                                         unsigned RemainingBytes,
                                         unsigned SrcAddrSpace,
                                         unsigned DestAddrSpace,
                                         unsigned SrcAlign,
                                         unsigned DestAlign) const {
    for (unsigned i = 0; i != RemainingBytes; ++i)
+9 −2
Original line number Diff line number Diff line
@@ -776,16 +776,23 @@ Value *TargetTransformInfo::getOrCreateResultFromMemIntrinsic(

Type *TargetTransformInfo::getMemcpyLoopLoweringType(LLVMContext &Context,
                                                     Value *Length,
                                                     unsigned SrcAddrSpace,
                                                     unsigned DestAddrSpace,
                                                     unsigned SrcAlign,
                                                     unsigned DestAlign) const {
  return TTIImpl->getMemcpyLoopLoweringType(Context, Length, SrcAlign,
  return TTIImpl->getMemcpyLoopLoweringType(Context, Length, SrcAddrSpace,
                                            DestAddrSpace, SrcAlign,
                                            DestAlign);
}

void TargetTransformInfo::getMemcpyLoopResidualLoweringType(
    SmallVectorImpl<Type *> &OpsOut, LLVMContext &Context,
    unsigned RemainingBytes, unsigned SrcAlign, unsigned DestAlign) const {
    unsigned RemainingBytes,
    unsigned SrcAddrSpace,
    unsigned DestAddrSpace,
    unsigned SrcAlign, unsigned DestAlign) const {
  TTIImpl->getMemcpyLoopResidualLoweringType(OpsOut, Context, RemainingBytes,
                                             SrcAddrSpace, DestAddrSpace,
                                             SrcAlign, DestAlign);
}

+8 −7
Original line number Diff line number Diff line
@@ -36,16 +36,16 @@ void llvm::createMemCpyLoopKnownSize(Instruction *InsertBefore, Value *SrcAddr,
  Function *ParentFunc = PreLoopBB->getParent();
  LLVMContext &Ctx = PreLoopBB->getContext();

  unsigned SrcAS = cast<PointerType>(SrcAddr->getType())->getAddressSpace();
  unsigned DstAS = cast<PointerType>(DstAddr->getType())->getAddressSpace();

  Type *TypeOfCopyLen = CopyLen->getType();
  Type *LoopOpType =
      TTI.getMemcpyLoopLoweringType(Ctx, CopyLen, SrcAlign, DestAlign);
    TTI.getMemcpyLoopLoweringType(Ctx, CopyLen, SrcAS, DstAS, SrcAlign, DestAlign);

  unsigned LoopOpSize = getLoopOperandSizeInBytes(LoopOpType);
  uint64_t LoopEndCount = CopyLen->getZExtValue() / LoopOpSize;

  unsigned SrcAS = cast<PointerType>(SrcAddr->getType())->getAddressSpace();
  unsigned DstAS = cast<PointerType>(DstAddr->getType())->getAddressSpace();

  if (LoopEndCount != 0) {
    // Split
    PostLoopBB = PreLoopBB->splitBasicBlock(InsertBefore, "memcpy-split");
@@ -99,6 +99,7 @@ void llvm::createMemCpyLoopKnownSize(Instruction *InsertBefore, Value *SrcAddr,

    SmallVector<Type *, 5> RemainingOps;
    TTI.getMemcpyLoopResidualLoweringType(RemainingOps, Ctx, RemainingBytes,
                                          SrcAS, DstAS,
                                          SrcAlign, DestAlign);

    for (auto OpTy : RemainingOps) {
@@ -144,15 +145,15 @@ void llvm::createMemCpyLoopUnknownSize(Instruction *InsertBefore,

  Function *ParentFunc = PreLoopBB->getParent();
  LLVMContext &Ctx = PreLoopBB->getContext();
  unsigned SrcAS = cast<PointerType>(SrcAddr->getType())->getAddressSpace();
  unsigned DstAS = cast<PointerType>(DstAddr->getType())->getAddressSpace();

  Type *LoopOpType =
      TTI.getMemcpyLoopLoweringType(Ctx, CopyLen, SrcAlign, DestAlign);
    TTI.getMemcpyLoopLoweringType(Ctx, CopyLen, SrcAS, DstAS, SrcAlign, DestAlign);
  unsigned LoopOpSize = getLoopOperandSizeInBytes(LoopOpType);

  IRBuilder<> PLBuilder(PreLoopBB->getTerminator());

  unsigned SrcAS = cast<PointerType>(SrcAddr->getType())->getAddressSpace();
  unsigned DstAS = cast<PointerType>(DstAddr->getType())->getAddressSpace();
  PointerType *SrcOpType = PointerType::get(LoopOpType, SrcAS);
  PointerType *DstOpType = PointerType::get(LoopOpType, DstAS);
  if (SrcAddr->getType() != SrcOpType) {