Commit f2ab1419 authored by Yakubov, Sergey's avatar Yakubov, Sergey
Browse files

refactor, add pydantic utils

parent cecd298c
Loading
Loading
Loading
Loading
Loading
+3 −3
Original line number Diff line number Diff line
import importlib.metadata

__version__ = importlib.metadata.version(__package__)
from .bindings_map import bindings_map

from typing import Any
__version__ = importlib.metadata.version(__package__)

bindings_map: Any = {}
__all__ = ["bindings_map"]
+16 −0
Original line number Diff line number Diff line
"""Module for bindings map ant it's utils."""

from typing import Any, Dict

from pydantic import BaseModel

from mvvm_lib.utils import rget_list_of_fields

bindings_map: Dict[str, Any] = {}


def update_bindings_map(source: Any, value: Any) -> None:
    if isinstance(source, BaseModel):
        fields = rget_list_of_fields(source)
        for field in fields:
            bindings_map[field] = value
+79 −0
Original line number Diff line number Diff line
"""Pydantic utils."""

import re
from typing import Any

from pydantic import BaseModel, ValidationError
from pydantic.fields import FieldInfo

from mvvm_lib import bindings_map


def get_nested_pydantic_field(model: BaseModel, field_path: str) -> FieldInfo:
    """
    Retrieves the Pydantic ModelField object for a nested field in a Pydantic model using a dot-separated field path.

    :param model: Pydantic model instance
    :param field_path: Dot-separated path to the field (e.g., "config.nested.nested2")
    :return: The Pydantic ModelField instance
    """
    fields = field_path.split(".")
    current_model: Any = model

    for field in fields:
        if "[" in field:
            base = field.split("[")[0]
            current_model = getattr(current_model, base)
            for _ in range(field.count("[")):
                current_model = current_model[0]
            continue
        if issubclass(type(getattr(current_model, field)), BaseModel):
            current_model = getattr(current_model, field)
        else:
            return current_model.__fields__[field]

    raise Exception(f"Cannot find field {field_path}")


def get_field_info(field_name: str) -> FieldInfo:
    binding = bindings_map.get(field_name, None)
    if not binding:
        raise Exception(f"Cannot find field {field_name}")
    return get_nested_pydantic_field(binding.viewmodel_linked_object, field_name)


def validate_pydantic_parameter(name: str, value: Any, index: int) -> str | None:
    if name not in bindings_map:
        print(f"cannot find {name} in bindings_map")  # no error, just do not validate for now
        return None
    binding = bindings_map[name]
    current_model = binding.viewmodel_linked_object
    # get list of nested fields (if any) and get the corresponding model
    fields = name.split(".")
    for field in fields[:-1]:
        if "[index]" in field:
            field = field.removesuffix("[index]")
            current_model = getattr(current_model, field)[index]
        elif "[" in field:
            base = field.split("[")[0]
            indices = re.findall(r"\[(\d+)\]", field)
            indices = [int(num) for num in indices]
            for i in indices:
                current_model = getattr(current_model, base)[i]
        else:
            current_model = getattr(current_model, field)
    final_field = fields[-1]
    # copy model so we do not modify the current one
    model = current_model.copy(deep=True)
    # force set field value
    setattr(model, final_field, value)
    # validate changed model
    try:
        model.__class__(**model.model_dump(warnings=False))
    except ValidationError as e:
        for error in e.errors():
            if (len(error["loc"]) > 0 and final_field in str(error["loc"][0])) or (
                len(error["loc"]) == 0 and e.title == current_model.__class__.__name__
            ):
                return error["msg"]
    return None
+4 −5
Original line number Diff line number Diff line
@@ -5,8 +5,7 @@ from typing import Any, Optional

from pydantic import BaseModel

from mvvm_lib import bindings_map

from ..bindings_map import update_bindings_map
from ..utils import rsetattr

