Unverified Commit 889708ee authored by Jan Leyonberg's avatar Jan Leyonberg Committed by GitHub
Browse files

[CIR][OpenMP] Enable emission of target functions (#193204)

This PR allows generation of target device functions for OpenMP. It also
handles filtering out host functions that do not contain target regions.

Assisted-by: Cursor / claude-4.6-opus-high
parent cf30e4b5
Loading
Loading
Loading
Loading
+41 −11
Original line number Diff line number Diff line
@@ -24,6 +24,7 @@
#include "clang/AST/DeclOpenACC.h"
#include "clang/AST/GlobalDecl.h"
#include "clang/AST/RecordLayout.h"
#include "clang/AST/StmtOpenMP.h"
#include "clang/Basic/DiagnosticFrontend.h"
#include "clang/Basic/SourceManager.h"
#include "clang/CIR/Dialect/IR/CIRAttrs.h"
@@ -150,6 +151,8 @@ CIRGenModule::CIRGenModule(mlir::MLIRContext &mlirContext,

  if (langOpts.CUDA)
    createCUDARuntime();
  if (langOpts.OpenMP)
    createOpenMPRuntime();

  // Set the module name to be the name of the main file. TranslationUnitDecl
  // often contains invalid source locations and isn't a reliable source for the
@@ -183,6 +186,10 @@ void CIRGenModule::createCUDARuntime() {
  cudaRuntime.reset(createNVCUDARuntime(*this));
}

void CIRGenModule::createOpenMPRuntime() {
  openMPRuntime = std::make_unique<CIRGenOpenMPRuntime>(*this);
}

/// FIXME: this could likely be a common helper and not necessarily related
/// with codegen.
/// Return the best known alignment for an unknown pointer to a
@@ -455,12 +462,6 @@ void CIRGenModule::emitGlobal(clang::GlobalDecl gd) {
    return;
  }

  // TODO(OMP): The logic in this function for the 'rest' of the OpenMP
  // declarative declarations is complicated and needs to be done on a per-kind
  // basis, so all of that needs to be added when we implement the individual
  // global-allowed declarations. See uses of `cir::MissingFeatures::openMP
  // throughout this function.

  const auto *global = cast<ValueDecl>(gd.getDecl());

  // If this is CUDA, be selective about which declarations we emit.
@@ -493,6 +494,22 @@ void CIRGenModule::emitGlobal(clang::GlobalDecl gd) {
      return;
  }

  if (langOpts.OpenMP) {
    // If this is OpenMP, check if it is legal to emit this global normally.
    if (openMPRuntime && openMPRuntime->emitTargetGlobal(gd))
      return;
    if (auto *drd = dyn_cast<OMPDeclareReductionDecl>(global)) {
      if (mustBeEmitted(global))
        emitOMPDeclareReduction(drd);
      return;
    }
    if (auto *dmd = dyn_cast<OMPDeclareMapperDecl>(global)) {
      if (mustBeEmitted(global))
        emitOMPDeclareMapper(dmd);
      return;
    }
  }

  if (const auto *fd = dyn_cast<FunctionDecl>(global)) {
    // Update deferred annotations with the latest declaration if the function
    // was already used or defined.
@@ -601,6 +618,9 @@ void CIRGenModule::emitGlobalFunctionDefinition(clang::GlobalDecl gd,

  if (funcDecl->getAttr<AnnotateAttr>())
    deferredAnnotations[getMangledName(gd)] = funcDecl;

  if (getLangOpts().OpenMP && funcDecl->hasAttr<OMPDeclareTargetDeclAttr>())
    getOpenMPRuntime().emitDeclareTargetFunction(funcDecl, funcOp);
}

/// Track functions to be called before main() runs.
@@ -2837,11 +2857,21 @@ cir::FuncOp CIRGenModule::getOrCreateCIRFunction(
  const Decl *d = gd.getDecl();

  if (const auto *fd = cast_or_null<FunctionDecl>(d)) {
    // For the device mark the function as one that should be emitted.
    if (getLangOpts().OpenMPIsTargetDevice && fd->isDefined() && !dontDefer &&
        !isForDefinition)
      errorNYI(fd->getSourceRange(),
               "getOrCreateCIRFunction: OpenMP target function");
    // For the device, mark the function as one that should be emitted.
    if (getLangOpts().OpenMPIsTargetDevice && openMPRuntime &&
        !getOpenMPRuntime().markAsGlobalTarget(gd) && fd->isDefined() &&
        !dontDefer && !isForDefinition) {
      if (const FunctionDecl *fdDef = fd->getDefinition()) {
        GlobalDecl gdDef;
        if (const auto *cd = dyn_cast<CXXConstructorDecl>(fdDef))
          gdDef = GlobalDecl(cd, gd.getCtorType());
        else if (const auto *dd = dyn_cast<CXXDestructorDecl>(fdDef))
          gdDef = GlobalDecl(dd, gd.getDtorType());
        else
          gdDef = GlobalDecl(fdDef);
        emitGlobal(gdDef);
      }
    }

    // Any attempts to use a MultiVersion function should result in retrieving
    // the iFunc instead. Name mangling will handle the rest of the changes.
+10 −0
Original line number Diff line number Diff line
@@ -16,6 +16,7 @@
#include "CIRGenBuilder.h"
#include "CIRGenCUDARuntime.h"
#include "CIRGenCall.h"
#include "CIRGenOpenMPRuntime.h"
#include "CIRGenTypeCache.h"
#include "CIRGenTypes.h"
#include "CIRGenVTables.h"
@@ -95,6 +96,9 @@ private:
  /// Holds the CUDA runtime
  std::unique_ptr<CIRGenCUDARuntime> cudaRuntime;

  /// Holds the OpenMP runtime
  std::unique_ptr<CIRGenOpenMPRuntime> openMPRuntime;

  /// Per-function codegen information. Updated everytime emitCIR is called
  /// for FunctionDecls's.
  CIRGenFunction *curCGF = nullptr;
@@ -130,6 +134,7 @@ private:
  std::vector<const CXXRecordDecl *> opportunisticVTables;

  void createCUDARuntime();
  void createOpenMPRuntime();

  /// A helper for constructAttributeList that handles return attributes.
  void constructFunctionReturnAttributes(const CIRGenFunctionInfo &info,
@@ -741,6 +746,11 @@ public:
    return *cudaRuntime;
  }

  CIRGenOpenMPRuntime &getOpenMPRuntime() {
    assert(openMPRuntime != nullptr);
    return *openMPRuntime;
  }

  mlir::IntegerAttr getSize(CharUnits size) {
    return builder.getSizeFromCharUnits(size);
  }
+194 −0
Original line number Diff line number Diff line
//===--- CIRGenOpenMPRuntime.cpp - OpenMP code generation helpers ------=--===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Helpers for OpenMP-specific CIR code generation.
//
////===--------------------------------------------------------------------===//

#include "CIRGenOpenMPRuntime.h"
#include "CIRGenModule.h"

#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "mlir/Dialect/OpenMP/OpenMPInterfaces.h"
#include "clang/AST/OpenMPClause.h"

using namespace clang;
using namespace clang::CIRGen;

static mlir::omp::DeclareTargetDeviceType
convertDeviceType(OMPDeclareTargetDeclAttr::DevTypeTy devTy) {
  switch (devTy) {
  case OMPDeclareTargetDeclAttr::DT_Host:
    return mlir::omp::DeclareTargetDeviceType::host;
  case OMPDeclareTargetDeclAttr::DT_NoHost:
    return mlir::omp::DeclareTargetDeviceType::nohost;
  case OMPDeclareTargetDeclAttr::DT_Any:
    return mlir::omp::DeclareTargetDeviceType::any;
  }
  llvm_unreachable("unexpected device type");
}

static mlir::omp::DeclareTargetCaptureClause
convertCaptureClause(OMPDeclareTargetDeclAttr::MapTypeTy mapTy) {
  switch (mapTy) {
  case OMPDeclareTargetDeclAttr::MT_To:
    return mlir::omp::DeclareTargetCaptureClause::to;
  case OMPDeclareTargetDeclAttr::MT_Enter:
    return mlir::omp::DeclareTargetCaptureClause::enter;
  case OMPDeclareTargetDeclAttr::MT_Link:
    return mlir::omp::DeclareTargetCaptureClause::link;
  case OMPDeclareTargetDeclAttr::MT_Local:
    return mlir::omp::DeclareTargetCaptureClause::none;
  }
  llvm_unreachable("unexpected map type");
}

/// Returns true if the declaration should be skipped based on its
/// device_type attribute and the current compilation mode.
static bool isAssumedToBeNotEmitted(const ValueDecl *vd, bool isDevice) {
  std::optional<OMPDeclareTargetDeclAttr::DevTypeTy> devTy =
      OMPDeclareTargetDeclAttr::getDeviceType(vd);
  if (!devTy)
    return false;
  // Do not emit device_type(nohost) functions for the host.
  if (!isDevice && *devTy == OMPDeclareTargetDeclAttr::DT_NoHost)
    return true;
  // Do not emit device_type(host) functions for the device.
  if (isDevice && *devTy == OMPDeclareTargetDeclAttr::DT_Host)
    return true;
  return false;
}

/// Recursively check whether the statement tree contains any OpenMP target
/// execution directive (e.g. 'omp target', 'omp target parallel', etc.).
/// Used to identify host functions that must be emitted on the device because
/// they contain target regions that will be outlined during MLIR lowering.
static bool containsTargetRegion(const Stmt *s) {
  if (!s)
    return false;
  if (const auto *e = dyn_cast<OMPExecutableDirective>(s))
    if (isOpenMPTargetExecutionDirective(e->getDirectiveKind()))
      return true;
  for (const Stmt *child : s->children())
    if (containsTargetRegion(child))
      return true;
  return false;
}

bool CIRGenOpenMPRuntime::emitTargetFunctions(GlobalDecl gd) {
  bool isDevice = cgm.getLangOpts().OpenMPIsTargetDevice;

  if (!isDevice) {
    if (const auto *fd = dyn_cast<FunctionDecl>(gd.getDecl()))
      if (isAssumedToBeNotEmitted(cast<ValueDecl>(fd), isDevice))
        return true;
    return false;
  }

  const auto *vd = cast<ValueDecl>(gd.getDecl());

  if (const auto *fd = dyn_cast<FunctionDecl>(vd))
    if (isAssumedToBeNotEmitted(cast<ValueDecl>(fd), isDevice))
      return true;

  // Do not emit function if it is not marked as declare target.
  if (OMPDeclareTargetDeclAttr::isDeclareTargetDeclaration(vd) ||
      alreadyEmittedTargetDecls.count(vd) != 0)
    return false;

  // We must also host functions that contain target regions,
  // because the omp.target ops are nested inside the host function rather than
  // being outlined early. The containsTargetRegion check handles this.
  if (const auto *fd = dyn_cast<FunctionDecl>(vd))
    if (fd->doesThisDeclarationHaveABody() &&
        containsTargetRegion(fd->getBody()))
      return false;

  return true;
}

bool CIRGenOpenMPRuntime::emitTargetGlobalVariable(GlobalDecl gd) {
  if (isAssumedToBeNotEmitted(cast<ValueDecl>(gd.getDecl()),
                              cgm.getLangOpts().OpenMPIsTargetDevice))
    return true;

  if (!cgm.getLangOpts().OpenMPIsTargetDevice)
    return false;

  // We do not need to scan for target regions since there's not early
  // outlining like in OGCG, they will be emitted as omp.target ops instead.

  const auto *vd = cast<VarDecl>(gd.getDecl());

  // Do not emit variable if it is not marked as declare target.
  // OGCG also defers link-clause and USM variables here; we emit errorNYI
  // for those since they are not yet supported.
  std::optional<OMPDeclareTargetDeclAttr::MapTypeTy> res =
      OMPDeclareTargetDeclAttr::isDeclareTargetDeclaration(vd);
  if (!res || *res == OMPDeclareTargetDeclAttr::MT_Link ||
      ((*res == OMPDeclareTargetDeclAttr::MT_To ||
        *res == OMPDeclareTargetDeclAttr::MT_Enter) &&
       false /* NYI: HasRequiresUnifiedSharedMemory */)) {
    if (res && *res == OMPDeclareTargetDeclAttr::MT_Link)
      cgm.errorNYI(vd->getSourceRange(),
                   "declare target global variable with link clause");
    // OGCG defers these variables for later emission. We skip them for now.
    return true;
  }
  return false;
}

// Mirrors CGOpenMPRuntime::emitTargetGlobal.
bool CIRGenOpenMPRuntime::emitTargetGlobal(GlobalDecl gd) {
  if (isa<FunctionDecl>(gd.getDecl()) ||
      isa<OMPDeclareReductionDecl>(gd.getDecl()))
    return emitTargetFunctions(gd);

  return emitTargetGlobalVariable(gd);
}

bool CIRGenOpenMPRuntime::markAsGlobalTarget(GlobalDecl gd) {
  if (!cgm.getLangOpts().OpenMPIsTargetDevice)
    return true;

  const auto *d = cast<FunctionDecl>(gd.getDecl());

  if (OMPDeclareTargetDeclAttr::isDeclareTargetDeclaration(d)) {
    if (d->hasBody() && alreadyEmittedTargetDecls.count(d) == 0) {
      auto f = dyn_cast_if_present<cir::FuncOp>(
          cgm.getGlobalValue(cgm.getMangledName(gd)));
      if (f)
        return !f.isDeclaration();
      return false;
    }
    return true;
  }

  return !alreadyEmittedTargetDecls.insert(d).second;
}

void CIRGenOpenMPRuntime::emitDeclareTargetFunction(const FunctionDecl *fd,
                                                    cir::FuncOp funcOp) {
  const auto *attr = fd->getAttr<OMPDeclareTargetDeclAttr>();
  assert(attr && "expected OMPDeclareTargetDeclAttr");

  // Handles the 'indirect' clause here by creating a global variable to hold
  // the device function address for runtime resolution of indirect calls on
  // the device.
  if (std::optional<OMPDeclareTargetDeclAttr *> activeAttr =
          OMPDeclareTargetDeclAttr::getActiveAttr(fd))
    if ((*activeAttr)->getIndirect())
      cgm.errorNYI(fd->getSourceRange(),
                   "declare target function with indirect clause");

  auto declTargetIface =
      llvm::cast<mlir::omp::DeclareTargetInterface>(funcOp.getOperation());
  declTargetIface.setDeclareTarget(convertDeviceType(attr->getDevType()),
                                   convertCaptureClause(attr->getMapType()),
                                   /*automap=*/false);
}
+58 −0
Original line number Diff line number Diff line
//===--- CIRGenOpenMPRuntime.h - OpenMP code generation helpers -----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef LLVM_CLANG_LIB_CIR_CODEGEN_CIRGENOPENMPRUNTIME_H
#define LLVM_CLANG_LIB_CIR_CODEGEN_CIRGENOPENMPRUNTIME_H

#include "clang/AST/DeclOpenMP.h"
#include "clang/AST/GlobalDecl.h"
#include "clang/AST/StmtOpenMP.h"
#include "clang/CIR/Dialect/IR/CIRDialect.h"
#include "llvm/ADT/DenseSet.h"

namespace clang::CIRGen {

class CIRGenModule;

class CIRGenOpenMPRuntime {
  CIRGenModule &cgm;

  /// Declarations that have been force-emitted for the target device because
  /// they are transitively referenced from declare target functions.
  llvm::DenseSet<CanonicalDeclPtr<const Decl>> alreadyEmittedTargetDecls;

  /// Returns false if the given function or declare reduction should be
  /// emitted. Returns true if it should eb skipped.
  /// emission).
  bool emitTargetFunctions(GlobalDecl gd);

  /// Returns false if given global variable should be emitted. Returns
  /// true if it should be skipped.
  bool emitTargetGlobalVariable(GlobalDecl gd);

public:
  explicit CIRGenOpenMPRuntime(CIRGenModule &cgm) : cgm(cgm) {}

  /// Check whether the given GlobalDecl needs special handling for device
  /// compilation. Returns false if it should be emitted, true if it should be
  /// skipped.
  bool emitTargetGlobal(GlobalDecl gd);

  /// Mark a function reference as one that should be emitted on the device.
  /// Returns false if it should be emitted, true if the function is already
  /// handled and should be skipped.
  bool markAsGlobalTarget(GlobalDecl gd);

  /// If the function has an OMPDeclareTargetDeclAttr, set the corresponding
  /// omp.declare_target attribute on the emitted cir.func op.
  void emitDeclareTargetFunction(const FunctionDecl *fd, cir::FuncOp funcOp);
};

} // namespace clang::CIRGen

#endif // LLVM_CLANG_LIB_CIR_CODEGEN_CIRGENOPENMPRUNTIME_H
+1 −0
Original line number Diff line number Diff line
@@ -43,6 +43,7 @@ add_clang_library(clangCIR
  CIRGenOpenACCClause.cpp
  CIRGenOpenACCRecipe.cpp
  CIRGenOpenMPClause.cpp
  CIRGenOpenMPRuntime.cpp
  CIRGenPointerAuth.cpp
  CIRGenRecordLayoutBuilder.cpp
  CIRGenStmt.cpp
Loading