Unverified Commit ad7334dd authored by Nick Draper's avatar Nick Draper Committed by GitHub
Browse files

Merge pull request #28229 from mantidproject/27915_FixLinksBetweenErrorAndYColsInTableWorkspaces

Fix links between error and y cols in table workspaces
parents d78ab2e4 bb58c28e
......@@ -149,6 +149,10 @@ public:
/// Set plot type where
void setPlotType(int t);
int getLinkedYCol() const { return m_linkedYCol; }
void setLinkedYCol(const int yCol);
/**
* Fills a std vector with values from the column if the types are compatible.
* @param maxSize :: Set size to less than the full column.
......@@ -201,6 +205,9 @@ protected:
/// X = 1, Y = 2, Z = 3, xErr = 4, yErr = 5, Label = 6
int m_plotType;
/// For error columns - the index of the related data column
int m_linkedYCol = -1;
/// Column read-only flag
bool m_isReadOnly;
......
......@@ -36,6 +36,8 @@ void Column::setPlotType(int t) {
}
}
void Column::setLinkedYCol(const int yCol) { m_linkedYCol = yCol; }
/**
* No implementation by default.
*/
......
......@@ -11,6 +11,7 @@
#include "MantidAPI/WorkspaceFactory.h"
#include "MantidAPI/WorkspaceProperty.h"
#include "MantidKernel/V3D.h"
#include "MantidKernel/WarningSuppressions.h"
#include "MantidPythonInterface/core/Converters/CloneToNDArray.h"
#include "MantidPythonInterface/core/Converters/NDArrayToVector.h"
#include "MantidPythonInterface/core/Converters/PySequenceToVector.h"
......@@ -27,6 +28,7 @@
#include <boost/python/dict.hpp>
#include <boost/python/list.hpp>
#include <boost/python/make_constructor.hpp>
#include <boost/python/overloads.hpp>
#include <cstring>
#include <vector>
......@@ -229,8 +231,11 @@ int getPlotType(ITableWorkspace &self, const object &column) {
* @param self Reference to TableWorkspace this is called on
* @param column Name or index of column
* @param ptype PlotType: 0=None, 1=X, 2=Y, 3=Z, 4=xErr, 5=yErr, 6=Label
* @param linkedCol Index of the column that the column parameter is linked to
* (typically used for an error column)
*/
void setPlotType(ITableWorkspace &self, const object &column, int ptype) {
void setPlotType(ITableWorkspace &self, const object &column, int ptype,
int linkedCol = -1) {
// Find the column
Mantid::API::Column_sptr colptr;
if (STR_CHECK(column.ptr())) {
......@@ -239,9 +244,38 @@ void setPlotType(ITableWorkspace &self, const object &column, int ptype) {
colptr = self.getColumn(extract<int>(column)());
}
colptr->setPlotType(ptype);
if (linkedCol >= 0) {
colptr->setLinkedYCol(linkedCol);
}
self.modified();
}
GNU_DIAG_OFF("unused-local-typedef")
// Ignore -Wconversion warnings coming from boost::python
// Seen with GCC 7.1.1 and Boost 1.63.0
GNU_DIAG_OFF("conversion")
BOOST_PYTHON_FUNCTION_OVERLOADS(setPlotType_overloads, setPlotType, 3, 4)
GNU_DIAG_ON("conversion")
GNU_DIAG_ON("unused-local-typedef")
/**
* Get the data column associated with a Y error column
* @param self Reference to TableWorkspace this is called on
* @param column Name or index of column
* @return index of the associated Y column
*/
int getLinkedYCol(ITableWorkspace &self, const object &column) {
// Find the column
Mantid::API::Column_const_sptr colptr;
if (STR_CHECK(column.ptr())) {
colptr = self.getColumn(extract<std::string>(column)());
} else {
colptr = self.getColumn(extract<int>(column)());
}
return colptr->getLinkedYCol();
}
/**
* Access a cell and return a corresponding Python type
* @param self A reference to the TableWorkspace python object that we were
......@@ -613,11 +647,18 @@ void export_ITableWorkspace() {
"Accepts column name or index. \nPossible return values: "
"(0 = None, 1 = X, 2 = Y, 3 = Z, 4 = xErr, 5 = yErr, 6 = Label).")
.def("setPlotType", &setPlotType,
(arg("self"), arg("column"), arg("ptype")),
"Set the plot type of given column. "
"Accepts column name or index. \nPossible type values: "
"(0 = None, 1 = X, 2 = Y, 3 = Z, 4 = xErr, 5 = yErr, 6 = Label).")
.def(
"setPlotType", setPlotType,
setPlotType_overloads(
(arg("self"), arg("column"), arg("ptype"), arg("linkedCol") = -1),
"Set the plot type of given column. "
"Accepts column name or index. \nPossible type values: "
"(0 = None, 1 = X, 2 = Y, 3 = Z, 4 = xErr, 5 = yErr, 6 = "
"Label)."))
.def("getLinkedYCol", &getLinkedYCol, (arg("self"), arg("column")),
"Get the plot type of given column as an integer. "
"Accepts column name or index. ")
.def("removeColumn", &ITableWorkspace::removeColumn,
(arg("self"), arg("name")), "Remove the named column.")
......
......@@ -114,5 +114,6 @@ Bugfixes
- Fixed bug that caused an error if a MDHistoWorkspace was plotted and a user attempted to open a context menu.
- Fixed a bug which caused graphic scaling issues when the double-click menu was used to set an axis as log-scaled.
- Fixed a bug where the colorbar in the instrument view would sometimes have no markers if the scale was set to SymmetricLog10.
- Fixed a bug where setting columns to Y error in table workspaces wasn't working. The links between the Y error and Y columns weren't being set up properly
:ref:`Release 5.0.0 <v5.0.0>`
......@@ -114,3 +114,5 @@ class MockWorkspace:
self.emit_repaint = StrictMock()
self.getPlotType = StrictMock()
self.getLinkedYCol = StrictMock()
......@@ -12,14 +12,12 @@ class ErrorColumn:
CANNOT_SET_Y_TO_BE_OWN_YERR_MESSAGE = "Cannot set Y column to be its own YErr"
UNHANDLED_COMPARISON_LOGIC_MESSAGE = "Unhandled comparison logic with type {}"
def __init__(self, column, related_y_column, label_index):
def __init__(self, column, related_y_column):
self.column = column
self.related_y_column = related_y_column
if self.column == self.related_y_column:
raise ValueError(self.CANNOT_SET_Y_TO_BE_OWN_YERR_MESSAGE)
self.label_index = label_index
def __eq__(self, other):
if isinstance(other, ErrorColumn):
return self.related_y_column == other.related_y_column or self.column == other.column
......@@ -35,3 +33,6 @@ class ErrorColumn:
return self.column == other
else:
raise RuntimeError(self.UNHANDLED_COMPARISON_LOGIC_MESSAGE.format(type(other)))
def __int__(self):
return self.column
......@@ -30,8 +30,7 @@ class TableWorkspaceDisplayEncoder(TableWorkspaceDisplayAttributes):
def _encode_marked_columns(marked_columns):
as_y_err = []
for y_err in marked_columns.as_y_err:
as_y_err.append({"column": y_err.column, "relatedY": y_err.related_y_column,
"labelIndex": y_err.label_index})
as_y_err.append({"column": y_err.column, "relatedY": y_err.related_y_column})
return {"as_x": marked_columns.as_x, "as_y": marked_columns.as_y, "as_y_err": as_y_err}
......@@ -54,7 +53,7 @@ class TableWorkspaceDisplayDecoder(TableWorkspaceDisplayAttributes):
error_columns = []
for y_err in obj_dic["markedColumns"]["as_y_err"]:
error_columns.append(ErrorColumn(column=y_err["column"], label_index=y_err["labelIndex"],
error_columns.append(ErrorColumn(column=y_err["column"],
related_y_column=y_err["relatedY"]))
pres.model.marked_columns.as_y_err = error_columns
......
......@@ -18,14 +18,6 @@ class MarkedColumns:
self.as_y = []
self.as_y_err = []
def _add(self, col_index, add_to, remove_from):
assert all(
add_to is not remove for remove in remove_from), "Can't add and remove from the same list at the same time!"
self._remove(col_index, remove_from)
if col_index not in add_to:
add_to.append(col_index)
def _remove(self, col_index, remove_from):
"""
Remove the column index from all lists
......@@ -34,21 +26,37 @@ class MarkedColumns:
:param remove_from: List of lists from which the column index will be removed
:return:
"""
removed_cols=[]
for list in remove_from:
try:
list.remove(col_index)
# remove all (for error cols there could be more than one match)
for _ in range(list.count(col_index)):
list_index = list.index(col_index)
col_to_remove = list.pop(list_index)
if col_to_remove not in removed_cols:
removed_cols.append(col_to_remove)
except ValueError:
# column not in this list, but might be in another one so we continue the loop
continue
return removed_cols
def add_x(self, col_index):
removed_items = self._remove(col_index, [self.as_y, self.as_y_err])
# if the column previously had a Y Err associated with it -> this will remove it from the YErr list
self._remove_associated_yerr_columns(col_index)
self._remove_associated_yerr_columns(col_index, removed_items)
def add_x(self, col_index):
self._add(col_index, self.as_x, [self.as_y, self.as_y_err])
if col_index not in self.as_x:
self.as_x.append(col_index)
return removed_items
def add_y(self, col_index):
self._add(col_index, self.as_y, [self.as_x, self.as_y_err])
removed_items = self._remove(col_index, [self.as_x, self.as_y_err])
if col_index not in self.as_y:
self.as_y.append(col_index)
return removed_items
def add_y_err(self, err_column):
if err_column.related_y_column in self.as_x:
......@@ -56,26 +64,28 @@ class MarkedColumns:
elif err_column.related_y_column in self.as_y_err:
raise ValueError("Trying to add YErr for column marked as YErr.")
# remove all labels for the column index
len_before_remove = len(self.as_y)
self._remove(err_column, [self.as_x, self.as_y, self.as_y_err])
# Check if the length of the list with columns marked Y has shrunk
# -> This means that columns have been removed, and the label_index is now _wrong_
# and has to be decremented to match the new label index correctly
len_after_remove = len(self.as_y)
if err_column.related_y_column > err_column.column and len_after_remove < len_before_remove:
err_column.label_index -= (len_before_remove - len_after_remove)
removed_items = self._remove(err_column, [self.as_x, self.as_y, self.as_y_err])
# if the column previously had a Y Err associated with it -> this will remove it from the YErr list
# this case isn't handled by the __eq__ and __comp__ functions on the ErrorColumn class
self._remove_associated_yerr_columns(err_column, removed_items)
self.as_y_err.append(err_column)
return removed_items
def remove(self, col_index):
self._remove(col_index, [self.as_x, self.as_y, self.as_y_err])
removed_cols = self._remove(col_index, [self.as_x, self.as_y, self.as_y_err])
# if the column previously had a Y Err associated with it -> this will remove it from the YErr list
self._remove_associated_yerr_columns(col_index, removed_cols)
def _remove_associated_yerr_columns(self, col_index):
def _remove_associated_yerr_columns(self, col_index, removed_cols):
# we can only have 1 Y Err for Y, so iterating and removing's iterator invalidation is not an
# issue as the code will exit immediately after the removal
for col in self.as_y_err:
if col.related_y_column == col_index:
self.as_y_err.remove(col)
if col not in removed_cols:
removed_cols.append(col)
break
def _make_labels(self, list, label):
......@@ -85,8 +95,9 @@ class MarkedColumns:
extra_labels = []
extra_labels.extend(self._make_labels(self.as_x, self.X_LABEL))
extra_labels.extend(self._make_labels(self.as_y, self.Y_LABEL))
err_labels = [(err_col.column, self.Y_ERR_LABEL.format(err_col.label_index),) for index, err_col in
enumerate(self.as_y_err)]
err_labels = [(err_col.column, self.Y_ERR_LABEL.format(self.as_y.index(err_col.related_y_column)),) for
index, err_col in
enumerate(self.as_y_err) if self.as_y.count(err_col.related_y_column)>0]
extra_labels.extend(err_labels)
return extra_labels
......
......@@ -69,16 +69,9 @@ class TableWorkspaceDisplayModel:
elif plot_type == TableWorkspaceColumnTypeMapping.Y:
self.marked_columns.add_y(col)
elif plot_type == TableWorkspaceColumnTypeMapping.YERR:
# mark YErrs only if there are any columns that have been marked as Y
# if there are none then do not mark anything as YErr
if len(self.marked_columns.as_y) > len(self.marked_columns.as_y_err):
# Assume all the YErrs are associated with the first available (no other YErr has it) Y column.
# There isn't a way to know the correct Y column, as that information is not stored
# in the table workspace - the original table workspace does not associate Y errors
# columns with specific Y columns
err_for_column = self.marked_columns.as_y[len(self.marked_columns.as_y_err)]
label = str(len(self.marked_columns.as_y_err))
self.marked_columns.add_y_err(ErrorColumn(col, err_for_column, label))
err_for_column = self.ws.getLinkedYCol(col)
if err_for_column >= 0:
self.marked_columns.add_y_err(ErrorColumn(col, err_for_column))
def _get_v3d_from_str(self, string):
if '[' in string and ']' in string:
......@@ -147,5 +140,5 @@ class TableWorkspaceDisplayModel:
SortTableWorkspace(InputWorkspace=self.ws, OutputWorkspace=self.ws, Columns=column_name,
Ascending=sort_ascending)
def set_column_type(self, col, type):
self.ws.setPlotType(col, type)
def set_column_type(self, col, type, linked_col_index=-1):
self.ws.setPlotType(col, type, linked_col_index)
......@@ -237,7 +237,7 @@ class TableWorkspaceDisplay(ObservingPresenter, DataCopier):
def action_set_as_y(self):
self._action_set_as(self.model.marked_columns.add_y, 2)
def action_set_as_y_err(self, related_y_column, label_index):
def action_set_as_y_err(self, related_y_column):
"""
:param related_y_column: The real index of the column for which the error is being marked
......@@ -250,13 +250,18 @@ class TableWorkspaceDisplay(ObservingPresenter, DataCopier):
return
try:
err_column = ErrorColumn(selected_column, related_y_column, label_index)
err_column = ErrorColumn(selected_column, related_y_column)
except ValueError as e:
self.view.show_warning(str(e))
return
self.model.marked_columns.add_y_err(err_column)
self.model.set_column_type(selected_column, 5)
removed_items = self.model.marked_columns.add_y_err(err_column)
# if a column other than the one the user has just picked as a y err column has been affected,
# reset it's type to None
for col in removed_items:
if col != selected_column:
self.model.set_column_type(int(col),0)
self.model.set_column_type(selected_column, 5, related_y_column)
self.update_column_headers()
def action_set_as_none(self):
......
......@@ -6,28 +6,32 @@ from mantidqt.widgets.workspacedisplay.table.error_column import ErrorColumn
class ErrorColumnTest(unittest.TestCase):
def test_correct_init(self):
ErrorColumn(0, 1, 0)
ErrorColumn(0, 1)
def test_raises_for_same_y_and_yerr(self):
self.assertRaises(ValueError, lambda: ErrorColumn(2, 2, 3))
self.assertRaises(ValueError, lambda: ErrorColumn(2, 2))
def test_eq_versus_ErrorColumn(self):
ec1 = ErrorColumn(0, 1, 0)
ec2 = ErrorColumn(0, 1, 0)
ec1 = ErrorColumn(0, 1)
ec2 = ErrorColumn(0, 1)
self.assertEqual(ec1, ec2)
ec1 = ErrorColumn(0, 3, 0)
ec2 = ErrorColumn(0, 1, 0)
ec1 = ErrorColumn(0, 3)
ec2 = ErrorColumn(0, 1)
self.assertEqual(ec1, ec2)
ec1 = ErrorColumn(2, 3, 0)
ec2 = ErrorColumn(0, 3, 0)
ec1 = ErrorColumn(2, 3)
ec2 = ErrorColumn(0, 3)
self.assertEqual(ec1, ec2)
def test_eq_versus_same_int(self):
ec = ErrorColumn(150, 1, 0)
ec = ErrorColumn(150, 1)
self.assertEqual(ec, 150)
def test_eq_unsupported_type(self):
ec = ErrorColumn(150, 1, 0)
ec = ErrorColumn(150, 1)
self.assertRaises(RuntimeError, lambda: ec == "awd")
def test_int(self):
ec = ErrorColumn(1, 2)
self.assertEqual(int(ec), 1)
......@@ -17,7 +17,7 @@ from mantidqt.widgets.workspacedisplay.table import StatusBarView
TABLEWORKSPACEDISPLAY_DICT = {"markedColumns": {"as_y": [2], "as_x": [1],
"as_y_err": [{"column": 3, "relatedY": 2, "labelIndex": 0}]},
"as_y_err": [{"column": 3, "relatedY": 2}]},
"workspace": "ws", "windowName": "ws"}
......@@ -52,9 +52,6 @@ class TableWorkspaceDisplayDecoderTest(unittest.TestCase):
self.assertEqual(
TABLEWORKSPACEDISPLAY_DICT["markedColumns"]["as_y_err"][0]["relatedY"],
view.presenter.model.marked_columns.as_y_err[0].related_y_column)
self.assertEqual(
TABLEWORKSPACEDISPLAY_DICT["markedColumns"]["as_y_err"][0]["labelIndex"],
view.presenter.model.marked_columns.as_y_err[0].label_index)
self.assertEqual(1, len(view.presenter.model.marked_columns.as_y_err))
......
......@@ -45,13 +45,13 @@ class MarkedColumnsTest(unittest.TestCase):
Test adding YErr columns that do not overlap in any way
"""
mc = MarkedColumns()
ec = ErrorColumn(2, 4, 0)
ec = ErrorColumn(2, 4)
mc.add_y_err(ec)
self.assertEqual(1, len(mc.as_y_err))
ec = ErrorColumn(3, 5, 0)
ec = ErrorColumn(3, 5)
mc.add_y_err(ec)
self.assertEqual(2, len(mc.as_y_err))
ec = ErrorColumn(1, 6, 0)
ec = ErrorColumn(1, 6)
mc.add_y_err(ec)
self.assertEqual(3, len(mc.as_y_err))
......@@ -75,14 +75,14 @@ class MarkedColumnsTest(unittest.TestCase):
def test_add_y_err_duplicate_column(self):
mc = MarkedColumns()
ec = ErrorColumn(2, 4, 0)
ec = ErrorColumn(2, 4)
mc.add_y_err(ec)
self.assertEqual(1, len(mc.as_y_err))
mc.add_y_err(ec)
self.assertEqual(1, len(mc.as_y_err))
ec2 = ErrorColumn(3, 5, 0)
ec2 = ErrorColumn(3, 5)
mc.add_y_err(ec2)
self.assertEqual(2, len(mc.as_y_err))
mc.add_y_err(ec2)
......@@ -123,7 +123,7 @@ class MarkedColumnsTest(unittest.TestCase):
-> The new YErr must replace the old one
"""
mc = MarkedColumns()
ec = ErrorColumn(column=2, related_y_column=4, label_index=0)
ec = ErrorColumn(column=2, related_y_column=4)
mc.add_y_err(ec)
self.assertEqual(1, len(mc.as_y_err))
self.assertEqual(2, mc.as_y_err[0].column)
......@@ -131,7 +131,7 @@ class MarkedColumnsTest(unittest.TestCase):
# different source column but contains error for the same column
# adding this one should replace the first one
ec2 = ErrorColumn(column=2, related_y_column=5, label_index=0)
ec2 = ErrorColumn(column=2, related_y_column=5)
mc.add_y_err(ec2)
self.assertEqual(1, len(mc.as_y_err))
self.assertEqual(2, mc.as_y_err[0].column)
......@@ -143,7 +143,7 @@ class MarkedColumnsTest(unittest.TestCase):
-> The new YErr must replace the old one
"""
mc = MarkedColumns()
ec = ErrorColumn(column=2, related_y_column=4, label_index=0)
ec = ErrorColumn(column=2, related_y_column=4)
mc.add_y_err(ec)
self.assertEqual(1, len(mc.as_y_err))
self.assertEqual(2, mc.as_y_err[0].column)
......@@ -151,7 +151,7 @@ class MarkedColumnsTest(unittest.TestCase):
# different source column but contains error for the same column
# adding this one should replace the first one
ec2 = ErrorColumn(column=3, related_y_column=4, label_index=0)
ec2 = ErrorColumn(column=3, related_y_column=4)
mc.add_y_err(ec2)
self.assertEqual(1, len(mc.as_y_err))
self.assertEqual(3, mc.as_y_err[0].column)
......@@ -164,7 +164,7 @@ class MarkedColumnsTest(unittest.TestCase):
"""
mc = MarkedColumns()
mc.add_y(4)
ec = ErrorColumn(column=2, related_y_column=4, label_index=0)
ec = ErrorColumn(column=2, related_y_column=4)
mc.add_y_err(ec)
# check that we have both a Y col and an associated YErr
......@@ -177,6 +177,11 @@ class MarkedColumnsTest(unittest.TestCase):
self.assertEqual(0, len(mc.as_y))
self.assertEqual(0, len(mc.as_y_err))
# check setting the column back to Y does not automatically reinstate the error column
mc.add_y(4)
self.assertEqual(1, len(mc.as_y))
self.assertEqual(0, len(mc.as_y_err))
def test_changing_y_to_none_removes_associated_yerr_columns(self):
"""
Test to check if a first column is marked as Y, a second column YErr is associated with it, but then
......@@ -184,7 +189,7 @@ class MarkedColumnsTest(unittest.TestCase):
"""
mc = MarkedColumns()
mc.add_y(4)
ec = ErrorColumn(column=2, related_y_column=4, label_index=0)
ec = ErrorColumn(column=2, related_y_column=4)
mc.add_y_err(ec)
# check that we have both a Y col and an associated YErr
......@@ -197,11 +202,16 @@ class MarkedColumnsTest(unittest.TestCase):
self.assertEqual(0, len(mc.as_y))
self.assertEqual(0, len(mc.as_y_err))
# check adding the Y column back in does not automatically reinstate the error column
mc.add_y(4)
self.assertEqual(1, len(mc.as_y))
self.assertEqual(0, len(mc.as_y_err))
def test_remove_column(self):
mc = MarkedColumns()
mc.add_y(4)
mc.add_x(3)
ec = ErrorColumn(column=2, related_y_column=6, label_index=0)
ec = ErrorColumn(column=2, related_y_column=6)
mc.add_y_err(ec)
self.assertEqual(1, len(mc.as_x))
......@@ -249,12 +259,12 @@ class MarkedColumnsTest(unittest.TestCase):
mc.add_y(2)
# change one of the columns to YErr
mc.add_y_err(ErrorColumn(1, 0, 0))
mc.add_y_err(ErrorColumn(1, 0))
expected = [(0, '[Y0]'), (2, '[Y1]'), (1, '[Y0_YErr]')]
self.assertEqual(expected, mc.build_labels())
# change the last Y column to YErr
mc.add_y_err(ErrorColumn(2, 0, 0))
mc.add_y_err(ErrorColumn(2, 0))
expected = [(0, '[Y0]'), (2, '[Y0_YErr]')]
self.assertEqual(expected, mc.build_labels())
......@@ -265,13 +275,13 @@ class MarkedColumnsTest(unittest.TestCase):
mc.add_y(2)
# change one of the columns to YErr
mc.add_y_err(ErrorColumn(0, 1, 1))
mc.add_y_err(ErrorColumn(0, 1))
# note: the first column is being set -> this decreases the label index of all columns to its right by 1
expected = [(1, '[Y0]'), (2, '[Y1]'), (0, '[Y0_YErr]')]
self.assertEqual(expected, mc.build_labels())
# change the last Y column to YErr
mc.add_y_err(ErrorColumn(2, 1, 0))
mc.add_y_err(ErrorColumn(2, 1))
expected = [(1, '[Y0]'), (2, '[Y0_YErr]')]
self.assertEqual(expected, mc.build_labels())
......@@ -282,7 +292,7 @@ class MarkedColumnsTest(unittest.TestCase):
mc.add_y(2)
mc.add_y(3)
mc.add_y_err(ErrorColumn(1, 0, 0))
mc.add_y_err(ErrorColumn(1, 0))
expected = [(0, '[Y0]'), (2, '[Y1]'), (3, '[Y2]'), (1, '[Y0_YErr]')]
self.assertEqual(expected, mc.build_labels())
......@@ -291,7 +301,16 @@ class MarkedColumnsTest(unittest.TestCase):
self.assertEqual(expected, mc.build_labels())
expected = [(1, '[X0]'), (2, '[Y0]'), (3, '[Y1]'), (0, '[Y1_YErr]')]
mc.add_y_err(ErrorColumn(0, 3, 2))
mc.add_y_err(ErrorColumn(0, 3))