Commit e858bc6f authored by Yakubov, Sergey's avatar Yakubov, Sergey
Browse files

refactor, add comments

parent 751ae196
Loading
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -24,7 +24,7 @@ with open("../pyproject.toml", "rb") as toml_file:

sys.path.insert(0, os.path.abspath("../src"))

exclude_patterns = ["_build", "Thumbs.db", ".DS_Store", "__main__.py"]
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store", "__main__.py", "_internal"]
extensions = ["sphinx.ext.autodoc", "sphinx.ext.napoleon", "sphinx_rtd_theme"]


+86 −0
Original line number Diff line number Diff line
"""Pydantic utils."""

import logging
import re
from typing import Any, Tuple

from deepdiff import DeepDiff
from pydantic import BaseModel, ValidationError
from pydantic.fields import FieldInfo

logger = logging.getLogger(__name__)


def _format_field_name_from_tuple(input_tuple: Tuple) -> str:
    res = ""
    for item in input_tuple:
        if isinstance(item, int):
            formatted = f"[{item}]"
        elif isinstance(item, str):
            formatted = f".{item}" if res else item
        else:
            formatted = str(item)
        res += formatted
    return res


def get_errored_fields_from_validation_error(e: ValidationError) -> list[str]:
    """
    Get a list of Pydantic model fields from a Pydantic ValidationError.

    Args:
        e (ValidationError): The Pydantic ValidationError containing the validation errors.

    Returns
    -------
        list[str]: A list of nested field names (putting indices in brackets, and using dots for nested fields
        e.g. nested.ranges[0]) that failed validation.
    """
    res = []
    for error in e.errors():
        res.append(_format_field_name_from_tuple(error["loc"]))
    return res


def _remove_brackets_suffix(s: str) -> str:
    return re.sub(r"\[\d+\]$", "", s)


def get_updated_fields(old: BaseModel, new: BaseModel) -> list[str]:
    """
    Get a list of Pydantic model fields that were updated.

    Uses DeepDiff package to compare new and old models and
    then processed the results to build lists in a format we want.
    """
    diff = DeepDiff(old, new)
    updates: list[str] = []
    if "values_changed" in diff:
        # DeepDiff adds .root to the root object, we don't need that
        updates = [k.removeprefix("root.") for k in diff["values_changed"].keys()]
    for item in ["iterable_item_added", "iterable_item_removed"]:
        # for added/removed items DeepDiff adds its index, we don't need that as well
        if item in diff:
            updates += [_remove_brackets_suffix(k.removeprefix("root.")) for k in diff[item].keys()]

    return updates


def get_nested_pydantic_field(model: BaseModel, field_path: str) -> FieldInfo:
    """Retrieve a nested field's metadata from a Pydantic model using a dot-separated path."""
    fields = field_path.split(".")
    current_model: Any = model

    for field in fields:
        if "[" in field:
            base = field.split("[")[0]
            current_model = getattr(current_model, base)
            for _ in range(field.count("[")):
                current_model = current_model[0]
            continue
        if issubclass(type(getattr(current_model, field)), BaseModel):
            current_model = getattr(current_model, field)
        else:
            return current_model.model_fields[field]

    raise Exception(f"Cannot find field {field_path}")
+1 −4
Original line number Diff line number Diff line
"""Common utilities."""
"""Internal common functions tp be used within the package."""

import re
from typing import Any
@@ -8,9 +8,6 @@ def normalize_field_name(field: str) -> str:
    return field.replace(".", "_").replace("[", "_").replace("]", "")


#            .replace("_length", "n"))


def list_has_objects(v: list) -> bool:
    for elem in v:
        if isinstance(elem, list):
+6 −6
Original line number Diff line number Diff line
"""Module for bindings map ant it's utils."""
"""Module for storing and accessing MVVM bindings.

This module contains a global dictionary, `bindings_map`, which holds the MVVM (Model-View-ViewModel) bindings.
Each binding is stored with a name as the key, allowing easy lookup and access to the associated binding from GUI
by using a field name (first part of which would be the binding key).
"""

from typing import Any, Dict

bindings_map: Dict[str, Any] = {}


def update_bindings_map(key: str | None, value: Any) -> None:
    if key:
        bindings_map[key] = value
+1 −1
Original line number Diff line number Diff line
@@ -5,8 +5,8 @@ from typing import Any

import param

from .._internal.utils import rgetattr, rsetattr
from ..interface import BindingInterface
from ..utils import rgetattr, rsetattr


def is_parameterized(var: Any) -> bool:
Loading