Commit d3870e28 authored by Yakubov, Sergey's avatar Yakubov, Sergey
Browse files

Merge branch '9-fix-problem-with-duplicate-field-names' into 'main'

fix a problem with same field names, refactor

Closes #9

See merge request ndip/public-packages/py-mvvm!9
parents 8c5fe371 f6851fd9
Loading
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
[tool.poetry]
name = "mvvm-lib"
version = "0.4.2"
version = "0.5.0"
description = "A Python Package for Model-View-ViewModel pattern"
authors = ["Yakubov, Sergey <yakubovs@ornl.gov>"]
readme = "README.md"
+3 −10
Original line number Diff line number Diff line
@@ -2,16 +2,9 @@

from typing import Any, Dict

from pydantic import BaseModel

from mvvm_lib.utils import rget_list_of_fields

bindings_map: Dict[str, Any] = {}


def update_bindings_map(source: Any, value: Any) -> None:
    #    if isinstance(source, BaseModel):
    if issubclass(type(source), BaseModel):
        fields = rget_list_of_fields(source)
        for field in fields:
            bindings_map[field] = value
def update_bindings_map(key: str | None, value: Any) -> None:
    if key:
        bindings_map[key] = value
+11 −8
Original line number Diff line number Diff line
@@ -26,26 +26,29 @@ def get_nested_pydantic_field(model: BaseModel, field_path: str) -> FieldInfo:
        if issubclass(type(getattr(current_model, field)), BaseModel):
            current_model = getattr(current_model, field)
        else:
            return current_model.__fields__[field]
            return current_model.model_fields[field]

    raise Exception(f"Cannot find field {field_path}")


def get_field_info(field_name: str) -> FieldInfo:
    binding = bindings_map.get(field_name, None)
    name = field_name.split(".")[0]
    field_name = field_name.removeprefix(f"{name}.")
    binding = bindings_map.get(name, None)
    if not binding:
        raise Exception(f"Cannot find field {field_name}")
        raise Exception(f"Cannot find binding for {name}")
    return get_nested_pydantic_field(binding.viewmodel_linked_object, field_name)


def validate_pydantic_parameter(name: str, value: Any, index: int) -> str | None:
    if name not in bindings_map:
        logger.warning(f"cannot find {name} in bindings_map")  # no error, just do not validate for now
def validate_pydantic_parameter(name: str, value: Any) -> str | None:
    object_name = name.split(".")[0]
    if object_name not in bindings_map:
        logger.warning(f"cannot find {object_name} in bindings_map")  # no error, just do not validate for now
        return None
    binding = bindings_map[name]
    binding = bindings_map[object_name]
    current_model = binding.viewmodel_linked_object
    # get list of nested fields (if any) and get the corresponding model
    fields = name.split(".")
    fields = name.split(".")[1:]
    for field in fields[:-1]:
        if "[" in field:
            base = field.split("[")[0]
+12 −8
Original line number Diff line number Diff line
"""Binding module for PyQt framework."""

import os
from typing import Any, Optional
from typing import Any, Callable, Optional

from pydantic import BaseModel

@@ -30,7 +30,7 @@ def is_callable(var: Any) -> bool:
    return inspect.isfunction(var) or inspect.ismethod(var)


class Communicator(QObject):
class PyQtCommunicator(QObject):
    """Communicator class, that provides methods required for binding to communicate between ViewModel and View."""

    signal = pyqtSignal(object)
@@ -45,10 +45,13 @@ class Communicator(QObject):
        self.viewmodel_linked_object = viewmodel_linked_object
        self.linked_object_attributes = linked_object_attributes
        self.callback_after_update = callback_after_update
        self.prefix = ""

    def _update_viewmodel_callback(self, key: Optional[str] = None, value: Any = None) -> None:
        if issubclass(type(self.viewmodel_linked_object), BaseModel):
            model = self.viewmodel_linked_object.copy(deep=True)
            if self.prefix and key:
                key = key.removeprefix(f"{self.prefix}.")
            rsetattr(model, key or "", value)
            try:
                new_model = model.__class__(**model.model_dump(warnings=False))
