Commit ae560061 authored by SomeoneSerge's avatar SomeoneSerge
Browse files

python3Packages.triton: fix cuda (ptxas, cudart paths)

parent e262792b
Loading
Loading
Loading
Loading
+35 −0
Original line number Diff line number Diff line
From 2751c5de5c61c90b56e3e392a41847f4c47258fd Mon Sep 17 00:00:00 2001
From: SomeoneSerge <else+aalto@someonex.net>
Date: Sun, 13 Oct 2024 14:16:48 +0000
Subject: [PATCH 1/3] _build: allow extra cc flags

---
 python/triton/runtime/build.py | 10 +++++++++-
 1 file changed, 9 insertions(+), 1 deletion(-)

diff --git a/python/triton/runtime/build.py b/python/triton/runtime/build.py
index d7baeb286..d334dce77 100644
--- a/python/triton/runtime/build.py
+++ b/python/triton/runtime/build.py
@@ -42,9 +42,17 @@ def _build(name, src, srcdir, library_dirs, include_dirs, libraries):
     py_include_dir = sysconfig.get_paths(scheme=scheme)["include"]
     include_dirs = include_dirs + [srcdir, py_include_dir]
     cc_cmd = [cc, src, "-O3", "-shared", "-fPIC", "-o", so]
+
+    # Nixpkgs support branch
+    # Allows passing e.g. extra -Wl,-rpath
+    cc_cmd_extra_flags = "@ccCmdExtraFlags@"
+    if cc_cmd_extra_flags != ("@" + "ccCmdExtraFlags@"): # substituteAll hack
+        import shlex
+        cc_cmd.extend(shlex.split(cc_cmd_extra_flags))
+
     cc_cmd += [f'-l{lib}' for lib in libraries]
     cc_cmd += [f"-L{dir}" for dir in library_dirs]
-    cc_cmd += [f"-I{dir}" for dir in include_dirs]
+    cc_cmd += [f"-I{dir}" for dir in include_dirs if dir is not None]
     ret = subprocess.check_call(cc_cmd)
     if ret == 0:
         return so
-- 
2.46.0
+70 −0
Original line number Diff line number Diff line
From 7407cb03eec82768e333909d87b7668b633bfe86 Mon Sep 17 00:00:00 2001
From: SomeoneSerge <else+aalto@someonex.net>
Date: Sun, 13 Oct 2024 14:28:48 +0000
Subject: [PATCH 2/3] {nvidia,amd}/driver: short-circuit before ldconfig

---
 python/triton/runtime/build.py       | 6 +++---
 third_party/amd/backend/driver.py    | 7 +++++++
 third_party/nvidia/backend/driver.py | 3 +++
 3 files changed, 13 insertions(+), 3 deletions(-)

diff --git a/python/triton/runtime/build.py b/python/triton/runtime/build.py
index d334dce77..a64e98da0 100644
--- a/python/triton/runtime/build.py
+++ b/python/triton/runtime/build.py
@@ -42,6 +42,9 @@ def _build(name, src, srcdir, library_dirs, include_dirs, libraries):
     py_include_dir = sysconfig.get_paths(scheme=scheme)["include"]
     include_dirs = include_dirs + [srcdir, py_include_dir]
     cc_cmd = [cc, src, "-O3", "-shared", "-fPIC", "-o", so]
+    cc_cmd += [f'-l{lib}' for lib in libraries]
+    cc_cmd += [f"-L{dir}" for dir in library_dirs]
+    cc_cmd += [f"-I{dir}" for dir in include_dirs if dir is not None]
 
     # Nixpkgs support branch
     # Allows passing e.g. extra -Wl,-rpath
@@ -50,9 +53,6 @@ def _build(name, src, srcdir, library_dirs, include_dirs, libraries):
         import shlex
         cc_cmd.extend(shlex.split(cc_cmd_extra_flags))
 
-    cc_cmd += [f'-l{lib}' for lib in libraries]
-    cc_cmd += [f"-L{dir}" for dir in library_dirs]
-    cc_cmd += [f"-I{dir}" for dir in include_dirs if dir is not None]
     ret = subprocess.check_call(cc_cmd)
     if ret == 0:
         return so
diff --git a/third_party/amd/backend/driver.py b/third_party/amd/backend/driver.py
index 0a8cd7bed..aab8805f6 100644
--- a/third_party/amd/backend/driver.py
+++ b/third_party/amd/backend/driver.py
@@ -24,6 +24,13 @@ def _get_path_to_hip_runtime_dylib():
             return env_libhip_path
         raise RuntimeError(f"TRITON_LIBHIP_PATH '{env_libhip_path}' does not point to a valid {lib_name}")
 
+    # ...on release/3.1.x:
+    #         return mmapped_path
+    #     raise RuntimeError(f"memory mapped '{mmapped_path}' in process does not point to a valid {lib_name}")
+
+    if os.path.isdir("@libhipDir@"):
+        return ["@libhipDir@"]
+
     paths = []
 
     import site
