Commit 2a3723ef authored by Evgenii Stepanov's avatar Evgenii Stepanov
Browse files

[memtag] Plug in stack safety analysis.

Summary:
Run StackSafetyAnalysis at the end of the IR pipeline and annotate
proven safe allocas with !stack-safe metadata. Do not instrument such
allocas in the AArch64StackTagging pass.

Reviewers: pcc, vitalybuka, ostannard

Reviewed By: vitalybuka

Subscribers: merge_guards_bot, kristof.beyls, hiraditya, cfe-commits, gilang, llvm-commits

Tags: #clang, #llvm

Differential Revision: https://reviews.llvm.org/D73513
parent 78ce1908
Loading
Loading
Loading
Loading
+16 −0
Original line number Diff line number Diff line
@@ -18,6 +18,7 @@
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/ADT/Triple.h"
#include "llvm/Analysis/StackSafetyAnalysis.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/Bitcode/BitcodeReader.h"
@@ -345,6 +346,11 @@ static void addDataFlowSanitizerPass(const PassManagerBuilder &Builder,
  PM.add(createDataFlowSanitizerPass(LangOpts.SanitizerBlacklistFiles));
}

static void addMemTagOptimizationPasses(const PassManagerBuilder &Builder,
                                        legacy::PassManagerBase &PM) {
  PM.add(createStackSafetyGlobalInfoWrapperPass(/*SetMetadata=*/true));
}

static TargetLibraryInfoImpl *createTLII(llvm::Triple &TargetTriple,
                                         const CodeGenOptions &CodeGenOpts) {
  TargetLibraryInfoImpl *TLII = new TargetLibraryInfoImpl(TargetTriple);
@@ -696,6 +702,11 @@ void EmitAssemblyHelper::CreatePasses(legacy::PassManager &MPM,
                           addDataFlowSanitizerPass);
  }

  if (LangOpts.Sanitize.has(SanitizerKind::MemTag)) {
    PMBuilder.addExtension(PassManagerBuilder::EP_OptimizerLast,
                           addMemTagOptimizationPasses);
  }

  // Set up the per-function pass manager.
  FPM.add(new TargetLibraryInfoWrapperPass(*TLII));
  if (CodeGenOpts.VerifyModule)