@@ -68,20 +71,21 @@ class Communicator(QObject):
        if self.callback_after_update:
            self.callback_after_update(key)

    def connect(self, *args: Any, **kwargs: Any) -> Any:
    def connect(self, name: str, callback: Callable) -> Any:
        # connect should be called from the View side to connect a
        # GUI element (via a function to change GUI element that is passed to the connect call)
        # and a linked_object (passed during bind creation from ViewModel side)
        self.signal.connect(*args, **kwargs)
        update_bindings_map(name, self)
        self.prefix = name
        self.signal.connect(callback)
        if self.viewmodel_linked_object:
            return self._update_viewmodel_callback
        else:
            return None

    def update_in_view(self, value: Any, **kwargs: Any) -> Any:
    def update_in_view(self, value: Any) -> Any:
        """Update a View (GUI) when called by a ViewModel."""
        update_bindings_map(value, self)
        return self.signal.emit(value, **kwargs)
        return self.signal.emit(value)


class PyQtBinding(BindingInterface):
@@ -95,4 +99,4 @@ class PyQtBinding(BindingInterface):
        For PyQt we use pyqtSignal to trigger GU
        I update and linked_object to trigger ViewModel/Model update
        """
        return Communicator(linked_object, linked_object_arguments, callback_after_update)
        return PyQtCommunicator(linked_object, linked_object_arguments, callback_after_update)
+15 −10
Original line number Diff line number Diff line
@@ -44,8 +44,6 @@ class TrameCommunicator(Communicator):
        callback_after_update: CallbackAfterUpdateType = None,
    ) -> None:
        self.state = state
        update_bindings_map(viewmodel_linked_object, self)

        self.viewmodel_linked_object = viewmodel_linked_object
        self._set_linked_object_attributes(linked_object_attributes, viewmodel_linked_object)
        self.viewmodel_callback_after_update = callback_after_update
@@ -71,11 +69,12 @@ class TrameCommunicator(Communicator):
        if is_callable(connector):
            self.connection = CallBackConnection(self, connector)
        else:
            self.connection = StateConnection(self, str(connector) if connector else None)
            connector = str(connector) if connector else None
            update_bindings_map(connector, self)
            self.connection = StateConnection(self, connector)
        return self.connection.get_callback()

    def update_in_view(self, value: Any) -> None:
        update_bindings_map(value, self)
        self.connection.update_in_view(value)


@@ -135,11 +134,17 @@ class StateConnection:
        self.linked_object_attributes = communicator.linked_object_attributes
        self._connect()

    async def _handle_callback(self, arg: str) -> None:
        if self.viewmodel_callback_after_update:
            if inspect.iscoroutinefunction(self.viewmodel_callback_after_update):
                await self.viewmodel_callback_after_update(arg)
            else:
                self.viewmodel_callback_after_update(arg)

    def _on_state_update(self, attribute_name: str, name_in_state: str) -> Callable:
        def update(**_kwargs: Any) -> None:
        async def update(**_kwargs: Any) -> None:
            rsetattr(self.viewmodel_linked_object, attribute_name, self.state[name_in_state])
            if self.viewmodel_callback_after_update:
                self.viewmodel_callback_after_update(attribute_name)
            await self._handle_callback(attribute_name)

        return update

@@ -177,7 +182,7 @@ class StateConnection:
            elif state_variable_name:

                @self.state.change(state_variable_name)
                def update_viewmodel_callback(**kwargs: dict) -> None:
                async def update_viewmodel_callback(**kwargs: dict) -> None:
                    updated = True
                    if self.viewmodel_linked_object and issubclass(type(self.viewmodel_linked_object), BaseModel):
                        json_str = json.dumps(kwargs[state_variable_name])
@@ -196,8 +201,8 @@ class StateConnection:
                        cast(Callable, self.viewmodel_linked_object)(kwargs[state_variable_name])
                    else:
                        raise Exception("cannot update", self.viewmodel_linked_object)
                    if self.viewmodel_callback_after_update and updated:
                        self.viewmodel_callback_after_update(state_variable_name)
                    if updated:
                        await self._handle_callback(state_variable_name)

    def update_in_view(self, value: Any) -> None:
        if issubclass(type(value), BaseModel):