Unverified Commit 1531e01e authored by Pavol Rusnak's avatar Pavol Rusnak Committed by GitHub
Browse files

python3Packages.cupy: add aarch64-linux support (#378958)

parents dbc08b5b 23b117dd
Loading
Loading
Loading
Loading
+3 −3
Original line number Diff line number Diff line
@@ -7,7 +7,7 @@
      {
        version = "8.9.5.30";
        minCudaVersion = "12.0";
        maxCudaVersion = "12.2";
        maxCudaVersion = "12.4";
        url = "https://developer.download.nvidia.com/compute/cudnn/redist/cudnn/linux-aarch64/cudnn-linux-aarch64-8.9.5.30_cuda12-archive.tar.xz";
        hash = "sha256-BJH3sC9VwiB362eL8xTB+RdSS9UHz1tlgjm/mKRyM6E=";
      }
@@ -75,7 +75,7 @@
      {
        version = "8.9.7.29";
        minCudaVersion = "12.0";
        maxCudaVersion = "12.2";
        maxCudaVersion = "12.4";
        url = "https://developer.download.nvidia.com/compute/cudnn/redist/cudnn/linux-sbsa/cudnn-linux-sbsa-8.9.7.29_cuda12-archive.tar.xz";
        hash = "sha256-6Yt8gAEHheXVygHuTOm1sMjHNYfqb4ZIvjTT+NHUe9E=";
      }
@@ -183,7 +183,7 @@
      {
        version = "8.9.7.29";
        minCudaVersion = "12.0";
        maxCudaVersion = "12.2";
        maxCudaVersion = "12.4";
        url = "https://developer.download.nvidia.com/compute/cudnn/redist/cudnn/linux-x86_64/cudnn-linux-x86_64-8.9.7.29_cuda12-archive.tar.xz";
        hash = "sha256-R1MzYlx+QqevPKCy91BqEG4wyTsaoAgc2cE++24h47s=";
      }
+40 −27
Original line number Diff line number Diff line
{
  lib,
  stdenv,
  buildPythonPackage,
  fetchFromGitHub,
  cython_0,
@@ -13,11 +14,19 @@
  addDriverRunpath,
  pythonOlder,
  symlinkJoin,
  fetchpatch
}:

let
  inherit (cudaPackages) cudnn cutensor nccl;
  inherit (cudaPackages) cudnn;

  shouldUsePkg =
    pkg: if pkg != null && lib.meta.availableOn stdenv.hostPlatform pkg then pkg else null;

  # some packages are not available on all platforms
  cuda_nvprof = shouldUsePkg (cudaPackages.nvprof or null);
  cutensor = shouldUsePkg (cudaPackages.cutensor or null);
  nccl = shouldUsePkg (cudaPackages.nccl or null);

  outpaths = with cudaPackages; [
    cuda_cccl # <nv/target>
    cuda_cudart
@@ -37,7 +46,14 @@ let
  ];
  cudatoolkit-joined = symlinkJoin {
    name = "cudatoolkit-joined-${cudaPackages.cudaVersion}";
    paths = outpaths ++ lib.concatMap (f: lib.map f outpaths) [lib.getLib lib.getDev (lib.getOutput "static") (lib.getOutput "stubs")];
    paths =
      outpaths
      ++ lib.concatMap (f: lib.map f outpaths) [
        lib.getLib
        lib.getDev
        (lib.getOutput "static")
        (lib.getOutput "stubs")
      ];
  };
in
buildPythonPackage rec {
@@ -47,6 +63,8 @@ buildPythonPackage rec {

  disabled = pythonOlder "3.7";

  stdenv = cudaPackages.backendStdenv;

  src = fetchFromGitHub {
    owner = "cupy";
    repo = "cupy";
@@ -55,14 +73,6 @@ buildPythonPackage rec {
    fetchSubmodules = true;
  };

  patches = [
    (fetchpatch {
      url =
        "https://github.com/cfhammill/cupy/commit/67526c756e4a0a70f0420bf0e7f081b8a35a8ee5.patch";
      hash = "sha256-WZgexBdM9J0ep5s+9CGZriVq0ZidCRccox+g0iDDywQ=";
    })
  ];

  # See https://docs.cupy.dev/en/v10.2.0/reference/environment.html. Seting both
  # CUPY_NUM_BUILD_JOBS and CUPY_NUM_NVCC_THREADS to NIX_BUILD_CORES results in
  # a small amount of thrashing but it turns out there are a large number of
@@ -118,7 +128,10 @@ buildPythonPackage rec {
    homepage = "https://cupy.chainer.org/";
    changelog = "https://github.com/cupy/cupy/releases/tag/v${version}";
    license = licenses.mit;
    platforms = [ "x86_64-linux" ];
    platforms = [
      "aarch64-linux"
      "x86_64-linux"
    ];
    maintainers = with maintainers; [ hyphon81 ];
  };
}
+1 −2
Original line number Diff line number Diff line
@@ -2832,8 +2832,7 @@ self: super: with self; {
  cufflinks = callPackage ../development/python-modules/cufflinks { };
  cupy = callPackage ../development/python-modules/cupy {
    # cupy 12.2.0 possibly incompatible with cutensor 2.0 that comes with cudaPackages_12
    cudaPackages = pkgs.cudaPackages_11.overrideScope (cu-fi: _: {
    cudaPackages = pkgs.cudaPackages.overrideScope (cu-fi: _: {
      # CuDNN 9 is not supported:
      # https://github.com/cupy/cupy/issues/8215
      cudnn = cu-fi.cudnn_8_9;