Unverified Commit 51d92d05 authored by Aleksana's avatar Aleksana Committed by GitHub
Browse files

Merge pull request #292750 from CertainLach/torchaudio-rocm

torchaudio: add rocm support
parents 970f689a 2eedfae4
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -495,7 +495,7 @@ in buildPythonPackage rec {
  requiredSystemFeatures = [ "big-parallel" ];

  passthru = {
    inherit cudaSupport cudaPackages;
    inherit cudaSupport cudaPackages rocmSupport rocmPackages;
    # At least for 1.10.2 `torch.fft` is unavailable unless BLAS provider is MKL. This attribute allows for easy detection of its availability.
    blasProvider = blas.provider;
    # To help debug when a package is broken due to CUDA support
+51 −1
Original line number Diff line number Diff line
@@ -9,10 +9,46 @@
, pybind11
, sox
, torch

, cudaSupport ? torch.cudaSupport
, cudaPackages
, rocmSupport ? torch.rocmSupport
, rocmPackages

, gpuTargets ? []
}:

let
  # TODO: Reuse one defined in torch?
  # Some of those dependencies are probbly not required,
  # but it breaks when the store path is different between torch and torchaudio
  rocmtoolkit_joined = symlinkJoin {
    name = "rocm-merged";

    paths = with rocmPackages; [
      rocm-core clr rccl miopen miopengemm rocrand rocblas
      rocsparse hipsparse rocthrust rocprim hipcub roctracer
      rocfft rocsolver hipfft hipsolver hipblas
      rocminfo rocm-thunk rocm-comgr rocm-device-libs
      rocm-runtime clr.icd hipify
    ];

    # Fix `setuptools` not being found
    postBuild = ''
      rm -rf $out/nix-support
    '';
  };
  # Only used for ROCm
  gpuTargetString = lib.strings.concatStringsSep ";" (
    if gpuTargets != [ ] then
    # If gpuTargets is specified, it always takes priority.
      gpuTargets
    else if rocmSupport then
      rocmPackages.clr.gpuTargets
    else
      throw "No GPU targets specified"
  );
in
buildPythonPackage rec {
  pname = "torchaudio";
  version = "2.3.0";
@@ -33,6 +69,11 @@ buildPythonPackage rec {
    substituteInPlace setup.py \
      --replace 'print(" --- Initializing submodules")' "return" \
      --replace "_fetch_archives(_parse_sources())" "pass"
  ''
  + lib.optionalString rocmSupport ''
    # There is no .info/version-dev, only .info/version
    substituteInPlace cmake/LoadHIP.cmake \
      --replace "/.info/version-dev" "/.info/version"
  '';

  env = {
@@ -55,7 +96,11 @@ buildPythonPackage rec {
    ninja
  ] ++ lib.optionals cudaSupport [
    cudaPackages.cuda_nvcc
  ];
  ] ++ lib.optionals rocmSupport (with rocmPackages; [
    clr
    rocblas
    hipblas
  ]);

  buildInputs = [
    ffmpeg-full
@@ -73,6 +118,11 @@ buildPythonPackage rec {
  BUILD_RNNT=0;
  BUILD_CTC_DECODER=0;

  preConfigure = lib.optionalString rocmSupport ''
    export ROCM_PATH=${rocmtoolkit_joined}
    export PYTORCH_ROCM_ARCH="${gpuTargetString}"
  '';

  dontUseCmakeConfigure = true;

  doCheck = false; # requires sox backend