Commit 6ce72f28 authored by Florian Hahn's avatar Florian Hahn
Browse files

Backport of rL326666 and rL326668 for PR36607 and PR36608.

[CallSiteSplitting] properly split musttail calls.

The original author was Fedor Indutny <fedor@indutny.com>.

`musttail` calls can't be naively splitted. The split blocks must
include not only the call instruction itself, but also (optional)
`bitcast` and `return` instructions that follow it.

Clone `bitcast` and `ret`, place them into the split blocks, and
remove the tail block when done.

Reviewers: junbuml, mcrosier, davidxl, davide, fhahn

Reviewed By: fhahn

Subscribers: JDevlieghere, llvm-commits

Differential Revision: https://reviews.llvm.org/D43729

llvm-svn: 329793
parent 34d881d7
Loading
Loading
Loading
Loading
+75 −2
Original line number Diff line number Diff line
@@ -201,6 +201,46 @@ static bool canSplitCallSite(CallSite CS) {
  return CallSiteBB->canSplitPredecessors();
}

static Instruction *cloneInstForMustTail(Instruction *I, Instruction *Before,
                                         Value *V) {
  Instruction *Copy = I->clone();
  Copy->setName(I->getName());
  Copy->insertBefore(Before);
  if (V)
    Copy->setOperand(0, V);
  return Copy;
}

/// Copy mandatory `musttail` return sequence that follows original `CI`, and
/// link it up to `NewCI` value instead:
///
///   * (optional) `bitcast NewCI to ...`
///   * `ret bitcast or NewCI`
///
/// Insert this sequence right before `SplitBB`'s terminator, which will be
/// cleaned up later in `splitCallSite` below.
static void copyMustTailReturn(BasicBlock *SplitBB, Instruction *CI,
                               Instruction *NewCI) {
  bool IsVoid = SplitBB->getParent()->getReturnType()->isVoidTy();
  auto II = std::next(CI->getIterator());

  BitCastInst *BCI = dyn_cast<BitCastInst>(&*II);
  if (BCI)
    ++II;

  ReturnInst *RI = dyn_cast<ReturnInst>(&*II);
  assert(RI && "`musttail` call must be followed by `ret` instruction");

  TerminatorInst *TI = SplitBB->getTerminator();
  Value *V = NewCI;
  if (BCI)
    V = cloneInstForMustTail(BCI, TI, V);
  cloneInstForMustTail(RI, TI, IsVoid ? nullptr : V);

  // FIXME: remove TI here, `DuplicateInstructionsInSplitBetween` has a bug
  // that prevents doing this now.
}

/// Return true if the CS is split into its new predecessors which are directly
/// hooked to each of its original predecessors pointed by PredBB1 and PredBB2.
/// CallInst1 and CallInst2 will be the new call-sites placed in the new
@@ -245,6 +285,7 @@ static void splitCallSite(CallSite CS, BasicBlock *PredBB1, BasicBlock *PredBB2,
                          Instruction *CallInst1, Instruction *CallInst2) {
  Instruction *Instr = CS.getInstruction();
  BasicBlock *TailBB = Instr->getParent();
  bool IsMustTailCall = CS.isMustTailCall();
  assert(Instr == (TailBB->getFirstNonPHIOrDbg()) && "Unexpected call-site");

  BasicBlock *SplitBlock1 =
@@ -276,9 +317,14 @@ static void splitCallSite(CallSite CS, BasicBlock *PredBB1, BasicBlock *PredBB2,
      ++ArgNo;
    }
  }
  // Clone and place bitcast and return instructions before `TI`
  if (IsMustTailCall) {
    copyMustTailReturn(SplitBlock1, CS.getInstruction(), CallInst1);
    copyMustTailReturn(SplitBlock2, CS.getInstruction(), CallInst2);
  }

  // Replace users of the original call with a PHI mering call-sites split.
  if (Instr->getNumUses()) {
  if (!IsMustTailCall && Instr->getNumUses()) {
    PHINode *PN = PHINode::Create(Instr->getType(), 2, "phi.call",
                                  TailBB->getFirstNonPHI());
    PN->addIncoming(CallInst1, SplitBlock1);
@@ -290,8 +336,25 @@ static void splitCallSite(CallSite CS, BasicBlock *PredBB1, BasicBlock *PredBB2,
               << "\n");
  DEBUG(dbgs() << "    " << *CallInst2 << " in " << SplitBlock2->getName()
               << "\n");
  Instr->eraseFromParent();

  NumCallSiteSplit++;

  // FIXME: remove TI in `copyMustTailReturn`
  if (IsMustTailCall) {
    // Remove superfluous `br` terminators from the end of the Split blocks
    // NOTE: Removing terminator removes the SplitBlock from the TailBB's
    // predecessors. Therefore we must get complete list of Splits before
    // attempting removal.
    SmallVector<BasicBlock *, 2> Splits(predecessors((TailBB)));
    assert(Splits.size() == 2 && "Expected exactly 2 splits!");
    for (unsigned i = 0; i < Splits.size(); i++)
      Splits[i]->getTerminator()->eraseFromParent();

    // Erase the tail block once done with musttail patching
    TailBB->eraseFromParent();
    return;
  }
  Instr->eraseFromParent();
}

// Return true if the call-site has an argument which is a PHI with only
@@ -369,7 +432,17 @@ static bool doCallSiteSplitting(Function &F, TargetLibraryInfo &TLI) {
      Function *Callee = CS.getCalledFunction();
      if (!Callee || Callee->isDeclaration())
        continue;

      // Successful musttail call-site splits result in erased CI and erased BB.
      // Check if such path is possible before attempting the splitting.
      bool IsMustTail = CS.isMustTailCall();

      Changed |= tryToSplitCallSite(CS);

      // There're no interesting instructions after this. The call site
      // itself might have been erased on splitting.
      if (IsMustTail)
        break;
    }
  }
  return Changed;
