Unverified Commit 96432e9d authored by Gaétan Lepage's avatar Gaétan Lepage Committed by GitHub
Browse files

python3Packages.torch: build with nvshmem support (#508278)

parents 79373c01 b8c3e949
Loading
Loading
Loading
Loading
+6 −0
Original line number Diff line number Diff line
@@ -25,6 +25,7 @@
  magma-cuda-static,
  # Use the system NCCL as long as we're targeting CUDA on a supported platform.
  useSystemNccl ? (cudaSupport && cudaPackages.nccl.meta.available || rocmSupport),
  withNvshmem ? (cudaSupport && cudaPackages.libnvshmem.meta.available),
  MPISupport ? false,
  mpi,
  buildDocs ? false,
@@ -469,6 +470,8 @@ buildPythonPackage.override { inherit stdenv; } (finalAttrs: {
    USE_SYSTEM_NCCL = finalAttrs.env.USE_NCCL;
    USE_STATIC_NCCL = finalAttrs.env.USE_NCCL;

    USE_NVSHMEM = setBool withNvshmem;

    # Set the correct Python library path, broken since
    # https://github.com/pytorch/pytorch/commit/3d617333e
    PYTHON_LIB_REL_PATH = "${placeholder "out"}/${python.sitePackages}";
@@ -578,6 +581,9 @@ buildPythonPackage.override { inherit stdenv; } (finalAttrs: {
      (lib.getDev nccl) # Provides nccl.h
      (lib.getOutput "static" nccl) # Provides static library
    ]
    ++ lists.optionals withNvshmem [
      cudaPackages.libnvshmem
    ]
    ++ [
      cuda_profiler_api # <cuda_profiler_api.h>
    ]