Commit 030575ec authored by Connor Baker's avatar Connor Baker
Browse files

cuda-modules: fix and clean up multiplex builder package selection logic

parent 7109c07f
Loading
Loading
Loading
Loading
+50 −54
Original line number Diff line number Diff line
@@ -30,16 +30,7 @@
  shimsFn ? (throw "shimsFn must be provided"),
}:
let
  inherit (lib)
    attrsets
    lists
    modules
    strings
    ;

  inherit (stdenv) hostPlatform;

  evaluatedModules = modules.evalModules {
  evaluatedModules = lib.modules.evalModules {
    modules = [
      ../modules
      releasesModule
@@ -50,49 +41,55 @@ let
  # - Releases: ../modules/${pname}/releases/releases.nix
  # - Package: ../modules/${pname}/releases/package.nix

  # FIXME: do this at the module system level
  propagatePlatforms = lib.mapAttrs (
    redistArch: packages: map (p: { inherit redistArch; } // p) packages
  );
  # redistArch :: String
  # Value is `"unsupported"` if the platform is not supported.
  redistArch = flags.getRedistArch stdenv.hostPlatform.system;

  # All releases across all platforms
  # Check whether a package supports our CUDA version.
  # satisfiesCudaVersion :: Package -> Bool
  satisfiesCudaVersion =
    package:
    lib.versionAtLeast cudaMajorMinorVersion package.minCudaVersion
    && lib.versionAtLeast package.maxCudaVersion cudaMajorMinorVersion;

  # Releases for our platform and CUDA version.
  # See ../modules/${pname}/releases/releases.nix
  releaseSets = propagatePlatforms evaluatedModules.config.${pname}.releases;
  # allPackages :: List Package
  allPackages = lib.filter satisfiesCudaVersion (
    evaluatedModules.config.${pname}.releases.${redistArch} or [ ]
  );

  # Compute versioned attribute name to be used in this package set
  # Patch version changes should not break the build, so we only use major and minor
  # computeName :: Package -> String
  computeName = { version, ... }: mkVersionedPackageName pname version;

  # Check whether a package supports our CUDA version and platform.
  # isSupported :: Package -> Bool
  isSupported =
    package:
    redistArch == package.redistArch
    && strings.versionAtLeast cudaMajorMinorVersion package.minCudaVersion
    && strings.versionAtLeast package.maxCudaVersion cudaMajorMinorVersion;

  # Get all of the packages for our given platform.
  # redistArch :: String
  # Value is `"unsupported"` if the platform is not supported.
  redistArch = flags.getRedistArch hostPlatform.system;

  preferable =
    p1: p2: (isSupported p2 -> isSupported p1) && (strings.versionOlder p2.version p1.version);
  computeName = package: mkVersionedPackageName pname package.version;

  # All the supported packages we can build for our platform.
  # perSystemReleases :: List Package
  allReleases = lib.pipe releaseSets [
    (lib.attrValues)
    (lists.flatten)
    (lib.groupBy (p: lib.versions.majorMinor p.version))
    (lib.mapAttrs (_: builtins.sort preferable))
    (lib.mapAttrs (_: lib.take 1))
    (lib.attrValues)
    (lib.concatMap lib.trivial.id)
  ];

  newest = builtins.head (builtins.sort preferable allReleases);
  # The newest package for each major-minor version, with newest first.
  # newestPackages :: List Package
  newestPackages =
    let
      newestForEachMajorMinorVersion = lib.foldl' (
        newestPackages: package:
        let
          majorMinorVersion = lib.versions.majorMinor package.version;
          existingPackage = newestPackages.${majorMinorVersion} or null;
        in
        newestPackages
        // {
          ${majorMinorVersion} =
            # Only keep the existing package if it is newer than the one we are considering.
            if existingPackage != null && lib.versionOlder package.version existingPackage.version then
              existingPackage
            else
              package;
        }
      ) { } allPackages;
    in
    # Sort the packages by version so the newest is first.
    # NOTE: builtins.sort requires a strict weak ordering, so we must use versionOlder rather than versionAtLeast.
    lib.sort (p1: p2: lib.versionOlder p2.version p1.version) (
      lib.attrValues newestForEachMajorMinorVersion
    );

  extension =
    final: _:
@@ -102,25 +99,24 @@ let
      buildPackage =
        package:
        let
          shims = final.callPackage shimsFn {
            inherit package;
            inherit (package) redistArch;
          };
          shims = final.callPackage shimsFn { inherit package redistArch; };
          name = computeName package;
          drv = final.callPackage ./manifest.nix {
            inherit pname redistName;
            inherit (shims) redistribRelease featureRelease;
          };
        in
        attrsets.nameValuePair name drv;
        lib.nameValuePair name drv;

      # versionedDerivations :: AttrSet Derivation
      versionedDerivations = builtins.listToAttrs (lists.map buildPackage allReleases);
      versionedDerivations = builtins.listToAttrs (lib.map buildPackage newestPackages);

      defaultDerivation = {
        ${pname} = (buildPackage newest).value;
        ${pname} = (buildPackage (lib.head newestPackages)).value;
      };
    in
    versionedDerivations // defaultDerivation;
    # NOTE: Must condition on the length of newestPackages to avoid non-total function lib.head aborting if
    # newestPackages is empty.
    lib.optionalAttrs (lib.length newestPackages > 0) (versionedDerivations // defaultDerivation);
in
extension