Commit f5d89523 authored by Filipp Zhinkin's avatar Filipp Zhinkin Committed by Sanjay Patel
Browse files

[InstCombine] Transform X == 0 ? 0 : X * Y --> X * freeze(Y)

Enabled mul folding optimization that was previously disabled
by being incorrect.
To preserve correctness, mul's operand that is not compared
with zero in select's condition is now frozen.

Related bug: https://bugs.llvm.org/show_bug.cgi?id=51286

Correctness:
https://alive2.llvm.org/ce/z/bHef7J
https://alive2.llvm.org/ce/z/QcR7sf
https://alive2.llvm.org/ce/z/vvBLzt
https://alive2.llvm.org/ce/z/jGDXgq
https://alive2.llvm.org/ce/z/3Pe8Z4
https://alive2.llvm.org/ce/z/LGga8M
https://alive2.llvm.org/ce/z/CTG5fs

Differential Revision: https://reviews.llvm.org/D108408
parent be102805
Loading
Loading
Loading
Loading
+54 −0
Original line number Diff line number Diff line
@@ -723,6 +723,58 @@ static Instruction *foldSetClearBits(SelectInst &Sel,
  return nullptr;
}

//   select (x == 0), 0, x * y --> freeze(y) * x
//   select (y == 0), 0, x * y --> freeze(x) * y
//   select (x == 0), undef, x * y --> freeze(y) * x
//   select (x == undef), 0, x * y --> freeze(y) * x
// Usage of mul instead of 0 will make the result more poisonous,
// so the operand that was not checked in the condition should be frozen.
// The latter folding is applied only when a constant compared with x is
// is a vector consisting of 0 and undefs. If a constant compared with x
// is a scalar undefined value or undefined vector then an expression
// should be already folded into a constant.
static Instruction *foldSelectZeroOrMul(SelectInst &SI, InstCombinerImpl &IC) {
  auto *CondVal = SI.getCondition();
  auto *TrueVal = SI.getTrueValue();
  auto *FalseVal = SI.getFalseValue();
  Value *X, *Y;
  ICmpInst::Predicate Predicate;

  // Assuming that constant compared with zero is not undef (but it may be
  // a vector with some undef elements). Otherwise (when a constant is undef)
  // the select expression should be already simplified.
  if (!match(CondVal, m_ICmp(Predicate, m_Value(X), m_Zero())) ||
      !ICmpInst::isEquality(Predicate))
    return nullptr;

  if (Predicate == ICmpInst::ICMP_NE)
    std::swap(TrueVal, FalseVal);

  // Check that TrueVal is a constant instead of matching it with m_Zero()
  // to handle the case when it is a scalar undef value or a vector containing
  // non-zero elements that are masked by undef elements in the compare
  // constant.
  auto *TrueValC = dyn_cast<Constant>(TrueVal);
  if (TrueValC == nullptr ||
      !match(FalseVal, m_c_Mul(m_Specific(X), m_Value(Y))) ||
      !isa<Instruction>(FalseVal))
    return nullptr;

  auto *ZeroC = cast<Constant>(cast<Instruction>(CondVal)->getOperand(1));
  auto *MergedC = Constant::mergeUndefsWith(TrueValC, ZeroC);
  // If X is compared with 0 then TrueVal could be either zero or undef.
  // m_Zero match vectors containing some undef elements, but for scalars
  // m_Undef should be used explicitly.
  if (!match(MergedC, m_Zero()) && !match(MergedC, m_Undef()))
    return nullptr;

  auto *FalseValI = cast<Instruction>(FalseVal);
  auto *FrY = IC.InsertNewInstBefore(new FreezeInst(Y, Y->getName() + ".fr"),
                                     *FalseValI);
  IC.replaceOperand(*FalseValI, FalseValI->getOperand(0) == Y ? 0 : 1, FrY);
  return IC.replaceInstUsesWith(SI, FalseValI);
}

/// Transform patterns such as (a > b) ? a - b : 0 into usub.sat(a, b).
/// There are 8 commuted/swapped variants of this pattern.
/// TODO: Also support a - UMIN(a,b) patterns.
@@ -2930,6 +2982,8 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
    return Add;
  if (Instruction *Or = foldSetClearBits(SI, Builder))
    return Or;
  if (Instruction *Mul = foldSelectZeroOrMul(SI, *this))
    return Mul;

  // Turn (select C, (op X, Y), (op X, Z)) -> (op X, (select C, Y, Z))
  auto *TI = dyn_cast<Instruction>(TrueVal);
