Skip to content
Snippets Groups Projects
Commit b734114f authored by Gigg, Martyn Anthony's avatar Gigg, Martyn Anthony
Browse files

Extract workspace tracking code into class

Some plot functions return multiple artists and
tracking these via a single object is much easier.
Refs #24000
parent 459e072c
No related branches found
No related tags found
No related merge requests found
......@@ -50,7 +50,7 @@ def plot_decorator(func):
def wrapper(self, *args, **kwargs):
func_value = func(self, *args, **kwargs)
# Saves saving it on array objects
if mantid.plots.helperfunctions.validate_args(*args, **kwargs):
if helperfunctions.validate_args(*args, **kwargs):
# Fill out kwargs with the values of args
for index, arg in enumerate(args):
if index is 0:
......@@ -74,6 +74,45 @@ def plot_decorator(func):
return wrapper
class _WorkspaceArtists(object):
"""Captures information regarding an artist that has been plotted
from a workspace. It allows for removal and replacement of said artists
"""
def __init__(self, artists, data_replace_cb):
"""
Initialize an instance
:param artists: A reference to a list of artists "attached" to a workspace
:param data_replace_cb: A reference to a callable with signature (artists, workspace) -> new_artists
"""
self._artists = artists
self._data_replace_cb = data_replace_cb
def remove(self, axes):
"""
Remove the tracked artists from the given axes
:param axes: A reference to the axes instance the artists are attached to
"""
# delete the artists from the axes
for artist in self._artists:
artist.remove()
# Remove doesn't catch removing the container for errorbars etc
if isinstance(artist, Container):
try:
axes.containers.remove(artist)
except ValueError:
pass
if (not axes.is_empty()) and axes.legend_ is not None:
axes.legend()
def replace_data(self, workspace):
"""Replace or replot artists based on a new workspace
:param workspace: The new workspace containing the data
"""
self._artists = self._data_replace_cb(self._artists, workspace)
class MantidAxes(Axes):
"""
This class defines the **mantid** projection for 2d plotting. One chooses
......@@ -106,26 +145,28 @@ class MantidAxes(Axes):
super(MantidAxes, self).__init__(*args, **kwargs)
self.tracked_workspaces = dict()
def track_workspace_artist(self, name, artist, replace_handler=None):
def track_workspace_artist(self, name, artists, data_replace_cb=None):
"""
Add the given workspace name to the list of workspaces
displayed on this Axes instance
:param name: The name of the workspace. If empty then no tracking takes place
:param artists: A single artist or list/tuple of length 1 containing the data for the workspace
:param replace_handler: A function to call when the data is replaced to update
:param artists: A single artist or iterable of artists containing the data for the workspace
:param data_replace_cb: A function to call when the data is replaced to update
the artist (optional)
:returns: The artists variable as it was passed in.
"""
if name:
artist_info = self.tracked_workspaces.setdefault(name, [])
if isinstance(artist, Iterable) and not isinstance(artist, Container):
artist = artist[0]
if replace_handler is None:
def replace_handler(_, __):
if data_replace_cb is None:
def data_replace_cb(_, __):
logger.warning("Updating data on this plot type is not yet supported")
artist_info.append([artist, replace_handler])
artist_info = self.tracked_workspaces.setdefault(name, [])
if isinstance(artists, Container) or not isinstance(artists, Iterable):
artist_seq = [artists]
else:
artist_seq = artists
artist_info.append(_WorkspaceArtists(artist_seq, data_replace_cb))
return artist
return artists
def remove_workspace_artists(self, name):
"""
......@@ -140,21 +181,9 @@ class MantidAxes(Axes):
except KeyError:
return False
# delete the artists from the figure
for artist, _ in artist_info:
artist.remove()
# Remove doesn't catch removing the container for errorbars etc
if isinstance(artist, Container):
try:
self.containers.remove(artist)
except ValueError:
pass
axes_empty = self.is_empty()
if (not axes_empty) and self.legend_ is not None:
self.legend()
return axes_empty
for workspace_artist in artist_info:
workspace_artist.remove(self)
return self.is_empty()
def replace_workspace_artists(self, name, workspace):
"""
......@@ -169,28 +198,10 @@ class MantidAxes(Axes):
except KeyError:
return False
for artist, handler in artist_info:
handler(artist, workspace)
if self.legend_:
self.legend()
for workspace_artist in artist_info:
workspace_artist.replace_data(workspace)
return True
def _replace_tracking_info(self, name, artist_orig, artist_replaced):
"""
Replace a tracked artist for the given workspace by a new one.
:param name: The name of the workspace
:param artist_orig: A reference to the original artist
:param artist_replaced: A reference to the new artist
:raises: A KeyError if the named workspace has not been tracked already
"""
artist_info = self.tracked_workspaces[name]
artist_orig_idx = None
for index, (artist, _) in enumerate(artist_info):
if artist is artist_orig:
artist_orig_idx = index
if artist_orig_idx is not None:
artist_info[artist_orig_idx][0] = artist_replaced
def is_empty(self):
"""
Checks the known artist containers to see if anything exists within them
......@@ -224,9 +235,10 @@ class MantidAxes(Axes):
if helperfunctions.validate_args(*args):
logger.debug('using plotfunctions')
def _data_update(line2d, workspace):
def _data_update(artists, workspace):
# It's only possible to plot 1 line at a time from a workspace
x, y, _, __ = plotfunctions._plot_impl(self, workspace, args, kwargs)
line2d.set_data(x, y)
artists[0].set_data(x, y)
self.relim()
self.autoscale()
......@@ -282,20 +294,31 @@ class MantidAxes(Axes):
if helperfunctions.validate_args(*args):
logger.debug('using plotfunctions')
def _data_update(container_orig, workspace):
def _data_update(artists, workspace):
# errorbar with workspaces can only return a single container
container_orig = artists[0]
# It is not possible to simply reset the error bars so
# we just plot new lines
# we have to plot new lines but ensure we don't reorder them on the plot!
orig_idx = self.containers.index(container_orig)
container_orig.remove()
self.containers.remove(container_orig)
# The container does not remove itself from the containers list
# but protect this just in case matplotlib starts doing this
try:
self.containers.remove(container_orig)
except ValueError:
pass
# this gets pushed back onto the containers list
container_new = plotfunctions.errorbar(self, workspace, **kwargs)
self.containers.insert(orig_idx, container_new)
self.containers.pop()
# update line properties to match original
orig_flat, new_flat = cbook.flatten(container_orig), cbook.flatten(container_new)
for artist_orig, artist_new in zip(orig_flat, new_flat):
artist_new.update_from(artist_orig)
# ax.relim does not support collections...
self._update_line_limits(container_new[0])
self.autoscale()
self._replace_tracking_info(workspace.name(), container_orig, container_new)
return container_new
return self.track_workspace_artist(args[0].name(),
plotfunctions.errorbar(self, *args, **kwargs),
......@@ -403,9 +426,9 @@ class MantidAxes(Axes):
if helperfunctions.validate_args(*args):
logger.debug('using plotfunctions')
def _update_data(artist, workspace):
self._redraw_colorplot(plotfunctions.imshow,
artist, workspace, **kwargs)
def _update_data(artists, workspace):
return self._redraw_colorplot(plotfunctions.imshow,
artists, workspace, **kwargs)
return self.track_workspace_artist(args[0].name(),
plotfunctions.imshow(self, *args, **kwargs),
......@@ -413,22 +436,26 @@ class MantidAxes(Axes):
else:
return Axes.imshow(self, *args, **kwargs)
def _redraw_colorplot(self, colorfunc, artist_orig, workspace, **kwargs):
def _redraw_colorplot(self, colorfunc, artists_orig, workspace, **kwargs):
"""
Redraw a pcolor* or imshow type plot bsaed on a new workspace
:param colorfunc: The Axes function to use to draw the new artist
:param artist_orig: A reference to the existing Artist object
:param artists_orig: A reference to an iterable of existing artists
:param workspace: A reference to the workspace object
:param kwargs: Any kwargs passed to the original call
"""
artist_orig.remove()
if hasattr(artist_orig, 'colorbar_cid'):
artist_orig.callbacksSM.disconnect(artist_orig.colorbar_cid)
im = colorfunc(self, workspace, **kwargs)
plotfunctions.update_colorplot_datalimits(self, im)
if artist_orig.colorbar is not None:
self._attach_colorbar(im, artist_orig.colorbar)
self._replace_tracking_info(workspace.name(), artist_orig, im)
for artist_orig in artists_orig:
artist_orig.remove()
if hasattr(artist_orig, 'colorbar_cid'):
artist_orig.callbacksSM.disconnect(artist_orig.colorbar_cid)
artists_new = colorfunc(self, workspace, **kwargs)
if not isinstance(artists_new, Iterable):
artists_new = [artists_new]
plotfunctions.update_colorplot_datalimits(self, artists_new)
for artist_orig, artist_new in zip(artists_orig, artists_new):
if artist_orig.colorbar is not None:
self._attach_colorbar(artist_new, artist_orig.colorbar)
return artists_new
def contour(self, *args, **kwargs):
"""
......
......@@ -8,6 +8,7 @@
#
#
from __future__ import (absolute_import, division, print_function)
import sys
import numpy
from skimage.transform import resize
......@@ -20,6 +21,10 @@ import matplotlib.colors
import matplotlib.dates as mdates
import matplotlib.image as mimage
# Used for initializing searches of max, min values
_LARGEST, _SMALLEST = float(sys.maxsize), -sys.maxsize
# ================================================
# Private 2D Helper functions
# ================================================
......@@ -263,7 +268,7 @@ def _pcolorpieces(axes, workspace, distribution, *args, **kwargs):
:param pcolortype: this keyword allows the plotting to be one of pcolormesh or
pcolorfast if there is "mesh" or "fast" in the value of the keyword, or
pcolor by default
Note: the return is the pcolor, pcolormesh, or pcolorfast of the last spectrum
:return: A list of the pcolor pieces created
'''
(x, y, z) = get_uneven_data(workspace, distribution)
mini = numpy.min([numpy.min(i) for i in z])
......@@ -289,11 +294,12 @@ def _pcolorpieces(axes, workspace, distribution, *args, **kwargs):
else:
pcolor = axes.pcolor
pieces = []
for xi, yi, zi in zip(x, y, z):
XX, YY = numpy.meshgrid(xi, yi, indexing='ij')
cm = pcolor(XX, YY, zi.reshape(-1, 1), **kwargs)
pieces.append(pcolor(XX, YY, zi.reshape(-1, 1), **kwargs))
return cm
return pieces
def pcolor(axes, workspace, *args, **kwargs):
......@@ -638,26 +644,31 @@ def tricontourf(axes, workspace, *args, **kwargs):
return axes.tricontourf(x, y, z, *args, **kwargs)
def update_colorplot_datalimits(axes, mappable):
def update_colorplot_datalimits(axes, mappables):
"""
For an colorplot (imshow, pcolor*) plots update the data limits on the axes
to circumvent bugs in matplotlib
:param mappable: A new mappable for this axes
:param mappables: An iterable of mappable for this axes
"""
# ax.relim in matplotlib < 2.2 doesn't take into account of images
# and it doesn't support collections at all as of verison 3 so we'll take
# over
if isinstance(mappable, mimage.AxesImage):
xmin, xmax, ymin, ymax = mappable.get_extent()
elif isinstance(mappable, mcoll.QuadMesh):
# coordinates are vertices of the grid
coords = mappable._coordinates
xmin, ymin = coords[0][0]
xmax, ymax = coords[-1][-1]
elif isinstance(mappable, mcoll.PolyCollection):
xmin, ymin = mappable._paths[0].get_extents().min
xmax, ymax = mappable._paths[-1].get_extents().max
else:
raise ValueError("Unknown mappable type '{}'".format(type(mappable)))
axes.update_datalim(((xmin, ymin), (xmax, ymax)))
xmin_all, xmax_all, ymin_all, ymax_all = _LARGEST, _SMALLEST, _LARGEST, _SMALLEST
for mappable in mappables:
if isinstance(mappable, mimage.AxesImage):
xmin, xmax, ymin, ymax = mappable.get_extent()
elif isinstance(mappable, mcoll.QuadMesh):
# coordinates are vertices of the grid
coords = mappable._coordinates
xmin, ymin = coords[0][0]
xmax, ymax = coords[-1][-1]
elif isinstance(mappable, mcoll.PolyCollection):
xmin, ymin = mappable._paths[0].get_extents().min
xmax, ymax = mappable._paths[-1].get_extents().max
else:
raise ValueError("Unknown mappable type '{}'".format(type(mappable)))
xmin_all, xmax_all = min(xmin_all, xmin), max(xmax_all, xmax)
ymin_all, ymax_all = min(ymin_all, ymin), max(ymax_all, ymax)
axes.update_datalim(((xmin_all, ymin_all), (xmax_all, ymax_all)))
axes.autoscale()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment