Loading pkgs/development/python-modules/transformer-engine/cuda-libs-paths.patch 0 → 100644 +145 −0 Original line number Diff line number Diff line diff --git a/transformer_engine/common/__init__.py b/transformer_engine/common/__init__.py index 02388d2e..f2eb337c 100644 --- a/transformer_engine/common/__init__.py +++ b/transformer_engine/common/__init__.py @@ -239,117 +239,7 @@ def _get_sys_extension() -> str: def _nvidia_cudart_include_dir() -> str: """Returns the include directory for cuda_runtime.h if exists in python environment.""" - try: - import nvidia - except ModuleNotFoundError: - return "" - - # Installing some nvidia-* packages, like nvshmem, create nvidia name, so "import nvidia" - # above doesn't through. However, they don't set "__file__" attribute. - if nvidia.__file__ is None: - return "" - - include_dir = Path(nvidia.__file__).parent / "cuda_runtime" - return str(include_dir) if include_dir.exists() else "" - - -@functools.lru_cache(maxsize=None) -def _load_cuda_library_from_python(lib_name: str, strict: bool = False): - """ - Attempts to load shared object file installed via python packages. - - `lib_name` : Name of package as found in the `nvidia` dir in python environment. - `strict` : If set to `True`, throw an error if lib is not found. - """ - - ext = _get_sys_extension() - nvidia_dir = os.path.join(sysconfig.get_path("purelib"), "nvidia") - - # PyPI packages provided by nvidia libs exist - # in 4 possible locations inside `nvidia`. - # Check by order of priority. - path_found = False - if os.path.isdir(os.path.join(nvidia_dir, "cu13", lib_name)): - so_paths = glob.glob(os.path.join(nvidia_dir, "cu13", lib_name, f"lib/lib*{ext}.*[0-9]")) - path_found = len(so_paths) > 0 - - if not path_found and os.path.isdir(os.path.join(nvidia_dir, "cu13")): - so_paths = glob.glob(os.path.join(nvidia_dir, "cu13", f"lib/lib{lib_name}*{ext}.*[0-9]")) - path_found = len(so_paths) > 0 - - if not path_found and os.path.isdir(os.path.join(nvidia_dir, lib_name)): - so_paths = glob.glob(os.path.join(nvidia_dir, lib_name, f"lib/lib*{ext}.*[0-9]")) - path_found = len(so_paths) > 0 - - if not path_found: - so_paths = glob.glob(os.path.join(nvidia_dir, f"cuda_{lib_name}", f"lib/lib*{ext}.*[0-9]")) - path_found = len(so_paths) > 0 - - ctypes_handles = [] - - if path_found: - for so_path in so_paths: - ctypes_handles.append(ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL)) - - if strict and not path_found: - raise RuntimeError(f"{lib_name} shared object not found.") - - return path_found, ctypes_handles - - -@functools.lru_cache(maxsize=None) -def _load_cuda_library_from_system(lib_name: str): - """ - Attempts to load shared object file installed via system/cuda-toolkit. - - `lib_name`: Name of library to load without extension or `lib` prefix. - """ - - # Where to look for the shared lib in decreasing order of preference. - paths = ( - os.environ.get(f"{lib_name.upper()}_HOME"), - os.environ.get(f"{lib_name.upper()}_PATH"), - os.environ.get("CUDA_HOME"), - os.environ.get("CUDA_PATH"), - "/usr/local/cuda", - ) - - for path in paths: - if path is None: - continue - libs = glob.glob(f"{path}/**/lib{lib_name}{_get_sys_extension()}*", recursive=True) - libs = [lib for lib in libs if "stub" not in lib] - libs.sort(reverse=True, key=os.path.basename) - if libs: - return True, ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL) - - # Search in LD_LIBRARY_PATH. - try: - _lib_handle = ctypes.CDLL(f"lib{lib_name}{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL) - return True, _lib_handle - except OSError: - return False, None - - -@functools.lru_cache(maxsize=None) -def _load_cuda_library(lib_name: str): - """ - Load given shared library. - Prioritize loading from system/toolkit - before checking python packages. - """ - - # Attempt to locate library in system. - found, handle = _load_cuda_library_from_system(lib_name) - if found: - return True, handle - - # Attempt to locate library in Python dist-packages. - found, handle = _load_cuda_library_from_python(lib_name) - if found: - return False, handle - - raise RuntimeError(f"{lib_name} shared object not found.") + return "@cudart_include_dir@" @functools.lru_cache(maxsize=None) @@ -364,18 +254,9 @@ if "NVTE_PROJECT_BUILDING" not in os.environ or bool(int(os.getenv("NVTE_RELEASE # `_load_cuda_library` is used for packages that must be loaded # during runtime. Both system and pypi packages are searched # and an error is thrown if not found. - _, _CUDNN_LIB_CTYPES = _load_cuda_library("cudnn") - system_nvrtc, _NVRTC_LIB_CTYPES = _load_cuda_library("nvrtc") - system_curand, _CURAND_LIB_CTYPES = _load_cuda_library("curand") - - # This additional step is necessary to be able to install TE wheels - # and import TE (without any guards) in an environment where the cuda - # toolkit might be absent without being guarded - load_libs_for_no_ctk = not system_nvrtc and not system_curand - if load_libs_for_no_ctk: - _CUBLAS_LIB_CTYPES = _load_cuda_library_from_python("cublas", strict=True) - _CUDART_LIB_CTYPES = _load_cuda_library_from_python("cudart", strict=True) - _CUDNN_ALL_LIB_CTYPES = _load_cuda_library_from_python("cudnn", strict=True) + _CUDNN_LIB_CTYPES = ctypes.CDLL("@libcudnn_so@", mode=ctypes.RTLD_GLOBAL) + _NVRTC_LIB_CTYPES = ctypes.CDLL("@libnvrtc_so@", mode=ctypes.RTLD_GLOBAL) + _CURAND_LIB_CTYPES = ctypes.CDLL("@libcurand_so@", mode=ctypes.RTLD_GLOBAL) _TE_LIB_CTYPES = _load_core_library() pkgs/development/python-modules/transformer-engine/default.nix 0 → 100644 +273 −0 Original line number Diff line number Diff line { lib, config, buildPythonPackage, fetchFromGitHub, replaceVars, fetchpatch, python, cudaPackages, # nativeBuildInputs autoAddDriverRunpath, autoPatchelfHook, mpi, # build-system cmake, ninja, pybind11, setuptools, # jax-only flax, jax, # pytorch-only: torch, # dependencies importlib-metadata, packaging, pydantic, # pytorch-only: einops, nvdlfw-inspect, onnx, onnxscript, cudaSupport ? config.cudaSupport, cudaCapabilities ? if withPytorch then torch.cudaCapabilities else cudaPackages.flags.cudaCapabilities, withMpi ? false, withPytorch ? true, withJax ? true, withNvshmem ? false, }: let inherit (lib) cmakeFeature concatStringsSep getInclude getLib optional optionalString optionals strings subtractLists ; inherit (cudaPackages) backendStdenv flags; frameworks = if (withJax || withPytorch) then concatStringsSep "," (optional withJax "jax" ++ optional withPytorch "pytorch") else "none"; cudaCapabilities' = subtractLists [ # Compilation will fail when providing those architectures: # error: static assertion failed with "Compiled for the generic architecture, while utilizing # family-specific features. # Please compile for smXXXf architecture instead of smXXX architecture." # Providing 10.0 and 12.0 respectively is enough as the CMake file will automatically add the # correct compilation flags for supporting those architectures. "10.3" "12.1" ] cudaCapabilities; in buildPythonPackage.override { stdenv = backendStdenv; } (finalAttrs: { pname = "transformer-engine"; version = "2.12"; pyproject = true; src = fetchFromGitHub { owner = "NVIDIA"; repo = "TransformerEngine"; tag = "v${finalAttrs.version}"; # Their CMakeLists.txt does not easily let us inject dependencies fetchSubmodules = true; hash = "sha256-/e11kacSYPKdjVEKAo3x/CarzKhO3tiTsMjYWLzHbls="; }; patches = optionals cudaSupport [ (replaceVars ./cuda-libs-paths.patch { libcudnn_so = "${getLib cudaPackages.cudnn}/lib/libcudnn.so"; libnvrtc_so = "${getLib cudaPackages.cuda_nvrtc}/lib/libnvrtc.so"; libcurand_so = "${getLib cudaPackages.libcurand}/lib/libcurand.so"; cudart_include_dir = "${getInclude cudaPackages.cuda_cudart}/include"; }) # https://github.com/NVIDIA/TransformerEngine/pull/2832 (fetchpatch { name = "fix-cuda-arch-cmake-logic"; url = "https://github.com/GaetanLepage/TransformerEngine/commit/a3cf63e0d03dd9af1d494854949387f1ae677bf0.patch"; hash = "sha256-g2aIF0fROsExEjuNiyI62/rrCOXYyOjyQIOn6rCrUyI="; }) ] ++ optionals withNvshmem [ # https://github.com/NVIDIA/TransformerEngine/pull/2815 (fetchpatch { name = "fix-nvshmem-build"; url = "https://github.com/NVIDIA/TransformerEngine/commit/e83c09742166dfef3f871cfa1407605feafb3afe.patch"; hash = "sha256-5pf0Dg1XL7oAQjR1JZcdgbeaGj9qw9G5+i9Ac0iff64="; }) ] ++ optionals (withMpi && withJax) [ # https://github.com/NVIDIA/TransformerEngine/pull/2835 (fetchpatch { name = "fix-jax-extension-build-with-mpi"; url = "https://github.com/GaetanLepage/TransformerEngine/commit/f68cd3cab34972a899ad0069e2c4ee806e8bc6fb.patch"; hash = "sha256-u0ljg1FwY0QjR+ETswpzWV+Sbv00JHI5CSrNQ/9zsuA="; }) ]; postPatch = # Patch build-system requirements: # - pybind11[global] doesn't exist in nixpkgs, just use regular pybind11 # - pip is not required for building this package # - torch, jax and flax should not been unconditionally required, but depending on the selected # 'frameworks' '' substituteInPlace pyproject.toml \ --replace-fail "pybind11[global]" "pybind11" \ --replace-fail '"pip", "torch>=2.1", "jax>=0.5.0", "flax>=0.7.1"' "" '' # Harcode the path to the output store path that transformer_engine will use to import # - libtransformer_engine.so # - transformer_engine_jax.cpython-313-x86_64-linux-gnu.so # - transformer_engine_torch.cpython-313-x86_64-linux-gnu.so # This skips their impure find logic. + '' substituteInPlace transformer_engine/common/__init__.py \ --replace-fail \ 'te_path = Path(importlib.util.find_spec("transformer_engine").origin).parent.parent' \ 'te_path = Path("${placeholder "out"}/${python.sitePackages}")' ''; # https://github.com/NVIDIA/TransformerEngine/blob/main/docs/envvars.rst env = { NVTE_RELEASE_BUILD = 0; # Do not include the git commit hash in the version string NVTE_NO_LOCAL_VERSION = 1; # Use the nixpkgs triton package NVTE_USE_PYTORCH_TRITON = 0; NVTE_FRAMEWORK = frameworks; NVTE_CUDA_ARCHS = strings.concatMapStringsSep ";" flags.dropDots cudaCapabilities'; NVTE_CMAKE_EXTRA_ARGS = toString [ (cmakeFeature "CUDNN_FRONTEND_INCLUDE_DIR" "${getInclude cudaPackages.cudnn-frontend}/include") ]; NVTE_UB_WITH_MPI = if withMpi then 1 else 0; # NOTE: Make sure to use mpi from buildPackages to match the spliced version created through nativeBuildInputs. MPI_HOME = optionalString withMpi (getLib mpi).outPath; NVTE_ENABLE_NVSHMEM = if withNvshmem then 1 else 0; NVSHMEM_HOME = optionalString withNvshmem cudaPackages.libnvshmem.outPath; }; build-system = [ cmake ninja pybind11 setuptools ] ++ optionals withJax [ flax jax ] ++ optionals withPytorch [ # Required to build extensions torch ]; dontUseCmakeConfigure = true; nativeBuildInputs = [ autoAddDriverRunpath autoPatchelfHook cudaPackages.cuda_nvcc ] ++ optionals withMpi [ # NOTE: mpi is in nativeBuildInputs because it contains compilers and is only discoverable by # CMake when a nativeBuildInput. mpi ]; buildInputs = [ cudaPackages.cuda_cudart # cuda_runtime.h cudaPackages.cuda_nvml_dev # nvml.h cudaPackages.cuda_nvrtc # nvrtc.h cudaPackages.cuda_nvtx # nvToolsExt.h cudaPackages.cuda_profiler_api # cuda_profiler_api.h cudaPackages.cudnn # cudnn.h cudaPackages.libcublas cudaPackages.libcurand # curand.h cudaPackages.libcusolver # cusolverDn.h cudaPackages.libcusparse # cusparse.h cudaPackages.nccl # nccl.h pybind11 # pybind11/pybind11.h ] ++ optionals withMpi [ mpi # mpi.h ]; runtimeDependencies = optionals withNvshmem [ # libnvshmem is already provided at build time by `$NVSHMEM_HOME` # We add it here so that it gets picked up by autoPatchelfHook (getLib cudaPackages.libnvshmem) ]; preBuild = '' export NVTE_BUILD_MAX_JOBS=$NIX_BUILD_CORES ''; dependencies = [ importlib-metadata packaging pydantic ] ++ optionals withJax [ flax jax ] ++ optionals withPytorch [ einops nvdlfw-inspect onnx onnxscript torch ]; # When built with nvshmem support `dlopen`ing libtransformer_engine.so `dlopen`s # libnvidia-ml.so.1 which is provided by the GPU driver at run time: # OSError: libnvidia-ml.so.1: cannot open shared object file: No such file or directory pythonImportsCheck = optionals (!withNvshmem) ( [ "transformer_engine" ] ++ optionals withJax [ "transformer_engine_jax" ] ++ optionals withPytorch [ "transformer_engine_torch" ] ); # Almost all tests require GPU access doCheck = false; meta = { description = "Library for accelerating Transformer models on NVIDIA GPUs"; homepage = "https://github.com/NVIDIA/TransformerEngine"; changelog = "https://github.com/NVIDIA/TransformerEngine/releases/tag/${finalAttrs.src.tag}"; license = lib.licenses.asl20; maintainers = with lib.maintainers; [ GaetanLepage ]; broken = !cudaSupport; }; }) pkgs/top-level/python-packages.nix +12 −0 Original line number Diff line number Diff line Loading @@ -19606,6 +19606,18 @@ self: super: with self; { transaction = callPackage ../development/python-modules/transaction { }; transformer-engine = callPackage ../development/python-modules/transformer-engine { }; transformer-engine-jax = transformer-engine.override { withJax = true; withPytorch = false; }; transformer-engine-pytorch = transformer-engine.override { withJax = false; withPytorch = true; }; transformers = callPackage ../development/python-modules/transformers { }; transformers_4 = callPackage ../development/python-modules/transformers/4.nix { }; Loading
pkgs/development/python-modules/transformer-engine/cuda-libs-paths.patch 0 → 100644 +145 −0 Original line number Diff line number Diff line diff --git a/transformer_engine/common/__init__.py b/transformer_engine/common/__init__.py index 02388d2e..f2eb337c 100644 --- a/transformer_engine/common/__init__.py +++ b/transformer_engine/common/__init__.py @@ -239,117 +239,7 @@ def _get_sys_extension() -> str: def _nvidia_cudart_include_dir() -> str: """Returns the include directory for cuda_runtime.h if exists in python environment.""" - try: - import nvidia - except ModuleNotFoundError: - return "" - - # Installing some nvidia-* packages, like nvshmem, create nvidia name, so "import nvidia" - # above doesn't through. However, they don't set "__file__" attribute. - if nvidia.__file__ is None: - return "" - - include_dir = Path(nvidia.__file__).parent / "cuda_runtime" - return str(include_dir) if include_dir.exists() else "" - - -@functools.lru_cache(maxsize=None) -def _load_cuda_library_from_python(lib_name: str, strict: bool = False): - """ - Attempts to load shared object file installed via python packages. - - `lib_name` : Name of package as found in the `nvidia` dir in python environment. - `strict` : If set to `True`, throw an error if lib is not found. - """ - - ext = _get_sys_extension() - nvidia_dir = os.path.join(sysconfig.get_path("purelib"), "nvidia") - - # PyPI packages provided by nvidia libs exist - # in 4 possible locations inside `nvidia`. - # Check by order of priority. - path_found = False - if os.path.isdir(os.path.join(nvidia_dir, "cu13", lib_name)): - so_paths = glob.glob(os.path.join(nvidia_dir, "cu13", lib_name, f"lib/lib*{ext}.*[0-9]")) - path_found = len(so_paths) > 0 - - if not path_found and os.path.isdir(os.path.join(nvidia_dir, "cu13")): - so_paths = glob.glob(os.path.join(nvidia_dir, "cu13", f"lib/lib{lib_name}*{ext}.*[0-9]")) - path_found = len(so_paths) > 0 - - if not path_found and os.path.isdir(os.path.join(nvidia_dir, lib_name)): - so_paths = glob.glob(os.path.join(nvidia_dir, lib_name, f"lib/lib*{ext}.*[0-9]")) - path_found = len(so_paths) > 0 - - if not path_found: - so_paths = glob.glob(os.path.join(nvidia_dir, f"cuda_{lib_name}", f"lib/lib*{ext}.*[0-9]")) - path_found = len(so_paths) > 0 - - ctypes_handles = [] - - if path_found: - for so_path in so_paths: - ctypes_handles.append(ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL)) - - if strict and not path_found: - raise RuntimeError(f"{lib_name} shared object not found.") - - return path_found, ctypes_handles - - -@functools.lru_cache(maxsize=None) -def _load_cuda_library_from_system(lib_name: str): - """ - Attempts to load shared object file installed via system/cuda-toolkit. - - `lib_name`: Name of library to load without extension or `lib` prefix. - """ - - # Where to look for the shared lib in decreasing order of preference. - paths = ( - os.environ.get(f"{lib_name.upper()}_HOME"), - os.environ.get(f"{lib_name.upper()}_PATH"), - os.environ.get("CUDA_HOME"), - os.environ.get("CUDA_PATH"), - "/usr/local/cuda", - ) - - for path in paths: - if path is None: - continue - libs = glob.glob(f"{path}/**/lib{lib_name}{_get_sys_extension()}*", recursive=True) - libs = [lib for lib in libs if "stub" not in lib] - libs.sort(reverse=True, key=os.path.basename) - if libs: - return True, ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL) - - # Search in LD_LIBRARY_PATH. - try: - _lib_handle = ctypes.CDLL(f"lib{lib_name}{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL) - return True, _lib_handle - except OSError: - return False, None - - -@functools.lru_cache(maxsize=None) -def _load_cuda_library(lib_name: str): - """ - Load given shared library. - Prioritize loading from system/toolkit - before checking python packages. - """ - - # Attempt to locate library in system. - found, handle = _load_cuda_library_from_system(lib_name) - if found: - return True, handle - - # Attempt to locate library in Python dist-packages. - found, handle = _load_cuda_library_from_python(lib_name) - if found: - return False, handle - - raise RuntimeError(f"{lib_name} shared object not found.") + return "@cudart_include_dir@" @functools.lru_cache(maxsize=None) @@ -364,18 +254,9 @@ if "NVTE_PROJECT_BUILDING" not in os.environ or bool(int(os.getenv("NVTE_RELEASE # `_load_cuda_library` is used for packages that must be loaded # during runtime. Both system and pypi packages are searched # and an error is thrown if not found. - _, _CUDNN_LIB_CTYPES = _load_cuda_library("cudnn") - system_nvrtc, _NVRTC_LIB_CTYPES = _load_cuda_library("nvrtc") - system_curand, _CURAND_LIB_CTYPES = _load_cuda_library("curand") - - # This additional step is necessary to be able to install TE wheels - # and import TE (without any guards) in an environment where the cuda - # toolkit might be absent without being guarded - load_libs_for_no_ctk = not system_nvrtc and not system_curand - if load_libs_for_no_ctk: - _CUBLAS_LIB_CTYPES = _load_cuda_library_from_python("cublas", strict=True) - _CUDART_LIB_CTYPES = _load_cuda_library_from_python("cudart", strict=True) - _CUDNN_ALL_LIB_CTYPES = _load_cuda_library_from_python("cudnn", strict=True) + _CUDNN_LIB_CTYPES = ctypes.CDLL("@libcudnn_so@", mode=ctypes.RTLD_GLOBAL) + _NVRTC_LIB_CTYPES = ctypes.CDLL("@libnvrtc_so@", mode=ctypes.RTLD_GLOBAL) + _CURAND_LIB_CTYPES = ctypes.CDLL("@libcurand_so@", mode=ctypes.RTLD_GLOBAL) _TE_LIB_CTYPES = _load_core_library()
pkgs/development/python-modules/transformer-engine/default.nix 0 → 100644 +273 −0 Original line number Diff line number Diff line { lib, config, buildPythonPackage, fetchFromGitHub, replaceVars, fetchpatch, python, cudaPackages, # nativeBuildInputs autoAddDriverRunpath, autoPatchelfHook, mpi, # build-system cmake, ninja, pybind11, setuptools, # jax-only flax, jax, # pytorch-only: torch, # dependencies importlib-metadata, packaging, pydantic, # pytorch-only: einops, nvdlfw-inspect, onnx, onnxscript, cudaSupport ? config.cudaSupport, cudaCapabilities ? if withPytorch then torch.cudaCapabilities else cudaPackages.flags.cudaCapabilities, withMpi ? false, withPytorch ? true, withJax ? true, withNvshmem ? false, }: let inherit (lib) cmakeFeature concatStringsSep getInclude getLib optional optionalString optionals strings subtractLists ; inherit (cudaPackages) backendStdenv flags; frameworks = if (withJax || withPytorch) then concatStringsSep "," (optional withJax "jax" ++ optional withPytorch "pytorch") else "none"; cudaCapabilities' = subtractLists [ # Compilation will fail when providing those architectures: # error: static assertion failed with "Compiled for the generic architecture, while utilizing # family-specific features. # Please compile for smXXXf architecture instead of smXXX architecture." # Providing 10.0 and 12.0 respectively is enough as the CMake file will automatically add the # correct compilation flags for supporting those architectures. "10.3" "12.1" ] cudaCapabilities; in buildPythonPackage.override { stdenv = backendStdenv; } (finalAttrs: { pname = "transformer-engine"; version = "2.12"; pyproject = true; src = fetchFromGitHub { owner = "NVIDIA"; repo = "TransformerEngine"; tag = "v${finalAttrs.version}"; # Their CMakeLists.txt does not easily let us inject dependencies fetchSubmodules = true; hash = "sha256-/e11kacSYPKdjVEKAo3x/CarzKhO3tiTsMjYWLzHbls="; }; patches = optionals cudaSupport [ (replaceVars ./cuda-libs-paths.patch { libcudnn_so = "${getLib cudaPackages.cudnn}/lib/libcudnn.so"; libnvrtc_so = "${getLib cudaPackages.cuda_nvrtc}/lib/libnvrtc.so"; libcurand_so = "${getLib cudaPackages.libcurand}/lib/libcurand.so"; cudart_include_dir = "${getInclude cudaPackages.cuda_cudart}/include"; }) # https://github.com/NVIDIA/TransformerEngine/pull/2832 (fetchpatch { name = "fix-cuda-arch-cmake-logic"; url = "https://github.com/GaetanLepage/TransformerEngine/commit/a3cf63e0d03dd9af1d494854949387f1ae677bf0.patch"; hash = "sha256-g2aIF0fROsExEjuNiyI62/rrCOXYyOjyQIOn6rCrUyI="; }) ] ++ optionals withNvshmem [ # https://github.com/NVIDIA/TransformerEngine/pull/2815 (fetchpatch { name = "fix-nvshmem-build"; url = "https://github.com/NVIDIA/TransformerEngine/commit/e83c09742166dfef3f871cfa1407605feafb3afe.patch"; hash = "sha256-5pf0Dg1XL7oAQjR1JZcdgbeaGj9qw9G5+i9Ac0iff64="; }) ] ++ optionals (withMpi && withJax) [ # https://github.com/NVIDIA/TransformerEngine/pull/2835 (fetchpatch { name = "fix-jax-extension-build-with-mpi"; url = "https://github.com/GaetanLepage/TransformerEngine/commit/f68cd3cab34972a899ad0069e2c4ee806e8bc6fb.patch"; hash = "sha256-u0ljg1FwY0QjR+ETswpzWV+Sbv00JHI5CSrNQ/9zsuA="; }) ]; postPatch = # Patch build-system requirements: # - pybind11[global] doesn't exist in nixpkgs, just use regular pybind11 # - pip is not required for building this package # - torch, jax and flax should not been unconditionally required, but depending on the selected # 'frameworks' '' substituteInPlace pyproject.toml \ --replace-fail "pybind11[global]" "pybind11" \ --replace-fail '"pip", "torch>=2.1", "jax>=0.5.0", "flax>=0.7.1"' "" '' # Harcode the path to the output store path that transformer_engine will use to import # - libtransformer_engine.so # - transformer_engine_jax.cpython-313-x86_64-linux-gnu.so # - transformer_engine_torch.cpython-313-x86_64-linux-gnu.so # This skips their impure find logic. + '' substituteInPlace transformer_engine/common/__init__.py \ --replace-fail \ 'te_path = Path(importlib.util.find_spec("transformer_engine").origin).parent.parent' \ 'te_path = Path("${placeholder "out"}/${python.sitePackages}")' ''; # https://github.com/NVIDIA/TransformerEngine/blob/main/docs/envvars.rst env = { NVTE_RELEASE_BUILD = 0; # Do not include the git commit hash in the version string NVTE_NO_LOCAL_VERSION = 1; # Use the nixpkgs triton package NVTE_USE_PYTORCH_TRITON = 0; NVTE_FRAMEWORK = frameworks; NVTE_CUDA_ARCHS = strings.concatMapStringsSep ";" flags.dropDots cudaCapabilities'; NVTE_CMAKE_EXTRA_ARGS = toString [ (cmakeFeature "CUDNN_FRONTEND_INCLUDE_DIR" "${getInclude cudaPackages.cudnn-frontend}/include") ]; NVTE_UB_WITH_MPI = if withMpi then 1 else 0; # NOTE: Make sure to use mpi from buildPackages to match the spliced version created through nativeBuildInputs. MPI_HOME = optionalString withMpi (getLib mpi).outPath; NVTE_ENABLE_NVSHMEM = if withNvshmem then 1 else 0; NVSHMEM_HOME = optionalString withNvshmem cudaPackages.libnvshmem.outPath; }; build-system = [ cmake ninja pybind11 setuptools ] ++ optionals withJax [ flax jax ] ++ optionals withPytorch [ # Required to build extensions torch ]; dontUseCmakeConfigure = true; nativeBuildInputs = [ autoAddDriverRunpath autoPatchelfHook cudaPackages.cuda_nvcc ] ++ optionals withMpi [ # NOTE: mpi is in nativeBuildInputs because it contains compilers and is only discoverable by # CMake when a nativeBuildInput. mpi ]; buildInputs = [ cudaPackages.cuda_cudart # cuda_runtime.h cudaPackages.cuda_nvml_dev # nvml.h cudaPackages.cuda_nvrtc # nvrtc.h cudaPackages.cuda_nvtx # nvToolsExt.h cudaPackages.cuda_profiler_api # cuda_profiler_api.h cudaPackages.cudnn # cudnn.h cudaPackages.libcublas cudaPackages.libcurand # curand.h cudaPackages.libcusolver # cusolverDn.h cudaPackages.libcusparse # cusparse.h cudaPackages.nccl # nccl.h pybind11 # pybind11/pybind11.h ] ++ optionals withMpi [ mpi # mpi.h ]; runtimeDependencies = optionals withNvshmem [ # libnvshmem is already provided at build time by `$NVSHMEM_HOME` # We add it here so that it gets picked up by autoPatchelfHook (getLib cudaPackages.libnvshmem) ]; preBuild = '' export NVTE_BUILD_MAX_JOBS=$NIX_BUILD_CORES ''; dependencies = [ importlib-metadata packaging pydantic ] ++ optionals withJax [ flax jax ] ++ optionals withPytorch [ einops nvdlfw-inspect onnx onnxscript torch ]; # When built with nvshmem support `dlopen`ing libtransformer_engine.so `dlopen`s # libnvidia-ml.so.1 which is provided by the GPU driver at run time: # OSError: libnvidia-ml.so.1: cannot open shared object file: No such file or directory pythonImportsCheck = optionals (!withNvshmem) ( [ "transformer_engine" ] ++ optionals withJax [ "transformer_engine_jax" ] ++ optionals withPytorch [ "transformer_engine_torch" ] ); # Almost all tests require GPU access doCheck = false; meta = { description = "Library for accelerating Transformer models on NVIDIA GPUs"; homepage = "https://github.com/NVIDIA/TransformerEngine"; changelog = "https://github.com/NVIDIA/TransformerEngine/releases/tag/${finalAttrs.src.tag}"; license = lib.licenses.asl20; maintainers = with lib.maintainers; [ GaetanLepage ]; broken = !cudaSupport; }; })
pkgs/top-level/python-packages.nix +12 −0 Original line number Diff line number Diff line Loading @@ -19606,6 +19606,18 @@ self: super: with self; { transaction = callPackage ../development/python-modules/transaction { }; transformer-engine = callPackage ../development/python-modules/transformer-engine { }; transformer-engine-jax = transformer-engine.override { withJax = true; withPytorch = false; }; transformer-engine-pytorch = transformer-engine.override { withJax = false; withPytorch = true; }; transformers = callPackage ../development/python-modules/transformers { }; transformers_4 = callPackage ../development/python-modules/transformers/4.nix { };