Commit 82018339 authored by Someone Serge's avatar Someone Serge
Browse files

treewide: cuda: use propagatedBuildInputs, lib.getOutput

parent f1ddae47
Loading
Loading
Loading
Loading
+3 −6
Original line number Diff line number Diff line
@@ -47,12 +47,9 @@ stdenv.mkDerivation (finalAttrs: {
  ] ++ lib.optionals cudaSupport (
      with cudaPackages;
      [
        cuda_cccl.dev
        cuda_cudart.dev
        cuda_cudart.lib
        cuda_cudart.static
        libcublas.dev
        libcublas.lib
        cuda_cccl
        cuda_cudart
        libcublas
      ]);

  cmakeFlags = [
+3 −7
Original line number Diff line number Diff line
@@ -46,16 +46,12 @@ let
    ++ optionals metalSupport [ MetalKit ];

   cudaBuildInputs = with cudaPackages; [
    cuda_cccl.dev # <nv/target>
    cuda_cccl # <nv/target>

    # A temporary hack for reducing the closure size, remove once cudaPackages
    # have stopped using lndir: https://github.com/NixOS/nixpkgs/issues/271792
    cuda_cudart.dev
    cuda_cudart.lib
    cuda_cudart.static
    libcublas.dev
    libcublas.lib
    libcublas.static
    cuda_cudart
    libcublas
  ];

  rocmBuildInputs = with rocmPackages; [
+8 −9
Original line number Diff line number Diff line
@@ -101,12 +101,12 @@ let
  };

  cudaToolkit = buildEnv {
    name = "cuda-toolkit";
    ignoreCollisions = true; # FIXME: find a cleaner way to do this without ignoring collisions
    name = "cuda-merged";
    paths = [
      cudaPackages.cudatoolkit
      cudaPackages.cuda_cudart
      cudaPackages.cuda_cudart.static
      (lib.getBin (cudaPackages.cuda_nvcc.__spliced.buildHost or cudaPackages.cuda_nvcc))
      (lib.getLib cudaPackages.cuda_cudart)
      (lib.getOutput "static" cudaPackages.cuda_cudart)
      (lib.getLib cudaPackages.libcublas)
    ];
  };

@@ -140,10 +140,6 @@ in
goBuild ((lib.optionalAttrs enableRocm {
  ROCM_PATH = rocmPath;
  CLBlast_DIR = "${clblast}/lib/cmake/CLBlast";
}) // (lib.optionalAttrs enableCuda {
  CUDA_LIB_DIR = "${cudaToolkit}/lib";
  CUDACXX = "${cudaToolkit}/bin/nvcc";
  CUDAToolkit_ROOT = cudaToolkit;
}) // {
  inherit pname version src vendorHash;

@@ -151,6 +147,8 @@ goBuild ((lib.optionalAttrs enableRocm {
    cmake
  ] ++ lib.optionals enableRocm [
    rocmPackages.llvm.bintools
  ] ++ lib.optionals enableCuda [
    cudaPackages.cuda_nvcc
  ] ++ lib.optionals (enableRocm || enableCuda) [
    makeWrapper
  ] ++ lib.optionals stdenv.isDarwin
@@ -160,6 +158,7 @@ goBuild ((lib.optionalAttrs enableRocm {
    (rocmLibs ++ [ libdrm ])
  ++ lib.optionals enableCuda [
    cudaPackages.cuda_cudart
    cudaPackages.libcublas
  ] ++ lib.optionals stdenv.isDarwin
    metalFrameworks;

+11 −11
Original line number Diff line number Diff line
@@ -44,7 +44,7 @@ filterAndCreateOverrides {
    }:
    prevAttrs: {
      buildInputs = prevAttrs.buildInputs ++ [
        libcublas.lib
        libcublas
        numactl
        rdma-core
      ];
@@ -66,17 +66,17 @@ filterAndCreateOverrides {
      buildInputs =
        prevAttrs.buildInputs
        # Always depends on this
        ++ [ libcublas.lib ]
        ++ [ libcublas ]
        # Dependency from 12.0 and on
        ++ lib.lists.optionals (cudaAtLeast "12.0") [ libnvjitlink.lib ]
        ++ lib.lists.optionals (cudaAtLeast "12.0") [ libnvjitlink ]
        # Dependency from 12.1 and on
        ++ lib.lists.optionals (cudaAtLeast "12.1") [ libcusparse.lib ];
        ++ lib.lists.optionals (cudaAtLeast "12.1") [ libcusparse ];

      brokenConditions = prevAttrs.brokenConditions // {
        "libnvjitlink missing (CUDA >= 12.0)" =
          !(cudaAtLeast "12.0" -> (libnvjitlink != null && libnvjitlink.lib != null));
          !(cudaAtLeast "12.0" -> (libnvjitlink != null && libnvjitlink != null));
        "libcusparse missing (CUDA >= 12.1)" =
          !(cudaAtLeast "12.1" -> (libcusparse != null && libcusparse.lib != null));
          !(cudaAtLeast "12.1" -> (libcusparse != null && libcusparse != null));
      };
    };

@@ -90,16 +90,16 @@ filterAndCreateOverrides {
      buildInputs =
        prevAttrs.buildInputs
        # Dependency from 12.0 and on
        ++ lib.lists.optionals (cudaAtLeast "12.0") [ libnvjitlink.lib ];
        ++ lib.lists.optionals (cudaAtLeast "12.0") [ libnvjitlink ];

      brokenConditions = prevAttrs.brokenConditions // {
        "libnvjitlink missing (CUDA >= 12.0)" =
          !(cudaAtLeast "12.0" -> (libnvjitlink != null && libnvjitlink.lib != null));
          !(cudaAtLeast "12.0" -> (libnvjitlink != null && libnvjitlink != null));
      };
    };

  # TODO(@connorbaker): cuda_cudart.dev depends on crt/host_config.h, which is from
  # cuda_nvcc.dev. It would be nice to be able to encode that.
  # (getDev cuda_nvcc). It would be nice to be able to encode that.
  cuda_cudart =
    { addDriverRunpath, lib }:
    prevAttrs: {
@@ -248,8 +248,8 @@ filterAndCreateOverrides {
    prevAttrs: {
      buildInputs = prevAttrs.buildInputs ++ [
        freeglut
        libcufft.lib
        libcurand.lib
        libcufft
        libcurand
        libGLU
        libglvnd
        mesa
+0 −1
Original line number Diff line number Diff line
@@ -307,7 +307,6 @@ backendStdenv.mkDerivation (finalAttrs: {
  # Make the CUDA-patched stdenv available
  passthru.stdenv = backendStdenv;


  meta = {
    description = "${redistribRelease.name}. By downloading and using the packages you accept the terms and conditions of the ${finalAttrs.meta.license.shortName}";
    sourceProvenance = [ sourceTypes.binaryNativeCode ];
Loading