Commit a225c7d3 authored by Gaetan Lepage's avatar Gaetan Lepage Committed by Gaétan Lepage
Browse files

python312Packages.tinygrad: fix libnvrtc patching

parent 641e2640
Loading
Loading
Loading
Loading
+13 −1
Original line number Diff line number Diff line
@@ -36,6 +36,8 @@
  torch,
  tqdm,
  transformers,

  tinygrad,
}:

buildPythonPackage rec {
@@ -92,7 +94,13 @@ buildPythonPackage rec {
      # pyobjc-framework-metal
    ];

  pythonImportsCheck = [ "tinygrad" ];
  pythonImportsCheck =
    [
      "tinygrad"
    ]
    ++ lib.optionals cudaSupport [
      "tinygrad.runtime.ops_nv"
    ];

  nativeCheckInputs = [
    blobfile
@@ -175,6 +183,10 @@ buildPythonPackage rec {
    "extra/"
  ];

  passthru.tests = {
    withCuda = tinygrad.override { cudaSupport = true; };
  };

  meta = {
    description = "Simple and powerful neural network framework";
    homepage = "https://github.com/tinygrad/tinygrad";
+1 −1
Original line number Diff line number Diff line
@@ -43,7 +43,7 @@ index 6af74187..c5a6c6c4 100644
+    pass
+if libnvrtc is None:
+    raise RuntimeError(f"`libnvrtc.so` not found")
+_libraries['libnvrtc.so'] = ctypes.CDLL(libnvrtc)
+_libraries['libnvrtc.so'] = libnvrtc
 def string_cast(char_pointer, encoding='utf-8', errors='strict'):
     value = ctypes.cast(char_pointer, ctypes.c_char_p).value
     if value is not None and encoding is not None: