Commit 1649cdae authored by Duggan, John's avatar Duggan, John
Browse files

Improve MVVM usage and performance profiling

parent 12d6cab3
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
[tool.poetry]
name = "ctscan-viz"
version = "0.2.1"
version = "0.2.2"
description = "Template application"
authors = []
readme = "README.md"
+43 −0
Original line number Diff line number Diff line
"""Main model."""

import os
from time import time
from typing import Optional

import numpy as np
from natsort import natsorted
from pyvista import ImageData, get_reader


class MainModel:
@@ -11,6 +15,7 @@ class MainModel:
    def __init__(self) -> None:
        self.data_directory = ""
        self.file_list: list[str] = []
        self.volume: Optional[ImageData] = None

    def get_file_count(self) -> int:
        return len(self.file_list)
@@ -18,6 +23,9 @@ class MainModel:
    def get_file_list(self) -> list[str]:
        return self.file_list

    def get_volume(self) -> Optional[ImageData]:
        return self.volume

    def update_data_directory(self, value: str) -> None:
        self.data_directory = value
        self.update_file_list()
@@ -33,3 +41,38 @@ class MainModel:
            self.file_list = natsorted(self.file_list)
        except OSError:
            pass

    def update_volume(self) -> None:
        # Read the slices and validate their dimensions
        x_range = 0
        y_range = 0
        z_range = 0
        slices = []
        for path in self.file_list:
            start = time()
            slice = get_reader(path).read()
            slices.append(slice)
            z_range += 1

            if x_range == 0 and y_range == 0:
                x_range = slice.dimensions[0]
                y_range = slice.dimensions[1]
            else:
                if x_range != slice.dimensions[0] or y_range != slice.dimensions[1]:
                    raise ValueError(
                        f"All slices must have the same dimensions. Expected {(x_range, y_range, 1)} but found "
                        "{slice.dimensions}."
                    )
            print(f"Read slice {z_range}/{len(self.file_list)} in {time() - start:.2f}s", flush=True)

        # Create a numpy array from the slices
        start = time()
        scalars = np.concatenate([slice.active_scalars for slice in slices])
        print(f"Numpy array creation: {time() - start:.2f}s", flush=True)

        start = time()
        # Define the 3D volume's extent
        self.volume = ImageData(dimensions=(x_range, y_range, z_range), origin=(0, 0, 0), spacing=(1, 1, 1))
        # Define the 3D volume's scalars for the color/opacity transfer functions
        self.volume["TIFF Scalars"] = scalars
        print(f"PyVista volume creation: {time() - start:.2f}s", flush=True)
+32 −4
Original line number Diff line number Diff line
"""Main view model."""

from asyncio import create_task, sleep
from threading import Thread
from typing import Optional

from mvvm_lib.interface import BindingInterface
from pyvista import ImageData

from ctscan_viz.models.main import MainModel

@@ -12,14 +17,36 @@ class MainViewModel:
        self.model = model
        self.binding = binding

        self.loading = False

        self.data_directory_bind = self.binding.new_bind()
        self.file_count_bind = self.binding.new_bind()
        self.render_files_bind = self.binding.new_bind()
        self.rendering_bind = self.binding.new_bind()
        self.loading_bind = self.binding.new_bind()
        self.render_bind = self.binding.new_bind()

    def get_volume(self) -> Optional[ImageData]:
        return self.model.get_volume()

    async def monitor_loading(self) -> None:
        while self.loading:
            await sleep(0.1)

        self.render_bind.update_in_view(None)
        self.update_view()

    def load_in_background(self) -> None:
        self.model.update_volume()
        self.loading = False

    def render_files(self) -> None:
        self.render_files_bind.update_in_view(self.model.get_file_list())
        self.rendering_bind.update_in_view(False)
        self.loading = True
        self.update_view()

        # We run in the background to avoid blocking the main Trame thread if loading is slow.
        self.loading_thread = Thread(target=self.load_in_background)
        self.loading_thread.daemon = True
        self.loading_thread.start()
        create_task(self.monitor_loading())

    def update_data_directory(self, value: str) -> None:
        self.model.update_data_directory(value)
@@ -27,3 +54,4 @@ class MainViewModel:

    def update_view(self) -> None:
        self.file_count_bind.update_in_view(self.model.get_file_count())
        self.loading_bind.update_in_view(self.loading)
+13 −40
Original line number Diff line number Diff line
"""PyVista plotter for CT scans."""

import time
from time import time
from typing import Any, Optional

import numpy as np
from pyvista import ImageData, Plotter, get_reader, start_xvfb, themes
from pyvista import Plotter, start_xvfb, themes
from pyvista.trame.ui import get_viewer
from trame.widgets import html
from trame.widgets import vuetify3 as vuetify
@@ -21,8 +21,8 @@ class VisualizationPanel:
        self.server = server
        self.vm = vm
        self.vm.file_count_bind.connect("file_count")
        self.vm.render_files_bind.connect(self.render_files)
        self.vm.rendering_bind.connect("rendering")
        self.vm.loading_bind.connect("loading")
        self.vm.render_bind.connect(self.render)

        self.plotter = self.create_plotter()
        self.create_ui()
@@ -37,49 +37,22 @@ class VisualizationPanel:
    def create_ui(self) -> None:
        @self.server.controller.trigger("start_render")
        def _start_render() -> None:
            self.plotter.clear()
            self.vm.render_files()

        view = get_viewer(self.plotter)

        with vuetify.VBtn(
            classes="mb-2",
            disabled=("file_count < 1 || rendering",),
            click="rendering = true; trigger('start_render');",
            disabled=("file_count < 1 || loading",),
            click="trigger('start_render');",
        ):
            vuetify.VProgressCircular(indeterminate=True, v_if="rendering", size=24)
            vuetify.VProgressCircular(indeterminate=True, v_if="loading", size=24)
            html.Span("Render {{ file_count }} files", v_else=True)
        view.ui(mode="server", style="height: 66vh;")

    def render_files(self, file_list: list[str]) -> None:
        self.plotter.clear()

        start = time.time()
        # Merge the slices into one 3D array
        x_range = 0
        y_range = 0
        z_range = 0
        scalars = np.array([])
        for path in file_list:
            slice_start = time.time()
            slice = get_reader(path).read()
            if len(scalars) == 0:
                x_range = slice.dimensions[0]
                y_range = slice.dimensions[1]
                scalars = slice.active_scalars
            else:
                scalars = np.concatenate((scalars, slice.active_scalars))
            z_range += 1
            print(f"Processed slice {z_range}/{len(file_list)} in {time.time() - slice_start:.2f}s", flush=True)
        print(f"Numpy array creation: {time.time() - start:.2f}s", flush=True)

        start = time.time()
        # Define the 3D volume's extent
        volume = ImageData(dimensions=(x_range, y_range, z_range), origin=(0, 0, 0), spacing=(1, 1, 1))
        # Define the 3D volume's scalars for the color/opacity transfer functions
        volume["TIFF Scalars"] = scalars
        print(f"PyVista volume creation: {time.time() - start:.2f}s", flush=True)

        start = time.time()
        self.plotter.add_volume(volume, opacity="sigmoid")
    def render(self, _: Optional[Any] = None) -> None:
        start = time()
        self.plotter.add_volume(self.vm.get_volume(), opacity="sigmoid")
        self.plotter.view_isometric(self.plotter)
        print(f"PyVista volume rendering: {time.time() - start:.2f}s", flush=True)
        print(f"PyVista volume rendering: {time() - start:.2f}s", flush=True)