Commit be983daa authored by Gaetan Lepage's avatar Gaetan Lepage
Browse files

cudaPackages.libnvshmem: make the build customizable

parent 0ed85b55
Loading
Loading
Loading
Loading
+60 −37
Original line number Diff line number Diff line
@@ -27,6 +27,14 @@
  ucx,
  # passthru.updateScript
  gitUpdater,

  withGdrcopy ? true,
  withIbgda ? true,
  withLibfabric ? true,
  withMpi ? true,
  withNccl ? true,
  withPmix ? true,
  withUcx ? true,
}:
let
  inherit (lib)
@@ -38,6 +46,8 @@ let
    getLib
    licenses
    maintainers
    optional
    optionals
    teams
    ;
in
@@ -64,13 +74,14 @@ backendStdenv.mkDerivation (finalAttrs: {
    cmake
    ninja

    # NOTE: mpi is in nativeBuildInputs because it contains compilers and is only discoverable by CMake
    # when a nativeBuildInput.
    mpi

    # NOTE: Python is required even if not building nvshmem4py:
    # https://github.com/NVIDIA/nvshmem/blob/131da55f643ac87c810ba0bc51d359258bf433a1/CMakeLists.txt#L173
    python3Packages.python
  ]
  ++ optionals withMpi [
    # NOTE: mpi is in nativeBuildInputs because it contains compilers and is only discoverable by CMake
    # when a nativeBuildInput.
    mpi
  ];

  # NOTE: Hardcoded standard versions mean CMake doesn't respect values we provide, so we need to patch the files.
@@ -98,12 +109,22 @@ backendStdenv.mkDerivation (finalAttrs: {
    cuda_nvml_dev
    cuda_nvrtc
    cuda_nvtx
    gdrcopy
    libfabric
    libnvjitlink
    rdma-core
  ]
  ++ optionals withLibfabric [
    libfabric
  ]
  ++ optionals withGdrcopy [
    gdrcopy
  ]
  ++ optionals withNccl [
    nccl
  ]
  ++ optionals withPmix [
    pmix
    rdma-core
  ]
  ++ optionals withUcx [
    ucx
  ];

@@ -113,7 +134,8 @@ backendStdenv.mkDerivation (finalAttrs: {
  env.CUDA_HOME = (getBin cuda_nvcc).outPath;

  # https://docs.nvidia.com/nvshmem/release-notes-install-guide/install-guide/nvshmem-install-proc.html#other-distributions
  cmakeFlags = [
  cmakeFlags = lib.concatLists [
    [
      (cmakeFeature "NVSHMEM_PREFIX" (placeholder "out"))

      (cmakeFeature "CUDA_HOME" (getBin cuda_nvcc).outPath)
@@ -121,29 +143,6 @@ backendStdenv.mkDerivation (finalAttrs: {

      (cmakeFeature "CMAKE_CUDA_ARCHITECTURES" flags.cmakeCudaArchitecturesString)

    (cmakeBool "NVSHMEM_USE_NCCL" true)
    (cmakeFeature "NCCL_HOME" (getDev nccl).outPath)

    (cmakeBool "NVSHMEM_USE_GDRCOPY" true)
    (cmakeFeature "GDRCOPY_HOME" (getDev gdrcopy).outPath)

    # NOTE: Make sure to use mpi from buildPackages to match the spliced version created through nativeBuildInputs.
    (cmakeBool "NVSHMEM_MPI_SUPPORT" true)
    (cmakeFeature "MPI_HOME" (getLib buildPackages.mpi).outPath)

    # TODO: Doesn't UCX need to be built with some argument when we want to use it with libnvshmem?
    (cmakeBool "NVSHMEM_UCX_SUPPORT" true)
    (cmakeFeature "UCX_HOME" (getDev ucx).outPath)

    (cmakeBool "NVSHMEM_LIBFABRIC_SUPPORT" true)
    (cmakeFeature "LIBFABRIC_HOME" (getDev libfabric).outPath)

    (cmakeBool "NVSHMEM_IBGDA_SUPPORT" true)
    # NOTE: no corresponding _HOME variable for IBGDA.

    (cmakeBool "NVSHMEM_PMIX_SUPPORT" true)
    (cmakeFeature "PMIX_HOME" (getDev pmix).outPath)

      (cmakeBool "NVSHMEM_BUILD_TESTS" true)
      (cmakeBool "NVSHMEM_BUILD_EXAMPLES" true)

@@ -156,6 +155,30 @@ backendStdenv.mkDerivation (finalAttrs: {

      # NOTE: unsupported because it requires Clang
      (cmakeBool "NVSHMEM_BUILD_BITCODE_LIBRARY" false)
    ]

    [ (cmakeBool "NVSHMEM_USE_NCCL" withNccl) ]
    (optional withNccl (cmakeFeature "NCCL_HOME" (getDev nccl).outPath))

    [ (cmakeBool "NVSHMEM_USE_GDRCOPY" withGdrcopy) ]
    (optional withGdrcopy (cmakeFeature "GDRCOPY_HOME" (getDev gdrcopy).outPath))

    # NOTE: Make sure to use mpi from buildPackages to match the spliced version created through nativeBuildInputs.
    [ (cmakeBool "NVSHMEM_MPI_SUPPORT" withMpi) ]
    (optional withMpi (cmakeFeature "MPI_HOME" (getLib buildPackages.mpi).outPath))

    # TODO: Doesn't UCX need to be built with some argument when we want to use it with libnvshmem?
    [ (cmakeBool "NVSHMEM_UCX_SUPPORT" withUcx) ]
    (optional withUcx (cmakeFeature "UCX_HOME" (getDev ucx).outPath))

    [ (cmakeBool "NVSHMEM_LIBFABRIC_SUPPORT" withLibfabric) ]
    (optional withLibfabric (cmakeFeature "LIBFABRIC_HOME" (getDev libfabric).outPath))

    # NOTE: no corresponding _HOME variable for IBGDA.
    [ (cmakeBool "NVSHMEM_IBGDA_SUPPORT" withIbgda) ]

    [ (cmakeBool "NVSHMEM_PMIX_SUPPORT" withPmix) ]
    (optional withPmix (cmakeFeature "PMIX_HOME" (getDev pmix).outPath))
  ];

  postInstall = ''