@@ -1300,6 +1311,11 @@ void EmitAssemblyHelper::EmitAssemblyWithNewPassManager(
          /*CompileKernel=*/true, /*Recover=*/true));
    }

    if (CodeGenOpts.OptimizationLevel > 0 &&
        LangOpts.Sanitize.has(SanitizerKind::MemTag)) {
      MPM.addPass(StackSafetyGlobalAnnotatorPass());
    }

    if (CodeGenOpts.OptimizationLevel == 0) {
      addCoroutinePassesAtO0(MPM, LangOpts, CodeGenOpts);
      addSanitizersAtO0(MPM, TargetTriple, LangOpts, CodeGenOpts);
+23 −0
Original line number Diff line number Diff line
// REQUIRES: aarch64-registered-target

// Old pass manager.
// RUN: %clang     -fno-experimental-new-pass-manager -target aarch64-unknown-linux -march=armv8+memtag -fsanitize=memtag %s -S -emit-llvm -o - | FileCheck %s --check-prefix=CHECK-NO-SAFETY
// RUN: %clang -O1 -fno-experimental-new-pass-manager -target aarch64-unknown-linux -march=armv8+memtag -fsanitize=memtag %s -S -emit-llvm -o - | FileCheck %s --check-prefix=CHECK-SAFETY
// RUN: %clang -O2 -fno-experimental-new-pass-manager -target aarch64-unknown-linux -march=armv8+memtag -fsanitize=memtag %s -S -emit-llvm -o - | FileCheck %s --check-prefix=CHECK-SAFETY
// RUN: %clang -O3 -fno-experimental-new-pass-manager -target aarch64-unknown-linux -march=armv8+memtag -fsanitize=memtag %s -S -emit-llvm -o - | FileCheck %s --check-prefix=CHECK-SAFETY

// New pass manager.
// RUN: %clang     -fexperimental-new-pass-manager -target aarch64-unknown-linux -march=armv8+memtag -fsanitize=memtag %s -S -emit-llvm -o - | FileCheck %s --check-prefix=CHECK-NO-SAFETY
// RUN: %clang -O1 -fexperimental-new-pass-manager -target aarch64-unknown-linux -march=armv8+memtag -fsanitize=memtag %s -S -emit-llvm -o - | FileCheck %s --check-prefix=CHECK-SAFETY
// RUN: %clang -O2 -fexperimental-new-pass-manager -target aarch64-unknown-linux -march=armv8+memtag -fsanitize=memtag %s -S -emit-llvm -o - | FileCheck %s --check-prefix=CHECK-SAFETY
// RUN: %clang -O3 -fexperimental-new-pass-manager -target aarch64-unknown-linux -march=armv8+memtag -fsanitize=memtag %s -S -emit-llvm -o - | FileCheck %s --check-prefix=CHECK-SAFETY

int z;
__attribute__((noinline)) void use(int *p) { *p = z; }
int foo() { int x; use(&x); return x; }

// CHECK-NO-SAFETY: define dso_local i32 @foo()
// CHECK-NO-SAFETY: %x = alloca i32, align 4{{$}}

// CHECK-SAFETY: define dso_local i32 @foo()
// CHECK-SAFETY: %x = alloca i32, align 4, !stack-safe
+16 −3
Original line number Diff line number Diff line
@@ -33,6 +33,8 @@ public:
  StackSafetyInfo &operator=(StackSafetyInfo &&);
  ~StackSafetyInfo();

  FunctionInfo *getInfo() const { return Info.get(); }

  // TODO: Add useful for client methods.
  void print(raw_ostream &O) const;
};
@@ -96,17 +98,26 @@ public:
  PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM);
};

class StackSafetyGlobalAnnotatorPass
    : public PassInfoMixin<StackSafetyGlobalAnnotatorPass> {

public:
  explicit StackSafetyGlobalAnnotatorPass() {}
  PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM);
};

/// This pass performs the global (interprocedural) stack safety analysis
/// (legacy pass manager).
class StackSafetyGlobalInfoWrapperPass : public ModulePass {
  StackSafetyGlobalInfo SSI;
  StackSafetyGlobalInfo SSGI;
  bool SetMetadata;

public:
  static char ID;

  StackSafetyGlobalInfoWrapperPass();
  StackSafetyGlobalInfoWrapperPass(bool SetMetadata = false);

  const StackSafetyGlobalInfo &getResult() const { return SSI; }
  const StackSafetyGlobalInfo &getResult() const { return SSGI; }

  void print(raw_ostream &O, const Module *M) const override;
  void getAnalysisUsage(AnalysisUsage &AU) const override;
@@ -114,6 +125,8 @@ public:
  bool runOnModule(Module &M) override;
};

ModulePass *createStackSafetyGlobalInfoWrapperPass(bool SetMetadata);

} // end namespace llvm

#endif // LLVM_ANALYSIS_STACKSAFETYANALYSIS_H
+43 −9
Original line number Diff line number Diff line
@@ -99,11 +99,11 @@ raw_ostream &operator<<(raw_ostream &OS, const UseInfo &U) {
}

struct AllocaInfo {
  const AllocaInst *AI = nullptr;
  AllocaInst *AI = nullptr;
  uint64_t Size = 0;
  UseInfo Use;

  AllocaInfo(unsigned PointerSize, const AllocaInst *AI, uint64_t Size)
  AllocaInfo(unsigned PointerSize, AllocaInst *AI, uint64_t Size)
      : AI(AI), Size(Size), Use(PointerSize) {}

