Unverified Commit 2ccfe3af authored by Pol Dellaiera's avatar Pol Dellaiera Committed by GitHub
Browse files

rabbit: unbreak the package (#390659)

parents 15474993 399baf3e
Loading
Loading
Loading
Loading
+47 −0
Original line number Diff line number Diff line
From 3488de815355051d2e369c7fe48a35dabf695cfc Mon Sep 17 00:00:00 2001
From: Pol Dellaiera <pol.dellaiera@protonmail.com>
Date: Mon, 17 Mar 2025 16:52:25 +0100
Subject: [PATCH] fix file loading

---
 rabbit.py | 17 ++++++++---------
 1 file changed, 8 insertions(+), 9 deletions(-)

diff --git a/rabbit.py b/rabbit.py
index a1826d3..697c880 100644
--- a/rabbit.py
+++ b/rabbit.py
@@ -9,6 +9,7 @@ from sklearn.ensemble import GradientBoostingClassifier
 import joblib
 import site
 from tqdm import tqdm
+from importlib.resources import files
 
 import GenerateActivities as gat
 import ExtractEvent as eev
@@ -59,15 +60,13 @@ def get_model():
     '''
 
     model_file = 'bimbas.joblib'
-    for dir in site.getsitepackages():
-        if dir.endswith('site-packages'):
-            target_dir = dir
-        else:
-            target_dir = site.getsitepackages()[0]
-    bot_identification_model = joblib.load(f'{target_dir}/{model_file}')
-    # bot_identification_model = joblib.load(model_file)
-    
-    return(bot_identification_model)
+    try:
+      resource_path = files("rabbit").joinpath(model_file)
+      bot_identification_model = joblib.load(resource_path)
+    except Exception as e:
+      raise RuntimeError(f"Failed to load the model: {e}")
+
+    return bot_identification_model
 
 def compute_confidence(probability_value):
     '''
-- 
2.48.1
+38 −3
Original line number Diff line number Diff line
@@ -2,10 +2,39 @@
  lib,
  python3,
  fetchFromGitHub,
  fetchPypi,
}:

python3.pkgs.buildPythonApplication rec {
let
  python3' =
    let
      packageOverrides = self: super: {
        scikit-learn = super.scikit-learn.overridePythonAttrs (old: {
          version = "1.5.2";

          src = fetchPypi {
            pname = "scikit_learn";
            version = "1.5.2";
            hash = "sha256-tCN+17P90KSIJ5LmjvJUXVuqUKyju0WqffRoE4rY+U0=";
          };

          # There are 2 tests that are failing, disabling the tests for now.
          # - test_csr_polynomial_expansion_index_overflow[csr_array-False-True-2-65535]
          # - test_csr_polynomial_expansion_index_overflow[csr_array-False-True-3-2344]
          doCheck = false;
        });
      };
    in
    python3.override {
      inherit packageOverrides;
      self = python3;
    };
in
python3'.pkgs.buildPythonApplication rec {
  pname = "rabbit";
  # Make sure to check for which version of scikit-learn this project was built
  # Currently version 2.3.1 is made with scikit-learn 1.5.2
  # Upgrading to newer versions of scikit-learn break the project
  version = "2.3.1";
  pyproject = true;

@@ -16,6 +45,12 @@ python3.pkgs.buildPythonApplication rec {
    hash = "sha256-QmP6yfVnlYoNVa4EUtKR9xbCnQW2V6deV0+hN9IGtic=";
  };

  patches = [
    # Fix file loading, to be removed at the next bump.
    # The author has been notified about the issue and currently working on it.
    ./fix-file-loading.patch
  ];

  pythonRelaxDeps = [
    "numpy"
    "scikit-learn"
@@ -24,11 +59,11 @@ python3.pkgs.buildPythonApplication rec {
    "urllib3"
  ];

  build-system = with python3.pkgs; [
  build-system = with python3'.pkgs; [
    setuptools
  ];

  dependencies = with python3.pkgs; [
  dependencies = with python3'.pkgs; [
    joblib
    numpy
    pandas