Commit cecd298c authored by Yakubov, Sergey's avatar Yakubov, Sergey
Browse files

initial changes for pydantic

parent 7926a3c0
Loading
Loading
Loading
Loading
Loading
+4 −0
Original line number Diff line number Diff line
import importlib.metadata

__version__ = importlib.metadata.version(__package__)

from typing import Any

bindings_map: Any = {}
+17 −1
Original line number Diff line number Diff line
"""Binding module for PyQt6 framework."""

import uuid
from typing import Any, Optional

from pydantic import BaseModel

from mvvm_lib import bindings_map

from ..utils import rsetattr

try:
@@ -31,12 +36,23 @@ class Communicator(QObject):
        callback_after_update: Any = None,
    ) -> None:
        super().__init__()
        self.id = str(uuid.uuid4())
        bindings_map[self.id] = self
        self.viewmodel_linked_object = viewmodel_linked_object
        self.linked_object_attributes = linked_object_attributes
        self.callback_after_update = callback_after_update

    def _update_viewmodel_callback(self, key: Optional[str] = None, value: Any = None) -> None:
        if isinstance(self.viewmodel_linked_object, dict):
        if isinstance(self.viewmodel_linked_object, BaseModel):
            model = self.viewmodel_linked_object.copy(deep=True)
            rsetattr(model, key or "", value)
            try:
                new_model = model.__class__(**model.model_dump(warnings=False))
                for f, v in new_model:
                    setattr(self.viewmodel_linked_object, f, v)
            except Exception:
                pass
        elif isinstance(self.viewmodel_linked_object, dict):
            self.viewmodel_linked_object.update({key: value})
        elif is_callable(self.viewmodel_linked_object):
            self.viewmodel_linked_object(value)
+44 −6
Original line number Diff line number Diff line
@@ -2,11 +2,16 @@

import asyncio
import inspect
import json
import uuid
from typing import Any, Callable, Optional, Union, cast

from pydantic import BaseModel
from trame_server.state import State
from typing_extensions import override

from mvvm_lib import bindings_map

from ..interface import (
    BindingInterface,
    CallbackAfterUpdateType,
@@ -30,6 +35,18 @@ def is_callable(var: Any) -> bool:
    return inspect.isfunction(var) or inspect.ismethod(var)


def _get_nested_attributes(obj: Any, prefix: str = "") -> Any:
    attributes = []
    for k, v in obj.__dict__.items():
        if not k.startswith("_"):  # Ignore private attributes
            full_key = f"{prefix}.{k}" if prefix else k
            if hasattr(v, "__dict__"):  # Check if the value is another object with attributes
                attributes.extend(_get_nested_attributes(v, prefix=full_key))
            else:
                attributes.append(full_key)
    return attributes


class TrameCommunicator(Communicator):
    """Communicator implementation for Trame."""

@@ -41,6 +58,8 @@ class TrameCommunicator(Communicator):
        callback_after_update: CallbackAfterUpdateType = None,
    ) -> None:
        self.state = state
        self.id = str(uuid.uuid4())
        bindings_map[self.id] = self
        self.viewmodel_linked_object = viewmodel_linked_object
        self._set_linked_object_attributes(linked_object_attributes, viewmodel_linked_object)
        self.viewmodel_callback_after_update = callback_after_update
@@ -53,12 +72,11 @@ class TrameCommunicator(Communicator):
        if (
            viewmodel_linked_object
            and not isinstance(viewmodel_linked_object, dict)
            and not isinstance(viewmodel_linked_object, BaseModel)
            and not is_callable(viewmodel_linked_object)
        ):
            if not linked_object_attributes:
                self.linked_object_attributes = [
                    k for k in viewmodel_linked_object.__dict__.keys() if not k.startswith("_")
                ]
                self.linked_object_attributes = _get_nested_attributes(viewmodel_linked_object)
            else:
                self.linked_object_attributes = linked_object_attributes

@@ -85,7 +103,16 @@ class CallBackConnection:
        self.linked_object_attributes = communicator.linked_object_attributes

    def _update_viewmodel_callback(self, value: Any, key: Optional[str] = None) -> None:
        if isinstance(self.viewmodel_linked_object, dict):
        if isinstance(self.viewmodel_linked_object, BaseModel):
            model = self.viewmodel_linked_object.copy(deep=True)
            rsetattr(model, key or "", value)
            try:
                new_model = model.__class__(**model.model_dump(warnings=False))
                for f, v in new_model:
                    setattr(self.viewmodel_linked_object, f, v)
            except Exception:
                pass
        elif isinstance(self.viewmodel_linked_object, dict):
            if not key:
                self.viewmodel_linked_object.update(value)
            else:
@@ -165,16 +192,27 @@ class StateConnection:

                @self.state.change(state_variable_name)
                def update_viewmodel_callback(**kwargs: dict) -> None:
                    if isinstance(self.viewmodel_linked_object, dict):
                    success = True
                    if isinstance(self.viewmodel_linked_object, BaseModel):
                        json_str = json.dumps(kwargs[state_variable_name])
                        try:
                            model = self.viewmodel_linked_object.model_validate_json(json_str)
                            for field, value in model:
                                setattr(self.viewmodel_linked_object, field, value)
                        except Exception:
                            success = False
                    elif isinstance(self.viewmodel_linked_object, dict):
                        self.viewmodel_linked_object.update(kwargs[state_variable_name])
                    elif is_callable(self.viewmodel_linked_object):
                        cast(Callable, self.viewmodel_linked_object)(kwargs[state_variable_name])
                    else:
                        raise Exception("cannot update", self.viewmodel_linked_object)
                    if self.viewmodel_callback_after_update:
                    if self.viewmodel_callback_after_update and success:
                        self.viewmodel_callback_after_update(state_variable_name)

    def update_in_view(self, value: Any) -> None:
        if hasattr(value, "model_dump"):
            value = value.model_dump()
        if self.linked_object_attributes:
            for attribute_name in self.linked_object_attributes:
                name_in_state = self._get_name_in_state(attribute_name)
+26 −12
Original line number Diff line number Diff line
"""Common utilities."""

import functools
import re
from typing import Any


def rsetattr(obj: Any, attr: str, val: Any) -> None:
    """Set nested attribute of an object."""
    pre, _, post = attr.rpartition(".")
    return setattr(rgetattr(obj, pre) if pre else obj, post, val)

def rgetattr(obj: Any, attr: str) -> Any:
    fields = attr.split(".")
    for field in fields:
        base = field.split("[")[0]
        obj = getattr(obj, base)
        indices = []
        indices = re.findall(r"\[(\d+)\]", field)
        indices = [int(num) for num in indices]
        for index in indices:
            obj = obj[index]
    return obj

def rgetattr(obj: Any, attr: str, *args: Any) -> Any:
    """Get nested attribute of an object."""

    def _getattr(obj: Any, attr: str) -> Any:
        return getattr(obj, attr, *args)

    return functools.reduce(_getattr, [obj] + attr.split("."))
def rsetattr(obj: Any, attr: str, val: Any) -> Any:
    pre, _, post = attr.rpartition(".")
    if pre:
        obj = rgetattr(obj, pre)
    if "[" in post:
        indices = re.findall(r"\[(\d+)\]", post)
        indices = [int(num) for num in indices]
        for i, index in enumerate(indices):
            if i == len(indices) - 1:
                obj[index] = val
            else:
                obj = obj[index]
    else:
        setattr(obj, post, val)