+109 −0
Original line number Diff line number Diff line
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
; RUN: opt < %s -callsite-splitting -S | FileCheck %s

define i8* @caller(i8* %a, i8* %b) {
; CHECK-LABEL: @caller(
; CHECK-NEXT:  Top:
; CHECK-NEXT:    [[C:%.*]] = icmp eq i8* [[A:%.*]], null
; CHECK-NEXT:    br i1 [[C]], label [[TAIL_PREDBB1_SPLIT:%.*]], label [[TBB:%.*]]
; CHECK:       TBB:
; CHECK-NEXT:    [[C2:%.*]] = icmp eq i8* [[B:%.*]], null
; CHECK-NEXT:    br i1 [[C2]], label [[TAIL_PREDBB2_SPLIT:%.*]], label [[END:%.*]]
; CHECK:       Tail.predBB1.split:
; CHECK-NEXT:    [[TMP0:%.*]] = musttail call i8* @callee(i8* null, i8* [[B]])
; CHECK-NEXT:    [[CB1:%.*]] = bitcast i8* [[TMP0]] to i8*
; CHECK-NEXT:    ret i8* [[CB1]]
; CHECK:       Tail.predBB2.split:
; CHECK-NEXT:    [[TMP1:%.*]] = musttail call i8* @callee(i8* nonnull [[A]], i8* null)
; CHECK-NEXT:    [[CB2:%.*]] = bitcast i8* [[TMP1]] to i8*
; CHECK-NEXT:    ret i8* [[CB2]]
; CHECK:       End:
; CHECK-NEXT:    ret i8* null
;
Top:
  %c = icmp eq i8* %a, null
  br i1 %c, label %Tail, label %TBB
TBB:
  %c2 = icmp eq i8* %b, null
  br i1 %c2, label %Tail, label %End
Tail:
  %ca = musttail call i8* @callee(i8* %a, i8* %b)
  %cb = bitcast i8* %ca to i8*
  ret i8* %cb
End:
  ret i8* null
}

define i8* @callee(i8* %a, i8* %b) noinline {
; CHECK-LABEL: define i8* @callee(
; CHECK-NEXT:    ret i8* [[A:%.*]]
;
  ret i8* %a
}

define i8* @no_cast_caller(i8* %a, i8* %b) {
; CHECK-LABEL: @no_cast_caller(
; CHECK-NEXT:  Top:
; CHECK-NEXT:    [[C:%.*]] = icmp eq i8* [[A:%.*]], null
; CHECK-NEXT:    br i1 [[C]], label [[TAIL_PREDBB1_SPLIT:%.*]], label [[TBB:%.*]]
; CHECK:       TBB:
; CHECK-NEXT:    [[C2:%.*]] = icmp eq i8* [[B:%.*]], null
; CHECK-NEXT:    br i1 [[C2]], label [[TAIL_PREDBB2_SPLIT:%.*]], label [[END:%.*]]
; CHECK:       Tail.predBB1.split:
; CHECK-NEXT:    [[TMP0:%.*]] = musttail call i8* @callee(i8* null, i8* [[B]])
; CHECK-NEXT:    ret i8* [[TMP0]]
; CHECK:       Tail.predBB2.split:
; CHECK-NEXT:    [[TMP1:%.*]] = musttail call i8* @callee(i8* nonnull [[A]], i8* null)
; CHECK-NEXT:    ret i8* [[TMP1]]
; CHECK:       End:
; CHECK-NEXT:    ret i8* null
;
Top:
  %c = icmp eq i8* %a, null
  br i1 %c, label %Tail, label %TBB
TBB:
  %c2 = icmp eq i8* %b, null
  br i1 %c2, label %Tail, label %End
Tail:
  %ca = musttail call i8* @callee(i8* %a, i8* %b)
  ret i8* %ca
End:
  ret i8* null
}

define void @void_caller(i8* %a, i8* %b) {
; CHECK-LABEL: @void_caller(
; CHECK-NEXT:  Top:
; CHECK-NEXT:    [[C:%.*]] = icmp eq i8* [[A:%.*]], null
; CHECK-NEXT:    br i1 [[C]], label [[TAIL_PREDBB1_SPLIT:%.*]], label [[TBB:%.*]]
; CHECK:       TBB:
; CHECK-NEXT:    [[C2:%.*]] = icmp eq i8* [[B:%.*]], null
; CHECK-NEXT:    br i1 [[C2]], label [[TAIL_PREDBB2_SPLIT:%.*]], label [[END:%.*]]
; CHECK:       Tail.predBB1.split:
; CHECK-NEXT:    musttail call void @void_callee(i8* null, i8* [[B]])
; CHECK-NEXT:    ret void
; CHECK:       Tail.predBB2.split:
; CHECK-NEXT:    musttail call void @void_callee(i8* nonnull [[A]], i8* null)
; CHECK-NEXT:    ret void
; CHECK:       End:
; CHECK-NEXT:    ret void
;
Top:
  %c = icmp eq i8* %a, null
  br i1 %c, label %Tail, label %TBB
TBB:
  %c2 = icmp eq i8* %b, null
  br i1 %c2, label %Tail, label %End
Tail:
  musttail call void @void_callee(i8* %a, i8* %b)
  ret void
End:
  ret void
}

define void @void_callee(i8* %a, i8* %b) noinline {
; CHECK-LABEL: define void @void_callee(
; CHECK-NEXT:    ret void
;
  ret void
}