Commit 9c87c558 authored by Christopher Tetreault's avatar Christopher Tetreault
Browse files

[SVE] Make cstfp_pred_ty and cst_pred_ty work with scalable splats

Reviewers: efriedma, lebedev.ri, fhahn, c-rhodes, david-arm

Reviewed By: efriedma, david-arm

Subscribers: tschuett, rkruppe, psnobl, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D83001
parent fcf0f75a
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -308,6 +308,7 @@ public:
  /// Return true if Ty is big enough to represent V.
  static bool isValueValidForType(Type *Ty, const APFloat &V);
  inline const APFloat &getValueAPF() const { return Val; }
  inline const APFloat &getValue() const { return Val; }

  /// Return true if the value is positive or negative zero.
  bool isZero() const { return Val.isZero(); }
+25 −49
Original line number Diff line number Diff line
@@ -262,17 +262,23 @@ template <int64_t Val> inline constantint_match<Val> m_ConstantInt() {
  return constantint_match<Val>();
}

/// This helper class is used to match scalar and fixed width vector integer
/// constants that satisfy a specified predicate.
/// For vector constants, undefined elements are ignored.
template <typename Predicate> struct cst_pred_ty : public Predicate {
/// This helper class is used to match constant scalars, vector splats,
/// and fixed width vectors that satisfy a specified predicate.
/// For fixed width vector constants, undefined elements are ignored.
template <typename Predicate, typename ConstantVal>
struct cstval_pred_ty : public Predicate {
  template <typename ITy> bool match(ITy *V) {
    if (const auto *CI = dyn_cast<ConstantInt>(V))
      return this->isValue(CI->getValue());
    if (const auto *FVTy = dyn_cast<FixedVectorType>(V->getType())) {
    if (const auto *CV = dyn_cast<ConstantVal>(V))
      return this->isValue(CV->getValue());
    if (const auto *VTy = dyn_cast<VectorType>(V->getType())) {
      if (const auto *C = dyn_cast<Constant>(V)) {
        if (const auto *CI = dyn_cast_or_null<ConstantInt>(C->getSplatValue()))
          return this->isValue(CI->getValue());
        if (const auto *CV = dyn_cast_or_null<ConstantVal>(C->getSplatValue()))
          return this->isValue(CV->getValue());

        // Number of elements of a scalable vector unknown at compile time
        auto *FVTy = dyn_cast<FixedVectorType>(VTy);
        if (!FVTy)
          return false;

        // Non-splat vector constant: check each element for a match.
        unsigned NumElts = FVTy->getNumElements();
@@ -284,8 +290,8 @@ template <typename Predicate> struct cst_pred_ty : public Predicate {
            return false;
          if (isa<UndefValue>(Elt))
            continue;
          auto *CI = dyn_cast<ConstantInt>(Elt);
          if (!CI || !this->isValue(CI->getValue()))
          auto *CV = dyn_cast<ConstantVal>(Elt);
          if (!CV || !this->isValue(CV->getValue()))
            return false;
          HasNonUndefElements = true;
        }
@@ -296,6 +302,14 @@ template <typename Predicate> struct cst_pred_ty : public Predicate {
  }
};

/// specialization of cstval_pred_ty for ConstantInt
template <typename Predicate>
using cst_pred_ty = cstval_pred_ty<Predicate, ConstantInt>;

/// specialization of cstval_pred_ty for ConstantFP
template <typename Predicate>
using cstfp_pred_ty = cstval_pred_ty<Predicate, ConstantFP>;

/// This helper class is used to match scalar and vector constants that
/// satisfy a specified predicate, and bind them to an APInt.
template <typename Predicate> struct api_pred_ty : public Predicate {
@@ -321,44 +335,6 @@ template <typename Predicate> struct api_pred_ty : public Predicate {
  }
};

/// This helper class is used to match scalar and vector floating-point
/// constants that satisfy a specified predicate.
/// For vector constants, undefined elements are ignored.
template <typename Predicate> struct cstfp_pred_ty : public Predicate {
  template <typename ITy> bool match(ITy *V) {
    if (const auto *CF = dyn_cast<ConstantFP>(V))
      return this->isValue(CF->getValueAPF());
    if (V->getType()->isVectorTy()) {
      if (const auto *C = dyn_cast<Constant>(V)) {
        if (const auto *CF = dyn_cast_or_null<ConstantFP>(C->getSplatValue()))
          return this->isValue(CF->getValueAPF());

        // Number of elements of a scalable vector unknown at compile time
        if (isa<ScalableVectorType>(V->getType()))
          return false;

        // Non-splat vector constant: check each element for a match.
        unsigned NumElts = cast<VectorType>(V->getType())->getNumElements();
        assert(NumElts != 0 && "Constant vector with no elements?");
        bool HasNonUndefElements = false;
        for (unsigned i = 0; i != NumElts; ++i) {
          Constant *Elt = C->getAggregateElement(i);
          if (!Elt)
            return false;
          if (isa<UndefValue>(Elt))
            continue;
          auto *CF = dyn_cast<ConstantFP>(Elt);
          if (!CF || !this->isValue(CF->getValueAPF()))
            return false;
          HasNonUndefElements = true;
        }
        return HasNonUndefElements;
      }
    }
    return false;
  }
};

///////////////////////////////////////////////////////////////////////////////
//
// Encapsulate constant value queries for use in templated predicate matchers.
+9 −0
Original line number Diff line number Diff line
@@ -1164,3 +1164,12 @@ define double @fmul_sqrt_select(double %x, i1 %c) {
  %mul = fmul fast double %sqr, %sel
  ret double %mul
}

; fastmath => z * splat(0) = splat(0), even for scalable vectors
define <vscale x 2 x float> @mul_scalable_splat_zero(<vscale x 2 x float> %z) {
; CHECK-LABEL: @mul_scalable_splat_zero(
; CHECK-NEXT:    ret <vscale x 2 x float> zeroinitializer
  %shuf = shufflevector <vscale x 2 x float> insertelement (<vscale x 2 x float> undef, float 0.0, i32 0), <vscale x 2 x float> undef, <vscale x 2 x i32> zeroinitializer
  %t3 = fmul fast <vscale x 2 x float> %shuf, %z
  ret <vscale x 2 x float> %t3
}
+9 −0
Original line number Diff line number Diff line
@@ -857,3 +857,12 @@ define <4 x i32> @combine_mul_nabs_v4i32(<4 x i32> %0) {
  %m = mul <4 x i32> %r, %r
  ret <4 x i32> %m
}

; z * splat(0) = splat(0), even for scalable vectors
define <vscale x 2 x i64> @mul_scalable_splat_zero(<vscale x 2 x i64> %z) {
; CHECK-LABEL: @mul_scalable_splat_zero(
; CHECK-NEXT:    ret <vscale x 2 x i64> zeroinitializer
  %shuf = shufflevector <vscale x 2 x i64> insertelement (<vscale x 2 x i64> undef, i64 0, i32 0), <vscale x 2 x i64> undef, <vscale x 2 x i32> zeroinitializer
  %t3 = mul <vscale x 2 x i64> %shuf, %z
  ret <vscale x 2 x i64> %t3
}
+177 −0
Original line number Diff line number Diff line
@@ -1325,6 +1325,183 @@ TEST_F(PatternMatchTest, IntrinsicMatcher) {
                            m_SpecificInt(10))));
}

namespace {

struct is_unsigned_zero_pred {
  bool isValue(const APInt &C) { return C.isNullValue(); }
};

struct is_float_zero_pred {
  bool isValue(const APFloat &C) { return C.isZero(); }
};

template <typename T> struct always_true_pred {
  bool isValue(const T &) { return true; }
};

template <typename T> struct always_false_pred {
  bool isValue(const T &) { return false; }
};

struct is_unsigned_max_pred {
  bool isValue(const APInt &C) { return C.isMaxValue(); }
};

struct is_float_nan_pred {
  bool isValue(const APFloat &C) { return C.isNaN(); }
};

} // namespace

TEST_F(PatternMatchTest, ConstantPredicateType) {

  // Scalar integer
  APInt U32Max = APInt::getAllOnesValue(32);
  APInt U32Zero = APInt::getNullValue(32);
  APInt U32DeadBeef(32, 0xDEADBEEF);

  Type *U32Ty = Type::getInt32Ty(Ctx);

  Constant *CU32Max = Constant::getIntegerValue(U32Ty, U32Max);
  Constant *CU32Zero = Constant::getIntegerValue(U32Ty, U32Zero);
  Constant *CU32DeadBeef = Constant::getIntegerValue(U32Ty, U32DeadBeef);

  EXPECT_TRUE(match(CU32Max, cst_pred_ty<is_unsigned_max_pred>()));
  EXPECT_FALSE(match(CU32Max, cst_pred_ty<is_unsigned_zero_pred>()));
  EXPECT_TRUE(match(CU32Max, cst_pred_ty<always_true_pred<APInt>>()));
  EXPECT_FALSE(match(CU32Max, cst_pred_ty<always_false_pred<APInt>>()));

  EXPECT_FALSE(match(CU32Zero, cst_pred_ty<is_unsigned_max_pred>()));
  EXPECT_TRUE(match(CU32Zero, cst_pred_ty<is_unsigned_zero_pred>()));
  EXPECT_TRUE(match(CU32Zero, cst_pred_ty<always_true_pred<APInt>>()));
  EXPECT_FALSE(match(CU32Zero, cst_pred_ty<always_false_pred<APInt>>()));

  EXPECT_FALSE(match(CU32DeadBeef, cst_pred_ty<is_unsigned_max_pred>()));
  EXPECT_FALSE(match(CU32DeadBeef, cst_pred_ty<is_unsigned_zero_pred>()));
  EXPECT_TRUE(match(CU32DeadBeef, cst_pred_ty<always_true_pred<APInt>>()));
  EXPECT_FALSE(match(CU32DeadBeef, cst_pred_ty<always_false_pred<APInt>>()));

  // Scalar float
  APFloat F32NaN = APFloat::getNaN(APFloat::IEEEsingle());
  APFloat F32Zero = APFloat::getZero(APFloat::IEEEsingle());
  APFloat F32Pi(3.14f);

  Type *F32Ty = Type::getFloatTy(Ctx);

  Constant *CF32NaN = ConstantFP::get(F32Ty, F32NaN);
  Constant *CF32Zero = ConstantFP::get(F32Ty, F32Zero);
  Constant *CF32Pi = ConstantFP::get(F32Ty, F32Pi);

  EXPECT_TRUE(match(CF32NaN, cstfp_pred_ty<is_float_nan_pred>()));
  EXPECT_FALSE(match(CF32NaN, cstfp_pred_ty<is_float_zero_pred>()));
  EXPECT_TRUE(match(CF32NaN, cstfp_pred_ty<always_true_pred<APFloat>>()));
  EXPECT_FALSE(match(CF32NaN, cstfp_pred_ty<always_false_pred<APFloat>>()));

  EXPECT_FALSE(match(CF32Zero, cstfp_pred_ty<is_float_nan_pred>()));
  EXPECT_TRUE(match(CF32Zero, cstfp_pred_ty<is_float_zero_pred>()));
  EXPECT_TRUE(match(CF32Zero, cstfp_pred_ty<always_true_pred<APFloat>>()));
  EXPECT_FALSE(match(CF32Zero, cstfp_pred_ty<always_false_pred<APFloat>>()));

  EXPECT_FALSE(match(CF32Pi, cstfp_pred_ty<is_float_nan_pred>()));
  EXPECT_FALSE(match(CF32Pi, cstfp_pred_ty<is_float_zero_pred>()));
  EXPECT_TRUE(match(CF32Pi, cstfp_pred_ty<always_true_pred<APFloat>>()));
  EXPECT_FALSE(match(CF32Pi, cstfp_pred_ty<always_false_pred<APFloat>>()));

  ElementCount FixedEC(4, false);
  ElementCount ScalableEC(4, true);

  // Vector splat

  for (auto EC : {FixedEC, ScalableEC}) {
    // integer

    Constant *CSplatU32Max = ConstantVector::getSplat(EC, CU32Max);
    Constant *CSplatU32Zero = ConstantVector::getSplat(EC, CU32Zero);
    Constant *CSplatU32DeadBeef = ConstantVector::getSplat(EC, CU32DeadBeef);

    EXPECT_TRUE(match(CSplatU32Max, cst_pred_ty<is_unsigned_max_pred>()));
    EXPECT_FALSE(match(CSplatU32Max, cst_pred_ty<is_unsigned_zero_pred>()));
    EXPECT_TRUE(match(CSplatU32Max, cst_pred_ty<always_true_pred<APInt>>()));
    EXPECT_FALSE(match(CSplatU32Max, cst_pred_ty<always_false_pred<APInt>>()));

    EXPECT_FALSE(match(CSplatU32Zero, cst_pred_ty<is_unsigned_max_pred>()));
    EXPECT_TRUE(match(CSplatU32Zero, cst_pred_ty<is_unsigned_zero_pred>()));
    EXPECT_TRUE(match(CSplatU32Zero, cst_pred_ty<always_true_pred<APInt>>()));
    EXPECT_FALSE(match(CSplatU32Zero, cst_pred_ty<always_false_pred<APInt>>()));

    EXPECT_FALSE(match(CSplatU32DeadBeef, cst_pred_ty<is_unsigned_max_pred>()));
    EXPECT_FALSE(
        match(CSplatU32DeadBeef, cst_pred_ty<is_unsigned_zero_pred>()));
    EXPECT_TRUE(
        match(CSplatU32DeadBeef, cst_pred_ty<always_true_pred<APInt>>()));
    EXPECT_FALSE(
        match(CSplatU32DeadBeef, cst_pred_ty<always_false_pred<APInt>>()));

    // float

    Constant *CSplatF32NaN = ConstantVector::getSplat(EC, CF32NaN);
    Constant *CSplatF32Zero = ConstantVector::getSplat(EC, CF32Zero);
    Constant *CSplatF32Pi = ConstantVector::getSplat(EC, CF32Pi);

    EXPECT_TRUE(match(CSplatF32NaN, cstfp_pred_ty<is_float_nan_pred>()));
    EXPECT_FALSE(match(CSplatF32NaN, cstfp_pred_ty<is_float_zero_pred>()));
    EXPECT_TRUE(
        match(CSplatF32NaN, cstfp_pred_ty<always_true_pred<APFloat>>()));
    EXPECT_FALSE(
        match(CSplatF32NaN, cstfp_pred_ty<always_false_pred<APFloat>>()));

    EXPECT_FALSE(match(CSplatF32Zero, cstfp_pred_ty<is_float_nan_pred>()));
    EXPECT_TRUE(match(CSplatF32Zero, cstfp_pred_ty<is_float_zero_pred>()));
    EXPECT_TRUE(
        match(CSplatF32Zero, cstfp_pred_ty<always_true_pred<APFloat>>()));
    EXPECT_FALSE(
        match(CSplatF32Zero, cstfp_pred_ty<always_false_pred<APFloat>>()));

    EXPECT_FALSE(match(CSplatF32Pi, cstfp_pred_ty<is_float_nan_pred>()));
    EXPECT_FALSE(match(CSplatF32Pi, cstfp_pred_ty<is_float_zero_pred>()));
    EXPECT_TRUE(match(CSplatF32Pi, cstfp_pred_ty<always_true_pred<APFloat>>()));
    EXPECT_FALSE(
        match(CSplatF32Pi, cstfp_pred_ty<always_false_pred<APFloat>>()));
  }

  // Int arbitrary vector

  Constant *CMixedU32 = ConstantVector::get({CU32Max, CU32Zero, CU32DeadBeef});
  Constant *CU32Undef = UndefValue::get(U32Ty);
  Constant *CU32MaxWithUndef =
      ConstantVector::get({CU32Undef, CU32Max, CU32Undef});

  EXPECT_FALSE(match(CMixedU32, cst_pred_ty<is_unsigned_max_pred>()));
  EXPECT_FALSE(match(CMixedU32, cst_pred_ty<is_unsigned_zero_pred>()));
  EXPECT_TRUE(match(CMixedU32, cst_pred_ty<always_true_pred<APInt>>()));
  EXPECT_FALSE(match(CMixedU32, cst_pred_ty<always_false_pred<APInt>>()));

  EXPECT_TRUE(match(CU32MaxWithUndef, cst_pred_ty<is_unsigned_max_pred>()));
  EXPECT_FALSE(match(CU32MaxWithUndef, cst_pred_ty<is_unsigned_zero_pred>()));
  EXPECT_TRUE(match(CU32MaxWithUndef, cst_pred_ty<always_true_pred<APInt>>()));
  EXPECT_FALSE(
      match(CU32MaxWithUndef, cst_pred_ty<always_false_pred<APInt>>()));

  // Float arbitrary vector

  Constant *CMixedF32 = ConstantVector::get({CF32NaN, CF32Zero, CF32Pi});
  Constant *CF32Undef = UndefValue::get(F32Ty);
  Constant *CF32NaNWithUndef =
      ConstantVector::get({CF32Undef, CF32NaN, CF32Undef});

  EXPECT_FALSE(match(CMixedF32, cstfp_pred_ty<is_float_nan_pred>()));
  EXPECT_FALSE(match(CMixedF32, cstfp_pred_ty<is_float_zero_pred>()));
  EXPECT_TRUE(match(CMixedF32, cstfp_pred_ty<always_true_pred<APFloat>>()));
  EXPECT_FALSE(match(CMixedF32, cstfp_pred_ty<always_false_pred<APFloat>>()));

  EXPECT_TRUE(match(CF32NaNWithUndef, cstfp_pred_ty<is_float_nan_pred>()));
  EXPECT_FALSE(match(CF32NaNWithUndef, cstfp_pred_ty<is_float_zero_pred>()));
  EXPECT_TRUE(
      match(CF32NaNWithUndef, cstfp_pred_ty<always_true_pred<APFloat>>()));
  EXPECT_FALSE(
      match(CF32NaNWithUndef, cstfp_pred_ty<always_false_pred<APFloat>>()));
}

template <typename T> struct MutableConstTest : PatternMatchTest { };

typedef ::testing::Types<std::tuple<Value*, Instruction*>,