  StringRef getName() const { return AI->getName(); }
@@ -205,7 +205,7 @@ StackSafetyInfo::FunctionInfo::FunctionInfo(const GlobalAlias *A) : GV(A) {
namespace {

class StackSafetyLocalAnalysis {
  const Function &F;
  Function &F;
  const DataLayout &DL;
  ScalarEvolution &SE;
  unsigned PointerSize = 0;
@@ -227,7 +227,7 @@ class StackSafetyLocalAnalysis {
  }

public:
  StackSafetyLocalAnalysis(const Function &F, ScalarEvolution &SE)
  StackSafetyLocalAnalysis(Function &F, ScalarEvolution &SE)
      : F(F), DL(F.getParent()->getDataLayout()), SE(SE),
        PointerSize(DL.getPointerSizeInBits()),
        UnknownRange(PointerSize, true) {}
@@ -653,17 +653,47 @@ PreservedAnalyses StackSafetyGlobalPrinterPass::run(Module &M,
  return PreservedAnalyses::all();
}

static bool SetStackSafetyMetadata(Module &M,
                                   const StackSafetyGlobalInfo &SSGI) {
  bool Changed = false;
  unsigned Width = M.getDataLayout().getPointerSizeInBits();
  for (auto &F : M.functions()) {
    if (F.isDeclaration() || F.hasOptNone())
      continue;
    auto Iter = SSGI.find(&F);
    if (Iter == SSGI.end())
      continue;
    StackSafetyInfo::FunctionInfo *Summary = Iter->second.getInfo();
    for (auto &AS : Summary->Allocas) {
      ConstantRange AllocaRange{APInt(Width, 0), APInt(Width, AS.Size)};
      if (AllocaRange.contains(AS.Use.Range)) {
        AS.AI->setMetadata(M.getMDKindID("stack-safe"),
                           MDNode::get(M.getContext(), None));
        Changed = true;
      }
    }
  }
  return Changed;
}

PreservedAnalyses
StackSafetyGlobalAnnotatorPass::run(Module &M, ModuleAnalysisManager &AM) {
  auto &SSGI = AM.getResult<StackSafetyGlobalAnalysis>(M);
  (void)SetStackSafetyMetadata(M, SSGI);
  return PreservedAnalyses::all();
}

char StackSafetyGlobalInfoWrapperPass::ID = 0;

StackSafetyGlobalInfoWrapperPass::StackSafetyGlobalInfoWrapperPass()
    : ModulePass(ID) {
StackSafetyGlobalInfoWrapperPass::StackSafetyGlobalInfoWrapperPass(bool SetMetadata)
    : ModulePass(ID), SetMetadata(SetMetadata) {
  initializeStackSafetyGlobalInfoWrapperPassPass(
      *PassRegistry::getPassRegistry());
}

void StackSafetyGlobalInfoWrapperPass::print(raw_ostream &O,
                                             const Module *M) const {
  ::print(SSI, O, *M);
  ::print(SSGI, O, *M);
}

void StackSafetyGlobalInfoWrapperPass::getAnalysisUsage(
@@ -676,8 +706,12 @@ bool StackSafetyGlobalInfoWrapperPass::runOnModule(Module &M) {
      M, [this](Function &F) -> const StackSafetyInfo & {
        return getAnalysis<StackSafetyInfoWrapperPass>(F).getResult();
      });
  SSI = SSDFA.run();
  return false;
  SSGI = SSDFA.run();
  return SetMetadata ? SetStackSafetyMetadata(M, SSGI) : false;
}

ModulePass *llvm::createStackSafetyGlobalInfoWrapperPass(bool SetMetadata) {
  return new StackSafetyGlobalInfoWrapperPass(SetMetadata);
}

static const char LocalPassArg[] = "stack-safety-local";
+1 −0
Original line number Diff line number Diff line
@@ -92,6 +92,7 @@ MODULE_PASS("tsan-module", ThreadSanitizerPass())
MODULE_PASS("kasan-module", ModuleAddressSanitizerPass(/*CompileKernel=*/true, false, true, false))
MODULE_PASS("sancov-module", ModuleSanitizerCoveragePass())
MODULE_PASS("poison-checking", PoisonCheckingPass())
MODULE_PASS("stack-safety-annotator", StackSafetyGlobalAnnotatorPass())
#undef MODULE_PASS

#ifndef CGSCC_ANALYSIS
Loading