Unverified Commit dfa42e6b authored by Gaétan Lepage's avatar Gaétan Lepage Committed by GitHub
Browse files

python3Packages.objax: drop (#448426)

parents b967613e a0e67364
Loading
Loading
Loading
Loading
+0 −83
Original line number Diff line number Diff line
{
  lib,
  buildPythonPackage,
  fetchFromGitHub,
  jax,
  jaxlib,
  keras,
  numpy,
  parameterized,
  pillow,
  pytestCheckHook,
  pythonOlder,
  scipy,
  setuptools,
  tensorboard,
  tensorflow,
}:

buildPythonPackage rec {
  pname = "objax";
  version = "1.8.0";
  pyproject = true;

  disabled = pythonOlder "3.9";

  src = fetchFromGitHub {
    owner = "google";
    repo = "objax";
    tag = "v${version}";
    hash = "sha256-WD+pmR8cEay4iziRXqF3sHUzCMBjmLJ3wZ3iYOD+hzk=";
  };

  patches = [
    # Issue reported upstream: https://github.com/google/objax/issues/270
    ./replace-deprecated-device_buffers.patch
  ];

  build-system = [ setuptools ];

  # Avoid propagating the dependency on `jaxlib`, see
  # https://github.com/NixOS/nixpkgs/issues/156767
  buildInputs = [ jaxlib ];

  dependencies = [
    jax
    numpy
    parameterized
    pillow
    scipy
    tensorboard
  ];

  pythonImportsCheck = [ "objax" ];

  # This is necessary to ignore the presence of two protobufs version (tensorflow is bringing an
  # older version).
  catchConflicts = false;

  nativeCheckInputs = [
    keras
    pytestCheckHook
    tensorflow
  ];

  enabledTestPaths = [ "tests/*.py" ];

  disabledTests = [
    # Test requires internet access for prefetching some weights
    "test_pretrained_keras_weight_0_ResNet50V2"
    # ModuleNotFoundError: No module named 'tree'
    "TestResNetV2Pretrained"
  ];

  meta = with lib; {
    description = "Machine learning framework that provides an Object Oriented layer for JAX";
    homepage = "https://github.com/google/objax";
    changelog = "https://github.com/google/objax/releases/tag/v${version}";
    license = licenses.asl20;
    maintainers = with maintainers; [ ndl ];
    # Tests test_syncbn_{0,1,2}d and other tests from tests/parallel.py fail
    broken = true;
  };
}
+0 −14
Original line number Diff line number Diff line
diff --git a/objax/util/util.py b/objax/util/util.py
index c31a356..344cf9a 100644
--- a/objax/util/util.py
+++ b/objax/util/util.py
@@ -117,7 +117,8 @@ def get_local_devices():
     if _local_devices is None:
         x = jn.zeros((jax.local_device_count(), 1), dtype=jn.float32)
         sharded_x = map_to_device(x)
-        _local_devices = [b.device() for b in sharded_x.device_buffers]
+        device_buffers = [buf.data for buf in sharded_x.addressable_shards]
+        _local_devices = [list(b.devices())[0] for b in device_buffers]
     return _local_devices
 
 
+1 −0
Original line number Diff line number Diff line
@@ -493,6 +493,7 @@ mapAliases ({
  ntlm-auth = throw "ntlm-auth has been removed, because it relies on the md4 implementation provided by openssl. Use pyspnego instead.";
  oauth = throw "oauth has been removed as it is unmaintained"; # added 2025-05-16
  oauth2 = throw "oauth2 has been removed as it is unmaintained"; # added 2025-05-16
  objax = throw "objax has been removed because the upstream project was archived."; # Added 2025-10-04
  openai-triton = triton; # added 2024-07-18
  openai-triton-bin = triton-bin; # added 2024-07-18
  openai-triton-cuda = triton-cuda; # added 2024-07-18
+0 −2
Original line number Diff line number Diff line
@@ -10768,8 +10768,6 @@ self: super: with self; {
  obfsproxy = callPackage ../development/python-modules/obfsproxy { };
  objax = callPackage ../development/python-modules/objax { };
  objexplore = callPackage ../development/python-modules/objexplore { };
  objgraph = callPackage ../development/python-modules/objgraph {