From b734114f1a48b0493a7edd17f6d4c828c8c7f70e Mon Sep 17 00:00:00 2001
From: Martyn Gigg <martyn.gigg@stfc.ac.uk>
Date: Wed, 9 Jan 2019 16:58:24 +0000
Subject: [PATCH] Extract workspace tracking code into class

Some plot functions return multiple artists and
tracking these via a single object is much easier.
Refs #24000
---
 .../PythonInterface/mantid/plots/__init__.py  | 159 ++++++++++--------
 .../mantid/plots/plotfunctions.py             |  47 ++++--
 2 files changed, 122 insertions(+), 84 deletions(-)

diff --git a/Framework/PythonInterface/mantid/plots/__init__.py b/Framework/PythonInterface/mantid/plots/__init__.py
index 47f1ef0f2c1..f77326f43cf 100644
--- a/Framework/PythonInterface/mantid/plots/__init__.py
+++ b/Framework/PythonInterface/mantid/plots/__init__.py
@@ -50,7 +50,7 @@ def plot_decorator(func):
     def wrapper(self, *args, **kwargs):
         func_value = func(self, *args, **kwargs)
         # Saves saving it on array objects
-        if mantid.plots.helperfunctions.validate_args(*args, **kwargs):
+        if helperfunctions.validate_args(*args, **kwargs):
             # Fill out kwargs with the values of args
             for index, arg in enumerate(args):
                 if index is 0:
@@ -74,6 +74,45 @@ def plot_decorator(func):
     return wrapper
 
 
+class _WorkspaceArtists(object):
+    """Captures information regarding an artist that has been plotted
+    from a workspace. It allows for removal and replacement of said artists
+
+    """
+    def __init__(self, artists, data_replace_cb):
+        """
+        Initialize an instance
+        :param artists: A reference to a list of artists "attached" to a workspace
+        :param data_replace_cb: A reference to a callable with signature (artists, workspace) -> new_artists
+        """
+        self._artists = artists
+        self._data_replace_cb = data_replace_cb
+
+    def remove(self, axes):
+        """
+        Remove the tracked artists from the given axes
+        :param axes: A reference to the axes instance the artists are attached to
+        """
+        # delete the artists from the axes
+        for artist in self._artists:
+            artist.remove()
+            # Remove doesn't catch removing the container for errorbars etc
+            if isinstance(artist, Container):
+                try:
+                    axes.containers.remove(artist)
+                except ValueError:
+                    pass
+
+        if (not axes.is_empty()) and axes.legend_ is not None:
+            axes.legend()
+
+    def replace_data(self, workspace):
+        """Replace or replot artists based on a new workspace
+        :param workspace: The new workspace containing the data
+        """
+        self._artists = self._data_replace_cb(self._artists, workspace)
+
+
 class MantidAxes(Axes):
     """
     This class defines the **mantid** projection for 2d plotting. One chooses
@@ -106,26 +145,28 @@ class MantidAxes(Axes):
         super(MantidAxes, self).__init__(*args, **kwargs)
         self.tracked_workspaces = dict()
 
-    def track_workspace_artist(self, name, artist, replace_handler=None):
+    def track_workspace_artist(self, name, artists, data_replace_cb=None):
         """
         Add the given workspace name to the list of workspaces
         displayed on this Axes instance
         :param name: The name of the workspace. If empty then no tracking takes place
