Commit 3318c1d6 authored by Tom Stellard's avatar Tom Stellard
Browse files

Merging r326404:

------------------------------------------------------------------------
r326404 | rnk | 2018-02-28 17:19:18 -0800 (Wed, 28 Feb 2018) | 14 lines

[IPSCCP] do not break musttail invariant (PR36485)

Do not replace results of `musttail` calls with a constant if the
call itself can't be removed.

Do not zap returns of `musttail` callees, if the call site can't be
removed and replaced with a constant.

Do not zap returns of `musttail`-calling blocks, this breaks
invariant too.

Patch by Fedor Indutny

Differential Revision: https://reviews.llvm.org/D43695
------------------------------------------------------------------------

llvm-svn: 329855
parent 8f8ecdf7
Loading
Loading
Loading
Loading
+55 −1
Original line number Diff line number Diff line
@@ -223,6 +223,10 @@ class SCCPSolver : public InstVisitor<SCCPSolver> {
  /// represented here for efficient lookup.
  SmallPtrSet<Function *, 16> MRVFunctionsTracked;

  /// MustTailFunctions - Each function here is a callee of non-removable
  /// musttail call site.
  SmallPtrSet<Function *, 16> MustTailCallees;

  /// TrackingIncomingArguments - This is the set of functions for whose
  /// arguments we make optimistic assumptions about and try to prove as
  /// constants.
@@ -289,6 +293,18 @@ public:
      TrackedRetVals.insert(std::make_pair(F, LatticeVal()));
  }

  /// AddMustTailCallee - If the SCCP solver finds that this function is called
  /// from non-removable musttail call site.
  void AddMustTailCallee(Function *F) {
    MustTailCallees.insert(F);
  }

  /// Returns true if the given function is called from non-removable musttail
  /// call site.
  bool isMustTailCallee(Function *F) {
    return MustTailCallees.count(F);
  }

  void AddArgumentTrackedFunction(Function *F) {
    TrackingIncomingArguments.insert(F);
  }
@@ -358,6 +374,12 @@ public:
    return MRVFunctionsTracked;
  }

  /// getMustTailCallees - Get the set of functions which are called
  /// from non-removable musttail call sites.
  const SmallPtrSet<Function *, 16> getMustTailCallees() {
    return MustTailCallees;
  }

  /// markOverdefined - Mark the specified value overdefined.  This
  /// works with both scalars and structs.
  void markOverdefined(Value *V) {
@@ -1672,6 +1694,23 @@ static bool tryToReplaceWithConstant(SCCPSolver &Solver, Value *V) {
          IV.isConstant() ? IV.getConstant() : UndefValue::get(V->getType());
  }
  assert(Const && "Constant is nullptr here!");

  // Replacing `musttail` instructions with constant breaks `musttail` invariant
  // unless the call itself can be removed
  CallInst *CI = dyn_cast<CallInst>(V);
  if (CI && CI->isMustTailCall() && !isInstructionTriviallyDead(CI)) {
    CallSite CS(CI);
    Function *F = CS.getCalledFunction();

    // Don't zap returns of the callee
    if (F)
      Solver.AddMustTailCallee(F);

    DEBUG(dbgs() << "  Can\'t treat the result of musttail call : " << *CI
                 << " as a constant\n");
    return false;
  }

  DEBUG(dbgs() << "  Constant: " << *Const << " = " << *V << '\n');

  // Replaces all of the uses of a variable with uses of the constant.
@@ -1802,11 +1841,26 @@ static void findReturnsToZap(Function &F,
  if (!Solver.isArgumentTrackedFunction(&F))
    return;

  for (BasicBlock &BB : F)
  // There is a non-removable musttail call site of this function. Zapping
  // returns is not allowed.
  if (Solver.isMustTailCallee(&F)) {
    DEBUG(dbgs() << "Can't zap returns of the function : " << F.getName()
                 << " due to present musttail call of it\n");
    return;
  }

  for (BasicBlock &BB : F) {
    if (CallInst *CI = BB.getTerminatingMustTailCall()) {
      DEBUG(dbgs() << "Can't zap return of the block due to present "
                   << "musttail call : " << *CI << "\n");
      return;
    }

    if (auto *RI = dyn_cast<ReturnInst>(BB.getTerminator()))
      if (!isa<UndefValue>(RI->getOperand(0)))
        ReturnsToZap.push_back(RI);
  }
}

static bool runIPSCCP(Module &M, const DataLayout &DL,
                      const TargetLibraryInfo *TLI) {