Skip to content
Snippets Groups Projects
Commit 62c3691a authored by WHITFIELDRE email's avatar WHITFIELDRE email
Browse files

Add support to SliceViewer for MatrixWorkspace

Closes #25599
parent c79a3f58
No related branches found
No related tags found
No related merge requests found
......@@ -9,17 +9,30 @@
#
from __future__ import (absolute_import, division, print_function)
from mantid.plots.helperfunctions import get_indices
from mantid.api import MatrixWorkspace, MultipleExperimentInfos
import numpy as np
from mantid.py3compat.enum import Enum
class WS_TYPE(Enum):
MDE = 0
MDH = 1
MATRIX = 2
class SliceViewerModel(object):
"""Store the workspace to be plotted. Can be MatrixWorkspace, MDEventWorkspace or MDHistoWorkspace"""
def __init__(self, ws):
if len(ws.getNonIntegratedDimensions()) < 2:
raise ValueError("workspace must have at least 2 non-integrated dimensions")
if not ws.isMDHistoWorkspace():
raise ValueError("currenly only works for MDHistoWorkspace")
if isinstance(ws, MatrixWorkspace):
if ws.getNumberHistograms() < 2:
raise ValueError("workspace must contain at least 2 spectrum")
if ws.blocksize() < 2:
raise ValueError("workspace must contain at least 2 bin")
elif ws.isMDHistoWorkspace():
if len(ws.getNonIntegratedDimensions()) < 2:
raise ValueError("workspace must have at least 2 non-integrated dimensions")
else:
raise ValueError("currenly only works for MatrixWorkspace and MDHistoWorkspace")
self._ws = ws
......@@ -50,3 +63,14 @@ class SliceViewerModel(object):
returns a list of dict for each dimension conainting dim_info
"""
return [self.get_dim_info(n) for n in range(self.get_ws().getNumDims())]
def get_ws_type(self):
if isinstance(self.get_ws(), MatrixWorkspace):
return WS_TYPE.MATRIX
elif isinstance(self.get_ws(), MultipleExperimentInfos):
if self.get_ws().isMDHistoWorkspace():
return WS_TYPE.MDH
else:
return WS_TYPE.MDE
else:
raise ValueError("Unsupported workspace type")
......@@ -8,7 +8,7 @@
#
#
from __future__ import (absolute_import, division, print_function)
from .model import SliceViewerModel
from .model import SliceViewerModel, WS_TYPE
from .view import SliceViewerView
......@@ -16,12 +16,21 @@ class SliceViewer(object):
def __init__(self, ws, parent=None, model=None, view=None):
# Create model and view, or accept mocked versions
self.model = model if model else SliceViewerModel(ws)
if self.model.get_ws_type() == WS_TYPE.MDH:
self.new_plot = self.new_plot_MDH
else:
self.new_plot = self.new_plot_matrix
self.view = view if view else SliceViewerView(self, self.model.get_dimensions_info(), parent)
self.new_plot()
def new_plot(self):
self.view.plot(self.model.get_ws(), slicepoint=self.view.dimensions.get_slicepoint())
def new_plot_MDH(self):
self.view.plot_MDH(self.model.get_ws(), slicepoint=self.view.dimensions.get_slicepoint())
def new_plot_matrix(self):
self.view.plot_matrix(self.model.get_ws())
def update_plot_data(self):
self.view.update_plot_data(self.model.get_data(self.view.dimensions.get_slicepoint(), self.view.dimensions.transpose))
......@@ -9,8 +9,8 @@
#
from __future__ import (absolute_import, division, print_function)
from mantid.simpleapi import CreateMDHistoWorkspace
from mantidqt.widgets.sliceviewer.model import SliceViewerModel
from mantid.simpleapi import CreateMDHistoWorkspace, CreateWorkspace
from mantidqt.widgets.sliceviewer.model import SliceViewerModel, WS_TYPE
from numpy.testing import assert_equal
import numpy as np
import unittest
......@@ -28,12 +28,22 @@ class SliceViewerModelTest(unittest.TestCase):
Names='Dim1,Dim2,Dim3',
Units='MomentumTransfer,EnergyTransfer,Angstrom',
OutputWorkspace='ws_MD_2d')
self.ws2d_histo = CreateWorkspace(DataX=[10, 20, 30, 10, 20, 30],
DataY=[2, 3, 4, 5],
DataE=[1, 2, 3, 4],
NSpec=2,
Distribution=True,
UnitX='Wavelength',
VerticalAxisUnit='DeltaE',
VerticalAxisValues=[4, 6, 8],
OutputWorkspace='ws2d_histo')
def test_model_MDH(self):
model = SliceViewerModel(self.ws_MD_3D)
self.assertEqual(model.get_ws(), self.ws_MD_3D)
self.assertEqual(model.get_ws_type(), WS_TYPE.MDH)
assert_equal(model.get_data((None, 2, 2)), range(90,95))
assert_equal(model.get_data((1, 2, None)), range(18,118,25))
......@@ -58,6 +68,32 @@ class SliceViewerModelTest(unittest.TestCase):
self.assertEqual(dim_info['name'], 'Dim3')
self.assertEqual(dim_info['units'], 'Angstrom')
def test_model_Histo(self):
model = SliceViewerModel(self.ws2d_histo)
self.assertEqual(model.get_ws(), self.ws2d_histo)
self.assertEqual(model.get_ws_type(), WS_TYPE.MATRIX)
dim_info = model.get_dim_info(0)
self.assertEqual(dim_info['minimum'], 10)
self.assertEqual(dim_info['maximum'], 30)
self.assertEqual(dim_info['number_of_bins'], 2)
self.assertAlmostEqual(dim_info['width'], 10)
self.assertEqual(dim_info['name'], 'Wavelength')
self.assertEqual(dim_info['units'], 'Angstrom')
dim_infos = model.get_dimensions_info()
self.assertEqual(len(dim_infos), 2)
dim_info = dim_infos[1]
self.assertEqual(dim_info['minimum'], 4)
self.assertEqual(dim_info['maximum'], 8)
self.assertEqual(dim_info['number_of_bins'], 2)
self.assertAlmostEqual(dim_info['width'], 2)
self.assertEqual(dim_info['name'], 'Energy transfer')
self.assertEqual(dim_info['units'], 'meV')
if __name__ == '__main__':
unittest.main()
......@@ -14,7 +14,7 @@ matplotlib.use('Agg') # noqa: E402
import unittest
from mantid.py3compat import mock
from mantidqt.widgets.sliceviewer.model import SliceViewerModel
from mantidqt.widgets.sliceviewer.model import SliceViewerModel, WS_TYPE
from mantidqt.widgets.sliceviewer.presenter import SliceViewer
from mantidqt.widgets.sliceviewer.view import SliceViewerView
......@@ -27,7 +27,9 @@ class SliceViewerTest(unittest.TestCase):
self.model = mock.Mock(spec=SliceViewerModel)
def test_sliceviewer(self):
def test_sliceviewer_MDH(self):
self.model.get_ws_type = mock.Mock(return_value=WS_TYPE.MDH)
presenter = SliceViewer(None, model=self.model, view=self.view)
......@@ -35,7 +37,7 @@ class SliceViewerTest(unittest.TestCase):
self.assertEqual(self.model.get_dimensions_info.call_count, 0)
self.assertEqual(self.model.get_ws.call_count, 1)
self.assertEqual(self.view.dimensions.get_slicepoint.call_count, 1)
self.assertEqual(self.view.plot.call_count, 1)
self.assertEqual(self.view.plot_MDH.call_count, 1)
# new_plot
self.model.reset_mock()
......@@ -43,7 +45,7 @@ class SliceViewerTest(unittest.TestCase):
presenter.new_plot()
self.assertEqual(self.model.get_ws.call_count, 1)
self.assertEqual(self.view.dimensions.get_slicepoint.call_count, 1)
self.assertEqual(self.view.plot.call_count, 1)
self.assertEqual(self.view.plot_MDH.call_count, 1)
# update_plot_data
self.model.reset_mock()
......@@ -53,6 +55,26 @@ class SliceViewerTest(unittest.TestCase):
self.assertEqual(self.view.dimensions.get_slicepoint.call_count, 1)
self.assertEqual(self.view.update_plot_data.call_count, 1)
def test_sliceviewer_matrix(self):
self.model.get_ws_type = mock.Mock(return_value=WS_TYPE.MATRIX)
presenter = SliceViewer(None, model=self.model, view=self.view)
# setup calls
self.assertEqual(self.model.get_dimensions_info.call_count, 0)
self.assertEqual(self.model.get_ws.call_count, 1)
self.assertEqual(self.view.dimensions.get_slicepoint.call_count, 0)
self.assertEqual(self.view.plot_matrix.call_count, 1)
# new_plot
self.model.reset_mock()
self.view.reset_mock()
presenter.new_plot()
self.assertEqual(self.model.get_ws.call_count, 1)
self.assertEqual(self.view.dimensions.get_slicepoint.call_count, 0)
self.assertEqual(self.view.plot_matrix.call_count, 1)
if __name__ == '__main__':
unittest.main()
......@@ -14,6 +14,7 @@ from mantidqt.MPLwidgets import FigureCanvas, NavigationToolbar2QT as Navigation
from matplotlib.figure import Figure
from .dimensionwidget import DimensionWidget
from mantidqt.widgets.colorbar.colorbar import ColorbarWidget
from mantidqt.plotting.functions import use_imshow
class SliceViewerView(QWidget):
......@@ -54,9 +55,9 @@ class SliceViewerView(QWidget):
self.show()
def plot(self, ws, **kwargs):
def plot_MDH(self, ws, **kwargs):
"""
clears the plot and creates a new one using the workspace
clears the plot and creates a new one using a MDHistoWorkspace
"""
self.ax.clear()
self.im = self.ax.imshow(ws, origin='lower', aspect='auto',
......@@ -67,6 +68,21 @@ class SliceViewerView(QWidget):
self.mpl_toolbar.update() # clear nav stack
self.canvas.draw_idle()
def plot_matrix(self, ws, **kwargs):
"""
clears the plot and creates a new one using a MatrixWorkspace
"""
self.ax.clear()
if use_imshow(ws):
self.im = self.ax.imshow(ws, origin='lower', aspect='auto',
norm=self.colorbar.get_norm(), **kwargs)
else:
self.im = self.ax.pcolormesh(ws, norm=self.colorbar.get_norm(), **kwargs)
self.ax.set_title('')
self.colorbar.set_mappable(self.im)
self.mpl_toolbar.update() # clear nav stack
self.canvas.draw_idle()
def update_plot_data(self, data):
"""
This just updates the plot data without creating a new plot
......
......@@ -113,6 +113,7 @@ void WorkspaceTreeWidgetSimple::popupContextMenu() {
matrixWS->getInstrument() &&
!matrixWS->getInstrument()->getName().empty());
menu->addAction(m_sampleLogs);
menu->addAction(m_sliceViewer);
} else if (boost::dynamic_pointer_cast<ITableWorkspace>(workspace)) {
menu->addAction(m_showData);
menu->addAction(m_showAlgorithmHistory);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment