Commit 7844184e authored by Connor Baker's avatar Connor Baker
Browse files

ucc: CUDA fixups

parent 98cc20e3
Loading
Loading
Loading
Loading
+27 −8
Original line number Diff line number Diff line
{
inputs@{
  autoconf,
  automake,
  config,
@@ -15,6 +15,7 @@
  enableSse42 ? stdenv.hostPlatform.sse4_2Support,
}:
let
  inherit (lib.attrsets) getLib;
  inherit (lib.lists) optionals;
  inherit (lib.strings) concatStringsSep;

@@ -22,15 +23,21 @@ let
    cuda_cccl
    cuda_cudart
    cuda_nvcc
    cuda_nvml_dev
    cudaFlags
    nccl
    ;

  stdenv = throw "Use effectiveStdenv instead";
  effectiveStdenv = if enableCuda then cudaPackages.backendStdenv else inputs.stdenv;
in
stdenv.mkDerivation (finalAttrs: {
effectiveStdenv.mkDerivation (finalAttrs: {
  __structuredAttrs = true;
  # TODO: When strictDeps is enabled, the CUDA build fails during configurePhase because it can't find all the CUDA
  # dependencies. As such, we hold off on enabling strictDeps until CUDA compilation works.
  # https://github.com/openucx/ucc/blob/0c0fc21559835044ab107199e334f7157d6a0d3d/config/m4/cuda.m4
  strictDeps = false;
  # TODO(@connorbaker):
  # When strictDeps is enabled, `cuda_nvcc` is required as the argument to `--with-cuda` in `configureFlags` or else
  # configurePhase fails with `checking for cuda_runtime.h... no`.
  # This is odd, especially given `cuda_runtime.h` is provided by `cuda_cudart.dev`, which is already in `buildInputs`.
  strictDeps = true;

  pname = "ucc";
  version = "1.3.0";
@@ -55,7 +62,7 @@ stdenv.mkDerivation (finalAttrs: {
      substituteInPlace "$comp" \
        --replace-quiet \
          "/bin/bash" \
          "${stdenv.shell}"
          "${effectiveStdenv.shell}"
    done
  '';

@@ -70,8 +77,20 @@ stdenv.mkDerivation (finalAttrs: {
    ++ optionals enableCuda [
      cuda_cccl
      cuda_cudart
      cuda_nvml_dev
      nccl
    ];

  # NOTE: With `__structuredAttrs` enabled, `LDFLAGS` must be set under `env` so it is assured to be a string;
  # otherwise, we might have forgotten to convert it to a string and Nix would make LDFLAGS a shell variable
  # referring to an array!
  env.LDFLAGS = builtins.toString (
    optionals enableCuda [
      # Fake libnvidia-ml.so (the real one is deployed impurely)
      "-L${getLib cuda_nvml_dev}/lib/stubs"
    ]
  );

  preConfigure = ''
    ./autogen.sh
  '';
@@ -81,7 +100,7 @@ stdenv.mkDerivation (finalAttrs: {
    ++ optionals enableSse42 [ "--with-sse42" ]
    ++ optionals enableAvx [ "--with-avx" ]
    ++ optionals enableCuda [
      "--with-cuda=${cuda_cudart}"
      "--with-cuda=${cuda_nvcc}"
      "--with-nvcc-gencode=${concatStringsSep " " cudaFlags.gencode}"
    ];