Commit c9e93c84 authored by Tyker's avatar Tyker
Browse files

Add Query API for llvm.assume holding attributes

Reviewers: jdoerfert, sstefan1, uenoku

Reviewed By: jdoerfert

Subscribers: mgorny, hiraditya, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D72885
parent 8ee0e1dc
Loading
Loading
Loading
Loading
+36 −0
Original line number Diff line number Diff line
@@ -16,6 +16,7 @@
#ifndef LLVM_TRANSFORMS_UTILS_ASSUMEBUILDER_H
#define LLVM_TRANSFORMS_UTILS_ASSUMEBUILDER_H

#include "llvm/IR/Attributes.h"
#include "llvm/IR/Instruction.h"
#include "llvm/IR/PassManager.h"

@@ -30,6 +31,41 @@ inline CallInst *BuildAssumeFromInst(Instruction *I) {
  return BuildAssumeFromInst(I, I->getModule());
}

/// It is possible to have multiple Value for the argument of an attribute in
/// the same llvm.assume on the same llvm::Value. This is rare but need to be
/// dealt with.
enum class AssumeQuery {
  Highest, ///< Take the highest value available.
  Lowest,  ///< Take the lowest value available.
};

/// Query the operand bundle of an llvm.assume to find a single attribute of
/// the specified kind applied on a specified Value.
///
/// This has a non-constant complexity. It should only be used when a single
/// attribute is going to be queried.
///
/// Return true iff the queried attribute was found.
/// If ArgVal is set. the argument will be stored to ArgVal.
bool hasAttributeInAssume(CallInst &AssumeCI, Value *IsOn, StringRef AttrName,
                          uint64_t *ArgVal = nullptr,
                          AssumeQuery AQR = AssumeQuery::Highest);
inline bool hasAttributeInAssume(CallInst &AssumeCI, Value *IsOn,
                                 Attribute::AttrKind Kind,
                                 uint64_t *ArgVal = nullptr,
                                 AssumeQuery AQR = AssumeQuery::Highest) {
  return hasAttributeInAssume(
      AssumeCI, IsOn, Attribute::getNameFromAttrKind(Kind), ArgVal, AQR);
}

/// TODO: Add an function to create/fill a map from the bundle when users intend
/// to make many different queries on the same bundles. to be used for example
/// in the Attributor.

//===----------------------------------------------------------------------===//
// Utilities for testing
//===----------------------------------------------------------------------===//

