Skip to content
Snippets Groups Projects
test_functions.py 14 KiB
Newer Older
# Mantid Repository : https://github.com/mantidproject/mantid
# Copyright © 2018 ISIS Rutherford Appleton Laboratory UKRI,
#   NScD Oak Ridge National Laboratory, European Spallation Source,
#   Institut Laue - Langevin & CSNS, Institute of High Energy Physics, CAS
# SPDX - License - Identifier: GPL - 3.0 +
#  This file is part of the mantid workbench.
from __future__ import absolute_import

# std imports
from unittest import TestCase, main
import matplotlib
matplotlib.use('AGG')  # noqa
import matplotlib.pyplot as plt
# register mantid projection
import mantid.plots  # noqa
from mantid.api import AnalysisDataService, WorkspaceFactory
from mantid.simpleapi import CreateWorkspace
from mantid.kernel import config
from mantid.plots import MantidAxes
from mantid.py3compat import mock
from mantidqt.dialogs.spectraselectordialog import SpectraSelection
from mantidqt.plotting.functions import (can_overplot, current_figure_or_none, figure_title,
                                         manage_workspace_names, plot, plot_from_names,
                                         pcolormesh_from_names)


# Avoid importing the whole of mantid for a single mock of the workspace class
class FakeWorkspace(object):
    def __init__(self, name):
        self._name = name

    def name(self):
        return self._name


@manage_workspace_names
def workspace_names_dummy_func(workspaces):
    return workspaces