diff --git a/third_party/nvidia/backend/driver.py b/third_party/nvidia/backend/driver.py
index 90f71138b..30fbadb2a 100644
--- a/third_party/nvidia/backend/driver.py
+++ b/third_party/nvidia/backend/driver.py
@@ -21,6 +21,9 @@ def libcuda_dirs():
     if env_libcuda_path:
         return [env_libcuda_path]
 
+    if os.path.exists("@libcudaStubsDir@"):
+        return ["@libcudaStubsDir@"]
+
     libs = subprocess.check_output(["/sbin/ldconfig", "-p"]).decode()
     # each line looks like the following:
     # libcuda.so.1 (libc6,x86-64) => /lib/x86_64-linux-gnu/libcuda.so.1
-- 
2.46.0
+46 −0
Original line number Diff line number Diff line
From 6f92d54e5a544bc34bb07f2808d554a71cc0e4c3 Mon Sep 17 00:00:00 2001
From: SomeoneSerge <else+aalto@someonex.net>
Date: Sun, 13 Oct 2024 14:30:19 +0000
Subject: [PATCH 3/3] nvidia: cudart a systempath

---
 third_party/nvidia/backend/driver.c  | 2 +-
 third_party/nvidia/backend/driver.py | 5 +++--
 2 files changed, 4 insertions(+), 3 deletions(-)

diff --git a/third_party/nvidia/backend/driver.c b/third_party/nvidia/backend/driver.c
index 44524da27..fbdf0d156 100644
--- a/third_party/nvidia/backend/driver.c
+++ b/third_party/nvidia/backend/driver.c
@@ -1,4 +1,4 @@
-#include "cuda.h"
+#include <cuda.h>
 #include <dlfcn.h>
 #include <stdbool.h>
 #define PY_SSIZE_T_CLEAN
diff --git a/third_party/nvidia/backend/driver.py b/third_party/nvidia/backend/driver.py
index 30fbadb2a..65c0562ed 100644
--- a/third_party/nvidia/backend/driver.py
+++ b/third_party/nvidia/backend/driver.py
@@ -10,7 +10,8 @@ from triton.backends.compiler import GPUTarget
 from triton.backends.driver import GPUDriver
 
 dirname = os.path.dirname(os.path.realpath(__file__))
-include_dir = [os.path.join(dirname, "include")]
+import shlex
+include_dir = [*shlex.split("@cudaToolkitIncludeDirs@"), os.path.join(dirname, "include")]
 libdevice_dir = os.path.join(dirname, "lib")
 libraries = ['cuda']
 
