Unverified Commit ef7ef716 authored by Martin Weinelt's avatar Martin Weinelt Committed by GitHub
Browse files

onnxruntime: add ROCm support (#454399)

parents 68990df0 27767283
Loading
Loading
Loading
Loading
+58 −4
Original line number Diff line number Diff line
@@ -14,6 +14,7 @@
  howard-hinnant-date,
  libpng,
  nlohmann_json,
  perl,
  pkg-config,
  python3Packages,
  re2,
@@ -24,8 +25,10 @@
  pythonSupport ? true,
  cudaSupport ? config.cudaSupport,
  ncclSupport ? cudaSupport && cudaPackages.nccl.meta.available,
  rocmSupport ? config.rocmSupport,
  withFullProtobuf ? false,
  cudaPackages ? { },
  rocmPackages,
}@inputs:

let
@@ -121,6 +124,9 @@ effectiveStdenv.mkDerivation rec {
  ]
  ++ lib.optionals isCudaJetson [
    cudaPackages.autoAddCudaCompatRunpath
  ]
  ++ lib.optionals rocmSupport [
    perl # for tools/ci_build/hipify-perl
  ];

  buildInputs = [
@@ -156,6 +162,22 @@ effectiveStdenv.mkDerivation rec {
    ]
    ++ lib.optionals ncclSupport [ nccl ]
  )
  ++ lib.optionals rocmSupport [
    rocmPackages.clr
    rocmPackages.hipblas
    rocmPackages.hipcub
    rocmPackages.hipfft
    rocmPackages.hiprand
    rocmPackages.hipsparse
    rocmPackages.rocblas
    rocmPackages.rocprim
    rocmPackages.rocrand
    rocmPackages.rocthrust
    rocmPackages.miopen
    rocmPackages.rccl
    rocmPackages.rocm-smi
    rocmPackages.roctracer
  ]
  ++ lib.optionals effectiveStdenv.hostPlatform.isDarwin [
    (darwinMinVersionHook "13.3")
  ];
@@ -203,6 +225,7 @@ effectiveStdenv.mkDerivation rec {
    (lib.cmakeBool "onnxruntime_USE_FULL_PROTOBUF" withFullProtobuf)
    (lib.cmakeBool "onnxruntime_USE_CUDA" cudaSupport)
    (lib.cmakeBool "onnxruntime_USE_NCCL" (cudaSupport && ncclSupport))
    (lib.cmakeBool "onnxruntime_USE_ROCM" rocmSupport)
    (lib.cmakeBool "onnxruntime_ENABLE_LTO" (!cudaSupport || cudaPackages.cudaOlder "12.8"))
  ]
  ++ lib.optionals pythonSupport [
@@ -213,15 +236,43 @@ effectiveStdenv.mkDerivation rec {
    (lib.cmakeFeature "onnxruntime_CUDNN_HOME" "${cudaPackages.cudnn}")
    (lib.cmakeFeature "CMAKE_CUDA_ARCHITECTURES" cudaArchitecturesString)
    (lib.cmakeFeature "onnxruntime_NVCC_THREADS" "1")
  ]
  ++ lib.optionals rocmSupport [
    # Werror combines with rocprim header issues to cause errors (warp size const deprecation)
    "--compile-no-warning-as-error"
    (lib.cmakeFeature "CMAKE_HIP_ARCHITECTURES" (
      builtins.concatStringsSep ";" rocmPackages.clr.localGpuTargets or rocmPackages.clr.gpuTargets
    ))
    (lib.cmakeFeature "onnxruntime_ROCM_HOME" "${rocmPackages.clr}")
    # Incompatible with packaged version, far too slow to build vendored version
    (lib.cmakeBool "onnxruntime_USE_COMPOSABLE_KERNEL" false)
    (lib.cmakeBool "onnxruntime_USE_COMPOSABLE_KERNEL_CK_TILE" false)
  ];

  env = lib.optionalAttrs effectiveStdenv.cc.isClang {
  env =
    lib.optionalAttrs effectiveStdenv.cc.isClang {
      NIX_CFLAGS_COMPILE = "-Wno-error";
    }
    // lib.optionalAttrs rocmSupport {
      MIOPEN_PATH = rocmPackages.miopen;
      # HIP steps fail to find ROCm libs when not in HIPFLAGS, causing
      # fatal error: 'rocrand/rocrand.h' file not found
      HIPFLAGS = lib.concatMapStringsSep " " (pkg: "-I${lib.getInclude pkg}/include") [
        rocmPackages.hipblas
        rocmPackages.hipcub
        rocmPackages.hiprand
        rocmPackages.hipsparse
        rocmPackages.rocblas
        rocmPackages.rocprim
        rocmPackages.rocrand
        rocmPackages.rocthrust
      ];
    };

  doCheck =
    !(
      cudaSupport
      || rocmSupport
      || builtins.elem effectiveStdenv.buildPlatform.system [
        # aarch64-linux fails cpuinfo test, because /sys/devices/system/cpu/ does not exist in the sandbox
        "aarch64-linux"
@@ -231,7 +282,7 @@ effectiveStdenv.mkDerivation rec {
      ]
    );

  requiredSystemFeatures = lib.optionals cudaSupport [ "big-parallel" ];
  requiredSystemFeatures = lib.optionals (cudaSupport || rocmSupport) [ "big-parallel" ];

  hardeningEnable = lib.optionals (effectiveStdenv.hostPlatform.system == "loongarch64-linux") [
    "nostrictaliasing"
@@ -247,6 +298,9 @@ effectiveStdenv.mkDerivation rec {
      "GetRuntimePath() const { return PathString(); }" \
      "GetRuntimePath() const { return PathString(\"$out/lib/\"); }"
  ''
  + lib.optionalString rocmSupport ''
    patchShebangs tools/ci_build/hipify-perl
  ''
  + lib.optionalString (effectiveStdenv.hostPlatform.system == "aarch64-linux") ''
    # https://github.com/NixOS/nixpkgs/pull/226734#issuecomment-1663028691
    rm -v onnxruntime/test/optimizer/nhwc_transformer_test.cc