-        :param artists: A single artist or list/tuple of length 1 containing the data for the workspace
-        :param replace_handler: A function to call when the data is replaced to update
+        :param artists: A single artist or iterable of artists containing the data for the workspace
+        :param data_replace_cb: A function to call when the data is replaced to update
         the artist (optional)
         :returns: The artists variable as it was passed in.
         """
         if name:
-            artist_info = self.tracked_workspaces.setdefault(name, [])
-            if isinstance(artist, Iterable) and not isinstance(artist, Container):
-                artist = artist[0]
-            if replace_handler is None:
-                def replace_handler(_, __):
+            if data_replace_cb is None:
+                def data_replace_cb(_, __):
                     logger.warning("Updating data on this plot type is not yet supported")
-            artist_info.append([artist, replace_handler])
+            artist_info = self.tracked_workspaces.setdefault(name, [])
+            if isinstance(artists, Container) or not isinstance(artists, Iterable):
+                artist_seq = [artists]
+            else:
+                artist_seq = artists
+            artist_info.append(_WorkspaceArtists(artist_seq, data_replace_cb))
 
-        return artist
+        return artists
 
     def remove_workspace_artists(self, name):
         """
@@ -140,21 +181,9 @@ class MantidAxes(Axes):
         except KeyError:
             return False
 
-        # delete the artists from the figure
-        for artist, _ in artist_info:
-            artist.remove()
-            # Remove doesn't catch removing the container for errorbars etc
-            if isinstance(artist, Container):
-                try:
-                    self.containers.remove(artist)
-                except ValueError:
-                    pass
-
-        axes_empty = self.is_empty()
-        if (not axes_empty) and self.legend_ is not None:
-            self.legend()
-
-        return axes_empty
+        for workspace_artist in artist_info:
+            workspace_artist.remove(self)
+        return self.is_empty()
 
     def replace_workspace_artists(self, name, workspace):
         """
@@ -169,28 +198,10 @@ class MantidAxes(Axes):
         except KeyError:
             return False
 
-        for artist, handler in artist_info:
-            handler(artist, workspace)
-        if self.legend_:
-            self.legend()
+        for workspace_artist in artist_info:
+            workspace_artist.replace_data(workspace)
         return True
 
-    def _replace_tracking_info(self, name, artist_orig, artist_replaced):
-        """
-        Replace a tracked artist for the given workspace by a new one.
-        :param name: The name of the workspace
-        :param artist_orig: A reference to the original artist
-        :param artist_replaced: A reference to the new artist
-        :raises: A KeyError if the named workspace has not been tracked already
-        """
-        artist_info = self.tracked_workspaces[name]
-        artist_orig_idx = None
-        for index, (artist, _) in enumerate(artist_info):
-            if artist is artist_orig:
-                artist_orig_idx = index
-        if artist_orig_idx is not None:
-            artist_info[artist_orig_idx][0] = artist_replaced
-
     def is_empty(self):
         """
         Checks the known artist containers to see if anything exists within them
@@ -224,9 +235,10 @@ class MantidAxes(Axes):
         if helperfunctions.validate_args(*args):
             logger.debug('using plotfunctions')
 
-            def _data_update(line2d, workspace):
+            def _data_update(artists, workspace):
+                # It's only possible to plot 1 line at a time from a workspace
                 x, y, _, __ = plotfunctions._plot_impl(self, workspace, args, kwargs)
-                line2d.set_data(x, y)
+                artists[0].set_data(x, y)
                 self.relim()
                 self.autoscale()
 
@@ -282,20 +294,31 @@ class MantidAxes(Axes):
         if helperfunctions.validate_args(*args):
             logger.debug('using plotfunctions')
 
-            def _data_update(container_orig, workspace):
+            def _data_update(artists, workspace):
+                # errorbar with workspaces can only return a single container
+                container_orig = artists[0]
                 # It is not possible to simply reset the error bars so
-                # we just plot new lines
+                # we have to plot new lines but ensure we don't reorder them on the plot!
+                orig_idx = self.containers.index(container_orig)
                 container_orig.remove()
-                self.containers.remove(container_orig)
+                # The container does not remove itself from the containers list
+                # but protect this just in case matplotlib starts doing this
+                try:
+                    self.containers.remove(container_orig)
+                except ValueError:
+                    pass
+                # this gets pushed back onto the containers list
                 container_new = plotfunctions.errorbar(self, workspace, **kwargs)
+                self.containers.insert(orig_idx, container_new)
+                self.containers.pop()
+                # update line properties to match original
                 orig_flat, new_flat = cbook.flatten(container_orig), cbook.flatten(container_new)
                 for artist_orig, artist_new in zip(orig_flat, new_flat):
                      artist_new.update_from(artist_orig)
-
                 # ax.relim does not support collections...
                 self._update_line_limits(container_new[0])
                 self.autoscale()
-                self._replace_tracking_info(workspace.name(), container_orig, container_new)
+                return container_new
 
             return self.track_workspace_artist(args[0].name(),
                                                plotfunctions.errorbar(self, *args, **kwargs),
@@ -403,9 +426,9 @@ class MantidAxes(Axes):
         if helperfunctions.validate_args(*args):
             logger.debug('using plotfunctions')
 
-            def _update_data(artist, workspace):
-                self._redraw_colorplot(plotfunctions.imshow,
-                                       artist, workspace, **kwargs)
+            def _update_data(artists, workspace):
+                return self._redraw_colorplot(plotfunctions.imshow,
+                                              artists, workspace, **kwargs)
 
             return self.track_workspace_artist(args[0].name(),
                                                plotfunctions.imshow(self, *args, **kwargs),
@@ -413,22 +436,26 @@ class MantidAxes(Axes):
         else:
             return Axes.imshow(self, *args, **kwargs)
 
-    def _redraw_colorplot(self, colorfunc, artist_orig, workspace, **kwargs):
+    def _redraw_colorplot(self, colorfunc, artists_orig, workspace, **kwargs):
         """
         Redraw a pcolor* or imshow type plot bsaed on a new workspace
         :param colorfunc: The Axes function to use to draw the new artist
-        :param artist_orig: A reference to the existing Artist object
+        :param artists_orig: A reference to an iterable of existing artists
         :param workspace: A reference to the workspace object
         :param kwargs: Any kwargs passed to the original call
         """
-        artist_orig.remove()
-        if hasattr(artist_orig, 'colorbar_cid'):
-            artist_orig.callbacksSM.disconnect(artist_orig.colorbar_cid)
-        im = colorfunc(self, workspace, **kwargs)
-        plotfunctions.update_colorplot_datalimits(self, im)
-        if artist_orig.colorbar is not None:
-            self._attach_colorbar(im, artist_orig.colorbar)
-        self._replace_tracking_info(workspace.name(), artist_orig, im)
+        for artist_orig in artists_orig:
+            artist_orig.remove()
+            if hasattr(artist_orig, 'colorbar_cid'):
+                artist_orig.callbacksSM.disconnect(artist_orig.colorbar_cid)
+        artists_new = colorfunc(self, workspace, **kwargs)
+        if not isinstance(artists_new, Iterable):
+            artists_new = [artists_new]
+        plotfunctions.update_colorplot_datalimits(self, artists_new)
+        for artist_orig, artist_new in zip(artists_orig, artists_new):
+            if artist_orig.colorbar is not None:
+                self._attach_colorbar(artist_new, artist_orig.colorbar)
+        return artists_new
 
     def contour(self, *args, **kwargs):
         """
diff --git a/Framework/PythonInterface/mantid/plots/plotfunctions.py b/Framework/PythonInterface/mantid/plots/plotfunctions.py
index b598a087641..97a17cfd300 100644
--- a/Framework/PythonInterface/mantid/plots/plotfunctions.py
+++ b/Framework/PythonInterface/mantid/plots/plotfunctions.py
@@ -8,6 +8,7 @@
 #
 #
 from __future__ import (absolute_import, division, print_function)
+import sys
 
 import numpy
 from skimage.transform import resize
@@ -20,6 +21,10 @@ import matplotlib.colors
 import matplotlib.dates as mdates
 import matplotlib.image as mimage
 
+
+# Used for initializing searches of max, min values
+_LARGEST, _SMALLEST = float(sys.maxsize), -sys.maxsize
+
 # ================================================
 # Private 2D Helper functions
 # ================================================
@@ -263,7 +268,7 @@ def _pcolorpieces(axes, workspace, distribution, *args, **kwargs):
     :param pcolortype: this keyword allows the plotting to be one of pcolormesh or
         pcolorfast if there is "mesh" or "fast" in the value of the keyword, or
         pcolor by default
-    Note: the return is the pcolor, pcolormesh, or pcolorfast of the last spectrum
+    :return: A list of the pcolor pieces created
     '''
     (x, y, z) = get_uneven_data(workspace, distribution)
     mini = numpy.min([numpy.min(i) for i in z])
@@ -289,11 +294,12 @@ def _pcolorpieces(axes, workspace, distribution, *args, **kwargs):
     else:
         pcolor = axes.pcolor
 
+    pieces = []
     for xi, yi, zi in zip(x, y, z):
         XX, YY = numpy.meshgrid(xi, yi, indexing='ij')
-        cm = pcolor(XX, YY, zi.reshape(-1, 1), **kwargs)
+        pieces.append(pcolor(XX, YY, zi.reshape(-1, 1), **kwargs))
 
-    return cm
+    return pieces
 
 
 def pcolor(axes, workspace, *args, **kwargs):
@@ -638,26 +644,31 @@ def tricontourf(axes, workspace, *args, **kwargs):
     return axes.tricontourf(x, y, z, *args, **kwargs)
 
 
-def update_colorplot_datalimits(axes, mappable):
+def update_colorplot_datalimits(axes, mappables):
     """
     For an colorplot (imshow, pcolor*) plots update the data limits on the axes
     to circumvent bugs in matplotlib
-    :param mappable: A new mappable for this axes
+    :param mappables: An iterable of mappable for this axes
     """
     # ax.relim in matplotlib < 2.2 doesn't take into account of images
     # and it doesn't support collections at all as of verison 3 so we'll take
     # over
-    if isinstance(mappable, mimage.AxesImage):
-        xmin, xmax, ymin, ymax = mappable.get_extent()
-    elif isinstance(mappable, mcoll.QuadMesh):
-        # coordinates are vertices of the grid
-        coords = mappable._coordinates
-        xmin, ymin = coords[0][0]
-        xmax, ymax = coords[-1][-1]
-    elif isinstance(mappable, mcoll.PolyCollection):
-        xmin, ymin = mappable._paths[0].get_extents().min
-        xmax, ymax = mappable._paths[-1].get_extents().max
-    else:
-        raise ValueError("Unknown mappable type '{}'".format(type(mappable)))
-    axes.update_datalim(((xmin, ymin), (xmax, ymax)))
+    xmin_all, xmax_all, ymin_all, ymax_all = _LARGEST, _SMALLEST, _LARGEST, _SMALLEST
+    for mappable in mappables:
+        if isinstance(mappable, mimage.AxesImage):
+            xmin, xmax, ymin, ymax = mappable.get_extent()
+        elif isinstance(mappable, mcoll.QuadMesh):
+            # coordinates are vertices of the grid
+            coords = mappable._coordinates
+            xmin, ymin = coords[0][0]
+            xmax, ymax = coords[-1][-1]
+        elif isinstance(mappable, mcoll.PolyCollection):
+            xmin, ymin = mappable._paths[0].get_extents().min
+            xmax, ymax = mappable._paths[-1].get_extents().max
+        else:
+            raise ValueError("Unknown mappable type '{}'".format(type(mappable)))
+        xmin_all, xmax_all = min(xmin_all, xmin), max(xmax_all, xmax)
+        ymin_all, ymax_all = min(ymin_all, ymin), max(ymax_all, ymax)
+
+    axes.update_datalim(((xmin_all, ymin_all), (xmax_all, ymax_all)))
     axes.autoscale()
-- 
GitLab