Loading pkgs/development/python-modules/objax/default.nixdeleted 100644 → 0 +0 −83 Original line number Diff line number Diff line { lib, buildPythonPackage, fetchFromGitHub, jax, jaxlib, keras, numpy, parameterized, pillow, pytestCheckHook, pythonOlder, scipy, setuptools, tensorboard, tensorflow, }: buildPythonPackage rec { pname = "objax"; version = "1.8.0"; pyproject = true; disabled = pythonOlder "3.9"; src = fetchFromGitHub { owner = "google"; repo = "objax"; tag = "v${version}"; hash = "sha256-WD+pmR8cEay4iziRXqF3sHUzCMBjmLJ3wZ3iYOD+hzk="; }; patches = [ # Issue reported upstream: https://github.com/google/objax/issues/270 ./replace-deprecated-device_buffers.patch ]; build-system = [ setuptools ]; # Avoid propagating the dependency on `jaxlib`, see # https://github.com/NixOS/nixpkgs/issues/156767 buildInputs = [ jaxlib ]; dependencies = [ jax numpy parameterized pillow scipy tensorboard ]; pythonImportsCheck = [ "objax" ]; # This is necessary to ignore the presence of two protobufs version (tensorflow is bringing an # older version). catchConflicts = false; nativeCheckInputs = [ keras pytestCheckHook tensorflow ]; enabledTestPaths = [ "tests/*.py" ]; disabledTests = [ # Test requires internet access for prefetching some weights "test_pretrained_keras_weight_0_ResNet50V2" # ModuleNotFoundError: No module named 'tree' "TestResNetV2Pretrained" ]; meta = with lib; { description = "Machine learning framework that provides an Object Oriented layer for JAX"; homepage = "https://github.com/google/objax"; changelog = "https://github.com/google/objax/releases/tag/v${version}"; license = licenses.asl20; maintainers = with maintainers; [ ndl ]; # Tests test_syncbn_{0,1,2}d and other tests from tests/parallel.py fail broken = true; }; } pkgs/development/python-modules/objax/replace-deprecated-device_buffers.patchdeleted 100644 → 0 +0 −14 Original line number Diff line number Diff line diff --git a/objax/util/util.py b/objax/util/util.py index c31a356..344cf9a 100644 --- a/objax/util/util.py +++ b/objax/util/util.py @@ -117,7 +117,8 @@ def get_local_devices(): if _local_devices is None: x = jn.zeros((jax.local_device_count(), 1), dtype=jn.float32) sharded_x = map_to_device(x) - _local_devices = [b.device() for b in sharded_x.device_buffers] + device_buffers = [buf.data for buf in sharded_x.addressable_shards] + _local_devices = [list(b.devices())[0] for b in device_buffers] return _local_devices pkgs/top-level/python-aliases.nix +1 −0 Original line number Diff line number Diff line Loading @@ -493,6 +493,7 @@ mapAliases ({ ntlm-auth = throw "ntlm-auth has been removed, because it relies on the md4 implementation provided by openssl. Use pyspnego instead."; oauth = throw "oauth has been removed as it is unmaintained"; # added 2025-05-16 oauth2 = throw "oauth2 has been removed as it is unmaintained"; # added 2025-05-16 objax = throw "objax has been removed because the upstream project was archived."; # Added 2025-10-04 openai-triton = triton; # added 2024-07-18 openai-triton-bin = triton-bin; # added 2024-07-18 openai-triton-cuda = triton-cuda; # added 2024-07-18 Loading pkgs/top-level/python-packages.nix +0 −2 Original line number Diff line number Diff line Loading @@ -10768,8 +10768,6 @@ self: super: with self; { obfsproxy = callPackage ../development/python-modules/obfsproxy { }; objax = callPackage ../development/python-modules/objax { }; objexplore = callPackage ../development/python-modules/objexplore { }; objgraph = callPackage ../development/python-modules/objgraph { Loading Loading
pkgs/development/python-modules/objax/default.nixdeleted 100644 → 0 +0 −83 Original line number Diff line number Diff line { lib, buildPythonPackage, fetchFromGitHub, jax, jaxlib, keras, numpy, parameterized, pillow, pytestCheckHook, pythonOlder, scipy, setuptools, tensorboard, tensorflow, }: buildPythonPackage rec { pname = "objax"; version = "1.8.0"; pyproject = true; disabled = pythonOlder "3.9"; src = fetchFromGitHub { owner = "google"; repo = "objax"; tag = "v${version}"; hash = "sha256-WD+pmR8cEay4iziRXqF3sHUzCMBjmLJ3wZ3iYOD+hzk="; }; patches = [ # Issue reported upstream: https://github.com/google/objax/issues/270 ./replace-deprecated-device_buffers.patch ]; build-system = [ setuptools ]; # Avoid propagating the dependency on `jaxlib`, see # https://github.com/NixOS/nixpkgs/issues/156767 buildInputs = [ jaxlib ]; dependencies = [ jax numpy parameterized pillow scipy tensorboard ]; pythonImportsCheck = [ "objax" ]; # This is necessary to ignore the presence of two protobufs version (tensorflow is bringing an # older version). catchConflicts = false; nativeCheckInputs = [ keras pytestCheckHook tensorflow ]; enabledTestPaths = [ "tests/*.py" ]; disabledTests = [ # Test requires internet access for prefetching some weights "test_pretrained_keras_weight_0_ResNet50V2" # ModuleNotFoundError: No module named 'tree' "TestResNetV2Pretrained" ]; meta = with lib; { description = "Machine learning framework that provides an Object Oriented layer for JAX"; homepage = "https://github.com/google/objax"; changelog = "https://github.com/google/objax/releases/tag/v${version}"; license = licenses.asl20; maintainers = with maintainers; [ ndl ]; # Tests test_syncbn_{0,1,2}d and other tests from tests/parallel.py fail broken = true; }; }
pkgs/development/python-modules/objax/replace-deprecated-device_buffers.patchdeleted 100644 → 0 +0 −14 Original line number Diff line number Diff line diff --git a/objax/util/util.py b/objax/util/util.py index c31a356..344cf9a 100644 --- a/objax/util/util.py +++ b/objax/util/util.py @@ -117,7 +117,8 @@ def get_local_devices(): if _local_devices is None: x = jn.zeros((jax.local_device_count(), 1), dtype=jn.float32) sharded_x = map_to_device(x) - _local_devices = [b.device() for b in sharded_x.device_buffers] + device_buffers = [buf.data for buf in sharded_x.addressable_shards] + _local_devices = [list(b.devices())[0] for b in device_buffers] return _local_devices
pkgs/top-level/python-aliases.nix +1 −0 Original line number Diff line number Diff line Loading @@ -493,6 +493,7 @@ mapAliases ({ ntlm-auth = throw "ntlm-auth has been removed, because it relies on the md4 implementation provided by openssl. Use pyspnego instead."; oauth = throw "oauth has been removed as it is unmaintained"; # added 2025-05-16 oauth2 = throw "oauth2 has been removed as it is unmaintained"; # added 2025-05-16 objax = throw "objax has been removed because the upstream project was archived."; # Added 2025-10-04 openai-triton = triton; # added 2024-07-18 openai-triton-bin = triton-bin; # added 2024-07-18 openai-triton-cuda = triton-cuda; # added 2024-07-18 Loading
pkgs/top-level/python-packages.nix +0 −2 Original line number Diff line number Diff line Loading @@ -10768,8 +10768,6 @@ self: super: with self; { obfsproxy = callPackage ../development/python-modules/obfsproxy { }; objax = callPackage ../development/python-modules/objax { }; objexplore = callPackage ../development/python-modules/objexplore { }; objgraph = callPackage ../development/python-modules/objgraph { Loading