Commit c8a14c2d authored by Sanjay Patel's avatar Sanjay Patel
Browse files

[IR] fix potential crash in Constant::isElementWiseEqual()

There's only one user of this API currently, and it seems
impossible that it would compare values with different types.

But that's not true in general, so we need to make sure the
types are the same.

As denoted by the FIXME comments, we will also crash on FP
values. That's what brought me here, but we can make that a
follow-up patch.
parent cfd366ba
Loading
Loading
Loading
Loading
+9 −4
Original line number Diff line number Diff line
@@ -280,12 +280,17 @@ bool Constant::isElementWiseEqual(Value *Y) const {
  // Are they fully identical?
  if (this == Y)
    return true;
  // They may still be identical element-wise (if they have `undef`s).
  auto *Cy = dyn_cast<Constant>(Y);
  if (!Cy)

  // The input value must be a vector constant with the same type.
  Type *Ty = getType();
  if (!isa<Constant>(Y) || !Ty->isVectorTy() || Ty != Y->getType())
    return false;

  // They may still be identical element-wise (if they have `undef`s).
  // FIXME: This crashes on FP vector constants.
  return match(ConstantExpr::getICmp(ICmpInst::Predicate::ICMP_EQ,
                                     const_cast<Constant *>(this), Cy),
                                     const_cast<Constant *>(this),
                                     cast<Constant>(Y)),
               m_One());
}

+39 −0
Original line number Diff line number Diff line
@@ -585,5 +585,44 @@ TEST(ConstantsTest, FoldGlobalVariablePtr) {
      Instruction::And, TheConstantExpr, TheConstant)->isNullValue());
}

// Check that undefined elements in vector constants are matched
// correctly for both integer and floating-point types.

TEST(ConstantsTest, isElementWiseEqual) {
  LLVMContext Context;

  Type *Int32Ty = Type::getInt32Ty(Context);
  Constant *CU = UndefValue::get(Int32Ty);
  Constant *C1 = ConstantInt::get(Int32Ty, 1);
  Constant *C2 = ConstantInt::get(Int32Ty, 2);

  Constant *C1211 = ConstantVector::get({C1, C2, C1, C1});
  Constant *C12U1 = ConstantVector::get({C1, C2, CU, C1});
  Constant *C12U2 = ConstantVector::get({C1, C2, CU, C2});
  Constant *C12U21 = ConstantVector::get({C1, C2, CU, C2, C1});

  EXPECT_TRUE(C1211->isElementWiseEqual(C12U1));
  EXPECT_TRUE(C12U1->isElementWiseEqual(C1211));
  EXPECT_FALSE(C12U2->isElementWiseEqual(C12U1));
  EXPECT_FALSE(C12U1->isElementWiseEqual(C12U2));
  EXPECT_FALSE(C12U21->isElementWiseEqual(C12U2));

/* FIXME: This will crash.
  Type *FltTy = Type::getFloatTy(Context);
  Constant *CFU = UndefValue::get(FltTy);
  Constant *CF1 = ConstantFP::get(FltTy, 1.0);
  Constant *CF2 = ConstantFP::get(FltTy, 2.0);

  Constant *CF1211 = ConstantVector::get({CF1, CF2, CF1, CF1});
  Constant *CF12U1 = ConstantVector::get({CF1, CF2, CFU, CF1});
  Constant *CF12U2 = ConstantVector::get({CF1, CF2, CFU, CF2});

  EXPECT_TRUE(CF1211->isElementWiseEqual(CF12U1));
  EXPECT_TRUE(CF12U1->isElementWiseEqual(CF1211));
  EXPECT_FALSE(CF12U2->isElementWiseEqual(CF12U1));
  EXPECT_FALSE(CF12U1->isElementWiseEqual(CF12U2));
*/
}

}  // end anonymous namespace
}  // end namespace llvm