Unverified Commit 61d7cc96 authored by Markus Kowalewski's avatar Markus Kowalewski Committed by GitHub
Browse files

ucc: CUDA fixups (#369956)

parents 1d8a3e9a 7844184e
Loading
Loading
Loading
Loading
+73 −35
Original line number Diff line number Diff line
{
  stdenv,
  lib,
  fetchFromGitHub,
  libtool,
  automake,
inputs@{
  autoconf,
  ucx,
  automake,
  config,
  enableCuda ? config.cudaSupport,
  cudaPackages,
  fetchFromGitHub,
  lib,
  libtool,
  stdenv,
  ucx,
  # Configuration options
  enableAvx ? stdenv.hostPlatform.avxSupport,
  enableCuda ? config.cudaSupport,
  enableSse41 ? stdenv.hostPlatform.sse4_1Support,
  enableSse42 ? stdenv.hostPlatform.sse4_2Support,
}:
let
  inherit (lib.attrsets) getLib;
  inherit (lib.lists) optionals;
  inherit (lib.strings) concatStringsSep;

  inherit (cudaPackages)
    cuda_cccl
    cuda_cudart
    cuda_nvcc
    cuda_nvml_dev
    cudaFlags
    nccl
    ;

  stdenv = throw "Use effectiveStdenv instead";
  effectiveStdenv = if enableCuda then cudaPackages.backendStdenv else inputs.stdenv;
in
effectiveStdenv.mkDerivation (finalAttrs: {
  __structuredAttrs = true;
  # TODO(@connorbaker):
  # When strictDeps is enabled, `cuda_nvcc` is required as the argument to `--with-cuda` in `configureFlags` or else
  # configurePhase fails with `checking for cuda_runtime.h... no`.
  # This is odd, especially given `cuda_runtime.h` is provided by `cuda_cudart.dev`, which is already in `buildInputs`.
  strictDeps = true;

stdenv.mkDerivation rec {
  pname = "ucc";
  version = "1.3.0";

  src = fetchFromGitHub {
    owner = "openucx";
    repo = "ucc";
    rev = "v${version}";
    sha256 = "sha256-xcJLYktkxNK2ewWRgm8zH/dMaIoI+9JexuswXi7MpAU=";
    tag = "v${finalAttrs.version}";
    hash = "sha256-xcJLYktkxNK2ewWRgm8zH/dMaIoI+9JexuswXi7MpAU=";
  };

  outputs = [
@@ -32,44 +56,58 @@ stdenv.mkDerivation rec {

  enableParallelBuilding = true;

  # NOTE: We use --replace-quiet because not all Makefile.am files contain /bin/bash.
  postPatch = ''

    for comp in $(find src/components -name Makefile.am); do
      substituteInPlace $comp \
        --replace "/bin/bash" "${stdenv.shell}"
      substituteInPlace "$comp" \
        --replace-quiet \
          "/bin/bash" \
          "${effectiveStdenv.shell}"
    done
  '';

  nativeBuildInputs = [
    libtool
    automake
    autoconf
  ] ++ lib.optionals enableCuda [ cudaPackages.cuda_nvcc ];
    automake
    libtool
  ] ++ optionals enableCuda [ cuda_nvcc ];

  buildInputs =
    [ ucx ]
    ++ lib.optionals enableCuda [
      cudaPackages.cuda_cccl
      cudaPackages.cuda_cudart
    ++ optionals enableCuda [
      cuda_cccl
      cuda_cudart
      cuda_nvml_dev
      nccl
    ];

  preConfigure =
    ''
  # NOTE: With `__structuredAttrs` enabled, `LDFLAGS` must be set under `env` so it is assured to be a string;
  # otherwise, we might have forgotten to convert it to a string and Nix would make LDFLAGS a shell variable
  # referring to an array!
  env.LDFLAGS = builtins.toString (
    optionals enableCuda [
      # Fake libnvidia-ml.so (the real one is deployed impurely)
      "-L${getLib cuda_nvml_dev}/lib/stubs"
    ]
  );

  preConfigure = ''
    ./autogen.sh
    ''
    + lib.optionalString enableCuda ''
      configureFlagsArray+=( "--with-nvcc-gencode=${builtins.concatStringsSep " " cudaPackages.cudaFlags.gencode}" )
  '';

  configureFlags =
    [ ]
    ++ lib.optional enableSse41 "--with-sse41"
    ++ lib.optional enableSse42 "--with-sse42"
    ++ lib.optional enableAvx "--with-avx"
    ++ lib.optional enableCuda "--with-cuda=${cudaPackages.cuda_cudart}";
    optionals enableSse41 [ "--with-sse41" ]
    ++ optionals enableSse42 [ "--with-sse42" ]
    ++ optionals enableAvx [ "--with-avx" ]
    ++ optionals enableCuda [
      "--with-cuda=${cuda_nvcc}"
      "--with-nvcc-gencode=${concatStringsSep " " cudaFlags.gencode}"
    ];

  postInstall = ''
    find $out/lib/ -name "*.la" -exec rm -f \{} \;
    find "$out/lib/" -name "*.la" -exec rm -f \{} \;

    moveToOutput bin/ucc_info $dev
    moveToOutput bin/ucc_info "$dev"
  '';

  meta = with lib; {
@@ -79,4 +117,4 @@ stdenv.mkDerivation rec {
    maintainers = [ maintainers.markuskowa ];
    platforms = platforms.linux;
  };
}
})