@@ -149,7 +150,7 @@ def make_launcher(constants, signature, ids):
     # generate glue code
     params = [i for i in signature.keys() if i not in constants]
     src = f"""
-#include \"cuda.h\"
+#include <cuda.h>
 #include <stdbool.h>
 #include <Python.h>
 #include <dlfcn.h>
-- 
2.46.0
+26 −0
Original line number Diff line number Diff line
From e503e572b6d444cd27f1cdf124aaf553aa3a8665 Mon Sep 17 00:00:00 2001
From: SomeoneSerge <else+aalto@someonex.net>
Date: Mon, 14 Oct 2024 00:12:05 +0000
Subject: [PATCH 4/4] nvidia: allow static ptxas path

---
 third_party/nvidia/backend/compiler.py | 3 +++
 1 file changed, 3 insertions(+)

diff --git a/third_party/nvidia/backend/compiler.py b/third_party/nvidia/backend/compiler.py
index 6d7994923..6720e8f97 100644
--- a/third_party/nvidia/backend/compiler.py
+++ b/third_party/nvidia/backend/compiler.py
@@ -20,6 +20,9 @@ def _path_to_binary(binary: str):
         os.path.join(os.path.dirname(__file__), "bin", binary),
     ]
 
+    import shlex
+    paths.extend(shlex.split("@nixpkgsExtraBinaryPaths@"))
+
     for bin in paths:
         if os.path.exists(bin) and os.path.isfile(bin):
             result = subprocess.check_output([bin, "--version"], stderr=subprocess.STDOUT)
-- 
2.46.0
+83 −51
Original line number Diff line number Diff line
{
  lib,
  addDriverRunpath,
  buildPythonPackage,
  cmake,
  config,
@@ -15,10 +16,13 @@
  pybind11,
  python,
  runCommand,
  substituteAll,
  setuptools,
  torchWithRocm,
  zlib,
  cudaSupport ? config.cudaSupport,
  rocmSupport ? config.rocmSupport,
  rocmPackages,
}:

buildPythonPackage {
@@ -34,12 +38,36 @@ buildPythonPackage {
    hash = "sha256-L5KqiR+TgSyKjEBlkE0yOU1pemMHFk2PhEmxLdbbxUU=";
  };

  patches = [
  patches =
    [
      ./0001-setup.py-introduce-TRITON_OFFLINE_BUILD.patch
      (substituteAll {
        src = ./0001-_build-allow-extra-cc-flags.patch;
        ccCmdExtraFlags = "-Wl,-rpath,${addDriverRunpath.driverLink}/lib";
      })
      (substituteAll (
        {
          src = ./0002-nvidia-amd-driver-short-circuit-before-ldconfig.patch;
        }
        // lib.optionalAttrs rocmSupport { libhipDir = "${lib.getLib rocmPackages.clr}/lib"; }
        // lib.optionalAttrs cudaSupport {
          libcudaStubsDir = "${lib.getLib cudaPackages.cuda_cudart}/lib/stubs";
          ccCmdExtraFlags = "-Wl,-rpath,${addDriverRunpath.driverLink}/lib";
        }
      ))
    ]
    ++ lib.optionals cudaSupport [
      (substituteAll {
        src = ./0003-nvidia-cudart-a-systempath.patch;
        cudaToolkitIncludeDirs = "${lib.getInclude cudaPackages.cuda_cudart}/include";
      })
      (substituteAll {
        src = ./0004-nvidia-allow-static-ptxas-path.patch;
        nixpkgsExtraBinaryPaths = lib.escapeShellArgs [ (lib.getExe' cudaPackages.cuda_nvcc "ptxas") ];
      })
    ];

  postPatch =
    ''
  postPatch = ''
    # Use our `cmakeFlags` instead and avoid downloading dependencies
    # remove any downloads
    substituteInPlace python/setup.py \
@@ -54,9 +82,9 @@ buildPythonPackage {
      --replace-fail "include(GoogleTest)" "find_package(GTest REQUIRED)"
  '';

  build-system = [ setuptools ];

  nativeBuildInputs = [
    setuptools
    # pytestCheckHook # Requires torch (circular dependency) and probably needs GPUs:
    cmake
    ninja

@@ -76,7 +104,7 @@ buildPythonPackage {
    zlib
  ];

  propagatedBuildInputs = [
  dependencies = [
    filelock
    # triton uses setuptools at runtime:
    # https://github.com/NixOS/nixpkgs/pull/286763/#discussion_r1480392652
@@ -106,26 +134,40 @@ buildPythonPackage {
    cd python
  '';

  env = {
  env =
    {
      TRITON_BUILD_PROTON = "OFF";
      TRITON_OFFLINE_BUILD = true;
  } // lib.optionalAttrs cudaSupport {
    CC = "${cudaPackages.backendStdenv.cc}/bin/cc";
    CXX = "${cudaPackages.backendStdenv.cc}/bin/c++";
    }
    // lib.optionalAttrs cudaSupport {
      CC = lib.getExe' cudaPackages.backendStdenv.cc "cc";
      CXX = lib.getExe' cudaPackages.backendStdenv.cc "c++";

      # TODO: Unused because of how TRITON_OFFLINE_BUILD currently works (subject to change)
      TRITON_PTXAS_PATH = lib.getExe' cudaPackages.cuda_nvcc "ptxas"; # Make sure cudaPackages is the right version each update (See python/setup.py)
    TRITON_CUOBJDUMP_PATH = cudaPackages.cuda_cuobjdump;
    TRITON_NVDISASM_PATH = cudaPackages.cuda_nvdisasm;
    TRITON_CUDACRT_PATH = cudaPackages.cuda_nvcc;
    TRITON_CUDART_PATH = cudaPackages.cuda_cudart;
      TRITON_CUOBJDUMP_PATH = lib.getExe' cudaPackages.cuda_cuobjdump "cuobjdump";
      TRITON_NVDISASM_PATH = lib.getExe' cudaPackages.cuda_nvdisasm "nvdisasm";
      TRITON_CUDACRT_PATH = lib.getInclude cudaPackages.cuda_nvcc;
      TRITON_CUDART_PATH = lib.getInclude cudaPackages.cuda_cudart;
      TRITON_CUPTI_PATH = cudaPackages.cuda_cupti;
    };

  pythonRemoveDeps = [
    # Circular dependency, cf. https://github.com/triton-lang/triton/issues/1374
    "torch"

    # CLI tools without dist-info
    "cmake"
    "lit"
  ];

  # CMake is run by setup.py instead
  dontUseCmakeConfigure = true;
  checkInputs = [ cmake ]; # ctest
  dontUseSetuptoolsCheck = true;

  nativeCheckInputs = [
    cmake
    # Requires torch (circular dependency) and GPU access: pytestCheckHook
  ];
  preCheck = ''
    # build/temp* refers to build_ext.build_temp (looked up in the build logs)
    (cd ./build/temp* ; ctest)
@@ -134,11 +176,10 @@ buildPythonPackage {
    cd test/unit
  '';

  # Circular dependency on torch
  # pythonImportsCheck = [
  #   "triton"
  #   "triton.language"
  # ];
  pythonImportsCheck = [
    "triton"
    "triton.language"
  ];

  # Ultimately, torch is our test suite:
  passthru.tests = {
@@ -157,15 +198,6 @@ buildPythonPackage {
        '';
  };

  pythonRemoveDeps = [
    # Circular dependency, cf. https://github.com/triton-lang/triton/issues/1374
    "torch"

    # CLI tools without dist-info
    "cmake"
    "lit"
  ];

  meta = with lib; {
    description = "Language and compiler for writing highly efficient custom Deep-Learning primitives";
    homepage = "https://github.com/triton-lang/triton";