/// This pass will try to build an llvm.assume for every instruction in the
/// function. Its main purpose is testing.
struct AssumeBuilderPass : public PassInfoMixin<AssumeBuilderPass> {
+101 −8
Original line number Diff line number Diff line
@@ -15,13 +15,13 @@

using namespace llvm;

namespace {

cl::opt<bool> ShouldPreserveAllAttributes(
    "assume-preserve-all", cl::init(false), cl::Hidden,
    cl::desc("enable preservation of all attrbitues. even those that are "
             "unlikely to be usefull"));

namespace {

struct AssumedKnowledge {
  const char *Name;
  Value *Argument;
@@ -59,22 +59,33 @@ template <> struct DenseMapInfo<AssumedKnowledge> {

namespace {

/// Index of elements in the operand bundle.
/// If the element exist it is guaranteed to be what is specified in this enum
/// but it may not exist.
enum BundleOpInfoElem {
  BOIE_WasOn = 0,
  BOIE_Argument = 1,
};

/// Deterministically compare OperandBundleDef.
/// The ordering is:
/// - by the name of the attribute, (doesn't change)
/// - then by the Value of the argument, (doesn't change)
/// - by the attribute's name aka operand bundle tag, (doesn't change)
/// - then by the numeric Value of the argument, (doesn't change)
/// - lastly by the Name of the current Value it WasOn. (may change)
/// This order is deterministic and allows looking for the right kind of
/// attribute with binary search. However finding the right WasOn needs to be
/// done via linear search because values can get remplaced.
/// done via linear search because values can get replaced.
bool isLowerOpBundle(const OperandBundleDef &LHS, const OperandBundleDef &RHS) {
  auto getTuple = [](const OperandBundleDef &Op) {
    return std::make_tuple(
        Op.getTag(),
        Op.input_size() < 2
        Op.input_size() <= BOIE_Argument
            ? 0
            : cast<ConstantInt>(*std::next(Op.input_begin()))->getZExtValue(),
        Op.input_size() < 1 ? StringRef("") : (*Op.input_begin())->getName());
            : cast<ConstantInt>(*(Op.input_begin() + BOIE_Argument))
                  ->getZExtValue(),
         Op.input_size() <= BOIE_WasOn
            ? StringRef("")
            : (*(Op.input_begin() + BOIE_WasOn))->getName());
  };
  return getTuple(LHS) < getTuple(RHS);
}
@@ -160,6 +171,88 @@ CallInst *llvm::BuildAssumeFromInst(const Instruction *I, Module *M) {
  return Builder.build();
}

#ifndef NDEBUG

static bool isExistingAttribute(StringRef Name) {
  return StringSwitch<bool>(Name)
#define GET_ATTR_NAMES
#define ATTRIBUTE_ALL(ENUM_NAME, DISPLAY_NAME) .Case(#DISPLAY_NAME, true)
#include "llvm/IR/Attributes.inc"
      .Default(false);
}

#endif

bool llvm::hasAttributeInAssume(CallInst &AssumeCI, Value *IsOn,
                                StringRef AttrName, uint64_t *ArgVal,
                                AssumeQuery AQR) {
  IntrinsicInst &Assume = cast<IntrinsicInst>(AssumeCI);
  assert(Assume.getIntrinsicID() == Intrinsic::assume &&
         "this function is intended to be used on llvm.assume");
  assert(isExistingAttribute(AttrName) && "this attribute doesn't exist");
  assert((ArgVal == nullptr || Attribute::doesAttrKindHaveArgument(
                                   Attribute::getAttrKindFromName(AttrName))) &&
         "requested value for an attribute that has no argument");
  if (Assume.bundle_op_infos().empty())
    return false;

  CallInst::bundle_op_iterator Lookup;

  /// The right attribute can be found by binary search. After this finding the
  /// right WasOn needs to be done via linear search.
  /// Element have been ordered by argument value so the first we find is the
  /// one we need.
  if (AQR == AssumeQuery::Lowest)
    Lookup =
        llvm::lower_bound(Assume.bundle_op_infos(), AttrName,
                          [](const CallBase::BundleOpInfo &BOI, StringRef RHS) {
                            assert(isExistingAttribute(BOI.Tag->getKey()) &&
                                   "this attribute doesn't exist");
                            return BOI.Tag->getKey() < RHS;
                          });
  else
    Lookup = std::prev(
        llvm::upper_bound(Assume.bundle_op_infos(), AttrName,
                          [](StringRef LHS, const CallBase::BundleOpInfo &BOI) {
                            assert(isExistingAttribute(BOI.Tag->getKey()) &&
                                   "this attribute doesn't exist");
                            return LHS < BOI.Tag->getKey();
                          }));

  auto getValueFromBundleOpInfo = [&Assume](const CallBase::BundleOpInfo &BOI,
                                            unsigned Idx) {
    assert(BOI.End - BOI.Begin > Idx && "index out of range");
    return (Assume.op_begin() + BOI.Begin + Idx)->get();
  };

  if (Lookup == Assume.bundle_op_info_end() ||
      Lookup->Tag->getKey() != AttrName)
    return false;
  if (IsOn) {
    if (Lookup->End - Lookup->Begin < BOIE_WasOn)
      return false;
    while (true) {
      if (Lookup == Assume.bundle_op_info_end() ||
          Lookup->Tag->getKey() != AttrName)
        return false;
      if (getValueFromBundleOpInfo(*Lookup, BOIE_WasOn) == IsOn)
        break;
      if (AQR == AssumeQuery::Highest &&
          Lookup == Assume.bundle_op_info_begin())
        return false;
      Lookup = Lookup + (AQR == AssumeQuery::Lowest ? 1 : -1);
    }
  }

  if (Lookup->End - Lookup->Begin < BOIE_Argument)
    return true;
  if (ArgVal)
    *ArgVal =
        cast<ConstantInt>(getValueFromBundleOpInfo(*Lookup, BOIE_Argument))
            ->getZExtValue();
  return true;
}

PreservedAnalyses AssumeBuilderPass::run(Function &F,
                                         FunctionAnalysisManager &AM) {
  for (Instruction &I : instructions(F))
+1 −0
Original line number Diff line number Diff line
@@ -15,6 +15,7 @@ add_llvm_unittest(UtilsTests
  CodeMoverUtilsTest.cpp
  FunctionComparatorTest.cpp
  IntegerDivisionTest.cpp
  KnowledgeRetentionTest.cpp
  LocalTest.cpp
  LoopRotationUtilsTest.cpp
  LoopUtilsTest.cpp
+215 −0
Original line number Diff line number Diff line
//===- KnowledgeRetention.h - utilities to preserve informations *- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "llvm/Transforms/Utils/KnowledgeRetention.h"
#include "llvm/AsmParser/Parser.h"
#include "llvm/IR/CallSite.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/Support/Regex.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/CommandLine.h"
#include "gtest/gtest.h"

using namespace llvm;

extern cl::opt<bool> ShouldPreserveAllAttributes;

static void RunTest(
    StringRef Head, StringRef Tail,
    std::vector<std::pair<StringRef, llvm::function_ref<void(Instruction *)>>>
        &Tests) {
  std::string IR;
  IR.append(Head.begin(), Head.end());
  for (auto &Elem : Tests)
    IR.append(Elem.first.begin(), Elem.first.end());
  IR.append(Tail.begin(), Tail.end());
  LLVMContext C;
  SMDiagnostic Err;
  std::unique_ptr<Module> Mod = parseAssemblyString(IR, Err, C);
  if (!Mod)
    Err.print("AssumeQueryAPI", errs());
  unsigned Idx = 0;
  for (Instruction &I : (*Mod->getFunction("test")->begin())) {
    if (Idx < Tests.size())
      Tests[Idx].second(&I);
    Idx++;
  }
}

void AssertMatchesExactlyAttributes(CallInst *Assume, Value *WasOn,
                                    StringRef AttrToMatch) {
  Regex Reg(AttrToMatch);
  SmallVector<StringRef, 1> Matches;
  for (StringRef Attr : {
#define GET_ATTR_NAMES
#define ATTRIBUTE_ALL(ENUM_NAME, DISPLAY_NAME) StringRef(#DISPLAY_NAME),
#include "llvm/IR/Attributes.inc"
       }) {
    bool ShouldHaveAttr = Reg.match(Attr, &Matches) && Matches[0] == Attr;
    if (ShouldHaveAttr != hasAttributeInAssume(*Assume, WasOn, Attr))
      ASSERT_TRUE(false);
  }
}

void AssertHasTheRightValue(CallInst *Assume, Value *WasOn,
                            Attribute::AttrKind Kind, unsigned Value, bool Both,
                            AssumeQuery AQ = AssumeQuery::Highest) {
  if (!Both) {
    uint64_t ArgVal = 0;
    ASSERT_TRUE(hasAttributeInAssume(*Assume, WasOn, Kind, &ArgVal, AQ));
    ASSERT_EQ(ArgVal, Value);
    return;
  }
  uint64_t ArgValLow = 0;
  uint64_t ArgValHigh = 0;
  bool ResultLow = hasAttributeInAssume(*Assume, WasOn, Kind, &ArgValLow,
                                        AssumeQuery::Lowest);
  bool ResultHigh = hasAttributeInAssume(*Assume, WasOn, Kind, &ArgValHigh,
                                         AssumeQuery::Highest);
  if (ResultLow != ResultHigh)
    ASSERT_TRUE(false);
  if (ArgValLow != Value || ArgValLow != ArgValHigh)
    ASSERT_EQ(ArgValLow, Value);
}

TEST(AssumeQueryAPI, Basic) {
  StringRef Head =
      "declare void @llvm.assume(i1)\n"
      "declare void @func(i32*, i32*)\n"
      "declare void @func1(i32*, i32*, i32*, i32*)\n"
      "declare void @func_many(i32*) \"no-jump-tables\" nounwind "
      "\"less-precise-fpmad\" willreturn norecurse\n"
      "define void @test(i32* %P, i32* %P1, i32* %P2, i32* %P3) {\n";
  StringRef Tail = "ret void\n"
                   "}";
  std::vector<std::pair<StringRef, llvm::function_ref<void(Instruction *)>>>
      Tests;
  Tests.push_back(std::make_pair(
      "call void @func(i32* nonnull align 4 dereferenceable(16) %P, i32* align "
      "8 noalias %P1)\n",
      [](Instruction *I) {
        CallInst *Assume = BuildAssumeFromInst(I);
        Assume->insertBefore(I);
        AssertMatchesExactlyAttributes(Assume, I->getOperand(0),
                                       "(nonnull|align|dereferenceable)");
        AssertMatchesExactlyAttributes(Assume, I->getOperand(1),
                                       "(noalias|align)");
        AssertHasTheRightValue(Assume, I->getOperand(0),
                               Attribute::AttrKind::Dereferenceable, 16, true);
        AssertHasTheRightValue(Assume, I->getOperand(0),
                               Attribute::AttrKind::Alignment, 4, true);
        AssertHasTheRightValue(Assume, I->getOperand(0),
                               Attribute::AttrKind::Alignment, 4, true);
      }));
  Tests.push_back(std::make_pair(
      "call void @func1(i32* nonnull align 32 dereferenceable(48) %P, i32* "
      "nonnull "
      "align 8 dereferenceable(28) %P, i32* nonnull align 64 "
      "dereferenceable(4) "
      "%P, i32* nonnull align 16 dereferenceable(12) %P)\n",
      [](Instruction *I) {
        CallInst *Assume = BuildAssumeFromInst(I);
        Assume->insertBefore(I);
        AssertMatchesExactlyAttributes(Assume, I->getOperand(0),
                                       "(nonnull|align|dereferenceable)");
        AssertMatchesExactlyAttributes(Assume, I->getOperand(1),
                                       "(nonnull|align|dereferenceable)");
        AssertMatchesExactlyAttributes(Assume, I->getOperand(2),
                                       "(nonnull|align|dereferenceable)");
        AssertMatchesExactlyAttributes(Assume, I->getOperand(3),
                                       "(nonnull|align|dereferenceable)");
        AssertHasTheRightValue(Assume, I->getOperand(0),
                               Attribute::AttrKind::Dereferenceable, 48, false,
                               AssumeQuery::Highest);
        AssertHasTheRightValue(Assume, I->getOperand(0),
                               Attribute::AttrKind::Alignment, 64, false,
                               AssumeQuery::Highest);
        AssertHasTheRightValue(Assume, I->getOperand(1),
                               Attribute::AttrKind::Alignment, 64, false,
                               AssumeQuery::Highest);
        AssertHasTheRightValue(Assume, I->getOperand(0),
                               Attribute::AttrKind::Dereferenceable, 4, false,
                               AssumeQuery::Lowest);
        AssertHasTheRightValue(Assume, I->getOperand(0),
                               Attribute::AttrKind::Alignment, 8, false,
                               AssumeQuery::Lowest);
        AssertHasTheRightValue(Assume, I->getOperand(1),
                               Attribute::AttrKind::Alignment, 8, false,
                               AssumeQuery::Lowest);
      }));
  Tests.push_back(std::make_pair(
      "call void @func_many(i32* align 8 %P1) cold\n", [](Instruction *I) {
        ShouldPreserveAllAttributes.setValue(true);
        CallInst *Assume = BuildAssumeFromInst(I);
        Assume->insertBefore(I);
        AssertMatchesExactlyAttributes(
            Assume, nullptr,
            "(align|no-jump-tables|less-precise-fpmad|"
            "nounwind|norecurse|willreturn|cold)");
        ShouldPreserveAllAttributes.setValue(false);
      }));
  Tests.push_back(
      std::make_pair("call void @llvm.assume(i1 true)\n", [](Instruction *I) {
        CallInst *Assume = cast<CallInst>(I);
        AssertMatchesExactlyAttributes(Assume, nullptr, "");
      }));
  Tests.push_back(std::make_pair(
      "call void @func1(i32* readnone align 32 "
      "dereferenceable(48) noalias %P, i32* "
      "align 8 dereferenceable(28) %P1, i32* align 64 "
      "dereferenceable(4) "
      "%P2, i32* nonnull align 16 dereferenceable(12) %P3)\n",
      [](Instruction *I) {
        CallInst *Assume = BuildAssumeFromInst(I);
        Assume->insertBefore(I);
        AssertMatchesExactlyAttributes(
            Assume, I->getOperand(0),
            "(readnone|align|dereferenceable|noalias)");
        AssertMatchesExactlyAttributes(Assume, I->getOperand(1),
                                       "(align|dereferenceable)");
        AssertMatchesExactlyAttributes(Assume, I->getOperand(2),
                                       "(align|dereferenceable)");
        AssertMatchesExactlyAttributes(Assume, I->getOperand(3),
                                       "(nonnull|align|dereferenceable)");
        AssertHasTheRightValue(Assume, I->getOperand(0),
                               Attribute::AttrKind::Alignment, 32, true);
        AssertHasTheRightValue(Assume, I->getOperand(0),
                               Attribute::AttrKind::Dereferenceable, 48, true);
        AssertHasTheRightValue(Assume, I->getOperand(1),
                               Attribute::AttrKind::Dereferenceable, 28, true);
        AssertHasTheRightValue(Assume, I->getOperand(1),
                               Attribute::AttrKind::Alignment, 8, true);
        AssertHasTheRightValue(Assume, I->getOperand(2),
                               Attribute::AttrKind::Alignment, 64, true);
        AssertHasTheRightValue(Assume, I->getOperand(2),
                               Attribute::AttrKind::Dereferenceable, 4, true);
        AssertHasTheRightValue(Assume, I->getOperand(3),
                               Attribute::AttrKind::Alignment, 16, true);
        AssertHasTheRightValue(Assume, I->getOperand(3),
                               Attribute::AttrKind::Dereferenceable, 12, true);
      }));

  /// Keep this test last as it modifies the function.
  Tests.push_back(std::make_pair(
      "call void @func(i32* nonnull align 4 dereferenceable(16) %P, i32* align "
      "8 noalias %P1)\n",
      [](Instruction *I) {
        CallInst *Assume = BuildAssumeFromInst(I);
        Assume->insertBefore(I);
        Value *New = I->getFunction()->getArg(3);
        Value *Old = I->getOperand(0);
        AssertMatchesExactlyAttributes(Assume, New, "");
        AssertMatchesExactlyAttributes(Assume, Old,
                                       "(nonnull|align|dereferenceable)");
        Old->replaceAllUsesWith(New);
        AssertMatchesExactlyAttributes(Assume, New,
                                       "(nonnull|align|dereferenceable)");
        AssertMatchesExactlyAttributes(Assume, Old, "");
      }));
  RunTest(Head, Tail, Tests);
}