Unverified Commit 9833fa8d authored by ruro's avatar ruro
Browse files

cudaPackages.tensorrt: remove old extension.nix

parent e9b255a8
Loading
Loading
Loading
Loading
+0 −223
Original line number Diff line number Diff line
final: prev:
let

  inherit (final) callPackage;
  inherit (prev)
    cudatoolkit
    cudaVersion
    lib
    pkgs
    ;

  ### TensorRT

  buildTensorRTPackage = args: callPackage ./generic.nix { } args;

  toUnderscore = str: lib.replaceStrings [ "." ] [ "_" ] str;

  majorMinorPatch = str: lib.concatStringsSep "." (lib.take 3 (lib.splitVersion str));

  tensorRTPackages =
    let
      # Check whether a file is supported for our cuda version
      isSupported = fileData: lib.elem cudaVersion fileData.supportedCudaVersions;
      # Return the first file that is supported. In practice there should only ever be one anyway.
      supportedFile = files: lib.findFirst isSupported null files;

      # Compute versioned attribute name to be used in this package set
      computeName = version: "tensorrt_${toUnderscore version}";

      # Supported versions with versions as keys and file as value
      supportedVersions =
        lib.recursiveUpdate
          {
            tensorrt = {
              enable = false;
              fileVersionCuda = null;
              fileVersionCudnn = null;
              fullVersion = "0.0.0";
              sha256 = null;
              tarball = null;
              supportedCudaVersions = [ ];
            };
          }
          (
            lib.mapAttrs' (version: attrs: lib.nameValuePair (computeName version) attrs) (
              lib.filterAttrs (version: file: file != null) (
                lib.mapAttrs (version: files: supportedFile files) tensorRTVersions
              )
            )
          );

      # Add all supported builds as attributes
      allBuilds = lib.mapAttrs (
        name: file: buildTensorRTPackage (lib.removeAttrs file [ "fileVersionCuda" ])
      ) supportedVersions;

      # Set the default attributes, e.g. tensorrt = tensorrt_8_4;
      defaultName = computeName tensorRTDefaultVersion;
      defaultBuild = lib.optionalAttrs (allBuilds ? ${defaultName}) {
        tensorrt = allBuilds.${computeName tensorRTDefaultVersion};
      };
    in
    {
      inherit buildTensorRTPackage;
    }
    // allBuilds
    // defaultBuild;

  tarballURL =
    {
      fullVersion,
      fileVersionCuda,
      fileVersionCudnn ? null,
    }:
    "TensorRT-${fullVersion}.Linux.x86_64-gnu.cuda-${fileVersionCuda}"
    + lib.optionalString (fileVersionCudnn != null) ".cudnn${fileVersionCudnn}"
    + ".tar.gz";

  tensorRTVersions = {
    "8.6.1" = [
      rec {
        fileVersionCuda = "12.0";
        fullVersion = "8.6.1.6";
        sha256 = "sha256-D4FXpfxTKZQ7M4uJNZE3M1CvqQyoEjnNrddYDNHrolQ=";
        tarball = tarballURL { inherit fileVersionCuda fullVersion; };
        supportedCudaVersions = [
          "12.0"
          "12.1"
        ];
      }
    ];
    "8.5.3" = [
      rec {
        fileVersionCuda = "11.8";
        fileVersionCudnn = "8.6";
        fullVersion = "8.5.3.1";
        sha256 = "sha256-BNeuOYvPTUAfGxI0DVsNrX6Z/FAB28+SE0ptuGu7YDY=";
        tarball = tarballURL { inherit fileVersionCuda fileVersionCudnn fullVersion; };
        supportedCudaVersions = [
          "11.0"
          "11.1"
          "11.2"
          "11.3"
          "11.4"
          "11.5"
          "11.6"
          "11.7"
          "11.8"
        ];
      }
      rec {
        fileVersionCuda = "10.2";
        fileVersionCudnn = "8.6";
        fullVersion = "8.5.3.1";
        sha256 = "sha256-WCt6yfOmFbrjqdYCj6AE2+s2uFpISwk6urP+2I0BnGQ=";
        tarball = tarballURL { inherit fileVersionCuda fileVersionCudnn fullVersion; };
        supportedCudaVersions = [ "10.2" ];
      }
    ];
    "8.5.2" = [
      rec {
        fileVersionCuda = "11.8";
        fileVersionCudnn = "8.6";
        fullVersion = "8.5.2.2";
        sha256 = "sha256-Ov5irNS/JETpEz01FIFNMs9YVmjGHL7lSXmDpgCdgao=";
        tarball = tarballURL { inherit fileVersionCuda fileVersionCudnn fullVersion; };
        supportedCudaVersions = [
          "11.0"
          "11.1"
          "11.2"
          "11.3"
          "11.4"
          "11.5"
          "11.6"
          "11.7"
          "11.8"
        ];
      }
      rec {
        fileVersionCuda = "10.2";
        fileVersionCudnn = "8.6";
        fullVersion = "8.5.2.2";
        sha256 = "sha256-UruwQShYcHLY5d81lKNG7XaoUsZr245c+PUpUN6pC5E=";
        tarball = tarballURL { inherit fileVersionCuda fileVersionCudnn fullVersion; };
        supportedCudaVersions = [ "10.2" ];
      }
    ];
    "8.5.1" = [
      rec {
        fileVersionCuda = "11.8";
        fileVersionCudnn = "8.6";
        fullVersion = "8.5.1.7";
        sha256 = "sha256-Ocx/B3BX0TY3lOj/UcTPIaXb7M8RFrACC6Da4PMGMHY=";
        tarball = tarballURL { inherit fileVersionCuda fileVersionCudnn fullVersion; };
        supportedCudaVersions = [
          "11.0"
          "11.1"
          "11.2"
          "11.3"
          "11.4"
          "11.5"
          "11.6"
          "11.7"
          "11.8"
        ];
      }
      rec {
        fileVersionCuda = "10.2";
        fileVersionCudnn = "8.6";
        fullVersion = "8.5.1.7";
        sha256 = "sha256-CcFGJhw7nFdPnSYYSxcto2MHK3F84nLQlJYjdIw8dPM=";
        tarball = tarballURL { inherit fileVersionCuda fileVersionCudnn fullVersion; };
        supportedCudaVersions = [ "10.2" ];
      }
    ];
    "8.4.0" = [
      rec {
        fileVersionCuda = "11.6";
        fileVersionCudnn = "8.3";
        fullVersion = "8.4.0.6";
        sha256 = "sha256-DNgHHXF/G4cK2nnOWImrPXAkOcNW6Wy+8j0LRpAH/LQ=";
        tarball = tarballURL { inherit fileVersionCuda fileVersionCudnn fullVersion; };
        supportedCudaVersions = [
          "11.0"
          "11.1"
          "11.2"
          "11.3"
          "11.4"
          "11.5"
          "11.6"
        ];
      }
      rec {
        fileVersionCuda = "10.2";
        fileVersionCudnn = "8.3";
        fullVersion = "8.4.0.6";
        sha256 = "sha256-aCzH0ZI6BrJ0v+e5Bnm7b8mNltA7NNuIa8qRKzAQv+I=";
        tarball = tarballURL { inherit fileVersionCuda fileVersionCudnn fullVersion; };
        supportedCudaVersions = [ "10.2" ];
      }
    ];
  };

  # Default attributes
  tensorRTDefaultVersion =
    {
      "10.2" = "8.4.0";
      "11.0" = "8.4.0";
      "11.1" = "8.4.0";
      "11.2" = "8.4.0";
      "11.3" = "8.4.0";
      "11.4" = "8.4.0";
      "11.5" = "8.4.0";
      "11.6" = "8.4.0";
      "11.7" = "8.5.3";
      "11.8" = "8.5.3";
      "12.0" = "8.6.1";
      "12.1" = "8.6.1";
    }
    .${cudaVersion} or "8.4.0";

in
tensorRTPackages