Commit a0fe7bd4 authored by Joseph Torsney's avatar Joseph Torsney Committed by Zhou, Wenduo
Browse files

Sliceviewer fix BinMD parameters when transposing re conflict.

parent 76370556
......@@ -10,10 +10,10 @@
from qtpy.QtWidgets import (QWidget, QHBoxLayout, QVBoxLayout, QLabel, QPushButton, QSlider,
QDoubleSpinBox, QSpinBox)
from qtpy.QtCore import Qt, Signal
from enum import Enum
from enum import IntEnum
class State(Enum):
class State(IntEnum):
X = 0
Y = 1
NONE = 2
......@@ -64,6 +64,11 @@ class DimensionWidget(QWidget):
self.transpose = False
# Store the current and previous valid dimensions state, i.e. after self.change_dims has executed.
# Initial values for both is the current state
self.previous_dimensions_state = self._get_states()
self.dimensions_state = self._get_states()
def change_dims(self, number):
states = [d.get_state() for n, d in enumerate(self.dims)]
......@@ -84,6 +89,10 @@ class DimensionWidget(QWidget):
if n != number and d.get_state() == State.Y:
d.set_state(State.NONE)
# Store the previous dimensions state and reset the current states.
self.previous_dimensions_state = self.dimensions_state
self.dimensions_state = self._get_states()
self.check_transpose()
self.dimensionsChanged.emit()
......@@ -141,6 +150,23 @@ class DimensionWidget(QWidget):
if value is not None:
self.dims[index].set_value(value)
def get_states(self):
return self._get_axis_indices_from_states(self.dimensions_state)
def get_previous_states(self):
return self._get_axis_indices_from_states(self.previous_dimensions_state)
def _get_states(self):
return [d.get_state() for d in self.dims]
@staticmethod
def _get_axis_indices_from_states(states):
"""
:return: a list where the value (0, 1, None) at index i
represents the axis that dimension i is to be displayed on.
"""
return list(map(lambda i: i if i <= 1 else None, [int(state) for state in states]))
class Dimension(QWidget):
stateChanged = Signal(int)
......
......@@ -145,7 +145,8 @@ class SliceViewerModel:
def get_ws_MDE(self,
slicepoint: Sequence[Optional[float]],
bin_params: Optional[Sequence[float]],
limits: Optional[tuple] = None):
limits: Optional[tuple] = None,
dimension_indices: Optional[tuple] = None):
"""
:param slicepoint: ND sequence of either None or float. A float defines the point
in that dimension for the slice.
......@@ -155,9 +156,10 @@ class SliceViewerModel:
not provided the full extent of each dimension is used
"""
workspace = self._get_ws()
params, _, __ = _roi_binmd_parameters(workspace, slicepoint, bin_params, limits)
params, _, __ = _roi_binmd_parameters(workspace, slicepoint, bin_params, limits, dimension_indices)
params['EnableLogging'] = LOG_GET_WS_MDE_ALGORITHM_CALLS
return BinMD(InputWorkspace=workspace, OutputWorkspace=self._rebinned_name, **params)
binned = BinMD(InputWorkspace=workspace, OutputWorkspace=self._rebinned_name, **params)
return binned
def get_data_MDH(self, slicepoint, transpose=False):
indices, _ = get_indices(self.get_ws(), slicepoint=slicepoint)
......@@ -166,7 +168,7 @@ class SliceViewerModel:
else:
return np.ma.masked_invalid(self.get_ws().getSignalArray()[indices])
def get_data_MDE(self, slicepoint, bin_params, limits=None, transpose=False):
def get_data_MDE(self, slicepoint, bin_params, dimension_indices, limits=None, transpose=False):
"""
:param slicepoint: ND sequence of either None or float. A float defines the point
in that dimension for the slice.
......@@ -178,10 +180,10 @@ class SliceViewerModel:
"""
if transpose:
return np.ma.masked_invalid(
self.get_ws_MDE(slicepoint, bin_params, limits).getSignalArray().squeeze()).T
self.get_ws_MDE(slicepoint, bin_params, limits, dimension_indices).getSignalArray().squeeze()).T
else:
return np.ma.masked_invalid(
self.get_ws_MDE(slicepoint, bin_params, limits).getSignalArray().squeeze())
self.get_ws_MDE(slicepoint, bin_params, limits, dimension_indices).getSignalArray().squeeze())
def get_dim_limits(self, slicepoint, transpose):
"""
......@@ -527,7 +529,8 @@ class SliceViewerModel:
# private functions
def _roi_binmd_parameters(workspace, slicepoint: Sequence[Optional[float]],
bin_params: Optional[Sequence[float]],
limits: tuple) -> Tuple[dict, int, int]:
limits: tuple,
dimension_indices: tuple) -> Tuple[dict, int, int]:
"""
Return a sequence of 2-tuples defining the limits for MDEventWorkspace binning
:param workspace: MDEventWorkspace that is to be binned
......@@ -539,7 +542,7 @@ def _roi_binmd_parameters(workspace, slicepoint: Sequence[Optional[float]],
:return: 3-tuple (binmd parameters, index of X dimension, index of Y dimension)
"""
xindex, yindex = _display_indices(slicepoint)
dim_limits = _dimension_limits(workspace, slicepoint, limits)
dim_limits = _dimension_limits(workspace, dimension_indices, limits)
ndims = workspace.getNumDims()
ws_basis = np.eye(ndims)
output_extents, output_bins = [], []
......@@ -567,7 +570,7 @@ def _roi_binmd_parameters(workspace, slicepoint: Sequence[Optional[float]],
def _dimension_limits(workspace,
slicepoint: Sequence[Optional[float]],
dimension_indices: Optional[tuple],
limits: Optional[Sequence[tuple]] = None) -> Sequence[tuple]:
"""
Return a sequence of 2-tuples defining the limits for MDEventWorkspace binning
......@@ -579,10 +582,11 @@ def _dimension_limits(workspace,
"""
dim_limits = [(dim.getMinimum(), dim.getMaximum())
for dim in [workspace.getDimension(i) for i in range(workspace.getNumDims())]]
xindex, yindex = _display_indices(slicepoint)
if limits is not None:
dim_limits[xindex] = limits[0]
dim_limits[yindex] = limits[1]
# Match the view limits to the dimension they're for.
for dim, axis in enumerate(dimension_indices):
if axis is not None:
dim_limits[dim] = limits[axis]
return dim_limits
......
......@@ -72,7 +72,7 @@ class SliceViewer(ObservingPresenter):
self.ads_observer = SliceViewerADSObserver(self.replace_workspace, self.rename_workspace,
self.ADS_cleared, self.delete_workspace)
def new_plot_MDH(self):
def new_plot_MDH(self, dimensions_transposing=False, dimensions_changing=False):
"""
Tell the view to display a new plot of an MDHistoWorkspace
"""
......@@ -83,9 +83,9 @@ class SliceViewer(ObservingPresenter):
data_view.plot_MDH(self.model.get_ws(), slicepoint=self.get_slicepoint())
self._call_peaks_presenter_if_created("notify", PeaksViewerPresenter.Event.OverlayPeaks)
else:
self.new_plot_MDE()
self.new_plot_MDE(dimensions_transposing, dimensions_changing)
def new_plot_MDE(self):
def new_plot_MDE(self, dimensions_transposing=False, dimensions_changing=False):
"""
Tell the view to display a new plot of an MDEventWorkspace
"""
......@@ -93,24 +93,34 @@ class SliceViewer(ObservingPresenter):
limits = data_view.get_axes_limits()
if limits is not None:
xlim, ylim = limits
# view limits are in orthogonal frame. transform to nonorthogonal
# model frame
if data_view.nonorthogonal_mode:
xlim, ylim = limits
inv_tr = data_view.nonortho_transform.inv_tr
# viewing axis y not aligned with plot axis
xmin_p, ymax_p = inv_tr(xlim[0], ylim[1])
xmax_p, ymin_p = inv_tr(xlim[1], ylim[0])
xlim, ylim = (xmin_p, xmax_p), (ymin_p, ymax_p)
if data_view.dimensions.transpose:
limits = ylim, xlim
else:
limits = xlim, ylim
limits = [xlim, ylim]
# The value at the i'th index of this tells us that the axis with that value (0 or 1) will display dimension i
dimension_indices = self.view.dimensions.get_states()
if dimensions_transposing:
# Since the dimensions are transposing, the limits we have from the view are the wrong way around
# with respect to the axes the dimensions are about to be displayed, so get the previous dimension states.
dimension_indices = self.view.dimensions.get_previous_states()
elif dimensions_changing:
# If we are changing which dimensions are to be displayed, the limits we got from the view are stale
# as they refer to the previous two dimensions that were displayed.
limits = None
data_view.plot_MDH(
self.model.get_ws_MDE(slicepoint=self.get_slicepoint(),
bin_params=data_view.dimensions.get_bin_params(),
limits=limits))
limits=limits,
dimension_indices=dimension_indices))
self._call_peaks_presenter_if_created("notify", PeaksViewerPresenter.Event.OverlayPeaks)
def new_plot_matrix(self):
......@@ -123,8 +133,7 @@ class SliceViewer(ObservingPresenter):
"""
self.view.data_view.update_plot_data(
self.model.get_data(self.get_slicepoint(),
transpose=self.view.data_view.dimensions.transpose),
self.view.data_view.dimensions.transpose)
transpose=self.view.data_view.dimensions.transpose))
def update_plot_data_MDE(self):
"""
......@@ -134,9 +143,9 @@ class SliceViewer(ObservingPresenter):
data_view.update_plot_data(
self.model.get_data(self.get_slicepoint(),
bin_params=data_view.dimensions.get_bin_params(),
dimension_indices=data_view.dimensions.get_states(),
limits=data_view.get_axes_limits(),
transpose=self.view.data_view.dimensions.transpose),
self.view.data_view.dimensions.transpose)
transpose=self.view.data_view.dimensions.transpose))
def update_plot_data_matrix(self):
# should never be called, since this workspace type is only 2D the plot dimensions never change
......@@ -181,7 +190,15 @@ class SliceViewer(ObservingPresenter):
else:
data_view.disable_tool_button(ToolItemText.NONORTHOGONAL_AXES)
self.new_plot()
ws_type = self.model.get_ws_type()
if ws_type == WS_TYPE.MDH or ws_type == WS_TYPE.MDE:
if sliceinfo.slicepoint[data_view.dimensions.get_previous_states().index(None)] is None:
# The dimension of the slicepoint has changed
self.new_plot(dimensions_changing=True)
else:
self.new_plot(dimensions_transposing=True)
else:
self.new_plot()
def slicepoint_changed(self):
"""Indicates the slicepoint has been updated"""
......
......@@ -355,7 +355,7 @@ class SliceViewerDataView(QWidget):
"""
self.presenter.export_region(limits, cut)
def update_plot_data(self, data, transposed=False):
def update_plot_data(self, data):
"""
This just updates the plot data without creating a new plot. The extents
can change if the data has been rebinned.
......@@ -363,14 +363,7 @@ class SliceViewerDataView(QWidget):
if self.nonortho_transform:
self.image.set_array(data.T.ravel())
else:
# need to update extent and limits of orthog axes when transposed (non orthog limits reset anyway)
extent = self.image.get_extent()
self.image.set_data(data.T)
if transposed:
extent = (extent[2], extent[3], extent[0], extent[1])
self.image.set_extent(extent)
self.ax.set_xlim((extent[0], extent[1]))
self.ax.set_ylim((extent[2], extent[3]))
self.colorbar.update_clim()
def track_cursor_checked(self):
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment