Unverified Commit 7c342c55 authored by Gaétan Lepage's avatar Gaétan Lepage Committed by GitHub
Browse files

cudaPackages.cuda_crt: patch math_functions.h signatures,...

cudaPackages.cuda_crt: patch math_functions.h signatures, python3Packages.deep-ep: cleanup, fix build on CUDA>=13.0 (#512796)
parents 76f759a9 6e46f918
Loading
Loading
Loading
Loading
+30 −1
Original line number Diff line number Diff line
{ backendStdenv, buildRedist }:
{
  lib,
  backendStdenv,
  buildRedist,
  glibc,
}:
buildRedist {
  redistName = "cuda";
  pname = "cuda_crt";

  outputs = [ "out" ];

  # Fix compatibility with glibc 2.42:
  # CUDA >= 13.0 fixed sinpi/cospi (using __NV_IEC_60559_FUNCS_EXCEPTION_SPECIFIER), but
  # rsqrt/rsqrtf in math_functions.h still lack noexcept, conflicting with glibc 2.42's
  # declarations.
  postInstall = lib.optionalString (lib.versionAtLeast glibc.version "2.42") ''
    nixLog "Patching math_functions.h rsqrt signatures to match glibc's ones"
    substituteInPlace "''${!outputInclude:?}/include/crt/math_functions.h" \
      --replace-fail \
        "rsqrt(double x);" \
        "rsqrt(double x) noexcept (true);" \
      --replace-fail \
        "rsqrtf(float x);" \
        "rsqrtf(float x) noexcept (true);"

    nixLog "Patching math_functions.hpp rsqrt signatures to match glibc's ones"
    substituteInPlace "''${!outputInclude:?}/include/crt/math_functions.hpp" \
      --replace-fail \
        "__func__(double rsqrt(const double a))" \
        "__func__(double rsqrt(const double a) throw())" \
      --replace-fail \
        "__func__(float rsqrtf(const float a))" \
        "__func__(float rsqrtf(const float a) throw())"
  '';

  brokenAssertions = [
    # TODO(@connorbaker): Build fails on x86 when using pkgsLLVM.
    #  .../include/crt/host_defines.h:67:2:
+6 −0
Original line number Diff line number Diff line
@@ -168,6 +168,11 @@ buildRedist (finalAttrs: {
      # The cospi|sinpi|rsqrt function signatures in include/common/math_functions.h do not match
      # glibc 2.42's.
      # Indeed, there they are declared with noexcept(true) which is not the case in cuda_nvcc.
      # - In CUDA < 13.0, sinpi/cospi/rsqrt all lack exception specifiers.
      # - In CUDA >= 13.0, NVIDIA fixed sinpi/cospi (using __NV_IEC_60559_FUNCS_EXCEPTION_SPECIFIER)
      #   but rsqrt/rsqrtf still lack noexcept, so we only patch those.
      #   As the CRT headers (including math_functions.h) moved to the cuda_crt package, the glibc
      #   2.42 compatibility patch is applied there instead.
      + lib.optionalString (cudaOlder "13.0" && lib.versionAtLeast glibc.version "2.42") ''
        nixLog "Patching math_functions.h signatures to match glibc's ones"
        substituteInPlace "''${!outputInclude:?}/include/crt/math_functions.h" \
@@ -213,6 +218,7 @@ buildRedist (finalAttrs: {
            "__func__(float cospif(const float a))" \
            "__func__(float cospif(const float a) throw())"
      ''

      # Fix clang CUDA compilation: host_defines.h redefines __noinline__ as
      # __attribute__((noinline)), which conflicts with libstdc++ >=12 using
      # __attribute__((__noinline__)) — the macro expands to
+17 −20
Original line number Diff line number Diff line
@@ -8,7 +8,6 @@
  setuptools,

  # env
  symlinkJoin,
  cudaPackages,

  # buildInputs
@@ -20,11 +19,6 @@
  cudaSupport ? config.cudaSupport,
}:
let
  inherit (lib)
    getBin
    getInclude
    ;

  minSupportedCudaCapability = "8.0"; # build fails with 7.5

  minCudaCapability = builtins.head (
@@ -66,31 +60,34 @@ buildPythonPackage.override { inherit (torch) stdenv; } (finalAttrs: {

      DISABLE_SM90_FEATURES = if disableSm90Features then "1" else "0";

      CUDA_HOME = symlinkJoin {
        name = "cuda-redist";
        paths = with cudaPackages; [
          (getBin cuda_nvcc)

          (getInclude cuda_cccl) # <nv/target>
          (getInclude cuda_cudart) # cuda_runtime.h
          (getInclude libcublas) # cublas_v2.h
          (getInclude libcusolver) # cusolverDn.h
          (getInclude libcusparse) # cusparse.h
        ];
      };
      CUDA_HOME = (lib.getBin cudaPackages.cuda_nvcc).outPath;
    }

    # nvshmem must be disabled (unsetting NVSHMEM_DIR) when supporting <9.0 capabilities
    # https://github.com/deepseek-ai/DeepEP/blob/v1.2.1/setup.py#L65
    // lib.optionalAttrs (!disableSm90Features) {
      NVSHMEM_DIR = (getInclude cudaPackages.libnvshmem).outPath;
      NVSHMEM_DIR = (lib.getInclude cudaPackages.libnvshmem).outPath;
    }
  );

  buildInputs = [
    pybind11
    rdma-core
  ];
  ]
  ++ lib.optionals cudaSupport (
    with cudaPackages;
    [
      cuda_cccl # <nv/target>
      cuda_cudart # cuda_runtime.h
      libcublas # cublas_v2.h
      libcusolver # cusolverDn.h
      libcusparse # cusparse.h
    ]
    # On CUDA >=13.0, crt/host_config.h is shipped in cudaPackages.cuda_crt
    ++ lib.optionals cuda_crt.meta.available [
      cuda_crt # crt/host_config.h
    ]
  );

  pythonImportsCheck = [ "deep_ep" ];