Unverified Commit 11e1f6c6 authored by Nick Draper's avatar Nick Draper Committed by GitHub
Browse files

Merge pull request #28683 from martyngigg/28579_sliceviewer_lineplots_alignment

Fixes for sliceviewer line plots
parents 8e361d5e eac6fe72
......@@ -39,9 +39,9 @@ class DimensionWidget(QWidget):
def __init__(self, dims_info, parent=None):
super().__init__(parent)
self.layout = QVBoxLayout(self)
self.layout.setContentsMargins(0,0,0,0)
self.layout.setSpacing(0)
layout = QVBoxLayout(self)
layout.setContentsMargins(0, 0, 0, 0)
layout.setSpacing(0)
self.dims, self.qflags = [], []
for n, dim in enumerate(dims_info):
self.qflags.append(dim['qdim'])
......@@ -55,7 +55,7 @@ class DimensionWidget(QWidget):
widget.valueChanged.connect(self.valueChanged)
if hasattr(widget, 'binningChanged'):
widget.binningChanged.connect(self.dimensionsChanged)
self.layout.addWidget(widget)
layout.addWidget(widget)
self.set_initial_states()
......@@ -165,7 +165,7 @@ class Dimension(QWidget):
self.number = number
self.layout = QHBoxLayout(self)
self.layout.setContentsMargins(0,0,0,0)
self.layout.setContentsMargins(0, 2, 0, 0)
self.name = QLabel(dim_info['name'])
self.units = QLabel(dim_info['units'])
......@@ -191,8 +191,8 @@ class Dimension(QWidget):
self.spinbox.editingFinished.connect(self.spinbox_changed)
self.layout.addWidget(self.name)
self.button_layout = QHBoxLayout(self)
self.button_layout.setContentsMargins(0,0,0,0)
self.button_layout = QHBoxLayout()
self.button_layout.setContentsMargins(0, 0, 0, 0)
self.button_layout.setSpacing(0)
self.button_layout.addWidget(self.x)
self.button_layout.addWidget(self.y)
......
......@@ -125,6 +125,28 @@ class SliceViewerModel:
return np.ma.masked_invalid(
self.get_ws_MDE(slicepoint, bin_params).getSignalArray().squeeze())
def get_dim_limits(self, slicepoint, transpose):
"""
Return a xlim, ylim) for the display dimensions where xlim, ylim are tuples
:param slicepoint: Sequence containing either a float or None where None indicates a display dimension
:param transpose: A boolean flag indicating if the display dimensions are transposed
"""
workspace = self._get_ws()
assert len(slicepoint) == workspace.getNumDims(
), "Expected len(slicepoint) to match number of workspace dimensions"
limits = []
for index, pt in enumerate(slicepoint):
if pt is None:
dimension = workspace.getDimension(index)
limits.append((dimension.getMinimum(), dimension.getMaximum()))
assert len(
limits) == 2, f"There should be exactly 2 display dimensions, found {len(limits)}"
xlim, ylim = limits
if transpose:
ylim, xlim = xlim, ylim
return xlim, ylim
def get_dim_info(self, n: int) -> dict:
"""
returns dict of (minimum, maximun, number_of_bins, width, name, units) for dimension n
......
......@@ -45,8 +45,8 @@ class SliceViewer(object):
self.view = view if view else SliceViewerView(self, self.model.get_dimensions_info(),
self.model.can_normalize_workspace(), parent)
if self.model.can_normalize_workspace():
self.view.data_view.norm_opts.currentTextChanged.connect(self.normalization_changed)
self.view.data_view.set_normalization(ws)
self.view.data_view.norm_opts.currentTextChanged.connect(self.normalization_changed)
if not self.model.can_support_peaks_overlays():
self.view.data_view.disable_peaks_button()
if not self.model.can_support_nonorthogonal_axes():
......@@ -111,13 +111,19 @@ class SliceViewer(object):
data_view.disable_nonorthogonal_axes_button()
self.new_plot()
self._peaks_view_presenter.notify(PeaksViewerPresenter.Event.OverlayPeaks)
self._call_peaks_presenter_if_created("notify", PeaksViewerPresenter.Event.OverlayPeaks)
def slicepoint_changed(self):
"""Indicates the slicepoint has been updated"""
self._peaks_view_presenter.notify(PeaksViewerPresenter.Event.SlicePointChanged)
self._call_peaks_presenter_if_created("notify",
PeaksViewerPresenter.Event.SlicePointChanged)
self.update_plot_data()
def show_all_data_requested(self):
"""Instructs the view to show all data"""
self.view.data_view.set_axes_limits(*self.model.get_dim_limits(
self.get_slicepoint(), self.view.data_view.dimensions.transpose))
def update_plot_data_MDH(self):
"""
Update the view to display an updated MDHistoWorkspace slice/cut
......@@ -191,19 +197,28 @@ class SliceViewer(object):
# cancelled
return
if names_to_overlay or names_overlayed:
self._peaks_view_presenter.overlay_peaksworkspaces(names_to_overlay)
self._create_peaks_presenter_if_necessary().overlay_peaksworkspaces(names_to_overlay)
else:
self.view.peaks_view.hide()
# private api
@property
def _peaks_view_presenter(self):
def _create_peaks_presenter_if_necessary(self):
if self._peaks_presenter is None:
self._peaks_presenter = \
PeaksViewerCollectionPresenter(self.view.peaks_view)
return self._peaks_presenter
def _call_peaks_presenter_if_created(self, attr, *args, **kwargs):
"""
Call a method on the peaks presenter if it has been created
:param attr: The attribute to call
:param *args: Positional-arguments to pass to call
:param **kwargs Keyword-arguments to pass to call
"""
if self._peaks_presenter is not None:
getattr(self._peaks_presenter, attr)(*args, **kwargs)
def _overlayed_peaks_workspaces(self):
"""
:return: A list of names of the current PeaksWorkspaces overlayed
......
......@@ -12,7 +12,9 @@ import numpy as np
class SamplingImage(mimage.AxesImage):
def __init__(self, ax, workspace,
def __init__(self,
ax,
workspace,
transpose=False,
cmap=None,
norm=None,
......@@ -24,46 +26,71 @@ class SamplingImage(mimage.AxesImage):
resample=False,
normalize=mantid.api.MDNormalization.NoNormalization,
**kwargs):
super().__init__(
ax,
cmap=cmap,
norm=norm,
interpolation=interpolation,
origin=origin,
extent=extent,
filternorm=filternorm,
filterrad=filterrad,
resample=resample,
**kwargs)
super().__init__(ax,
cmap=cmap,
norm=norm,
interpolation=interpolation,
origin=origin,
extent=extent,
filternorm=filternorm,
filterrad=filterrad,
resample=resample,
**kwargs)
self.ws = workspace
self.transpose = transpose
self.normalization = normalize
self._resize_cid, self._xlim_cid, self._ylim_cid = None, None, None
self._resample_required = True
def connect_events(self):
axes = self.axes
self._resize_cid = axes.get_figure().canvas.mpl_connect('resize_event', self._resize)
self._xlim_cid = axes.callbacks.connect('xlim_changed', self._xlim_changed)
self._ylim_cid = axes.callbacks.connect('ylim_changed', self._ylim_changed)
def disconnect_events(self):
axes = self.axes
axes.get_figure().canvas.mpl_disconnect(self._resize_cid)
axes.callbacks.disconnect(self._xlim_cid)
axes.callbacks.disconnect(self._ylim_cid)
def draw(self, renderer, *args, **kwargs):
if self._resample_required:
self._resample_image()
self._resample_required = False
super().draw(renderer, *args, **kwargs)
def remove(self):
self.disconnect_events()
super().remove()
def _xlim_changed(self, ax):
if self._update_extent():
self._resample_image()
self._resample_required = True
def _ylim_changed(self, ax):
if self._update_extent():
self._resample_image()
self._resample_required = True
def _resize(self, canvas):
self._resample_image()
self._resample_required = True
def _resample_image(self, xbins=None, ybins=None):
extent = self.get_extent()
if xbins is None or ybins is None:
bbox=self.get_window_extent().transformed(self.axes.get_figure().dpi_scale_trans.inverted())
bbox = self.get_window_extent().transformed(
self.axes.get_figure().dpi_scale_trans.inverted())
dpi = self.axes.get_figure().dpi
xbins = np.ceil(bbox.width*dpi)
ybins = np.ceil(bbox.height*dpi)
x=np.linspace(extent[0], extent[1], int(xbins))
y=np.linspace(extent[2], extent[3], int(ybins))
X,Y = np.meshgrid(x,y)
xbins = np.ceil(bbox.width * dpi)
ybins = np.ceil(bbox.height * dpi)
x = np.linspace(extent[0], extent[1], int(xbins))
y = np.linspace(extent[2], extent[3], int(ybins))
X, Y = np.meshgrid(x, y)
if self.transpose:
xy = np.vstack((Y.ravel(),X.ravel())).T
xy = np.vstack((Y.ravel(), X.ravel())).T
else:
xy = np.vstack((X.ravel(),Y.ravel())).T
xy = np.vstack((X.ravel(), Y.ravel())).T
data = self.ws.getSignalAtCoord(xy, self.normalization).reshape(X.shape)
self.set_data(data)
......@@ -80,10 +107,24 @@ class SamplingImage(mimage.AxesImage):
return False
def imshow_sampling(axes, workspace, cmap=None, norm=None, aspect=None,
interpolation=None, alpha=None, vmin=None, vmax=None,
origin=None, extent=None, shape=None, filternorm=1,
filterrad=4.0, imlim=None, resample=None, url=None, **kwargs):
def imshow_sampling(axes,
workspace,
cmap=None,
norm=None,
aspect=None,
interpolation=None,
alpha=None,
vmin=None,
vmax=None,
origin=None,
extent=None,
shape=None,
filternorm=1,
filterrad=4.0,
imlim=None,
resample=None,
url=None,
**kwargs):
"""Copy of imshow but replaced AxesImage with SamplingImage and added
callbacks and Mantid Workspace stuff.
......@@ -100,28 +141,33 @@ def imshow_sampling(axes, workspace, cmap=None, norm=None, aspect=None,
_setLabels2D(axes, workspace, transpose=transpose, xscale='linear')
if not extent:
extent = (workspace.getDimension(0).getMinimum(),
workspace.getDimension(0).getMaximum(),
workspace.getDimension(1).getMinimum(),
workspace.getDimension(1).getMaximum())
if transpose:
e1, e2, e3, e4 = extent
extent = e3, e4, e1, e2
extent = (workspace.getDimension(0).getMinimum(), workspace.getDimension(0).getMaximum(),
workspace.getDimension(1).getMinimum(), workspace.getDimension(1).getMaximum())
if transpose:
e1, e2, e3, e4 = extent
extent = e3, e4, e1, e2
# from matplotlib.axes.Axes.imshow
if norm is not None and not isinstance(norm, matplotlib.colors.Normalize):
raise ValueError(
"'norm' must be an instance of 'mcolors.Normalize'")
raise ValueError("'norm' must be an instance of 'mcolors.Normalize'")
if aspect is None:
aspect = matplotlib.rcParams['image.aspect']
axes.set_aspect(aspect)
im = SamplingImage(axes, workspace, transpose, cmap, norm, interpolation, origin, extent,
filternorm=filternorm, filterrad=filterrad,
resample=resample, **kwargs)
im.set_extent(im.get_extent())
im = SamplingImage(axes,
workspace,
transpose,
cmap,
norm,
interpolation,
origin,
extent,
filternorm=filternorm,
filterrad=filterrad,
resample=resample,
**kwargs)
im._resample_image(100, 100)
im.set_alpha(alpha)
im.set_url(url)
if im.get_clip_path() is None:
# image does not already have clipping set, clip to axes patch
im.set_clip_path(axes.patch)
......@@ -129,16 +175,15 @@ def imshow_sampling(axes, workspace, cmap=None, norm=None, aspect=None,
im.set_clim(vmin, vmax)
else:
im.autoscale_None()
im.set_url(url)
# update ax.dataLim, and, if autoscaling, set viewLim
# to tightly fit the image, regardless of dataLim.
im.set_extent(im.get_extent())
axes.add_image(im)
if extent:
axes.set_xlim(extent[0], extent[1])
axes.set_ylim(extent[2], extent[3])
axes.get_figure().canvas.mpl_connect('resize_event', im._resize)
axes.callbacks.connect('xlim_changed', im._xlim_changed)
axes.callbacks.connect('ylim_changed', im._ylim_changed)
im.connect_events()
return im
......@@ -7,13 +7,16 @@
# This file is part of the mantid workbench.
#
#
import matplotlib
import sys
import unittest
from unittest import mock
import matplotlib
matplotlib.use('Agg') # noqa: E402
import unittest
# Mock out simpleapi to import expensive import of something we don't use anyway
sys.modules['mantid.simpleapi'] = mock.MagicMock() # noqa: E402
import mantid.api
from unittest import mock
from mantidqt.widgets.sliceviewer.model import SliceViewerModel, WS_TYPE
from mantidqt.widgets.sliceviewer.presenter import SliceViewer
from mantidqt.widgets.sliceviewer.view import SliceViewerView, SliceViewerDataView
......@@ -193,6 +196,17 @@ class SliceViewerTest(unittest.TestCase):
data_view_mock.enable_lineplots_button.assert_called_once()
data_view_mock.enable_peaks_button.assert_called_once()
def test_request_to_show_all_data_sets_correct_limits_on_view(self):
presenter = SliceViewer(None, model=self.model, view=self.view)
self.model.get_dim_limits.return_value = ((-1, 1), (-2, 2))
presenter.show_all_data_requested()
data_view = self.view.data_view
self.model.get_dim_limits.assert_called_once_with([None, None, 0.5],
data_view.dimensions.transpose)
data_view.set_axes_limits.assert_called_once_with((-1, 1), (-2, 2))
@mock.patch("mantidqt.widgets.sliceviewer.presenter.SliceInfo")
def test_changing_dimensions_in_nonortho_mode_switches_to_ortho_when_dim_not_Q(
self, mock_sliceinfo_cls):
......
......@@ -27,13 +27,14 @@ class ToolItemText:
class SliceViewerNavigationToolbar(NavigationToolbar2QT):
gridClicked = Signal(bool)
homeClicked = Signal()
linePlotsClicked = Signal(bool)
nonOrthogonalClicked = Signal(bool)
peaksOverlayClicked = Signal(bool)
plotOptionsChanged = Signal()
toolitems = (
(ToolItemText.HOME, 'Reset original view', 'mdi.home', 'home', None),
(ToolItemText.HOME, 'Reset original view', 'mdi.home', 'homeClicked', None),
(ToolItemText.PAN, 'Pan axes with left mouse, zoom with right', 'mdi.arrow-all', 'pan',
False),
(ToolItemText.ZOOM, 'Zoom to rectangle', 'mdi.magnify', 'zoom', False),
......
......@@ -11,6 +11,7 @@ import mantid.api
from mantid.plots.axesfunctions import _pcolormesh_nonortho as pcolormesh_nonorthogonal
from mantid.plots.datafunctions import get_normalize_by_bin_width
from matplotlib import gridspec
from matplotlib.artist import setp as set_artist_property
from matplotlib.figure import Figure
from matplotlib.transforms import Bbox, BboxTransform
from mpl_toolkits.axisartist import Subplot as CurveLinearSubPlot
......@@ -45,11 +46,9 @@ class SliceViewerDataView(QWidget):
self.nonortho_tr = None
# Dimension widget
self.dimensions_layout = QHBoxLayout()
self.dimensions = DimensionWidget(dims_info, parent=self)
self.dimensions.dimensionsChanged.connect(self.presenter.dimensions_changed)
self.dimensions.valueChanged.connect(self.presenter.slicepoint_changed)
self.dimensions_layout.addWidget(self.dimensions)
self.colorbar_layout = QVBoxLayout()
self.colorbar_layout.setContentsMargins(0,0,0,0)
......@@ -69,39 +68,42 @@ class SliceViewerDataView(QWidget):
self.colorbar_layout.addLayout(self.norm_layout)
# MPL figure + colorbar
self.mpl_layout = QHBoxLayout()
self.mpl_layout.setContentsMargins(0,0,0,0)
self.mpl_layout.setSpacing(0)
mpl_layout = QHBoxLayout()
mpl_layout.setContentsMargins(0,0,0,0)
mpl_layout.setSpacing(0)
self.fig = Figure()
self.fig.set_tight_layout(True)
self.ax = None
self.axx, self.axy = None, None
self.image = None
self._grid_on = False
self.fig.set_facecolor(self.palette().window().color().getRgbF())
self.canvas = FigureCanvas(self.fig)
self.canvas.mpl_connect('motion_notify_event', self.mouse_move)
self.canvas.mpl_connect('axes_leave_event', self.mouse_outside_image)
self.create_axes_orthogonal()
self.mpl_layout.addWidget(self.canvas)
mpl_layout.addWidget(self.canvas)
self.colorbar_label = QLabel("Colormap")
self.colorbar_layout.addWidget(self.colorbar_label)
self.colorbar = ColorbarWidget(self)
self.colorbar_layout.addWidget(self.colorbar)
self.colorbar.colorbarChanged.connect(self.update_data_clim)
self.colorbar.colorbarChanged.connect(self.update_line_plot_limits)
self.mpl_layout.addLayout(self.colorbar_layout)
mpl_layout.addLayout(self.colorbar_layout)
# MPL toolbar
self.mpl_toolbar = SliceViewerNavigationToolbar(self.canvas, self)
self.mpl_toolbar.gridClicked.connect(self.toggle_grid)
self.mpl_toolbar.linePlotsClicked.connect(self.on_line_plots_toggle)
self.mpl_toolbar.homeClicked.connect(self.on_home_clicked)
self.mpl_toolbar.plotOptionsChanged.connect(self.colorbar.mappable_changed)
self.mpl_toolbar.nonOrthogonalClicked.connect(self.on_non_orthogonal_axes_toggle)
# layout
self.layout = QGridLayout(self)
self.layout.setSpacing(1)
self.layout.addLayout(self.dimensions_layout, 0, 0)
self.layout.addWidget(self.mpl_toolbar, 1, 0)
self.layout.addLayout(self.mpl_layout, 2, 0)
layout = QGridLayout(self)
layout.setSpacing(1)
layout.addWidget(self.dimensions, 0, 0)
layout.addWidget(self.mpl_toolbar, 1, 0)
layout.addLayout(mpl_layout, 2, 0)
@property
def grid_on(self):
......@@ -119,6 +121,7 @@ class SliceViewerDataView(QWidget):
self.ax.grid(self.grid_on)
if self.line_plots:
self.add_line_plots()
self.plot_MDH = self.plot_MDH_orthogonal
self.canvas.draw_idle()
......@@ -156,13 +159,13 @@ class SliceViewerDataView(QWidget):
wspace=0.0,
hspace=0.0)
image_axes.set_position(gs[1].get_position(self.fig))
image_axes.xaxis.set_visible(False)
image_axes.yaxis.set_visible(False)
set_artist_property(image_axes.get_xticklabels(), visible=False)
set_artist_property(image_axes.get_yticklabels(), visible=False)
self.axx = self.fig.add_subplot(gs[3], sharex=image_axes)
self.axx.yaxis.tick_right()
self.axy = self.fig.add_subplot(gs[0], sharey=image_axes)
self.axy.xaxis.tick_top()
self.update_line_plot_labels()
self.mpl_toolbar.update() # sync list of axes in navstack
self.canvas.draw_idle()
......@@ -175,7 +178,7 @@ class SliceViewerDataView(QWidget):
if image_axes is None:
return
self.clear_line_plots()
self.delete_line_plot_lines()
all_axes = self.fig.axes
# The order is defined by the order of the add_subplot calls so we always want to remove
# the last two Axes. Do it backwards to cope with the container size change
......@@ -184,8 +187,8 @@ class SliceViewerDataView(QWidget):
gs = gridspec.GridSpec(1, 1)
image_axes.set_position(gs[0].get_position(self.fig))
image_axes.xaxis.set_visible(True)
image_axes.yaxis.set_visible(True)
set_artist_property(image_axes.get_xticklabels(), visible=True)
set_artist_property(image_axes.get_yticklabels(), visible=True)
self.axx, self.axy = None, None
self.mpl_toolbar.update() # sync list of axes in navstack
......@@ -202,6 +205,10 @@ class SliceViewerDataView(QWidget):
transpose=self.dimensions.transpose,
norm=self.colorbar.get_norm(),
**kwargs)
extent = self.image.get_extent()
self.ax.set_xlim(extent[0], extent[1])
self.ax.set_ylim(extent[2], extent[3])
self.draw_plot()
def plot_MDH_nonorthogonal(self, ws, **kwargs):
......@@ -219,8 +226,16 @@ class SliceViewerDataView(QWidget):
def plot_matrix(self, ws, **kwargs):
"""
clears the plot and creates a new one using a MatrixWorkspace
clears the plot and creates a new one using a MatrixWorkspace keeping
the axes limits that have already been set
"""
old_extent = None
if self.image is not None:
old_extent = self.image.get_extent()
if self.image.transpose != self.dimensions.transpose:
e1, e2, e3, e4 = old_extent
old_extent = e3, e4, e1, e2
self.clear_image()
self.image = imshow_sampling(self.ax,
ws,
......@@ -229,29 +244,41 @@ class SliceViewerDataView(QWidget):
interpolation='none',
transpose=self.dimensions.transpose,
norm=self.colorbar.get_norm(),
extent=old_extent,
**kwargs)
self.image._resample_image()
self.draw_plot()
def clear_image(self):
"""Removes any image from the axes"""
if self.image is not None:
if self.line_plots:
self.delete_line_plot_lines()
self.image.remove()
self.image = None
def clear_figure(self):
"""Removes everything from the figure"""
if self.line_plots:
self.delete_line_plot_lines()
self.axx, self.axy = None, None
self.image = None
self.fig.clf()
self.ax = None
def draw_plot(self):
self.ax.set_title('')
self.colorbar.set_mappable(self.image)
self.colorbar.update_clim()
self.mpl_toolbar.update() # clear nav stack
self.clear_line_plots()
self.delete_line_plot_lines()
self.update_line_plot_labels()
self.canvas.draw_idle()
def on_home_clicked(self):
"""Reset the view to encompass all of the data"""
self.presenter.show_all_data_requested()
def update_plot_data(self, data):
"""
This just updates the plot data without creating a new plot
......@@ -308,22 +335,28 @@ class SliceViewerDataView(QWidget):