From b734114f1a48b0493a7edd17f6d4c828c8c7f70e Mon Sep 17 00:00:00 2001 From: Martyn Gigg <martyn.gigg@stfc.ac.uk> Date: Wed, 9 Jan 2019 16:58:24 +0000 Subject: [PATCH] Extract workspace tracking code into class Some plot functions return multiple artists and tracking these via a single object is much easier. Refs #24000 --- .../PythonInterface/mantid/plots/__init__.py | 159 ++++++++++-------- .../mantid/plots/plotfunctions.py | 47 ++++-- 2 files changed, 122 insertions(+), 84 deletions(-) diff --git a/Framework/PythonInterface/mantid/plots/__init__.py b/Framework/PythonInterface/mantid/plots/__init__.py index 47f1ef0f2c1..f77326f43cf 100644 --- a/Framework/PythonInterface/mantid/plots/__init__.py +++ b/Framework/PythonInterface/mantid/plots/__init__.py @@ -50,7 +50,7 @@ def plot_decorator(func): def wrapper(self, *args, **kwargs): func_value = func(self, *args, **kwargs) # Saves saving it on array objects - if mantid.plots.helperfunctions.validate_args(*args, **kwargs): + if helperfunctions.validate_args(*args, **kwargs): # Fill out kwargs with the values of args for index, arg in enumerate(args): if index is 0: @@ -74,6 +74,45 @@ def plot_decorator(func): return wrapper +class _WorkspaceArtists(object): + """Captures information regarding an artist that has been plotted + from a workspace. It allows for removal and replacement of said artists + + """ + def __init__(self, artists, data_replace_cb): + """ + Initialize an instance + :param artists: A reference to a list of artists "attached" to a workspace + :param data_replace_cb: A reference to a callable with signature (artists, workspace) -> new_artists + """ + self._artists = artists + self._data_replace_cb = data_replace_cb + + def remove(self, axes): + """ + Remove the tracked artists from the given axes + :param axes: A reference to the axes instance the artists are attached to + """ + # delete the artists from the axes + for artist in self._artists: + artist.remove() + # Remove doesn't catch removing the container for errorbars etc + if isinstance(artist, Container): + try: + axes.containers.remove(artist) + except ValueError: + pass + + if (not axes.is_empty()) and axes.legend_ is not None: + axes.legend() + + def replace_data(self, workspace): + """Replace or replot artists based on a new workspace + :param workspace: The new workspace containing the data + """ + self._artists = self._data_replace_cb(self._artists, workspace) + + class MantidAxes(Axes): """ This class defines the **mantid** projection for 2d plotting. One chooses @@ -106,26 +145,28 @@ class MantidAxes(Axes): super(MantidAxes, self).__init__(*args, **kwargs) self.tracked_workspaces = dict() - def track_workspace_artist(self, name, artist, replace_handler=None): + def track_workspace_artist(self, name, artists, data_replace_cb=None): """ Add the given workspace name to the list of workspaces displayed on this Axes instance :param name: The name of the workspace. If empty then no tracking takes place - :param artists: A single artist or list/tuple of length 1 containing the data for the workspace - :param replace_handler: A function to call when the data is replaced to update + :param artists: A single artist or iterable of artists containing the data for the workspace + :param data_replace_cb: A function to call when the data is replaced to update the artist (optional) :returns: The artists variable as it was passed in. """ if name: - artist_info = self.tracked_workspaces.setdefault(name, []) - if isinstance(artist, Iterable) and not isinstance(artist, Container): - artist = artist[0] - if replace_handler is None: - def replace_handler(_, __): + if data_replace_cb is None: + def data_replace_cb(_, __): logger.warning("Updating data on this plot type is not yet supported") - artist_info.append([artist, replace_handler]) + artist_info = self.tracked_workspaces.setdefault(name, []) + if isinstance(artists, Container) or not isinstance(artists, Iterable): + artist_seq = [artists] + else: + artist_seq = artists + artist_info.append(_WorkspaceArtists(artist_seq, data_replace_cb)) - return artist + return artists def remove_workspace_artists(self, name): """ @@ -140,21 +181,9 @@ class MantidAxes(Axes): except KeyError: return False - # delete the artists from the figure - for artist, _ in artist_info: - artist.remove() - # Remove doesn't catch removing the container for errorbars etc - if isinstance(artist, Container): - try: - self.containers.remove(artist) - except ValueError: - pass - - axes_empty = self.is_empty() - if (not axes_empty) and self.legend_ is not None: - self.legend() - - return axes_empty + for workspace_artist in artist_info: + workspace_artist.remove(self) + return self.is_empty() def replace_workspace_artists(self, name, workspace): """ @@ -169,28 +198,10 @@ class MantidAxes(Axes): except KeyError: return False - for artist, handler in artist_info: - handler(artist, workspace) - if self.legend_: - self.legend() + for workspace_artist in artist_info: + workspace_artist.replace_data(workspace) return True - def _replace_tracking_info(self, name, artist_orig, artist_replaced): - """ - Replace a tracked artist for the given workspace by a new one. - :param name: The name of the workspace - :param artist_orig: A reference to the original artist - :param artist_replaced: A reference to the new artist - :raises: A KeyError if the named workspace has not been tracked already - """ - artist_info = self.tracked_workspaces[name] - artist_orig_idx = None - for index, (artist, _) in enumerate(artist_info): - if artist is artist_orig: - artist_orig_idx = index - if artist_orig_idx is not None: - artist_info[artist_orig_idx][0] = artist_replaced - def is_empty(self): """ Checks the known artist containers to see if anything exists within them @@ -224,9 +235,10 @@ class MantidAxes(Axes): if helperfunctions.validate_args(*args): logger.debug('using plotfunctions') - def _data_update(line2d, workspace): + def _data_update(artists, workspace): + # It's only possible to plot 1 line at a time from a workspace x, y, _, __ = plotfunctions._plot_impl(self, workspace, args, kwargs) - line2d.set_data(x, y) + artists[0].set_data(x, y) self.relim() self.autoscale() @@ -282,20 +294,31 @@ class MantidAxes(Axes): if helperfunctions.validate_args(*args): logger.debug('using plotfunctions') - def _data_update(container_orig, workspace): + def _data_update(artists, workspace): + # errorbar with workspaces can only return a single container + container_orig = artists[0] # It is not possible to simply reset the error bars so - # we just plot new lines + # we have to plot new lines but ensure we don't reorder them on the plot! + orig_idx = self.containers.index(container_orig) container_orig.remove() - self.containers.remove(container_orig) + # The container does not remove itself from the containers list + # but protect this just in case matplotlib starts doing this + try: + self.containers.remove(container_orig) + except ValueError: + pass + # this gets pushed back onto the containers list container_new = plotfunctions.errorbar(self, workspace, **kwargs) + self.containers.insert(orig_idx, container_new) + self.containers.pop() + # update line properties to match original orig_flat, new_flat = cbook.flatten(container_orig), cbook.flatten(container_new) for artist_orig, artist_new in zip(orig_flat, new_flat): artist_new.update_from(artist_orig) - # ax.relim does not support collections... self._update_line_limits(container_new[0]) self.autoscale() - self._replace_tracking_info(workspace.name(), container_orig, container_new) + return container_new return self.track_workspace_artist(args[0].name(), plotfunctions.errorbar(self, *args, **kwargs), @@ -403,9 +426,9 @@ class MantidAxes(Axes): if helperfunctions.validate_args(*args): logger.debug('using plotfunctions') - def _update_data(artist, workspace): - self._redraw_colorplot(plotfunctions.imshow, - artist, workspace, **kwargs) + def _update_data(artists, workspace): + return self._redraw_colorplot(plotfunctions.imshow, + artists, workspace, **kwargs) return self.track_workspace_artist(args[0].name(), plotfunctions.imshow(self, *args, **kwargs), @@ -413,22 +436,26 @@ class MantidAxes(Axes): else: return Axes.imshow(self, *args, **kwargs) - def _redraw_colorplot(self, colorfunc, artist_orig, workspace, **kwargs): + def _redraw_colorplot(self, colorfunc, artists_orig, workspace, **kwargs): """ Redraw a pcolor* or imshow type plot bsaed on a new workspace :param colorfunc: The Axes function to use to draw the new artist - :param artist_orig: A reference to the existing Artist object + :param artists_orig: A reference to an iterable of existing artists :param workspace: A reference to the workspace object :param kwargs: Any kwargs passed to the original call """ - artist_orig.remove() - if hasattr(artist_orig, 'colorbar_cid'): - artist_orig.callbacksSM.disconnect(artist_orig.colorbar_cid) - im = colorfunc(self, workspace, **kwargs) - plotfunctions.update_colorplot_datalimits(self, im) - if artist_orig.colorbar is not None: - self._attach_colorbar(im, artist_orig.colorbar) - self._replace_tracking_info(workspace.name(), artist_orig, im) + for artist_orig in artists_orig: + artist_orig.remove() + if hasattr(artist_orig, 'colorbar_cid'): + artist_orig.callbacksSM.disconnect(artist_orig.colorbar_cid) + artists_new = colorfunc(self, workspace, **kwargs) + if not isinstance(artists_new, Iterable): + artists_new = [artists_new] + plotfunctions.update_colorplot_datalimits(self, artists_new) + for artist_orig, artist_new in zip(artists_orig, artists_new): + if artist_orig.colorbar is not None: + self._attach_colorbar(artist_new, artist_orig.colorbar) + return artists_new def contour(self, *args, **kwargs): """ diff --git a/Framework/PythonInterface/mantid/plots/plotfunctions.py b/Framework/PythonInterface/mantid/plots/plotfunctions.py index b598a087641..97a17cfd300 100644 --- a/Framework/PythonInterface/mantid/plots/plotfunctions.py +++ b/Framework/PythonInterface/mantid/plots/plotfunctions.py @@ -8,6 +8,7 @@ # # from __future__ import (absolute_import, division, print_function) +import sys import numpy from skimage.transform import resize @@ -20,6 +21,10 @@ import matplotlib.colors import matplotlib.dates as mdates import matplotlib.image as mimage + +# Used for initializing searches of max, min values +_LARGEST, _SMALLEST = float(sys.maxsize), -sys.maxsize + # ================================================ # Private 2D Helper functions # ================================================ @@ -263,7 +268,7 @@ def _pcolorpieces(axes, workspace, distribution, *args, **kwargs): :param pcolortype: this keyword allows the plotting to be one of pcolormesh or pcolorfast if there is "mesh" or "fast" in the value of the keyword, or pcolor by default - Note: the return is the pcolor, pcolormesh, or pcolorfast of the last spectrum + :return: A list of the pcolor pieces created ''' (x, y, z) = get_uneven_data(workspace, distribution) mini = numpy.min([numpy.min(i) for i in z]) @@ -289,11 +294,12 @@ def _pcolorpieces(axes, workspace, distribution, *args, **kwargs): else: pcolor = axes.pcolor + pieces = [] for xi, yi, zi in zip(x, y, z): XX, YY = numpy.meshgrid(xi, yi, indexing='ij') - cm = pcolor(XX, YY, zi.reshape(-1, 1), **kwargs) + pieces.append(pcolor(XX, YY, zi.reshape(-1, 1), **kwargs)) - return cm + return pieces def pcolor(axes, workspace, *args, **kwargs): @@ -638,26 +644,31 @@ def tricontourf(axes, workspace, *args, **kwargs): return axes.tricontourf(x, y, z, *args, **kwargs) -def update_colorplot_datalimits(axes, mappable): +def update_colorplot_datalimits(axes, mappables): """ For an colorplot (imshow, pcolor*) plots update the data limits on the axes to circumvent bugs in matplotlib - :param mappable: A new mappable for this axes + :param mappables: An iterable of mappable for this axes """ # ax.relim in matplotlib < 2.2 doesn't take into account of images # and it doesn't support collections at all as of verison 3 so we'll take # over - if isinstance(mappable, mimage.AxesImage): - xmin, xmax, ymin, ymax = mappable.get_extent() - elif isinstance(mappable, mcoll.QuadMesh): - # coordinates are vertices of the grid - coords = mappable._coordinates - xmin, ymin = coords[0][0] - xmax, ymax = coords[-1][-1] - elif isinstance(mappable, mcoll.PolyCollection): - xmin, ymin = mappable._paths[0].get_extents().min - xmax, ymax = mappable._paths[-1].get_extents().max - else: - raise ValueError("Unknown mappable type '{}'".format(type(mappable))) - axes.update_datalim(((xmin, ymin), (xmax, ymax))) + xmin_all, xmax_all, ymin_all, ymax_all = _LARGEST, _SMALLEST, _LARGEST, _SMALLEST + for mappable in mappables: + if isinstance(mappable, mimage.AxesImage): + xmin, xmax, ymin, ymax = mappable.get_extent() + elif isinstance(mappable, mcoll.QuadMesh): + # coordinates are vertices of the grid + coords = mappable._coordinates + xmin, ymin = coords[0][0] + xmax, ymax = coords[-1][-1] + elif isinstance(mappable, mcoll.PolyCollection): + xmin, ymin = mappable._paths[0].get_extents().min + xmax, ymax = mappable._paths[-1].get_extents().max + else: + raise ValueError("Unknown mappable type '{}'".format(type(mappable))) + xmin_all, xmax_all = min(xmin_all, xmin), max(xmax_all, xmax) + ymin_all, ymax_all = min(ymin_all, ymin), max(ymax_all, ymax) + + axes.update_datalim(((xmin_all, ymin_all), (xmax_all, ymax_all))) axes.autoscale() -- GitLab