try:
@@ -37,7 +36,6 @@ class Communicator(QObject):
    ) -> None:
        super().__init__()
        self.id = str(uuid.uuid4())
        bindings_map[self.id] = self
        self.viewmodel_linked_object = viewmodel_linked_object
        self.linked_object_attributes = linked_object_attributes
        self.callback_after_update = callback_after_update
@@ -74,9 +72,10 @@ class Communicator(QObject):
        else:
            return None

    def update_in_view(self, *args: Any, **kwargs: Any) -> Any:
    def update_in_view(self, value: Any, **kwargs: Any) -> Any:
        """Update a View (GUI) when called by a ViewModel."""
        return self.signal.emit(*args, **kwargs)
        update_bindings_map(value, self)
        return self.signal.emit(value, **kwargs)


class PyQtBinding(BindingInterface):
+8 −22
Original line number Diff line number Diff line
@@ -3,15 +3,13 @@
import asyncio
import inspect
import json
import uuid
from typing import Any, Callable, Optional, Union, cast

from pydantic import BaseModel
from trame_server.state import State
from typing_extensions import override

from mvvm_lib import bindings_map

from ..bindings_map import update_bindings_map
from ..interface import (
    BindingInterface,
    CallbackAfterUpdateType,
@@ -20,7 +18,7 @@ from ..interface import (
    LinkedObjectAttributesType,
    LinkedObjectType,
)
from ..utils import rgetattr, rsetattr
from ..utils import normalize_field_name, rget_list_of_fields, rgetattr, rsetattr


def is_async() -> bool:
@@ -35,18 +33,6 @@ def is_callable(var: Any) -> bool:
    return inspect.isfunction(var) or inspect.ismethod(var)


def _get_nested_attributes(obj: Any, prefix: str = "") -> Any:
    attributes = []
    for k, v in obj.__dict__.items():
        if not k.startswith("_"):  # Ignore private attributes
            full_key = f"{prefix}.{k}" if prefix else k
            if hasattr(v, "__dict__"):  # Check if the value is another object with attributes
                attributes.extend(_get_nested_attributes(v, prefix=full_key))
            else:
                attributes.append(full_key)
    return attributes


class TrameCommunicator(Communicator):
    """Communicator implementation for Trame."""

@@ -58,8 +44,8 @@ class TrameCommunicator(Communicator):
        callback_after_update: CallbackAfterUpdateType = None,
    ) -> None:
        self.state = state
        self.id = str(uuid.uuid4())
        bindings_map[self.id] = self
        update_bindings_map(viewmodel_linked_object, self)

        self.viewmodel_linked_object = viewmodel_linked_object
        self._set_linked_object_attributes(linked_object_attributes, viewmodel_linked_object)
        self.viewmodel_callback_after_update = callback_after_update
@@ -76,7 +62,7 @@ class TrameCommunicator(Communicator):
            and not is_callable(viewmodel_linked_object)
        ):
            if not linked_object_attributes:
                self.linked_object_attributes = _get_nested_attributes(viewmodel_linked_object)
                self.linked_object_attributes = rget_list_of_fields(viewmodel_linked_object)
            else:
                self.linked_object_attributes = linked_object_attributes

@@ -89,6 +75,7 @@ class TrameCommunicator(Communicator):
        return self.connection.get_callback()

    def update_in_view(self, value: Any) -> None:
        update_bindings_map(value, self)
        self.connection.update_in_view(value)


@@ -166,10 +153,9 @@ class StateConnection:
            self.state.dirty(name_in_state)

    def _get_name_in_state(self, attribute_name: str) -> str:
        name_in_state = normalize_field_name(attribute_name)
        if self.state_variable_name:
            name_in_state = f"{self.state_variable_name}_{attribute_name.replace('.', '_')}"
        else:
            name_in_state = attribute_name.replace(".", "_")
            name_in_state = f"{self.state_variable_name}_{name_in_state}"
        return name_in_state

    def _connect(self) -> None:
Loading