+52 −37
Original line number Diff line number Diff line
@@ -2844,12 +2844,12 @@ define <2 x i1> @partial_false_undef_condval(<2 x i1> %x) {
  ret <2 x i1> %r
}

; select (x == 0), 0, x * y --> freeze(y) * x
define i32 @mul_select_eq_zero(i32 %x, i32 %y) {
; CHECK-LABEL: @mul_select_eq_zero(
; CHECK-NEXT:    [[C:%.*]] = icmp eq i32 [[X:%.*]], 0
; CHECK-NEXT:    [[M:%.*]] = mul i32 [[X]], [[Y:%.*]]
; CHECK-NEXT:    [[R:%.*]] = select i1 [[C]], i32 0, i32 [[M]]
; CHECK-NEXT:    ret i32 [[R]]
; CHECK-NEXT:    [[Y_FR:%.*]] = freeze i32 [[Y:%.*]]
; CHECK-NEXT:    [[M:%.*]] = mul i32 [[Y_FR]], [[X:%.*]]
; CHECK-NEXT:    ret i32 [[M]]
;
  %c = icmp eq i32 %x, 0
  %m = mul i32 %x, %y
@@ -2857,12 +2857,12 @@ define i32 @mul_select_eq_zero(i32 %x, i32 %y) {
  ret i32 %r
}

; select (y == 0), 0, x * y --> freeze(x) * y
define i32 @mul_select_eq_zero_commute(i32 %x, i32 %y) {
; CHECK-LABEL: @mul_select_eq_zero_commute(
; CHECK-NEXT:    [[C:%.*]] = icmp eq i32 [[Y:%.*]], 0
; CHECK-NEXT:    [[M:%.*]] = mul i32 [[X:%.*]], [[Y]]
; CHECK-NEXT:    [[R:%.*]] = select i1 [[C]], i32 0, i32 [[M]]
; CHECK-NEXT:    ret i32 [[R]]
; CHECK-NEXT:    [[X_FR:%.*]] = freeze i32 [[X:%.*]]
; CHECK-NEXT:    [[M:%.*]] = mul i32 [[X_FR]], [[Y:%.*]]
; CHECK-NEXT:    ret i32 [[M]]
;
  %c = icmp eq i32 %y, 0
  %m = mul i32 %x, %y
@@ -2870,12 +2870,12 @@ define i32 @mul_select_eq_zero_commute(i32 %x, i32 %y) {
  ret i32 %r
}

; Check that mul's flags preserved during the transformation.
define i32 @mul_select_eq_zero_copy_flags(i32 %x, i32 %y) {
; CHECK-LABEL: @mul_select_eq_zero_copy_flags(
; CHECK-NEXT:    [[C:%.*]] = icmp eq i32 [[X:%.*]], 0
; CHECK-NEXT:    [[M:%.*]] = mul nuw nsw i32 [[X]], [[Y:%.*]]
; CHECK-NEXT:    [[R:%.*]] = select i1 [[C]], i32 0, i32 [[M]]
; CHECK-NEXT:    ret i32 [[R]]
; CHECK-NEXT:    [[Y_FR:%.*]] = freeze i32 [[Y:%.*]]
; CHECK-NEXT:    [[M:%.*]] = mul nuw nsw i32 [[Y_FR]], [[X:%.*]]
; CHECK-NEXT:    ret i32 [[M]]
;
  %c = icmp eq i32 %x, 0
  %m = mul nuw nsw i32 %x, %y
@@ -2883,25 +2883,31 @@ define i32 @mul_select_eq_zero_copy_flags(i32 %x, i32 %y) {
  ret i32 %r
}

; Check that the transformation could be applied after condition's inversion.
; select (x != 0), x * y, 0 --> freeze(y) * x
define i32 @mul_select_ne_zero(i32 %x, i32 %y) {
; CHECK-LABEL: @mul_select_ne_zero(
; CHECK-NEXT:    [[C_NOT:%.*]] = icmp eq i32 [[X:%.*]], 0
; CHECK-NEXT:    [[M:%.*]] = mul i32 [[X]], [[Y:%.*]]
; CHECK-NEXT:    [[R:%.*]] = select i1 [[C_NOT]], i32 0, i32 [[M]]
; CHECK-NEXT:    ret i32 [[R]]
; CHECK-NEXT:    [[C:%.*]] = icmp ne i32 [[X:%.*]], 0
; CHECK-NEXT:    [[Y_FR:%.*]] = freeze i32 [[Y:%.*]]
; CHECK-NEXT:    [[M:%.*]] = mul i32 [[Y_FR]], [[X]]
; CHECK-NEXT:    call void @use(i1 [[C]])
; CHECK-NEXT:    ret i32 [[M]]
;
  %c = icmp ne i32 %x, 0
  %m = mul i32 %x, %y
  %r = select i1 %c, i32 %m, i32 0
  call void @use(i1 %c)
  ret i32 %r
}

; Check that if one of a select's branches returns undef then
; an expression could be folded into mul as if there was a 0 instead of undef.
; select (x == 0), undef, x * y --> freeze(y) * x
define i32 @mul_select_eq_zero_sel_undef(i32 %x, i32 %y) {
; CHECK-LABEL: @mul_select_eq_zero_sel_undef(
; CHECK-NEXT:    [[C:%.*]] = icmp eq i32 [[X:%.*]], 0
; CHECK-NEXT:    [[M:%.*]] = mul i32 [[X]], [[Y:%.*]]
; CHECK-NEXT:    [[R:%.*]] = select i1 [[C]], i32 undef, i32 [[M]]
; CHECK-NEXT:    ret i32 [[R]]
; CHECK-NEXT:    [[Y_FR:%.*]] = freeze i32 [[Y:%.*]]
; CHECK-NEXT:    [[M:%.*]] = mul i32 [[Y_FR]], [[X:%.*]]
; CHECK-NEXT:    ret i32 [[M]]
;
  %c = icmp eq i32 %x, 0
  %m = mul i32 %x, %y
@@ -2909,15 +2915,16 @@ define i32 @mul_select_eq_zero_sel_undef(i32 %x, i32 %y) {
  ret i32 %r
}

; Check that the transformation is applied disregard to a number
; of expression's users.
define i32 @mul_select_eq_zero_multiple_users(i32 %x, i32 %y) {
; CHECK-LABEL: @mul_select_eq_zero_multiple_users(
; CHECK-NEXT:    [[M:%.*]] = mul i32 [[X:%.*]], [[Y:%.*]]
; CHECK-NEXT:    [[Y_FR:%.*]] = freeze i32 [[Y:%.*]]
; CHECK-NEXT:    [[M:%.*]] = mul i32 [[Y_FR]], [[X:%.*]]
; CHECK-NEXT:    call void @use_i32(i32 [[M]])
; CHECK-NEXT:    [[C:%.*]] = icmp eq i32 [[X]], 0
; CHECK-NEXT:    [[R:%.*]] = select i1 [[C]], i32 0, i32 [[M]]
; CHECK-NEXT:    call void @use_i32(i32 [[M]])
; CHECK-NEXT:    call void @use_i32(i32 [[R]])
; CHECK-NEXT:    ret i32 [[R]]
; CHECK-NEXT:    call void @use_i32(i32 [[M]])
; CHECK-NEXT:    ret i32 [[M]]
;
  %m = mul i32 %x, %y
  call void @use_i32(i32 %m)
@@ -2928,6 +2935,8 @@ define i32 @mul_select_eq_zero_multiple_users(i32 %x, i32 %y) {
  ret i32 %r
}

; Negative test: select's condition is unrelated to multiplied values,
; so the transformation should not be applied.
define i32 @mul_select_eq_zero_unrelated_condition(i32 %x, i32 %y, i32 %z) {
; CHECK-LABEL: @mul_select_eq_zero_unrelated_condition(
; CHECK-NEXT:    [[C:%.*]] = icmp eq i32 [[Z:%.*]], 0
@@ -2941,12 +2950,12 @@ define i32 @mul_select_eq_zero_unrelated_condition(i32 %x, i32 %y, i32 %z) {
  ret i32 %r
}

; select (<k x elt> x == 0), <k x elt> 0, <k x elt> x * y --> freeze(y) * x
define <4 x i32> @mul_select_eq_zero_vector(<4 x i32> %x, <4 x i32> %y) {
; CHECK-LABEL: @mul_select_eq_zero_vector(
; CHECK-NEXT:    [[C:%.*]] = icmp eq <4 x i32> [[X:%.*]], zeroinitializer
; CHECK-NEXT:    [[M:%.*]] = mul <4 x i32> [[X]], [[Y:%.*]]
; CHECK-NEXT:    [[R:%.*]] = select <4 x i1> [[C]], <4 x i32> zeroinitializer, <4 x i32> [[M]]
; CHECK-NEXT:    ret <4 x i32> [[R]]
; CHECK-NEXT:    [[Y_FR:%.*]] = freeze <4 x i32> [[Y:%.*]]
; CHECK-NEXT:    [[M:%.*]] = mul <4 x i32> [[Y_FR]], [[X:%.*]]
; CHECK-NEXT:    ret <4 x i32> [[M]]
;
  %c = icmp eq <4 x i32> %x, zeroinitializer
  %m = mul <4 x i32> %x, %y
@@ -2954,12 +2963,14 @@ define <4 x i32> @mul_select_eq_zero_vector(<4 x i32> %x, <4 x i32> %y) {
  ret <4 x i32> %r
}

; Check that a select is folded into multiplication if condition's operand
; is a vector consisting of zeros and undefs.
; select (<k x elt> x == {0, undef, ...}), <k x elt> 0, <k x elt> x * y --> freeze(y) * x
define <2 x i32> @mul_select_eq_undef_vector(<2 x i32> %x, <2 x i32> %y) {
; CHECK-LABEL: @mul_select_eq_undef_vector(
; CHECK-NEXT:    [[C:%.*]] = icmp eq <2 x i32> [[X:%.*]], <i32 0, i32 undef>
; CHECK-NEXT:    [[M:%.*]] = mul <2 x i32> [[X]], [[Y:%.*]]
; CHECK-NEXT:    [[R:%.*]] = select <2 x i1> [[C]], <2 x i32> <i32 0, i32 42>, <2 x i32> [[M]]
; CHECK-NEXT:    ret <2 x i32> [[R]]
; CHECK-NEXT:    [[Y_FR:%.*]] = freeze <2 x i32> [[Y:%.*]]
; CHECK-NEXT:    [[M:%.*]] = mul <2 x i32> [[Y_FR]], [[X:%.*]]
; CHECK-NEXT:    ret <2 x i32> [[M]]
;
  %c = icmp eq <2 x i32> %x, <i32 0, i32 undef>
  %m = mul <2 x i32> %x, %y
@@ -2967,12 +2978,14 @@ define <2 x i32> @mul_select_eq_undef_vector(<2 x i32> %x, <2 x i32> %y) {
  ret <2 x i32> %r
}

; Check that a select is folded into multiplication if other select's operand
; is a vector consisting of zeros and undefs.
; select (<k x elt> x == 0), <k x elt> {0, undef, ...}, <k x elt> x * y --> freeze(y) * x
define <2 x i32> @mul_select_eq_zero_sel_undef_vector(<2 x i32> %x, <2 x i32> %y) {
; CHECK-LABEL: @mul_select_eq_zero_sel_undef_vector(
; CHECK-NEXT:    [[C:%.*]] = icmp eq <2 x i32> [[X:%.*]], zeroinitializer
; CHECK-NEXT:    [[M:%.*]] = mul <2 x i32> [[X]], [[Y:%.*]]
; CHECK-NEXT:    [[R:%.*]] = select <2 x i1> [[C]], <2 x i32> <i32 0, i32 undef>, <2 x i32> [[M]]
; CHECK-NEXT:    ret <2 x i32> [[R]]
; CHECK-NEXT:    [[Y_FR:%.*]] = freeze <2 x i32> [[Y:%.*]]
; CHECK-NEXT:    [[M:%.*]] = mul <2 x i32> [[Y_FR]], [[X:%.*]]
; CHECK-NEXT:    ret <2 x i32> [[M]]
;
  %c = icmp eq <2 x i32> %x, zeroinitializer
  %m = mul <2 x i32> %x, %y
@@ -2980,6 +2993,8 @@ define <2 x i32> @mul_select_eq_zero_sel_undef_vector(<2 x i32> %x, <2 x i32> %y
  ret <2 x i32> %r
}

; Negative test: select should not be folded into mul because
; condition's operand and select's operand do not merge into zero vector.
define <2 x i32> @mul_select_eq_undef_vector_not_merging_to_zero(<2 x i32> %x, <2 x i32> %y) {
; CHECK-LABEL: @mul_select_eq_undef_vector_not_merging_to_zero(
; CHECK-NEXT:    [[C:%.*]] = icmp eq <2 x i32> [[X:%.*]], <i32 0, i32 undef>