Commit 43c8307c authored by Jun Ma's avatar Jun Ma
Browse files

[Coroutines] CoroElide enhancement

Fix regression of CoreElide pass when current function is
coroutine.

Differential Revision: https://reviews.llvm.org/D71663
parent 2b5a8976
Loading
Loading
Loading
Loading
+60 −18
Original line number Diff line number Diff line
@@ -8,6 +8,7 @@

#include "llvm/Transforms/Coroutines/CoroElide.h"
#include "CoroInternal.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/Analysis/AliasAnalysis.h"
#include "llvm/Analysis/InstructionSimplify.h"
#include "llvm/IR/Dominators.h"
@@ -27,8 +28,9 @@ struct Lowerer : coro::LowererBase {
  SmallVector<CoroBeginInst *, 1> CoroBegins;
  SmallVector<CoroAllocInst *, 1> CoroAllocs;
  SmallVector<CoroSubFnInst *, 4> ResumeAddr;
  SmallVector<CoroSubFnInst *, 4> DestroyAddr;
  DenseMap<CoroBeginInst *, SmallVector<CoroSubFnInst *, 4>> DestroyAddr;
  SmallVector<CoroFreeInst *, 1> CoroFrees;
  CoroSuspendInst *CoroFinalSuspend;

  Lowerer(Module &M) : LowererBase(M) {}

@@ -146,33 +148,62 @@ bool Lowerer::shouldElide(Function *F, DominatorTree &DT) const {
  if (CoroAllocs.empty())
    return false;

  // Check that for every coro.begin there is a coro.destroy directly
  // referencing the SSA value of that coro.begin along a non-exceptional path.
  // Check that for every coro.begin there is at least one coro.destroy directly
  // referencing the SSA value of that coro.begin along each
  // non-exceptional path.
  // If the value escaped, then coro.destroy would have been referencing a
  // memory location storing that value and not the virtual register.

  // First gather all of the non-exceptional terminators for the function.
  SmallPtrSet<Instruction *, 8> Terminators;
  bool HasMultiPred = false;
  // First gather all of the non-exceptional terminators for the function.
  // Consider the final coro.suspend as the real terminator when the current
  // function is a coroutine.
  if (CoroFinalSuspend) {
    // If block of final coro.suspend has more than one predecessor,
    // then there is one resume path and the others are exceptional paths,
    // consider these predecessors as terminators.
    BasicBlock *FinalBB = CoroFinalSuspend->getParent();
    if (FinalBB->hasNPredecessorsOrMore(2)) {
      HasMultiPred = true;
      for (auto *B : predecessors(FinalBB))
        Terminators.insert(B->getTerminator());
    } else
      Terminators.insert(CoroFinalSuspend);
  } else {
    for (BasicBlock &B : *F) {
      auto *TI = B.getTerminator();
      if (TI->getNumSuccessors() == 0 && !TI->isExceptionalTerminator() &&
          !isa<UnreachableInst>(TI))
        Terminators.insert(TI);
    }
  }

  // Filter out the coro.destroy that lie along exceptional paths.
  SmallPtrSet<CoroSubFnInst *, 4> DAs;
  for (CoroSubFnInst *DA : DestroyAddr) {
  SmallPtrSet<Instruction *, 2> TIs;
  SmallPtrSet<CoroBeginInst *, 8> ReferencedCoroBegins;
  for (auto &It : DestroyAddr) {
    for (CoroSubFnInst *DA : It.second) {
      for (Instruction *TI : Terminators) {
        if (DT.dominates(DA, TI)) {
          if (HasMultiPred)
            TIs.insert(TI);
          else
            DAs.insert(DA);
          break;
        }
      }
    }
    // If all the predecessors dominate coro.destroys that reference same
    // coro.begin, record the coro.begin
    if (TIs.size() == Terminators.size()) {
      ReferencedCoroBegins.insert(It.first);
      TIs.clear();
    }
  }

  // Find all the coro.begin referenced by coro.destroy along happy paths.
  SmallPtrSet<CoroBeginInst *, 8> ReferencedCoroBegins;
  for (CoroSubFnInst *DA : DAs) {
    if (auto *CB = dyn_cast<CoroBeginInst>(DA->getFrame()))
      ReferencedCoroBegins.insert(CB);
@@ -188,12 +219,22 @@ bool Lowerer::shouldElide(Function *F, DominatorTree &DT) const {

void Lowerer::collectPostSplitCoroIds(Function *F) {
  CoroIds.clear();
  for (auto &I : instructions(F))
  CoroFinalSuspend = nullptr;
  for (auto &I : instructions(F)) {
    if (auto *CII = dyn_cast<CoroIdInst>(&I))
      if (CII->getInfo().isPostSplit())
        // If it is the coroutine itself, don't touch it.
        if (CII->getCoroutine() != CII->getFunction())
          CoroIds.push_back(CII);

    if (auto *CSI = dyn_cast<CoroSuspendInst>(&I))
      if (CSI->isFinal()) {
        if (!CoroFinalSuspend)
          CoroFinalSuspend = CSI;
        else
          report_fatal_error("Only one suspend point can be marked as final");
      }
  }
}

bool Lowerer::processCoroId(CoroIdInst *CoroId, AAResults &AA,
@@ -226,7 +267,7 @@ bool Lowerer::processCoroId(CoroIdInst *CoroId, AAResults &AA,
          ResumeAddr.push_back(II);
          break;
        case CoroSubFnInst::DestroyIndex:
          DestroyAddr.push_back(II);
          DestroyAddr[CB].push_back(II);
          break;
        default:
          llvm_unreachable("unexpected coro.subfn.addr constant");
@@ -249,7 +290,8 @@ bool Lowerer::processCoroId(CoroIdInst *CoroId, AAResults &AA,
      Resumers,
      ShouldElide ? CoroSubFnInst::CleanupIndex : CoroSubFnInst::DestroyIndex);

  replaceWithConstant(DestroyAddrConstant, DestroyAddr);
  for (auto &It : DestroyAddr)
    replaceWithConstant(DestroyAddrConstant, It.second);

  if (ShouldElide) {
    auto *FrameTy = getFrameType(cast<Function>(ResumeAddrConstant));
+116 −0
Original line number Diff line number Diff line
@@ -84,6 +84,120 @@ entry:
  ret void
}

; CHECK-LABEL: @callResume_with_coro_suspend_1(
define void @callResume_with_coro_suspend_1() {
entry:
; CHECK: alloca %f.frame
; CHECK-NOT: coro.begin
; CHECK-NOT: CustomAlloc
; CHECK: call void @may_throw()
  %hdl = call i8* @f()

; CHECK-NEXT: call fastcc void bitcast (void (%f.frame*)* @f.resume to void (i8*)*)(i8* %vFrame)
  %0 = call i8* @llvm.coro.subfn.addr(i8* %hdl, i8 0)
  %1 = bitcast i8* %0 to void (i8*)*
  call fastcc void %1(i8* %hdl)
  %2 = call token @llvm.coro.save(i8* %hdl)
  %3 = call i8 @llvm.coro.suspend(token %2, i1 false)
  switch i8 %3, label  %coro.ret [
    i8 0, label %final.suspend
    i8 1, label %cleanups
  ]

; CHECK-LABEL: final.suspend:
final.suspend:
; CHECK-NEXT: call fastcc void bitcast (void (%f.frame*)* @f.cleanup to void (i8*)*)(i8* %vFrame)
  %4 = call i8* @llvm.coro.subfn.addr(i8* %hdl, i8 1)
  %5 = bitcast i8* %4 to void (i8*)*
  call fastcc void %5(i8* %hdl)
  %6 = call token @llvm.coro.save(i8* %hdl)
  %7 = call i8 @llvm.coro.suspend(token %6, i1 true)
  switch i8 %7, label  %coro.ret [
    i8 0, label %coro.ret
    i8 1, label %cleanups
  ]

; CHECK-LABEL: cleanups:
cleanups:
; CHECK-NEXT: call fastcc void bitcast (void (%f.frame*)* @f.cleanup to void (i8*)*)(i8* %vFrame)
  %8 = call i8* @llvm.coro.subfn.addr(i8* %hdl, i8 1)
  %9 = bitcast i8* %8 to void (i8*)*
  call fastcc void %9(i8* %hdl)
  br label %coro.ret

; CHECK-LABEL: coro.ret:
coro.ret:
; CHECK-NEXT: ret void
  ret void
}

; CHECK-LABEL: @callResume_with_coro_suspend_2(
define void @callResume_with_coro_suspend_2() personality i8* null {
entry:
; CHECK: alloca %f.frame
; CHECK-NOT: coro.begin
; CHECK-NOT: CustomAlloc
; CHECK: call void @may_throw()
  %hdl = call i8* @f()

  %0 = call token @llvm.coro.save(i8* %hdl)
; CHECK: invoke fastcc void bitcast (void (%f.frame*)* @f.resume to void (i8*)*)(i8* %vFrame)
  %1 = call i8* @llvm.coro.subfn.addr(i8* %hdl, i8 0)
  %2 = bitcast i8* %1 to void (i8*)*
  invoke fastcc void %2(i8* %hdl)
    to label %invoke.cont1 unwind label %lpad

; CHECK-LABEL: invoke.cont1:
invoke.cont1:
  %3 = call i8 @llvm.coro.suspend(token %0, i1 false)
  switch i8 %3, label  %coro.ret [
    i8 0, label %final.ready
    i8 1, label %cleanups
  ]

; CHECK-LABEL: lpad:
lpad:
  %4 = landingpad { i8*, i32 }
          catch i8* null
; CHECK: call fastcc void bitcast (void (%f.frame*)* @f.cleanup to void (i8*)*)(i8* %vFrame)
  %5 = call i8* @llvm.coro.subfn.addr(i8* %hdl, i8 1)
  %6 = bitcast i8* %5 to void (i8*)*
  call fastcc void %6(i8* %hdl)
  br label %final.suspend

; CHECK-LABEL: final.ready:
final.ready:
; CHECK-NEXT: call fastcc void bitcast (void (%f.frame*)* @f.cleanup to void (i8*)*)(i8* %vFrame)
  %7 = call i8* @llvm.coro.subfn.addr(i8* %hdl, i8 1)
  %8 = bitcast i8* %7 to void (i8*)*
  call fastcc void %8(i8* %hdl)
  br label %final.suspend

; CHECK-LABEL: final.suspend:
final.suspend:
  %9 = call token @llvm.coro.save(i8* %hdl)
  %10 = call i8 @llvm.coro.suspend(token %9, i1 true)
  switch i8 %10, label  %coro.ret [
    i8 0, label %coro.ret
    i8 1, label %cleanups
  ]

; CHECK-LABEL: cleanups:
cleanups:
; CHECK-NEXT: call fastcc void bitcast (void (%f.frame*)* @f.cleanup to void (i8*)*)(i8* %vFrame)
  %11 = call i8* @llvm.coro.subfn.addr(i8* %hdl, i8 1)
  %12 = bitcast i8* %11 to void (i8*)*
  call fastcc void %12(i8* %hdl)
  br label %coro.ret

; CHECK-LABEL: coro.ret:
coro.ret:
; CHECK-NEXT: ret void
  ret void
}



; CHECK-LABEL: @callResume_PR34897_no_elision(
define void @callResume_PR34897_no_elision(i1 %cond) {
; CHECK-LABEL: entry:
@@ -161,3 +275,5 @@ declare i8* @llvm.coro.free(token, i8*)
declare i8* @llvm.coro.begin(token, i8*)
declare i8* @llvm.coro.frame(token)
declare i8* @llvm.coro.subfn.addr(i8*, i8)
declare i8 @llvm.coro.suspend(token, i1)
declare token @llvm.coro.save(i8*)