Skip to content
Snippets Groups Projects
Commit cecd298c authored by Yakubov, Sergey's avatar Yakubov, Sergey
Browse files

initial changes for pydantic

parent 7926a3c0
No related branches found
No related tags found
1 merge request!6Update library to work with Pydantic models
Pipeline #640321 waiting for manual action
import importlib.metadata
__version__ = importlib.metadata.version(__package__)
from typing import Any
bindings_map: Any = {}
"""Binding module for PyQt6 framework."""
import uuid
from typing import Any, Optional
from pydantic import BaseModel
from mvvm_lib import bindings_map
from ..utils import rsetattr
try:
......@@ -31,12 +36,23 @@ class Communicator(QObject):
callback_after_update: Any = None,
) -> None:
super().__init__()
self.id = str(uuid.uuid4())
bindings_map[self.id] = self
self.viewmodel_linked_object = viewmodel_linked_object
self.linked_object_attributes = linked_object_attributes
self.callback_after_update = callback_after_update
def _update_viewmodel_callback(self, key: Optional[str] = None, value: Any = None) -> None:
if isinstance(self.viewmodel_linked_object, dict):
if isinstance(self.viewmodel_linked_object, BaseModel):
model = self.viewmodel_linked_object.copy(deep=True)
rsetattr(model, key or "", value)
try:
new_model = model.__class__(**model.model_dump(warnings=False))
for f, v in new_model:
setattr(self.viewmodel_linked_object, f, v)
except Exception:
pass
elif isinstance(self.viewmodel_linked_object, dict):
self.viewmodel_linked_object.update({key: value})
elif is_callable(self.viewmodel_linked_object):
self.viewmodel_linked_object(value)
......
......@@ -2,11 +2,16 @@
import asyncio
import inspect
import json
import uuid
from typing import Any, Callable, Optional, Union, cast
from pydantic import BaseModel
from trame_server.state import State
from typing_extensions import override
from mvvm_lib import bindings_map
from ..interface import (
BindingInterface,
CallbackAfterUpdateType,
......@@ -30,6 +35,18 @@ def is_callable(var: Any) -> bool:
return inspect.isfunction(var) or inspect.ismethod(var)
def _get_nested_attributes(obj: Any, prefix: str = "") -> Any:
attributes = []
for k, v in obj.__dict__.items():
if not k.startswith("_"): # Ignore private attributes
full_key = f"{prefix}.{k}" if prefix else k
if hasattr(v, "__dict__"): # Check if the value is another object with attributes
attributes.extend(_get_nested_attributes(v, prefix=full_key))
else:
attributes.append(full_key)
return attributes
class TrameCommunicator(Communicator):
"""Communicator implementation for Trame."""
......@@ -41,6 +58,8 @@ class TrameCommunicator(Communicator):
callback_after_update: CallbackAfterUpdateType = None,
) -> None:
self.state = state
self.id = str(uuid.uuid4())
bindings_map[self.id] = self
self.viewmodel_linked_object = viewmodel_linked_object
self._set_linked_object_attributes(linked_object_attributes, viewmodel_linked_object)
self.viewmodel_callback_after_update = callback_after_update
......@@ -53,12 +72,11 @@ class TrameCommunicator(Communicator):
if (
viewmodel_linked_object
and not isinstance(viewmodel_linked_object, dict)
and not isinstance(viewmodel_linked_object, BaseModel)
and not is_callable(viewmodel_linked_object)
):
if not linked_object_attributes:
self.linked_object_attributes = [
k for k in viewmodel_linked_object.__dict__.keys() if not k.startswith("_")
]
self.linked_object_attributes = _get_nested_attributes(viewmodel_linked_object)
else:
self.linked_object_attributes = linked_object_attributes
......@@ -85,7 +103,16 @@ class CallBackConnection:
self.linked_object_attributes = communicator.linked_object_attributes
def _update_viewmodel_callback(self, value: Any, key: Optional[str] = None) -> None:
if isinstance(self.viewmodel_linked_object, dict):
if isinstance(self.viewmodel_linked_object, BaseModel):
model = self.viewmodel_linked_object.copy(deep=True)
rsetattr(model, key or "", value)
try:
new_model = model.__class__(**model.model_dump(warnings=False))
for f, v in new_model:
setattr(self.viewmodel_linked_object, f, v)
except Exception:
pass
elif isinstance(self.viewmodel_linked_object, dict):
if not key:
self.viewmodel_linked_object.update(value)
else:
......@@ -165,16 +192,27 @@ class StateConnection:
@self.state.change(state_variable_name)
def update_viewmodel_callback(**kwargs: dict) -> None:
if isinstance(self.viewmodel_linked_object, dict):
success = True
if isinstance(self.viewmodel_linked_object, BaseModel):
json_str = json.dumps(kwargs[state_variable_name])
try:
model = self.viewmodel_linked_object.model_validate_json(json_str)
for field, value in model:
setattr(self.viewmodel_linked_object, field, value)
except Exception:
success = False
elif isinstance(self.viewmodel_linked_object, dict):
self.viewmodel_linked_object.update(kwargs[state_variable_name])
elif is_callable(self.viewmodel_linked_object):
cast(Callable, self.viewmodel_linked_object)(kwargs[state_variable_name])
else:
raise Exception("cannot update", self.viewmodel_linked_object)
if self.viewmodel_callback_after_update:
if self.viewmodel_callback_after_update and success:
self.viewmodel_callback_after_update(state_variable_name)
def update_in_view(self, value: Any) -> None:
if hasattr(value, "model_dump"):
value = value.model_dump()
if self.linked_object_attributes:
for attribute_name in self.linked_object_attributes:
name_in_state = self._get_name_in_state(attribute_name)
......
"""Common utilities."""
import functools
import re
from typing import Any
def rsetattr(obj: Any, attr: str, val: Any) -> None:
"""Set nested attribute of an object."""
pre, _, post = attr.rpartition(".")
return setattr(rgetattr(obj, pre) if pre else obj, post, val)
def rgetattr(obj: Any, attr: str) -> Any:
fields = attr.split(".")
for field in fields:
base = field.split("[")[0]
obj = getattr(obj, base)
indices = []
indices = re.findall(r"\[(\d+)\]", field)
indices = [int(num) for num in indices]
for index in indices:
obj = obj[index]
return obj
def rgetattr(obj: Any, attr: str, *args: Any) -> Any:
"""Get nested attribute of an object."""
def _getattr(obj: Any, attr: str) -> Any:
return getattr(obj, attr, *args)
return functools.reduce(_getattr, [obj] + attr.split("."))
def rsetattr(obj: Any, attr: str, val: Any) -> Any:
pre, _, post = attr.rpartition(".")
if pre:
obj = rgetattr(obj, pre)
if "[" in post:
indices = re.findall(r"\[(\d+)\]", post)
indices = [int(num) for num in indices]
for i, index in enumerate(indices):
if i == len(indices) - 1:
obj[index] = val
else:
obj = obj[index]
else:
setattr(obj, post, val)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment