Commit e9e1d475 authored by Wang, Pengfei's avatar Wang, Pengfei
Browse files

[X86] Refactor GetSSETypeAtOffset to fix pr51813

D105263 adds support for _Float16 type. It introduced a bug (pr51813) that generates a <4 x half> type instead the default double when passing blank structure by SSE registers.

Although I doubt it may expose a bug somewhere other than D105263, it's good to avoid return half type when no half type in arguments.

Reviewed By: LuoYuanke

Differential Revision: https://reviews.llvm.org/D109607
parent aaf00f3f
Loading
Loading
Loading
Loading
+39 −75
Original line number Diff line number Diff line
@@ -3407,52 +3407,18 @@ static bool BitsContainNoUserData(QualType Ty, unsigned StartBit,
  return false;
}

/// ContainsFloatAtOffset - Return true if the specified LLVM IR type has a
/// float member at the specified offset.  For example, {int,{float}} has a
/// float at offset 4.  It is conservatively correct for this routine to return
/// false.
static bool ContainsFloatAtOffset(llvm::Type *IRType, unsigned IROffset,
/// getFPTypeAtOffset - Return a floating point type at the specified offset.
static llvm::Type *getFPTypeAtOffset(llvm::Type *IRType, unsigned IROffset,
                                     const llvm::DataLayout &TD) {
  // Base case if we find a float.
  if (IROffset == 0 && IRType->isFloatTy())
    return true;

  // If this is a struct, recurse into the field at the specified offset.
  if (llvm::StructType *STy = dyn_cast<llvm::StructType>(IRType)) {
    const llvm::StructLayout *SL = TD.getStructLayout(STy);
    unsigned Elt = SL->getElementContainingOffset(IROffset);
    IROffset -= SL->getElementOffset(Elt);
    return ContainsFloatAtOffset(STy->getElementType(Elt), IROffset, TD);
  }

  // If this is an array, recurse into the field at the specified offset.
  if (llvm::ArrayType *ATy = dyn_cast<llvm::ArrayType>(IRType)) {
    llvm::Type *EltTy = ATy->getElementType();
    unsigned EltSize = TD.getTypeAllocSize(EltTy);
    IROffset -= IROffset/EltSize*EltSize;
    return ContainsFloatAtOffset(EltTy, IROffset, TD);
  }

  return false;
}

/// ContainsHalfAtOffset - Return true if the specified LLVM IR type has a
/// half member at the specified offset.  For example, {int,{half}} has a
/// half at offset 4.  It is conservatively correct for this routine to return
/// false.
/// FIXME: Merge with ContainsFloatAtOffset
static bool ContainsHalfAtOffset(llvm::Type *IRType, unsigned IROffset,
                                 const llvm::DataLayout &TD) {
  // Base case if we find a float.
  if (IROffset == 0 && IRType->isHalfTy())
    return true;
  if (IROffset == 0 && IRType->isFloatingPointTy())
    return IRType;

  // If this is a struct, recurse into the field at the specified offset.
  if (llvm::StructType *STy = dyn_cast<llvm::StructType>(IRType)) {
    const llvm::StructLayout *SL = TD.getStructLayout(STy);
    unsigned Elt = SL->getElementContainingOffset(IROffset);
    IROffset -= SL->getElementOffset(Elt);
    return ContainsHalfAtOffset(STy->getElementType(Elt), IROffset, TD);
    return getFPTypeAtOffset(STy->getElementType(Elt), IROffset, TD);
  }

  // If this is an array, recurse into the field at the specified offset.
@@ -3460,10 +3426,10 @@ static bool ContainsHalfAtOffset(llvm::Type *IRType, unsigned IROffset,
    llvm::Type *EltTy = ATy->getElementType();
    unsigned EltSize = TD.getTypeAllocSize(EltTy);
    IROffset -= IROffset / EltSize * EltSize;
    return ContainsHalfAtOffset(EltTy, IROffset, TD);
    return getFPTypeAtOffset(EltTy, IROffset, TD);
  }

  return false;
  return nullptr;
}

/// GetSSETypeAtOffset - Return a type that will be passed by the backend in the
@@ -3471,39 +3437,37 @@ static bool ContainsHalfAtOffset(llvm::Type *IRType, unsigned IROffset,
llvm::Type *X86_64ABIInfo::
GetSSETypeAtOffset(llvm::Type *IRType, unsigned IROffset,
                   QualType SourceTy, unsigned SourceOffset) const {
  // If the high 32 bits are not used, we have three choices. Single half,
  // single float or two halfs.
  if (BitsContainNoUserData(SourceTy, SourceOffset * 8 + 32,
                            SourceOffset * 8 + 64, getContext())) {
    if (ContainsFloatAtOffset(IRType, IROffset, getDataLayout()))
      return llvm::Type::getFloatTy(getVMContext());
    if (ContainsHalfAtOffset(IRType, IROffset + 2, getDataLayout()))
      return llvm::FixedVectorType::get(llvm::Type::getHalfTy(getVMContext()),
                                        2);

    return llvm::Type::getHalfTy(getVMContext());
  }

  // We want to pass as <2 x float> if the LLVM IR type contains a float at
  // offset+0 and offset+4. Walk the LLVM IR type to find out if this is the
  // case.
  if (ContainsFloatAtOffset(IRType, IROffset, getDataLayout()) &&
      ContainsFloatAtOffset(IRType, IROffset + 4, getDataLayout()))
    return llvm::FixedVectorType::get(llvm::Type::getFloatTy(getVMContext()),
                                      2);

  // We want to pass as <4 x half> if the LLVM IR type contains a half at
  // offset+0, +2, +4. Walk the LLVM IR type to find out if this is the case.
  if (ContainsHalfAtOffset(IRType, IROffset, getDataLayout()) &&
      ContainsHalfAtOffset(IRType, IROffset + 2, getDataLayout()) &&
      ContainsHalfAtOffset(IRType, IROffset + 4, getDataLayout()))
    return llvm::FixedVectorType::get(llvm::Type::getHalfTy(getVMContext()), 4);
  const llvm::DataLayout &TD = getDataLayout();
  llvm::Type *T0 = getFPTypeAtOffset(IRType, IROffset, TD);
  if (!T0 || T0->isDoubleTy())
    return llvm::Type::getDoubleTy(getVMContext());

  // Get the adjacent FP type.
  llvm::Type *T1 =
      getFPTypeAtOffset(IRType, IROffset + TD.getTypeAllocSize(T0), TD);
  if (T1 == nullptr) {
    // Check if IRType is a half + float. float type will be in IROffset+4 due
    // to its alignment.
    if (T0->isHalfTy())
      T1 = getFPTypeAtOffset(IRType, IROffset + 4, TD);
    // If we can't get a second FP type, return a simple half or float.
    // avx512fp16-abi.c:pr51813_2 shows it works to return float for
    // {float, i8} too.
    if (T1 == nullptr)
      return T0;
  }

  if (T0->isFloatTy() && T1->isFloatTy())
    return llvm::FixedVectorType::get(T0, 2);

  if (T0->isHalfTy() && T1->isHalfTy()) {
    llvm::Type *T2 = getFPTypeAtOffset(IRType, IROffset + 4, TD);
    if (T2 == nullptr)
      return llvm::FixedVectorType::get(T0, 2);
    return llvm::FixedVectorType::get(T0, 4);
  }

  // We want to pass as <4 x half> if the LLVM IR type contains a mix of float
  // and half.
  // FIXME: Do we have a better representation for the mixed type?
  if (ContainsFloatAtOffset(IRType, IROffset, getDataLayout()) ||
      ContainsFloatAtOffset(IRType, IROffset + 4, getDataLayout()))
  if (T0->isHalfTy() || T1->isHalfTy())
    return llvm::FixedVectorType::get(llvm::Type::getHalfTy(getVMContext()), 4);

  return llvm::Type::getDoubleTy(getVMContext());
+61 −11
Original line number Diff line number Diff line
// RUN: %clang_cc1 -triple x86_64-linux -emit-llvm  -target-feature +avx512fp16 < %s | FileCheck %s --check-prefixes=CHECK
// RUN: %clang_cc1 -triple x86_64-linux -emit-llvm  -target-feature +avx512fp16 < %s | FileCheck %s --check-prefixes=CHECK,CHECK-C
// RUN: %clang_cc1 -triple x86_64-linux -emit-llvm  -target-feature +avx512fp16 -x c++ -std=c++11 < %s | FileCheck %s --check-prefixes=CHECK,CHECK-CPP

struct half1 {
  _Float16 a;
};

struct half1 h1(_Float16 a) {
  // CHECK: define{{.*}}half @h1
  // CHECK: define{{.*}}half @
  struct half1 x;
  x.a = a;
  return x;
@@ -17,7 +18,7 @@ struct half2 {
};

struct half2 h2(_Float16 a, _Float16 b) {
  // CHECK: define{{.*}}<2 x half> @h2
  // CHECK: define{{.*}}<2 x half> @
  struct half2 x;
  x.a = a;
  x.b = b;
@@ -31,7 +32,7 @@ struct half3 {
};

struct half3 h3(_Float16 a, _Float16 b, _Float16 c) {
  // CHECK: define{{.*}}<4 x half> @h3
  // CHECK: define{{.*}}<4 x half> @
  struct half3 x;
  x.a = a;
  x.b = b;
@@ -47,7 +48,7 @@ struct half4 {
};

struct half4 h4(_Float16 a, _Float16 b, _Float16 c, _Float16 d) {
  // CHECK: define{{.*}}<4 x half> @h4
  // CHECK: define{{.*}}<4 x half> @
  struct half4 x;
  x.a = a;
  x.b = b;
@@ -62,7 +63,7 @@ struct floathalf {
};

struct floathalf fh(float a, _Float16 b) {
  // CHECK: define{{.*}}<4 x half> @fh
  // CHECK: define{{.*}}<4 x half> @
  struct floathalf x;
  x.a = a;
  x.b = b;
@@ -76,7 +77,7 @@ struct floathalf2 {
};

struct floathalf2 fh2(float a, _Float16 b, _Float16 c) {
  // CHECK: define{{.*}}<4 x half> @fh2
  // CHECK: define{{.*}}<4 x half> @
  struct floathalf2 x;
  x.a = a;
  x.b = b;
@@ -90,7 +91,7 @@ struct halffloat {
};

struct halffloat hf(_Float16 a, float b) {
  // CHECK: define{{.*}}<4 x half> @hf
  // CHECK: define{{.*}}<4 x half> @
  struct halffloat x;
  x.a = a;
  x.b = b;
@@ -104,7 +105,7 @@ struct half2float {
};

struct half2float h2f(_Float16 a, _Float16 b, float c) {
  // CHECK: define{{.*}}<4 x half> @h2f
  // CHECK: define{{.*}}<4 x half> @
  struct half2float x;
  x.a = a;
  x.b = b;
@@ -120,7 +121,7 @@ struct floathalf3 {
};

struct floathalf3 fh3(float a, _Float16 b, _Float16 c, _Float16 d) {
  // CHECK: define{{.*}}{ <4 x half>, half } @fh3
  // CHECK: define{{.*}}{ <4 x half>, half } @
  struct floathalf3 x;
  x.a = a;
  x.b = b;
@@ -138,7 +139,7 @@ struct half5 {
};

struct half5 h5(_Float16 a, _Float16 b, _Float16 c, _Float16 d, _Float16 e) {
  // CHECK: define{{.*}}{ <4 x half>, half } @h5
  // CHECK: define{{.*}}{ <4 x half>, half } @
  struct half5 x;
  x.a = a;
  x.b = b;
@@ -147,3 +148,52 @@ struct half5 h5(_Float16 a, _Float16 b, _Float16 c, _Float16 d, _Float16 e) {
  x.e = e;
  return x;
}

struct float2 {
  struct {} s;
  float a;
  float b;
};

float pr51813(struct float2 s) {
  // CHECK-C: define{{.*}} @pr51813(<2 x float>
  // CHECK-CPP: define{{.*}} @_Z7pr518136float2(double {{.*}}, float
  return s.a;
}

struct float3 {
  float a;
  struct {} s;
  float b;
};

float pr51813_2(struct float3 s) {
  // CHECK-C: define{{.*}} @pr51813_2(<2 x float>
  // CHECK-CPP: define{{.*}} @_Z9pr51813_26float3(double {{.*}}, float
  return s.a;
}

struct shalf2 {
  struct {} s;
  _Float16 a;
  _Float16 b;
};

_Float16 sf2(struct shalf2 s) {
  // CHECK-C: define{{.*}} @sf2(<2 x half>
  // CHECK-CPP: define{{.*}} @_Z3sf26shalf2(double {{.*}}
  return s.a;
};

struct halfs2 {
  _Float16 a;
  struct {} s1;
  _Float16 b;
  struct {} s2;
};

_Float16 fs2(struct shalf2 s) {
  // CHECK-C: define{{.*}} @fs2(<2 x half>
  // CHECK-CPP: define{{.*}} @_Z3fs26shalf2(double {{.*}}
  return s.a;
};