Commit 76430227 authored by Gaetan Lepage's avatar Gaetan Lepage
Browse files
parent e5074dff
Loading
Loading
Loading
Loading
+34 −32
Original line number Diff line number Diff line
@@ -11,38 +11,43 @@
  ocl-icd,
  stdenv,
  rocmPackages,

  # build-system
  setuptools,
  wheel,

  # dependencies
  numpy,
  tqdm,
  # nativeCheckInputs

  # tests
  blobfile,
  bottle,
  clang,
  hexdump,
  hypothesis,
  librosa,
  onnx,
  pillow,
  pydot,
  pytest-xdist,
  pytestCheckHook,
  safetensors,
  sentencepiece,
  tiktoken,
  torch,
  tqdm,
  transformers,
}:

buildPythonPackage rec {
  pname = "tinygrad";
  version = "0.9.0";
  version = "0.9.2";
  pyproject = true;

  src = fetchFromGitHub {
    owner = "tinygrad";
    repo = "tinygrad";
    rev = "refs/tags/v${version}";
    hash = "sha256-opBxciETZruZjHqz/3vO7rogzjvVJKItulIiok/Zs2Y=";
    hash = "sha256-fCKtJhZtqq6yjc6m41uvikzM9GArUlB8Q7jN/Np8+SM=";
  };

  patches = [
@@ -62,29 +67,20 @@ buildPythonPackage rec {
      substituteInPlace tinygrad/runtime/autogen/opencl.py \
        --replace-fail "ctypes.util.find_library('OpenCL')" "'${ocl-icd}/lib/libOpenCL.so'"
    ''
    # hipGetDevicePropertiesR0600 is a symbol from rocm-6. We are currently at rocm-5.
    # We are not sure that this works. Remove when rocm gets updated to version 6.
    + lib.optionalString rocmSupport ''
      substituteInPlace extra/hip_gpu_driver/hip_ioctl.py \
        --replace-fail "processor = platform.processor()" "processor = ${stdenv.hostPlatform.linuxArch}"
      substituteInPlace tinygrad/runtime/autogen/hip.py \
        --replace-fail "/opt/rocm/lib/libamdhip64.so" "${rocmPackages.clr}/lib/libamdhip64.so" \
        --replace-fail "/opt/rocm/lib/libhiprtc.so" "${rocmPackages.clr}/lib/libhiprtc.so" \
        --replace-fail "hipGetDevicePropertiesR0600" "hipGetDeviceProperties"

      substituteInPlace tinygrad/runtime/autogen/comgr.py \
        --replace-fail "/opt/rocm/lib/libamd_comgr.so" "${rocmPackages.rocm-comgr}/lib/libamd_comgr.so"
    '';

  build-system = [
    setuptools
    wheel
  ];
  build-system = [ setuptools ];

  dependencies =
    [
      numpy
      tqdm
    ]
    ++ lib.optionals stdenv.isDarwin [
      # pyobjc-framework-libdispatch
@@ -94,18 +90,22 @@ buildPythonPackage rec {
  pythonImportsCheck = [ "tinygrad" ];

  nativeCheckInputs = [
    blobfile
    bottle
    clang
    hexdump
    hypothesis
    librosa
    onnx
    pillow
    pydot
    pytest-xdist
    pytestCheckHook
    safetensors
    sentencepiece
    tiktoken
    torch
    tqdm
    transformers
  ];

@@ -115,6 +115,10 @@ buildPythonPackage rec {

  disabledTests =
    [
      # flaky: https://github.com/tinygrad/tinygrad/issues/6542
      # TODO: re-enable when https://github.com/tinygrad/tinygrad/pull/6560 gets merged
      "test_broadcastdot"

      # Require internet access
      "test_benchmark_openpilot_model"
      "test_bn_alone"
@@ -129,12 +133,14 @@ buildPythonPackage rec {
      "test_e2e_big"
      "test_fetch_small"
      "test_huggingface_enet_safetensors"
      "test_index_mnist"
      "test_linear_mnist"
      "test_load_convnext"
      "test_load_enet"
      "test_load_enet_alt"
      "test_load_llama2bfloat"
      "test_load_resnet"
      "test_mnist_val"
      "test_openpilot_model"
      "test_resnet"
      "test_shufflenet"
@@ -148,32 +154,28 @@ buildPythonPackage rec {
    ]
    # Fail on aarch64-linux with AssertionError
    ++ lib.optionals (stdenv.hostPlatform.system == "aarch64-linux") [
      "test_casts_to"
      "test_casts_to"
      "test_int8_to_uint16_negative"
      "test_casts_to"
      "test_casts_to"
      "test_casts_from"
      "test_casts_to"
      "test_int8"
      "test_casts_to"
      "test_int8_to_uint16_negative"
    ];

  disabledTestPaths =
    [
  disabledTestPaths = [
    # Require internet access
    "test/models/test_mnist.py"
    "test/models/test_real_world.py"
    "test/testextra/test_lr_scheduler.py"
    ]
    ++ lib.optionals (!rocmSupport) [ "extra/hip_gpu_driver/" ];

  meta = with lib; {
    # Files under this directory are not considered as tests by upstream and should be skipped
    "extra/"
  ];

  meta = {
    description = "Simple and powerful neural network framework";
    homepage = "https://github.com/tinygrad/tinygrad";
    changelog = "https://github.com/tinygrad/tinygrad/releases/tag/v${version}";
    license = licenses.mit;
    maintainers = with maintainers; [ GaetanLepage ];
    license = lib.licenses.mit;
    maintainers = with lib.maintainers; [ GaetanLepage ];
    # Requires unpackaged pyobjc-framework-libdispatch and pyobjc-framework-metal
    broken = stdenv.isDarwin;
  };
+39 −22
Original line number Diff line number Diff line
diff --git a/tinygrad/runtime/autogen/cuda.py b/tinygrad/runtime/autogen/cuda.py
index 359083a9..3cd5f7be 100644
index a30c8f53..e2078ff6 100644
--- a/tinygrad/runtime/autogen/cuda.py
+++ b/tinygrad/runtime/autogen/cuda.py
@@ -143,10 +143,25 @@ def char_pointer_cast(string, encoding='utf-8'):
     return ctypes.cast(string, ctypes.POINTER(ctypes.c_char))
@@ -145,7 +145,19 @@ def char_pointer_cast(string, encoding='utf-8'):
 
 
+NAME_TO_PATHS = {
+    "libcuda.so": ["@driverLink@/lib/libcuda.so"],
+    "libnvrtc.so": ["@libnvrtc@"],
+}
+def _try_dlopen(name):
 _libraries = {}
-_libraries['libcuda.so'] = ctypes.CDLL(ctypes.util.find_library('cuda'))
+libcuda = None
+try:
+        return ctypes.CDLL(name)
+    libcuda = ctypes.CDLL('libcuda.so')
+except OSError:
+    pass
+    for candidate in NAME_TO_PATHS.get(name, []):
+try:
+            return ctypes.CDLL(candidate)
+    libcuda = ctypes.CDLL('@driverLink@/lib/libcuda.so')
+except OSError:
+    pass
+    raise RuntimeError(f"{name} not found")
 
 _libraries = {}
-_libraries['libcuda.so'] = ctypes.CDLL(ctypes.util.find_library('cuda'))
-_libraries['libnvrtc.so'] = ctypes.CDLL(ctypes.util.find_library('nvrtc'))
+_libraries['libcuda.so'] = _try_dlopen('libcuda.so')
+_libraries['libnvrtc.so'] = _try_dlopen('libnvrtc.so')
+if libcuda is None:
+    raise RuntimeError(f"`libcuda.so` not found")
+
+_libraries['libcuda.so'] = libcuda
 
 
 cuuint32_t = ctypes.c_uint32
diff --git a/tinygrad/runtime/autogen/nvrtc.py b/tinygrad/runtime/autogen/nvrtc.py
index 6af74187..c5a6c6c4 100644
--- a/tinygrad/runtime/autogen/nvrtc.py
+++ b/tinygrad/runtime/autogen/nvrtc.py
@@ -10,7 +10,18 @@ import ctypes, ctypes.util
 
 
 _libraries = {}
-_libraries['libnvrtc.so'] = ctypes.CDLL(ctypes.util.find_library('nvrtc'))
+libnvrtc = None
+try:
+    libnvrtc = ctypes.CDLL('libnvrtc.so')
+except OSError:
+    pass
+try:
+    libnvrtc = ctypes.CDLL('@libnvrtc@')
+except OSError:
+    pass
+if libnvrtc is None:
+    raise RuntimeError(f"`libnvrtc.so` not found")
+_libraries['libnvrtc.so'] = ctypes.CDLL(libnvrtc)
 def string_cast(char_pointer, encoding='utf-8', errors='strict'):
     value = ctypes.cast(char_pointer, ctypes.c_char_p).value
     if value is not None and encoding is not None: