Unverified Commit 960f7dcd authored by Gagik Vardanyan's avatar Gagik Vardanyan Committed by GitHub
Browse files

Merge pull request #30350 from mantidproject/30315_fix_colourfill_project_save

Fix error when saving plots with normalisation
parents fdad2bce 4b17a287
......@@ -157,6 +157,10 @@ def _plot_impl(axes, workspace, args, kwargs):
kwargs['drawstyle'] = 'steps-post'
else:
normalize_by_bin_width, kwargs = get_normalize_by_bin_width(workspace, axes, **kwargs)
# the get... function returns kwargs without 'normalize_by_bin_width', but it is needed in _get_data_for_plot to
# avoid reverting to the default norm setting, in the event that this function is being called as part of a plot
# restoration
kwargs['normalize_by_bin_width'] = normalize_by_bin_width
x, y, _, _, indices, axis, kwargs = _get_data_for_plot(axes, workspace, kwargs)
if kwargs.pop('update_axes_labels', True):
_setLabels1D(axes,
......
......@@ -657,7 +657,7 @@ class MantidAxes(Axes):
with autoscale_on_update(self, autoscale_on):
artist = self.track_workspace_artist(workspace,
axesfunctions.plot(self, normalize_by_bin_width = is_normalized,
axesfunctions.plot(self, normalize_by_bin_width=is_normalized,
*args, **kwargs),
_data_update, spec_num, is_normalized,
MantidAxes.is_axis_of_type(MantidAxType.SPECTRUM, kwargs),
......
......@@ -209,12 +209,14 @@ def use_imshow(ws):
@manage_workspace_names
def pcolormesh(workspaces, fig=None):
def pcolormesh(workspaces, fig=None, normalize_by_bin_width=None):
"""
Create a figure containing pcolor subplots
:param workspaces: A list of workspace handles
:param fig: An optional figure to contain the new plots. Its current contents will be cleared
:param normalize_by_bin_width: Optional and only to be used in the event that the function is being called as part
of a plot restore
:returns: The figure containing the plots
"""
# check inputs
......@@ -232,7 +234,7 @@ def pcolormesh(workspaces, fig=None):
ax = axes[row_idx][col_idx]
if subplot_idx < workspaces_len:
ws = workspaces[subplot_idx]
pcm = pcolormesh_on_axis(ax, ws)
pcm = pcolormesh_on_axis(ax, ws, normalize_by_bin_width)
plots.append(pcm)
if col_idx < ncols - 1:
col_idx += 1
......@@ -271,11 +273,12 @@ def pcolormesh(workspaces, fig=None):
return fig
def pcolormesh_on_axis(ax, ws):
def pcolormesh_on_axis(ax, ws, normalize_by_bin_width=None):
"""
Plot a pcolormesh plot of the given workspace on the given axis
:param ax: A matplotlib axes instance
:param ws: A mantid workspace instance
:param normalize_by_bin_width: Optional keyword argument to pass to imshow in the event of a plot restoration
:return:
"""
ax.clear()
......@@ -283,7 +286,10 @@ def pcolormesh_on_axis(ax, ws):
scale = _get_colorbar_scale()
if use_imshow(ws):
pcm = ax.imshow(ws, cmap=ConfigService.getString("plots.images.Colormap"), aspect='auto', origin='lower',
norm=scale())
norm=scale(), normalize_by_bin_width=normalize_by_bin_width)
# remove normalize_by_bin_width from cargs if present so that this can be toggled in future
for cargs in pcm.axes.creation_args:
cargs.pop('normalize_by_bin_width')
else:
pcm = ax.pcolormesh(ws, cmap=ConfigService.getString("plots.images.Colormap"), norm=scale())
......
......@@ -31,7 +31,6 @@ class PlotsLoader(object):
def load_plots(self, plots_list):
if plots_list is None:
return
for plot_ in plots_list:
try:
self.make_fig(plot_)
......@@ -41,6 +40,10 @@ class PlotsLoader(object):
raise KeyboardInterrupt(str(e))
logger.warning("A plot was unable to be loaded from the save file. Error: " + str(e))
def restore_normalise_obj_from_dict(self, norm_dict):
norm = matplotlib.colors.Normalize(norm_dict['vmin'], norm_dict['vmax'], norm_dict['clip'])
return norm
def make_fig(self, plot_dict, create_plot=True):
"""
This method currently only considers single matplotlib.axes.Axes based figures as that is the most common case
......@@ -56,6 +59,10 @@ class PlotsLoader(object):
"The original plot title was: {}".format(plot_dict["label"]))
return
for sublist in creation_args:
for cargs_dict in sublist:
if 'norm' in cargs_dict and type(cargs_dict['norm']) is dict:
cargs_dict['norm'] = self.restore_normalise_obj_from_dict(cargs_dict['norm'])
fig, axes_matrix, _, _ = create_subplots(len(creation_args))
axes_list = axes_matrix.flatten().tolist()
for ax, cargs_list in zip(axes_list, creation_args):
......@@ -67,7 +74,8 @@ class PlotsLoader(object):
self.workspace_plot_func(workspace, ax, ax.figure, cargs)
elif "function" in cargs:
self.plot_func(ax, cargs)
for cargs in creation_args_copy:
cargs.pop('normalize_by_bin_width', None)
ax.creation_args = creation_args_copy
# Update the fig
......@@ -115,7 +123,7 @@ class PlotsLoader(object):
func = function_dict[function_to_call]
# Plotting is done via an Axes object unless a colorbar needs to be added
if function_to_call in ["imshow", "pcolormesh"]:
func([workspace], fig)
func([workspace], fig, normalize_by_bin_width=creation_arg['normalize_by_bin_width'])
self.color_bar_remade = True
else:
func(workspace, **creation_arg)
......
......@@ -6,6 +6,7 @@
# SPDX - License - Identifier: GPL - 3.0 +
# This file is part of the mantidqt package
#
from copy import deepcopy
import matplotlib.axis
from matplotlib import ticker
from matplotlib.image import AxesImage
......@@ -13,13 +14,7 @@ from matplotlib.image import AxesImage
from mantid import logger
from mantid.plots.legend import LegendProperties
try:
from matplotlib.colors import to_hex
except ImportError:
from matplotlib.colors import colorConverter, rgb2hex
def to_hex(color):
return rgb2hex(colorConverter.to_rgb(color))
from matplotlib.colors import to_hex, Normalize
class PlotsSaver(object):
......@@ -27,7 +22,7 @@ class PlotsSaver(object):
self.figure_creation_args = {}
def save_plots(self, plot_dict, is_project_recovery=False):
# if arguement is none return empty dictionary
# if argument is none return empty dictionary
if plot_dict is None:
return []
......@@ -48,18 +43,37 @@ class PlotsSaver(object):
logger.debug(error_string)
return plot_list
@staticmethod
def _convert_normalise_obj_to_dict(norm):
norm_dict = {'clip': norm.clip, 'vmin': norm.vmin, 'vmax': norm.vmax}
return norm_dict
@staticmethod
def _add_normalisation_kwargs(cargs_list, axes_list):
for ax_cargs, ax_dict in zip(cargs_list[0], axes_list):
is_norm = ax_dict.pop("_is_norm")
ax_cargs['normalize_by_bin_width'] = is_norm
def get_dict_from_fig(self, fig):
axes_list = []
create_list = []
for ax in fig.axes:
try:
create_list.append(ax.creation_args)
self.figure_creation_args = ax.creation_args
creation_args = deepcopy(ax.creation_args)
# convert the normalise object (if present) into a dict so that it can be json serialised
for args_dict in creation_args:
if 'norm' in args_dict.keys() and type(args_dict['norm']) is Normalize:
norm_dict = self._convert_normalise_obj_to_dict(args_dict['norm'])
args_dict['norm'] = norm_dict
create_list.append(creation_args)
self.figure_creation_args = creation_args
except AttributeError:
logger.debug("Axis had an axis without creation_args - Common with a Colorfill plot")
continue
axes_list.append(self.get_dict_for_axes(ax))
if create_list and axes_list:
self._add_normalisation_kwargs(create_list, axes_list)
fig_dict = {"creationArguments": create_list,
"axes": axes_list,
"label": fig._label,
......@@ -123,6 +137,11 @@ class PlotsSaver(object):
legend_dict["exists"] = False
ax_dict["legend"] = legend_dict
# add value to determine if ax has been normalised
ws_artists = [art for art in ax.tracked_workspaces.values()]
is_norm = all(art[0].is_normalized for art in ws_artists)
ax_dict["_is_norm"] = is_norm
return ax_dict
def get_dict_from_axes_properties(self, ax):
......
......@@ -135,7 +135,8 @@ class PlotsSaverTest(unittest.TestCase):
self.fig.axes[0].creation_args = [{u"specNum": 2, "function": "plot"}]
return_value = self.plot_saver.get_dict_from_fig(self.fig)
self.loader_plot_dict[u'creationArguments'] = [[{u"specNum": 2, "function": "plot"}]]
self.loader_plot_dict[u'creationArguments'] = [[{u"specNum": 2, "function": "plot", u"normalize_by_bin_width":
True}]]
self.maxDiff = None
self.assertDictEqual(return_value, self.loader_plot_dict)
......@@ -144,6 +145,7 @@ class PlotsSaverTest(unittest.TestCase):
self.plot_saver.figure_creation_args = [{"function": "plot"}]
return_value = self.plot_saver.get_dict_for_axes(self.fig.axes[0])
self.loader_plot_dict["axes"][0]['_is_norm'] = True
expected_value = self.loader_plot_dict["axes"][0]
self.maxDiff = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment