Commit bb310b3f authored by Florian Hahn's avatar Florian Hahn
Browse files

Recommit "[SCCP] Remove forcedconstant, go to overdefined instead"

This version includes a fix for a set of crashes caused by marking
values depending on a yet unknown & tracked call as overdefined.

In some cases, we would later discover that the call has a constant
result and try to mark a user of it as constant, although it was already
marked as overdefined. Most instruction handlers bail out early if the
instruction is already overdefined. But that is not necessary for
CastInsts for example. By skipping values that depend on skipped
calls, we resolve the crashes and also improve the precision in some
cases (see resolvedundefsin-tracked-fn.ll).

Note that we may not skip PHI nodes that may depend on a skipped call,
but they can be safely marked as overdefined, as we bail out early if
the PHI node is overdefined.

This reverts the revert commit
a74b31a3e9cd844c7ce2087978568e3f5ec8519.
parent 5bb49540
Loading
Loading
Loading
Loading
+28 −235
Original line number Diff line number Diff line
@@ -85,19 +85,13 @@ class LatticeVal {
    /// constant - This LLVM Value has a specific constant value.
    constant,

    /// forcedconstant - This LLVM Value was thought to be undef until
    /// ResolvedUndefsIn.  This is treated just like 'constant', but if merged
    /// with another (different) constant, it goes to overdefined, instead of
    /// asserting.
    forcedconstant,

    /// overdefined - This instruction is not known to be constant, and we know
    /// it has a value.
    overdefined
  };

  /// Val: This stores the current lattice value along with the Constant* for
  /// the constant if this is a 'constant' or 'forcedconstant' value.
  /// the constant if this is a 'constant' value.
  PointerIntPair<Constant *, 2, LatticeValueTy> Val;

  LatticeValueTy getLatticeValue() const {
@@ -109,9 +103,7 @@ public:

  bool isUnknown() const { return getLatticeValue() == unknown; }

  bool isConstant() const {
    return getLatticeValue() == constant || getLatticeValue() == forcedconstant;
  }
  bool isConstant() const { return getLatticeValue() == constant; }

  bool isOverdefined() const { return getLatticeValue() == overdefined; }

@@ -131,26 +123,15 @@ public:

  /// markConstant - Return true if this is a change in status.
  bool markConstant(Constant *V) {
    if (getLatticeValue() == constant) { // Constant but not forcedconstant.
    if (getLatticeValue() == constant) { // Constant
      assert(getConstant() == V && "Marking constant with different value");
      return false;
    }

    if (isUnknown()) {
    assert(isUnknown());
    Val.setInt(constant);
    assert(V && "Marking constant with NULL");
    Val.setPointer(V);
    } else {
      assert(getLatticeValue() == forcedconstant &&
             "Cannot move from overdefined to constant!");
      // Stay at forcedconstant if the constant is the same.
      if (V == getConstant()) return false;

      // Otherwise, we go to overdefined.  Assumptions made based on the
      // forced value are possibly wrong.  Assuming this is another constant
      // could expose a contradiction.
      Val.setInt(overdefined);
    }
    return true;
  }

@@ -170,12 +151,6 @@ public:
    return nullptr;
  }

  void markForcedConstant(Constant *V) {
    assert(isUnknown() && "Can't force a defined value!");
    Val.setInt(forcedconstant);
    Val.setPointer(V);
  }

  ValueLatticeElement toValueLattice() const {
    if (isOverdefined())
      return ValueLatticeElement::getOverdefined();
@@ -421,7 +396,7 @@ public:
  }

private:
  // pushToWorkList - Helper for markConstant/markForcedConstant/markOverdefined
  // pushToWorkList - Helper for markConstant/markOverdefined
  void pushToWorkList(LatticeVal &IV, Value *V) {
    if (IV.isOverdefined())
      return OverdefinedInstWorkList.push_back(V);
@@ -443,14 +418,6 @@ private:
    return markConstant(ValueState[V], V, C);
  }

  void markForcedConstant(Value *V, Constant *C) {
    assert(!V->getType()->isStructTy() && "structs should use mergeInValue");
    LatticeVal &IV = ValueState[V];
    IV.markForcedConstant(C);
    LLVM_DEBUG(dbgs() << "markForcedConstant: " << *C << ": " << *V << '\n');
    pushToWorkList(IV, V);
  }

  // markOverdefined - Make a value be marked as "overdefined". If the
  // value is not already overdefined, add it to the overdefined instruction
  // work list so that the users of the instruction are updated later.
@@ -996,8 +963,6 @@ void SCCPSolver::visitUnaryOperator(Instruction &I) {
  LatticeVal V0State = getValueState(I.getOperand(0));

  LatticeVal &IV = ValueState[&I];
  if (IV.isOverdefined()) return;

  if (V0State.isConstant()) {
    Constant *C = ConstantExpr::get(I.getOpcode(), V0State.getConstant());

@@ -1032,8 +997,10 @@ void SCCPSolver::visitBinaryOperator(Instruction &I) {
  }

  // If something is undef, wait for it to resolve.
  if (!V1State.isOverdefined() && !V2State.isOverdefined())
  if (!V1State.isOverdefined() && !V2State.isOverdefined()) {

    return;
  }

  // Otherwise, one of our operands is overdefined.  Try to produce something
  // better than overdefined with some tricks.
@@ -1054,7 +1021,6 @@ void SCCPSolver::visitBinaryOperator(Instruction &I) {
      NonOverdefVal = &V1State;
    else if (!V2State.isOverdefined())
      NonOverdefVal = &V2State;

    if (NonOverdefVal) {
      if (NonOverdefVal->isUnknown())
        return;
@@ -1174,7 +1140,6 @@ void SCCPSolver::visitLoadInst(LoadInst &I) {
  if (PtrVal.isUnknown()) return;   // The pointer is not resolved yet!

  LatticeVal &IV = ValueState[&I];
  if (IV.isOverdefined()) return;

  if (!PtrVal.isConstant() || I.isVolatile())
    return (void)markOverdefined(IV, &I);
@@ -1449,11 +1414,11 @@ void SCCPSolver::Solve() {
/// constraints on the condition of the branch, as that would impact other users
/// of the value.
///
/// This scan also checks for values that use undefs, whose results are actually
/// defined.  For example, 'zext i8 undef to i32' should produce all zeros
/// conservatively, as "(zext i8 X -> i32) & 0xFF00" must always return zero,
/// even if X isn't defined.
/// This scan also checks for values that use undefs. It conservatively marks
/// them as overdefined.
bool SCCPSolver::ResolvedUndefsIn(Function &F) {
  // Keep track of values that dependent on an yet unknown tracked function call. It only makes sense to resolve them once the call is resolved.
  SmallPtrSet<Value *, 8> DependsOnSkipped;
  for (BasicBlock &BB : F) {
    if (!BBExecutable.count(&BB))
      continue;
@@ -1468,14 +1433,15 @@ bool SCCPSolver::ResolvedUndefsIn(Function &F) {
        // Tracked calls must never be marked overdefined in ResolvedUndefsIn.
        if (CallSite CS = CallSite(&I))
          if (Function *F = CS.getCalledFunction())
            if (MRVFunctionsTracked.count(F))
            if (MRVFunctionsTracked.count(F)) {
              DependsOnSkipped.insert(&I);
              continue;
            }

        // extractvalue and insertvalue don't need to be marked; they are
        // tracked as precisely as their operands.
        if (isa<ExtractValueInst>(I) || isa<InsertValueInst>(I))
          continue;

        // Send the results of everything else to overdefined.  We could be
        // more precise than this but it isn't worth bothering.
        for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) {
@@ -1495,195 +1461,22 @@ bool SCCPSolver::ResolvedUndefsIn(Function &F) {
      // 2. It could be constant-foldable.
      // Because of the way we solve return values, tracked calls must
      // never be marked overdefined in ResolvedUndefsIn.
      if (CallSite CS = CallSite(&I)) {
      if (CallSite CS = CallSite(&I))
        if (Function *F = CS.getCalledFunction())
          if (TrackedRetVals.count(F))
          if (TrackedRetVals.count(F)) {
            DependsOnSkipped.insert(&I);
            continue;

        // If the call is constant-foldable, we mark it overdefined because
        // we do not know what return values are valid.
        markOverdefined(&I);
        return true;
          }

      // extractvalue is safe; check here because the argument is a struct.
      if (isa<ExtractValueInst>(I))
      // Skip instructions that depend on results of calls we skipped earlier. Otherwise we might mark I as overdefined to early when we would end up discovering a constant value for I, if the call later resolves to a constant.
      if (any_of(I.operands(), [&DependsOnSkipped](Value *V) {
                 return DependsOnSkipped.find(V) != DependsOnSkipped.end(); })) {
        DependsOnSkipped.insert(&I);
        continue;

      // Compute the operand LatticeVals, for convenience below.
      // Anything taking a struct is conservatively assumed to require
      // overdefined markings.
      if (I.getOperand(0)->getType()->isStructTy()) {
        markOverdefined(&I);
        return true;
      }
      LatticeVal Op0LV = getValueState(I.getOperand(0));
      LatticeVal Op1LV;
      if (I.getNumOperands() == 2) {
        if (I.getOperand(1)->getType()->isStructTy()) {
          markOverdefined(&I);
          return true;
      }

        Op1LV = getValueState(I.getOperand(1));
      }
      // If this is an instructions whose result is defined even if the input is
      // not fully defined, propagate the information.
      Type *ITy = I.getType();
      switch (I.getOpcode()) {
      case Instruction::Add:
      case Instruction::Sub:
      case Instruction::Trunc:
      case Instruction::FPTrunc:
      case Instruction::BitCast:
        break; // Any undef -> undef
      case Instruction::FSub:
      case Instruction::FAdd:
      case Instruction::FMul:
      case Instruction::FDiv:
      case Instruction::FRem:
        // Floating-point binary operation: be conservative.
        if (Op0LV.isUnknown() && Op1LV.isUnknown())
          markForcedConstant(&I, Constant::getNullValue(ITy));
        else
      markOverdefined(&I);
      return true;
      case Instruction::FNeg:
        break; // fneg undef -> undef
      case Instruction::ZExt:
      case Instruction::SExt:
      case Instruction::FPToUI:
      case Instruction::FPToSI:
      case Instruction::FPExt:
      case Instruction::PtrToInt:
      case Instruction::IntToPtr:
      case Instruction::SIToFP:
      case Instruction::UIToFP:
        // undef -> 0; some outputs are impossible
        markForcedConstant(&I, Constant::getNullValue(ITy));
        return true;
      case Instruction::Mul:
      case Instruction::And:
        // Both operands undef -> undef
        if (Op0LV.isUnknown() && Op1LV.isUnknown())
          break;
        // undef * X -> 0.   X could be zero.
        // undef & X -> 0.   X could be zero.
        markForcedConstant(&I, Constant::getNullValue(ITy));
        return true;
      case Instruction::Or:
        // Both operands undef -> undef
        if (Op0LV.isUnknown() && Op1LV.isUnknown())
          break;
        // undef | X -> -1.   X could be -1.
        markForcedConstant(&I, Constant::getAllOnesValue(ITy));
        return true;
      case Instruction::Xor:
        // undef ^ undef -> 0; strictly speaking, this is not strictly
        // necessary, but we try to be nice to people who expect this
        // behavior in simple cases
        if (Op0LV.isUnknown() && Op1LV.isUnknown()) {
          markForcedConstant(&I, Constant::getNullValue(ITy));
          return true;
        }
        // undef ^ X -> undef
        break;
      case Instruction::SDiv:
      case Instruction::UDiv:
      case Instruction::SRem:
      case Instruction::URem:
        // X / undef -> undef.  No change.
        // X % undef -> undef.  No change.
        if (Op1LV.isUnknown()) break;

        // X / 0 -> undef.  No change.
        // X % 0 -> undef.  No change.
        if (Op1LV.isConstant() && Op1LV.getConstant()->isZeroValue())
          break;

        // undef / X -> 0.   X could be maxint.
        // undef % X -> 0.   X could be 1.
        markForcedConstant(&I, Constant::getNullValue(ITy));
        return true;
      case Instruction::AShr:
        // X >>a undef -> undef.
        if (Op1LV.isUnknown()) break;

        // Shifting by the bitwidth or more is undefined.
        if (Op1LV.isConstant()) {
          if (auto *ShiftAmt = Op1LV.getConstantInt())
            if (ShiftAmt->getLimitedValue() >=
                ShiftAmt->getType()->getScalarSizeInBits())
              break;
        }

        // undef >>a X -> 0
        markForcedConstant(&I, Constant::getNullValue(ITy));
        return true;
      case Instruction::LShr:
      case Instruction::Shl:
        // X << undef -> undef.
        // X >> undef -> undef.
        if (Op1LV.isUnknown()) break;

        // Shifting by the bitwidth or more is undefined.
        if (Op1LV.isConstant()) {
          if (auto *ShiftAmt = Op1LV.getConstantInt())
            if (ShiftAmt->getLimitedValue() >=
                ShiftAmt->getType()->getScalarSizeInBits())
              break;
        }

        // undef << X -> 0
        // undef >> X -> 0
        markForcedConstant(&I, Constant::getNullValue(ITy));
        return true;
      case Instruction::Select:
        Op1LV = getValueState(I.getOperand(1));
        // undef ? X : Y  -> X or Y.  There could be commonality between X/Y.
        if (Op0LV.isUnknown()) {
          if (!Op1LV.isConstant())  // Pick the constant one if there is any.
            Op1LV = getValueState(I.getOperand(2));
        } else if (Op1LV.isUnknown()) {
          // c ? undef : undef -> undef.  No change.
          Op1LV = getValueState(I.getOperand(2));
          if (Op1LV.isUnknown())
            break;
          // Otherwise, c ? undef : x -> x.
        } else {
          // Leave Op1LV as Operand(1)'s LatticeValue.
        }

        if (Op1LV.isConstant())
          markForcedConstant(&I, Op1LV.getConstant());
        else
          markOverdefined(&I);
        return true;
      case Instruction::Load:
        // A load here means one of two things: a load of undef from a global,
        // a load from an unknown pointer.  Either way, having it return undef
        // is okay.
        break;
      case Instruction::ICmp:
        // X == undef -> undef.  Other comparisons get more complicated.
        Op0LV = getValueState(I.getOperand(0));
        Op1LV = getValueState(I.getOperand(1));

        if ((Op0LV.isUnknown() || Op1LV.isUnknown()) &&
            cast<ICmpInst>(&I)->isEquality())
          break;
        markOverdefined(&I);
        return true;
      case Instruction::Call:
      case Instruction::Invoke:
      case Instruction::CallBr:
        llvm_unreachable("Call-like instructions should have be handled early");
      default:
        // If we don't know what should happen here, conservatively mark it
        // overdefined.
        markOverdefined(&I);
        return true;
      }
    }

    // Check to see if we have a branch or switch on an undefined value.  If so
+6 −3
Original line number Diff line number Diff line
@@ -7,7 +7,9 @@ target triple = "x86_64-unknown-linux-gnu"
define i64 @fn2() {
; CHECK-LABEL: define {{[^@]+}}@fn2()
; CHECK-NEXT:  entry:
; CHECK-NEXT:    [[CALL2:%.*]] = call i64 @fn1(i64 undef)
; CHECK-NEXT:    [[CONV:%.*]] = sext i32 undef to i64
; CHECK-NEXT:    [[DIV:%.*]] = sdiv i64 8, [[CONV]]
; CHECK-NEXT:    [[CALL2:%.*]] = call i64 @fn1(i64 [[DIV]])
; CHECK-NEXT:    ret i64 [[CALL2]]
;
entry:
@@ -21,7 +23,8 @@ define internal i64 @fn1(i64 %p1) {
; CHECK-LABEL: define {{[^@]+}}@fn1
; CHECK-SAME: (i64 [[P1:%.*]])
; CHECK-NEXT:  entry:
; CHECK-NEXT:    [[COND:%.*]] = select i1 undef, i64 undef, i64 undef
; CHECK-NEXT:    [[TOBOOL:%.*]] = icmp ne i64 [[P1]], 0
; CHECK-NEXT:    [[COND:%.*]] = select i1 [[TOBOOL]], i64 [[P1]], i64 [[P1]]
; CHECK-NEXT:    ret i64 [[COND]]
;
entry:
+4 −2
Original line number Diff line number Diff line
@@ -11,7 +11,8 @@ define void @fn2(i32* %P) {
; CHECK:       for.cond1:
; CHECK-NEXT:    br i1 false, label [[IF_END]], label [[IF_END]]
; CHECK:       if.end:
; CHECK-NEXT:    [[CALL:%.*]] = call i32 @fn1(i32 undef)
; CHECK-NEXT:    [[TMP0:%.*]] = load i32, i32* null, align 4
; CHECK-NEXT:    [[CALL:%.*]] = call i32 @fn1(i32 [[TMP0]])
; CHECK-NEXT:    store i32 [[CALL]], i32* [[P]]
; CHECK-NEXT:    br label [[FOR_COND1:%.*]]
;
@@ -33,7 +34,8 @@ define internal i32 @fn1(i32 %p1) {
; CHECK-LABEL: define {{[^@]+}}@fn1
; CHECK-SAME: (i32 [[P1:%.*]])
; CHECK-NEXT:  entry:
; CHECK-NEXT:    [[COND:%.*]] = select i1 undef, i32 undef, i32 undef
; CHECK-NEXT:    [[TOBOOL:%.*]] = icmp ne i32 [[P1]], 0
; CHECK-NEXT:    [[COND:%.*]] = select i1 [[TOBOOL]], i32 [[P1]], i32 [[P1]]
; CHECK-NEXT:    ret i32 [[COND]]
;
entry:
+8 −4
Original line number Diff line number Diff line
; RUN: opt < %s -sccp -S | \
; RUN:   grep "ret i1 false"
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
; RUN: opt < %s -sccp -S | FileCheck %s

define i1 @foo() {
; CHECK-LABEL: @foo(
; CHECK-NEXT:    [[X:%.*]] = and i1 false, undef
; CHECK-NEXT:    ret i1 [[X]]
;
  %X = and i1 false, undef		; <i1> [#uses=1]
  ret i1 %X
}
+23 −1
Original line number Diff line number Diff line
@@ -18,7 +18,13 @@ define i101 @array() {
}

; CHECK-LABEL: @large_aggregate
; CHECK-NEXT: ret i101 undef
; CHECK-NEXT:    %B = load i101, i101* undef
; CHECK-NEXT:    %D = and i101 %B, 1
; CHECK-NEXT:    %DD = or i101 %D, 1
; CHECK-NEXT:    %G = getelementptr i101, i101* getelementptr inbounds ([6 x i101], [6 x i101]* @Y, i32 0, i32 5), i101 %DD
; CHECK-NEXT:    %L3 = load i101, i101* %G
; CHECK-NEXT:    ret i101 %L3
;
define i101 @large_aggregate() {
  %B = load i101, i101* undef
  %D = and i101 %B, 1
@@ -29,6 +35,22 @@ define i101 @large_aggregate() {
  ret i101 %L3
}

; CHECK-LABEL: define i101 @large_aggregate_2() {
; CHECK-NEXT:     %D = and i101 undef, 1
; CHECK-NEXT:     %DD = or i101 %D, 1
; CHECK-NEXT:     %G = getelementptr i101, i101* getelementptr inbounds ([6 x i101], [6 x i101]* @Y, i32 0, i32 5), i101 %DD
; CHECK-NEXT:     %L3 = load i101, i101* %G
; CHECK-NEXT:     ret i101 %L3
;
define i101 @large_aggregate_2() {
  %D = and i101 undef, 1
  %DD = or i101 %D, 1
  %F = getelementptr [6 x i101], [6 x i101]* @Y, i32 0, i32 5
  %G = getelementptr i101, i101* %F, i101 %DD
  %L3 = load i101, i101* %G
  ret i101 %L3
}

; CHECK-LABEL: @index_too_large
; CHECK-NEXT: store i101* getelementptr (i101, i101* getelementptr ([6 x i101], [6 x i101]* @Y, i32 0, i32 -1), i101 9224497936761618431), i101** undef
; CHECK-NEXT: ret void
Loading