class FunctionsTest(TestCase):

    _test_ws = None

    def setUp(self):
        if self._test_ws is None:
            self.__class__._test_ws = WorkspaceFactory.Instance().create(
                "Workspace2D", NVectors=2, YLength=5, XLength=5)

    def tearDown(self):
        AnalysisDataService.Instance().clear()
        plt.close('all')

    def test_can_overplot_returns_false_with_no_active_plots(self):
        self.assertFalse(can_overplot()[0])

    def test_can_overplot_returns_true_for_active_line_plot(self):
        plt.plot([1, 2])
        self.assertTrue(can_overplot()[0])

    def test_can_overplot_returns_false_for_active_patch_plot(self):
        plt.pcolormesh(np.arange(9.).reshape(3, 3))
        allowed, msg = can_overplot()
        self.assertFalse(allowed)
        self.assertGreater(len(msg), 0)

    def test_current_figure_or_none_returns_none_if_no_figures_exist(self):
        self.assertEqual(current_figure_or_none(), None)

    def test_figure_title_with_single_string(self):
        self.assertEqual("test-1", figure_title("test", 1))

    def test_figure_title_with_list_of_strings(self):
        self.assertEqual("first-10", figure_title(["first", "second"], 10))

    def test_figure_title_with_single_workspace(self):
        self.assertEqual("fake-5", figure_title(FakeWorkspace("fake"), 5))

    def test_figure_title_with_workspace_list(self):
        self.assertEqual("fake-10", figure_title((FakeWorkspace("fake"),
                                                  FakeWorkspace("nextfake")), 10))

    def test_figure_title_with_empty_list_raises_assertion(self):
        with self.assertRaises(AssertionError):
            figure_title([], 5)

    def test_that_plot_can_accept_workspace_names(self):
        ws_name1 = "some_workspace"
        AnalysisDataService.Instance().addOrReplace(ws_name1, self._test_ws)

        try:
            result_workspaces = workspace_names_dummy_func([ws_name1])
        except ValueError:
            self.fail("Passing workspace names should not raise a value error.")
        else:
            # The list of workspace names we pass in should have been converted
            # to a list of workspaces
            self.assertNotEqual(result_workspaces, [ws_name1])

    @mock.patch('mantidqt.plotting.functions.get_spectra_selection')
    @mock.patch('mantidqt.plotting.functions.plot')
    def test_plot_from_names_calls_plot(self, get_spectra_selection_mock, plot_mock):
        ws_name = 'test_plot_from_names_calls_plot-1'
        AnalysisDataService.Instance().addOrReplace(ws_name, self._test_ws)
        selection = SpectraSelection([self._test_ws])
        selection.wksp_indices = [0]
        get_spectra_selection_mock.return_value = selection
        plot_from_names([ws_name], errors=False, overplot=False)
        self.assertEqual(1, plot_mock.call_count)

    @mock.patch('mantidqt.plotting.functions.get_spectra_selection')
    def test_plot_from_names_produces_single_line_plot_for_valid_name(self, get_spectra_selection_mock):
        self._do_plot_from_names_test(get_spectra_selection_mock, expected_labels=["spec 1"], wksp_indices=[0],
                                      errors=False, overplot=False)

    @mock.patch('mantidqt.plotting.functions.get_spectra_selection')
    def test_plot_from_names_produces_single_error_plot_for_valid_name(self, get_spectra_selection_mock):
        fig = self._do_plot_from_names_test(get_spectra_selection_mock,
                                            # matplotlib does not set labels on the lines for error plots
                                            expected_labels=[],
                                            wksp_indices=[0], errors=True, overplot=False)
        self.assertEqual(1, len(fig.gca().containers))

    @mock.patch('mantidqt.plotting.functions.get_spectra_selection')
    def test_plot_from_names_produces_overplot_for_valid_name(self, get_spectra_selection_mock):
        # make first plot
        plot([self._test_ws], wksp_indices=[0])
        self._do_plot_from_names_test(get_spectra_selection_mock, expected_labels=["spec 1", "spec 2"],
                                      wksp_indices=[1], errors=False, overplot=True)

    @mock.patch('mantidqt.plotting.functions.get_spectra_selection')
    def test_plot_from_names_within_existing_figure(self, get_spectra_selection_mock):
        # make existing plot
        fig = plot([self._test_ws], wksp_indices=[0])
        self._do_plot_from_names_test(get_spectra_selection_mock, expected_labels=["spec 1", "spec 2"],
                                      wksp_indices=[1], errors=False, overplot=True,
    @mock.patch('mantidqt.plotting.functions.pcolormesh')
    def test_pcolormesh_from_names_calls_pcolormesh(self, pcolormesh_mock):
        ws_name = 'test_pcolormesh_from_names_calls_pcolormesh-1'
        AnalysisDataService.Instance().addOrReplace(ws_name, self._test_ws)
        pcolormesh_from_names([ws_name])
        self.assertEqual(1, pcolormesh_mock.call_count)

    def test_scale_is_correct_on_pcolourmesh_of_ragged_workspace(self):
        ws = CreateWorkspace(DataX=[1, 2, 3, 4, 2, 4, 6, 8], DataY=[2] * 8, NSpec=2)
        fig = pcolormesh_from_names([ws])
        self.assertEqual((1.8, 2.2), fig.axes[0].images[0].get_clim())

    def test_pcolormesh_from_names(self):
        ws_name = 'test_pcolormesh_from_names-1'
        AnalysisDataService.Instance().addOrReplace(ws_name, self._test_ws)
        fig = pcolormesh_from_names([ws_name])
Hahn, Steven's avatar
Hahn, Steven committed
        self.assertEqual(1, len(fig.gca().images))

    def test_pcolormesh_from_names_using_existing_figure(self):
        ws_name = 'test_pcolormesh_from_names-1'
        AnalysisDataService.Instance().addOrReplace(ws_name, self._test_ws)
        target_fig = plt.figure()
        fig = pcolormesh_from_names([ws_name], fig=target_fig)
        self.assertEqual(fig, target_fig)
Hahn, Steven's avatar
Hahn, Steven committed
        self.assertEqual(1, len(fig.gca().images))
    def test_workspace_can_be_plotted_on_top_of_scripted_plots(self):
        fig = plt.figure()
        plt.plot([0, 1], [0, 1])
        ws = self._test_ws
        plot([ws], wksp_indices=[1], fig=fig, overplot=True)
        ax = plt.gca()
        self.assertEqual(len(ax.lines), 2)

    def test_title_preserved_when_workspace_plotted_on_scripted_plot(self):
        fig = plt.figure()
        plt.plot([0, 1], [0, 1])
        plt.title("My Title")
        ws = self._test_ws
        plot([ws], wksp_indices=[1], fig=fig, overplot=True)
        ax = plt.gca()
        self.assertEqual("My Title", ax.get_title())

    def test_different_line_colors_when_plotting_over_scripted_fig(self):
        fig = plt.figure()
        plt.plot([0, 1], [0, 1])
        ws = self._test_ws
        plot([ws], wksp_indices=[1], fig=fig, overplot=True)
        ax = plt.gca()
        line_colors = [line.get_color() for line in ax.get_lines()]
        self.assertNotEqual(line_colors[0], line_colors[1])

    def test_workspace_tracked_when_plotting_over_scripted_fig(self):
        fig = plt.figure()
        plt.plot([0, 1], [0, 1])
        ws = self._test_ws
        plot([ws], wksp_indices=[1], fig=fig, overplot=True)
        ax = plt.gca()
        self.assertIn(ws.name(), ax.tracked_workspaces)

    def test_from_mpl_axes_success_with_default_args(self):
        plt.figure()
        plt.plot([0, 1], [0, 1])
        plt.plot([0, 2], [0, 2])
        ax = plt.gca()
        mantid_ax = MantidAxes.from_mpl_axes(ax)
        self.assertEqual(len(mantid_ax.lines), 2)
        self.assertIsInstance(mantid_ax, MantidAxes)

    def test_that_plot_spectrum_has_same_y_label_with_and_without_errorbars(self):
        auto_dist = config['graph1d.autodistribution']
        try:
            config['graph1d.autodistribution'] = 'Off'
            self._compare_errorbar_labels_and_title()
        finally:
            config['graph1d.autodistribution'] = auto_dist
    def test_that_plot_spectrum_has_same_y_label_with_and_without_errorbars_normalize_by_bin_width(self):
        auto_dist = config['graph1d.autodistribution']
        try:
            config['graph1d.autodistribution'] = 'On'
            self._compare_errorbar_labels_and_title()
        finally:
            config['graph1d.autodistribution'] = auto_dist
    def test_setting_waterfall_to_true_makes_waterfall_plot(self):
        fig = plt.figure()
        ws = self._test_ws
        plot([ws], wksp_indices=[0,1], fig=fig, waterfall=True)
        ax = plt.gca()


    def test_cannot_make_waterfall_plot_with_one_line(self):
        fig = plt.figure()
        ws = self._test_ws
        plot([ws], wksp_indices=[1], fig=fig, waterfall=True)
        ax = plt.gca()

        self.assertFalse(ax.is_waterfall())

    def test_overplotting_onto_waterfall_plot_maintains_waterfall(self):
        fig = plt.figure()
        ws = self._test_ws
        plot([ws], wksp_indices=[0,1], fig=fig, waterfall=True)
        # Overplot one of the same lines.
        plot([ws], wksp_indices=[0], fig=fig, overplot=True)
        ax = plt.gca()

        # Check that the lines which would be the same in a non-waterfall plot are different.
        self.assertNotEqual(ax.get_lines()[0].get_xdata()[0], ax.get_lines()[2].get_xdata()[0])
        self.assertNotEqual(ax.get_lines()[0].get_ydata()[0], ax.get_lines()[2].get_ydata()[0])

    def test_overplotting_onto_waterfall_plot_with_filled_areas_adds_another_filled_area(self):
        fig = plt.figure()
        ws = self._test_ws
        plot([ws], wksp_indices=[0, 1], fig=fig, waterfall=True)
        ax = plt.gca()
        ax.set_waterfall_fill(True)
        plot([ws], wksp_indices=[0], fig=fig, overplot=True)

        fills = [collection for collection in ax.collections
                 if isinstance(collection, matplotlib.collections.PolyCollection)]

        self.assertEqual(len(fills), 3)

    # ------------- Failure tests -------------

    def test_plot_from_names_with_non_plottable_workspaces_returns_None(self):
        table = WorkspaceFactory.Instance().createTable()
        table_name = 'test_plot_from_names_with_non_plottable_workspaces_returns_None'
        AnalysisDataService.Instance().addOrReplace(table_name, table)
        result = plot_from_names([table_name], errors=False, overplot=False)
        self.assertEqual(result, None)

    def test_pcolormesh_from_names_with_non_plottable_workspaces_returns_None(self):
        table = WorkspaceFactory.Instance().createTable()
        table_name = 'test_pcolormesh_from_names_with_non_plottable_workspaces_returns_None'
        AnalysisDataService.Instance().addOrReplace(table_name, table)
        result = pcolormesh_from_names([table_name])
        self.assertEqual(result, None)
    def test_that_manage_workspace_names_raises_on_mix_of_workspaces_and_names(self):
        ws = ["some_workspace", self._test_ws]
        AnalysisDataService.Instance().addOrReplace("some_workspace", self._test_ws)
        self.assertRaises(TypeError, workspace_names_dummy_func(ws))

    # ------------- Private -------------------
    def _do_plot_from_names_test(self, get_spectra_selection_mock, expected_labels,
                                 wksp_indices, errors, overplot, target_fig=None):
        ws_name = 'test_plot_from_names-1'
        AnalysisDataService.Instance().addOrReplace(ws_name, self._test_ws)

        selection = SpectraSelection([self._test_ws])
        selection.wksp_indices = wksp_indices
        get_spectra_selection_mock.return_value = selection
        fig = plot_from_names([ws_name], errors, overplot, target_fig)
        if target_fig is not None:
            self.assertEqual(target_fig, fig)

        plotted_lines = fig.gca().get_legend().get_lines()
        self.assertEqual(len(expected_labels), len(plotted_lines))
        for label_part, line in zip(expected_labels, plotted_lines):
            if label_part is not None:
                self.assertTrue(label_part in line.get_label(),
                                msg="Label fragment '{}' not found in line label".format(label_part))
        return fig

    def _compare_errorbar_labels_and_title(self):
        ws = self._test_ws
        ws.setYUnitLabel("MyLabel")
        ws.getAxis(0).setUnit("TOF")
        for distribution_ws in [True, False]:
            ws.setDistribution(distribution_ws)
            ax = plot([ws], wksp_indices=[1]).get_axes()[0]
            err_ax = plot([ws], wksp_indices=[1], errors=True).get_axes()[0]
            # Compare y-labels
            self.assertEqual(ax.get_ylabel(), err_ax.get_ylabel())
            # Compare x-labels
            self.assertEqual(ax.get_xlabel(), err_ax.get_xlabel())
            # Compare title
            self.assertEqual(ax.get_title(), err_ax.get_title())

if __name__ == '__main__':