Skip to content
Snippets Groups Projects
Commit f2ab1419 authored by Yakubov, Sergey's avatar Yakubov, Sergey
Browse files

refactor, add pydantic utils

parent cecd298c
No related branches found
No related tags found
1 merge request!6Update library to work with Pydantic models
Pipeline #640794 waiting for manual action
import importlib.metadata
__version__ = importlib.metadata.version(__package__)
from .bindings_map import bindings_map
from typing import Any
__version__ = importlib.metadata.version(__package__)
bindings_map: Any = {}
__all__ = ["bindings_map"]
"""Module for bindings map ant it's utils."""
from typing import Any, Dict
from pydantic import BaseModel
from mvvm_lib.utils import rget_list_of_fields
bindings_map: Dict[str, Any] = {}
def update_bindings_map(source: Any, value: Any) -> None:
if isinstance(source, BaseModel):
fields = rget_list_of_fields(source)
for field in fields:
bindings_map[field] = value
"""Pydantic utils."""
import re
from typing import Any
from pydantic import BaseModel, ValidationError
from pydantic.fields import FieldInfo
from mvvm_lib import bindings_map
def get_nested_pydantic_field(model: BaseModel, field_path: str) -> FieldInfo:
"""
Retrieves the Pydantic ModelField object for a nested field in a Pydantic model using a dot-separated field path.
:param model: Pydantic model instance
:param field_path: Dot-separated path to the field (e.g., "config.nested.nested2")
:return: The Pydantic ModelField instance
"""
fields = field_path.split(".")
current_model: Any = model
for field in fields:
if "[" in field:
base = field.split("[")[0]
current_model = getattr(current_model, base)
for _ in range(field.count("[")):
current_model = current_model[0]
continue
if issubclass(type(getattr(current_model, field)), BaseModel):
current_model = getattr(current_model, field)
else:
return current_model.__fields__[field]
raise Exception(f"Cannot find field {field_path}")
def get_field_info(field_name: str) -> FieldInfo:
binding = bindings_map.get(field_name, None)
if not binding:
raise Exception(f"Cannot find field {field_name}")
return get_nested_pydantic_field(binding.viewmodel_linked_object, field_name)
def validate_pydantic_parameter(name: str, value: Any, index: int) -> str | None:
if name not in bindings_map:
print(f"cannot find {name} in bindings_map") # no error, just do not validate for now
return None
binding = bindings_map[name]
current_model = binding.viewmodel_linked_object
# get list of nested fields (if any) and get the corresponding model
fields = name.split(".")
for field in fields[:-1]:
if "[index]" in field:
field = field.removesuffix("[index]")
current_model = getattr(current_model, field)[index]
elif "[" in field:
base = field.split("[")[0]
indices = re.findall(r"\[(\d+)\]", field)
indices = [int(num) for num in indices]
for i in indices:
current_model = getattr(current_model, base)[i]
else:
current_model = getattr(current_model, field)
final_field = fields[-1]
# copy model so we do not modify the current one
model = current_model.copy(deep=True)
# force set field value
setattr(model, final_field, value)
# validate changed model
try:
model.__class__(**model.model_dump(warnings=False))
except ValidationError as e:
for error in e.errors():
if (len(error["loc"]) > 0 and final_field in str(error["loc"][0])) or (
len(error["loc"]) == 0 and e.title == current_model.__class__.__name__
):
return error["msg"]
return None
......@@ -5,8 +5,7 @@ from typing import Any, Optional
from pydantic import BaseModel
from mvvm_lib import bindings_map
from ..bindings_map import update_bindings_map
from ..utils import rsetattr
try:
......@@ -37,7 +36,6 @@ class Communicator(QObject):
) -> None:
super().__init__()
self.id = str(uuid.uuid4())
bindings_map[self.id] = self
self.viewmodel_linked_object = viewmodel_linked_object
self.linked_object_attributes = linked_object_attributes
self.callback_after_update = callback_after_update
......@@ -74,9 +72,10 @@ class Communicator(QObject):
else:
return None
def update_in_view(self, *args: Any, **kwargs: Any) -> Any:
def update_in_view(self, value: Any, **kwargs: Any) -> Any:
"""Update a View (GUI) when called by a ViewModel."""
return self.signal.emit(*args, **kwargs)
update_bindings_map(value, self)
return self.signal.emit(value, **kwargs)
class PyQtBinding(BindingInterface):
......
......@@ -3,15 +3,13 @@
import asyncio
import inspect
import json
import uuid
from typing import Any, Callable, Optional, Union, cast
from pydantic import BaseModel
from trame_server.state import State
from typing_extensions import override
from mvvm_lib import bindings_map
from ..bindings_map import update_bindings_map
from ..interface import (
BindingInterface,
CallbackAfterUpdateType,
......@@ -20,7 +18,7 @@ from ..interface import (
LinkedObjectAttributesType,
LinkedObjectType,
)
from ..utils import rgetattr, rsetattr
from ..utils import normalize_field_name, rget_list_of_fields, rgetattr, rsetattr
def is_async() -> bool:
......@@ -35,18 +33,6 @@ def is_callable(var: Any) -> bool:
return inspect.isfunction(var) or inspect.ismethod(var)
def _get_nested_attributes(obj: Any, prefix: str = "") -> Any:
attributes = []
for k, v in obj.__dict__.items():
if not k.startswith("_"): # Ignore private attributes
full_key = f"{prefix}.{k}" if prefix else k
if hasattr(v, "__dict__"): # Check if the value is another object with attributes
attributes.extend(_get_nested_attributes(v, prefix=full_key))
else:
attributes.append(full_key)
return attributes
class TrameCommunicator(Communicator):
"""Communicator implementation for Trame."""
......@@ -58,8 +44,8 @@ class TrameCommunicator(Communicator):
callback_after_update: CallbackAfterUpdateType = None,
) -> None:
self.state = state
self.id = str(uuid.uuid4())
bindings_map[self.id] = self
update_bindings_map(viewmodel_linked_object, self)
self.viewmodel_linked_object = viewmodel_linked_object
self._set_linked_object_attributes(linked_object_attributes, viewmodel_linked_object)
self.viewmodel_callback_after_update = callback_after_update
......@@ -76,7 +62,7 @@ class TrameCommunicator(Communicator):
and not is_callable(viewmodel_linked_object)
):
if not linked_object_attributes:
self.linked_object_attributes = _get_nested_attributes(viewmodel_linked_object)
self.linked_object_attributes = rget_list_of_fields(viewmodel_linked_object)
else:
self.linked_object_attributes = linked_object_attributes
......@@ -89,6 +75,7 @@ class TrameCommunicator(Communicator):
return self.connection.get_callback()
def update_in_view(self, value: Any) -> None:
update_bindings_map(value, self)
self.connection.update_in_view(value)
......@@ -166,10 +153,9 @@ class StateConnection:
self.state.dirty(name_in_state)
def _get_name_in_state(self, attribute_name: str) -> str:
name_in_state = normalize_field_name(attribute_name)
if self.state_variable_name:
name_in_state = f"{self.state_variable_name}_{attribute_name.replace('.', '_')}"
else:
name_in_state = attribute_name.replace(".", "_")
name_in_state = f"{self.state_variable_name}_{name_in_state}"
return name_in_state
def _connect(self) -> None:
......
......@@ -4,12 +4,44 @@ import re
from typing import Any
def normalize_field_name(field: str) -> str:
return field.replace(".", "_").replace("[", "_").replace("]", "")
# .replace("_length", "n"))
def list_has_objects(v: list) -> bool:
for elem in v:
if isinstance(elem, list):
return list_has_objects(elem)
elif hasattr(elem, "__dict__"):
return True
return False
def rget_list_of_fields(obj: Any, prefix: str = "") -> Any:
if not hasattr(obj, "__dict__"):
return [prefix]
attributes = []
for k, v in obj.__dict__.items():
if not k.startswith("_"): # Ignore private attributes
full_key = f"{prefix}.{k}" if prefix else k
if isinstance(v, list) and list_has_objects(v):
for i, elem in enumerate(v):
attributes.extend(rget_list_of_fields(elem, prefix=f"{full_key}[{i}]"))
elif hasattr(v, "__dict__"): # Check if the value is another object with attributes
attributes.extend(rget_list_of_fields(v, prefix=full_key))
else:
attributes.append(full_key)
return attributes
def rgetattr(obj: Any, attr: str) -> Any:
fields = attr.split(".")
for field in fields:
base = field.split("[")[0]
obj = getattr(obj, base)
indices = []
indices = re.findall(r"\[(\d+)\]", field)
indices = [int(num) for num in indices]
for index in indices:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment