Unverified Commit 81751905 authored by Michal Paszkowski's avatar Michal Paszkowski Committed by GitHub
Browse files

[SPIR-V] Emit proper pointer type for OpenCL kernel arguments (#67726)

parent b858309d
Loading
Loading
Loading
Loading
+52 −55
Original line number Diff line number Diff line
@@ -2010,60 +2010,6 @@ static Type *parseTypeString(const StringRef Name, LLVMContext &Context) {
  llvm_unreachable("Unable to recognize type!");
}

static const TargetExtType *parseToTargetExtType(const Type *OpaqueType,
                                                 MachineIRBuilder &MIRBuilder) {
  assert(isSpecialOpaqueType(OpaqueType) &&
         "Not a SPIR-V/OpenCL special opaque type!");
  assert(!OpaqueType->isTargetExtTy() &&
         "This already is SPIR-V/OpenCL TargetExtType!");

  StringRef NameWithParameters = OpaqueType->getStructName();

  // Pointers-to-opaque-structs representing OpenCL types are first translated
  // to equivalent SPIR-V types. OpenCL builtin type names should have the
  // following format: e.g. %opencl.event_t
  if (NameWithParameters.startswith("opencl.")) {
    const SPIRV::OpenCLType *OCLTypeRecord =
        SPIRV::lookupOpenCLType(NameWithParameters);
    if (!OCLTypeRecord)
      report_fatal_error("Missing TableGen record for OpenCL type: " +
                         NameWithParameters);
    NameWithParameters = OCLTypeRecord->SpirvTypeLiteral;
    // Continue with the SPIR-V builtin type...
  }

  // Names of the opaque structs representing a SPIR-V builtins without
  // parameters should have the following format: e.g. %spirv.Event
  assert(NameWithParameters.startswith("spirv.") &&
         "Unknown builtin opaque type!");

  // Parameterized SPIR-V builtins names follow this format:
  // e.g. %spirv.Image._void_1_0_0_0_0_0_0, %spirv.Pipe._0
  if (NameWithParameters.find('_') == std::string::npos)
    return TargetExtType::get(OpaqueType->getContext(), NameWithParameters);

  SmallVector<StringRef> Parameters;
  unsigned BaseNameLength = NameWithParameters.find('_') - 1;
  SplitString(NameWithParameters.substr(BaseNameLength + 1), Parameters, "_");

  SmallVector<Type *, 1> TypeParameters;
  bool HasTypeParameter = !isDigit(Parameters[0][0]);
  if (HasTypeParameter)
    TypeParameters.push_back(parseTypeString(
        Parameters[0], MIRBuilder.getMF().getFunction().getContext()));
  SmallVector<unsigned> IntParameters;
  for (unsigned i = HasTypeParameter ? 1 : 0; i < Parameters.size(); i++) {
    unsigned IntParameter = 0;
    bool ValidLiteral = !Parameters[i].getAsInteger(10, IntParameter);
    assert(ValidLiteral &&
           "Invalid format of SPIR-V builtin parameter literal!");
    IntParameters.push_back(IntParameter);
  }
  return TargetExtType::get(OpaqueType->getContext(),
                            NameWithParameters.substr(0, BaseNameLength),
                            TypeParameters, IntParameters);
}

//===----------------------------------------------------------------------===//
// Implementation functions for builtin types.
//===----------------------------------------------------------------------===//
@@ -2127,6 +2073,56 @@ static SPIRVType *getSampledImageType(const TargetExtType *OpaqueType,
}

namespace SPIRV {
const TargetExtType *
parseBuiltinTypeNameToTargetExtType(std::string TypeName,
                                    MachineIRBuilder &MIRBuilder) {
  StringRef NameWithParameters = TypeName;

  // Pointers-to-opaque-structs representing OpenCL types are first translated
  // to equivalent SPIR-V types. OpenCL builtin type names should have the
  // following format: e.g. %opencl.event_t
  if (NameWithParameters.startswith("opencl.")) {
    const SPIRV::OpenCLType *OCLTypeRecord =
        SPIRV::lookupOpenCLType(NameWithParameters);
    if (!OCLTypeRecord)
      report_fatal_error("Missing TableGen record for OpenCL type: " +
                         NameWithParameters);
    NameWithParameters = OCLTypeRecord->SpirvTypeLiteral;
    // Continue with the SPIR-V builtin type...
  }

  // Names of the opaque structs representing a SPIR-V builtins without
  // parameters should have the following format: e.g. %spirv.Event
  assert(NameWithParameters.startswith("spirv.") &&
         "Unknown builtin opaque type!");

  // Parameterized SPIR-V builtins names follow this format:
  // e.g. %spirv.Image._void_1_0_0_0_0_0_0, %spirv.Pipe._0
  if (NameWithParameters.find('_') == std::string::npos)
    return TargetExtType::get(MIRBuilder.getContext(), NameWithParameters);

  SmallVector<StringRef> Parameters;
  unsigned BaseNameLength = NameWithParameters.find('_') - 1;
  SplitString(NameWithParameters.substr(BaseNameLength + 1), Parameters, "_");

  SmallVector<Type *, 1> TypeParameters;
  bool HasTypeParameter = !isDigit(Parameters[0][0]);
  if (HasTypeParameter)
    TypeParameters.push_back(parseTypeString(
        Parameters[0], MIRBuilder.getMF().getFunction().getContext()));
  SmallVector<unsigned> IntParameters;
  for (unsigned i = HasTypeParameter ? 1 : 0; i < Parameters.size(); i++) {
    unsigned IntParameter = 0;
    bool ValidLiteral = !Parameters[i].getAsInteger(10, IntParameter);
    assert(ValidLiteral &&
           "Invalid format of SPIR-V builtin parameter literal!");
    IntParameters.push_back(IntParameter);
  }
  return TargetExtType::get(MIRBuilder.getContext(),
                            NameWithParameters.substr(0, BaseNameLength),
                            TypeParameters, IntParameters);
}

SPIRVType *lowerBuiltinType(const Type *OpaqueType,
                            SPIRV::AccessQualifier::AccessQualifier AccessQual,
                            MachineIRBuilder &MIRBuilder,
@@ -2141,7 +2137,8 @@ SPIRVType *lowerBuiltinType(const Type *OpaqueType,
  // will be removed in the future release of LLVM.
  const TargetExtType *BuiltinType = dyn_cast<TargetExtType>(OpaqueType);
  if (!BuiltinType)
    BuiltinType = parseToTargetExtType(OpaqueType, MIRBuilder);
    BuiltinType = parseBuiltinTypeNameToTargetExtType(
        OpaqueType->getStructName().str(), MIRBuilder);

  unsigned NumStartingVRegs = MIRBuilder.getMRI()->getNumVirtRegs();

+12 −0
Original line number Diff line number Diff line
@@ -37,6 +37,18 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
                                 const Register OrigRet, const Type *OrigRetTy,
                                 const SmallVectorImpl<Register> &Args,
                                 SPIRVGlobalRegistry *GR);

/// Translates a string representing a SPIR-V or OpenCL builtin type to a
/// TargetExtType that can be further lowered with lowerBuiltinType().
///
/// \return A TargetExtType representing the builtin SPIR-V type.
///
/// \p TypeName is the full string representation of the SPIR-V or OpenCL
/// builtin type.
const TargetExtType *
parseBuiltinTypeNameToTargetExtType(std::string TypeName,
                                    MachineIRBuilder &MIRBuilder);

/// Handles the translation of the provided special opaque/builtin type \p Type
/// to SPIR-V type. Generates the corresponding machine instructions for the
/// target type or gets the already existing OpType<...> register from the
+28 −15
Original line number Diff line number Diff line
@@ -194,23 +194,38 @@ getKernelArgTypeQual(const Function &KernelFunction, unsigned ArgIdx) {
  return {};
}

static Type *getArgType(const Function &F, unsigned ArgIdx) {
static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
                                  SPIRVGlobalRegistry *GR,
                                  MachineIRBuilder &MIRBuilder) {
  // Read argument's access qualifier from metadata or default.
  SPIRV::AccessQualifier::AccessQualifier ArgAccessQual =
      getArgAccessQual(F, ArgIdx);

  Type *OriginalArgType = getOriginalFunctionType(F)->getParamType(ArgIdx);

  // In case of non-kernel SPIR-V function or already TargetExtType, use the
  // original IR type.
  if (F.getCallingConv() != CallingConv::SPIR_KERNEL ||
      isSpecialOpaqueType(OriginalArgType))
    return OriginalArgType;
    return GR->getOrCreateSPIRVType(OriginalArgType, MIRBuilder, ArgAccessQual);

  MDString *MDKernelArgType =
      getKernelArgAttribute(F, ArgIdx, "kernel_arg_type");
  if (!MDKernelArgType || !MDKernelArgType->getString().endswith("_t"))
    return OriginalArgType;
  if (!MDKernelArgType || (MDKernelArgType->getString().ends_with("*") &&
                           MDKernelArgType->getString().ends_with("_t")))
    return GR->getOrCreateSPIRVType(OriginalArgType, MIRBuilder, ArgAccessQual);

  std::string KernelArgTypeStr = "opencl." + MDKernelArgType->getString().str();
  Type *ExistingOpaqueType =
      StructType::getTypeByName(F.getContext(), KernelArgTypeStr);
  return ExistingOpaqueType
             ? ExistingOpaqueType
             : StructType::create(F.getContext(), KernelArgTypeStr);
  if (MDKernelArgType->getString().ends_with("*"))
    return GR->getOrCreateSPIRVTypeByName(
        MDKernelArgType->getString(), MIRBuilder,
        addressSpaceToStorageClass(OriginalArgType->getPointerAddressSpace()));

  if (MDKernelArgType->getString().ends_with("_t"))
    return GR->getOrCreateSPIRVTypeByName(
        "opencl." + MDKernelArgType->getString().str(), MIRBuilder,
        SPIRV::StorageClass::Function, ArgAccessQual);

  llvm_unreachable("Unable to recognize argument type name.");
}

static bool isEntryPoint(const Function &F) {
@@ -262,10 +277,8 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
      // TODO: handle the case of multiple registers.
      if (VRegs[i].size() > 1)
        return false;
      SPIRV::AccessQualifier::AccessQualifier ArgAccessQual =
          getArgAccessQual(F, i);
      auto *SpirvTy = GR->assignTypeToVReg(getArgType(F, i), VRegs[i][0],
                                           MIRBuilder, ArgAccessQual);
      auto *SpirvTy = getArgSPIRVType(F, i, GR, MIRBuilder);
      GR->assignSPIRVTypeToVReg(SpirvTy, VRegs[i][0], MIRBuilder.getMF());
      ArgTypeVRegs.push_back(SpirvTy);

      if (Arg.hasName())
+55 −13
Original line number Diff line number Diff line
@@ -956,40 +956,82 @@ SPIRVGlobalRegistry::checkSpecialInstr(const SPIRV::SpecialTypeDescriptor &TD,
}

// TODO: maybe use tablegen to implement this.
SPIRVType *
SPIRVGlobalRegistry::getOrCreateSPIRVTypeByName(StringRef TypeStr,
                                                MachineIRBuilder &MIRBuilder) {
SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVTypeByName(
    StringRef TypeStr, MachineIRBuilder &MIRBuilder,
    SPIRV::StorageClass::StorageClass SC,
    SPIRV::AccessQualifier::AccessQualifier AQ) {
  unsigned VecElts = 0;
  auto &Ctx = MIRBuilder.getMF().getFunction().getContext();

  // Parse strings representing either a SPIR-V or OpenCL builtin type.
  if (hasBuiltinTypePrefix(TypeStr))
    return getOrCreateSPIRVType(
        SPIRV::parseBuiltinTypeNameToTargetExtType(TypeStr.str(), MIRBuilder),
        MIRBuilder, AQ);

  // Parse type name in either "typeN" or "type vector[N]" format, where
  // N is the number of elements of the vector.
  Type *Type;
  Type *Ty;

  if (TypeStr.starts_with("atomic_"))
    TypeStr = TypeStr.substr(strlen("atomic_"));

  if (TypeStr.startswith("void")) {
    Type = Type::getVoidTy(Ctx);
    Ty = Type::getVoidTy(Ctx);
    TypeStr = TypeStr.substr(strlen("void"));
  } else if (TypeStr.startswith("bool")) {
    Ty = Type::getIntNTy(Ctx, 1);
    TypeStr = TypeStr.substr(strlen("bool"));
  } else if (TypeStr.startswith("char") || TypeStr.startswith("uchar")) {
    Ty = Type::getInt8Ty(Ctx);
    TypeStr = TypeStr.startswith("char") ? TypeStr.substr(strlen("char"))
                                         : TypeStr.substr(strlen("uchar"));
  } else if (TypeStr.startswith("short") || TypeStr.startswith("ushort")) {
    Ty = Type::getInt16Ty(Ctx);
    TypeStr = TypeStr.startswith("short") ? TypeStr.substr(strlen("short"))
                                          : TypeStr.substr(strlen("ushort"));
  } else if (TypeStr.startswith("int") || TypeStr.startswith("uint")) {
    Type = Type::getInt32Ty(Ctx);
    Ty = Type::getInt32Ty(Ctx);
    TypeStr = TypeStr.startswith("int") ? TypeStr.substr(strlen("int"))
                                        : TypeStr.substr(strlen("uint"));
  } else if (TypeStr.startswith("float")) {
    Type = Type::getFloatTy(Ctx);
    TypeStr = TypeStr.substr(strlen("float"));
  } else if (TypeStr.starts_with("long") || TypeStr.starts_with("ulong")) {
    Ty = Type::getInt64Ty(Ctx);
    TypeStr = TypeStr.startswith("long") ? TypeStr.substr(strlen("long"))
                                         : TypeStr.substr(strlen("ulong"));
  } else if (TypeStr.startswith("half")) {
    Type = Type::getHalfTy(Ctx);
    Ty = Type::getHalfTy(Ctx);
    TypeStr = TypeStr.substr(strlen("half"));
  } else if (TypeStr.startswith("opencl.sampler_t")) {
    Type = StructType::create(Ctx, "opencl.sampler_t");
  } else if (TypeStr.startswith("float")) {
    Ty = Type::getFloatTy(Ctx);
    TypeStr = TypeStr.substr(strlen("float"));
  } else if (TypeStr.startswith("double")) {
    Ty = Type::getDoubleTy(Ctx);
    TypeStr = TypeStr.substr(strlen("double"));
  } else
    llvm_unreachable("Unable to recognize SPIRV type name.");

  auto SpirvTy = getOrCreateSPIRVType(Ty, MIRBuilder, AQ);

  // Handle "type*" or  "type* vector[N]".
  if (TypeStr.starts_with("*")) {
    SpirvTy = getOrCreateSPIRVPointerType(SpirvTy, MIRBuilder, SC);
    TypeStr = TypeStr.substr(strlen("*"));
  }

  // Handle "typeN*" or  "type vector[N]*".
  bool IsPtrToVec = TypeStr.consume_back("*");

  if (TypeStr.startswith(" vector[")) {
    TypeStr = TypeStr.substr(strlen(" vector["));
    TypeStr = TypeStr.substr(0, TypeStr.find(']'));
  }
  TypeStr.getAsInteger(10, VecElts);
  auto SpirvTy = getOrCreateSPIRVType(Type, MIRBuilder);
  if (VecElts > 0)
    SpirvTy = getOrCreateSPIRVVectorType(SpirvTy, VecElts, MIRBuilder);

  if (IsPtrToVec)
    SpirvTy = getOrCreateSPIRVPointerType(SpirvTy, MIRBuilder, SC);

  return SpirvTy;
}

+5 −2
Original line number Diff line number Diff line
@@ -138,8 +138,11 @@ public:

  // Either generate a new OpTypeXXX instruction or return an existing one
  // corresponding to the given string containing the name of the builtin type.
  SPIRVType *getOrCreateSPIRVTypeByName(StringRef TypeStr,
                                        MachineIRBuilder &MIRBuilder);
  SPIRVType *getOrCreateSPIRVTypeByName(
      StringRef TypeStr, MachineIRBuilder &MIRBuilder,
      SPIRV::StorageClass::StorageClass SC = SPIRV::StorageClass::Function,
      SPIRV::AccessQualifier::AccessQualifier AQ =
          SPIRV::AccessQualifier::ReadWrite);

  // Return the SPIR-V type instruction corresponding to the given VReg, or
  // nullptr if no such type instruction exists.
Loading