diff --git a/Framework/API/inc/MantidAPI/AnalysisDataService.h b/Framework/API/inc/MantidAPI/AnalysisDataService.h index f428b8e7bc631d93e16bc7b2c49684219c43ebbe..b50402509a2393a9c2996f6d6ed36a6e7b8f5295 100644 --- a/Framework/API/inc/MantidAPI/AnalysisDataService.h +++ b/Framework/API/inc/MantidAPI/AnalysisDataService.h @@ -23,9 +23,8 @@ class WorkspaceGroup; /** The Analysis data service stores instances of the Workspace objects and anything that derives from template class - DynamicFactory<Mantid::Kernel::IAlgorithm>. - This is the primary data service that - the users will interact with either through writing scripts or directly + DynamicFactory<Mantid::Kernel::IAlgorithm>. This is the primary data service + that the users will interact with either through writing scripts or directly through the API. It is implemented as a singleton class. This is the manager/owner of Workspace* when registered. @@ -35,8 +34,8 @@ class WorkspaceGroup; @author L C Chapon, ISIS, Rutherford Appleton Laboratory Modified to inherit from DataService - Copyright © 2007-9 ISIS Rutherford Appleton Laboratory, NScD Oak Ridge - National Laboratory & European Spallation Source + Copyright © 2007-9 ISIS Rutherford Appleton Laboratory, NScD Oak + Ridge National Laboratory & European Spallation Source This file is part of Mantid. @@ -53,7 +52,8 @@ class WorkspaceGroup; You should have received a copy of the GNU General Public License along with this program. If not, see <http://www.gnu.org/licenses/>. - File change history is stored at: <https://github.com/mantidproject/mantid>. + File change history is stored at: + <https://github.com/mantidproject/mantid>. Code Documentation is available at: <http://doxygen.mantidproject.org> */ class DLLExport AnalysisDataServiceImpl final @@ -78,8 +78,7 @@ public: }; /// UnGroupingWorkspace notification is sent from UnGroupWorkspace algorithm - /// before the WorkspaceGroup is removed from the - /// DataService + /// before the WorkspaceGroup is removed from the DataService class UnGroupingWorkspaceNotification : public DataServiceNotification { public: /// Constructor @@ -88,8 +87,8 @@ public: : DataServiceNotification(name, obj) {} }; - /// GroupWorkspaces notification is send when a group is updated by adding or - /// removing members. + /// GroupWorkspaces notification is send when a group is updated by adding + /// or removing members. /// Disable observing the ADS by a group /// (WorkspaceGroup::observeADSNotifications(false)) /// to prevent sending this notification. @@ -110,12 +109,12 @@ public: void setIllegalCharacterList(const std::string &); /// Is the given name a valid name for an object in the ADS const std::string isValid(const std::string &name) const; - /// Overridden add member to attach the name to the workspace when a workspace - /// object is added to the service + /// Overridden add member to attach the name to the workspace when a + /// workspace object is added to the service void add(const std::string &name, const boost::shared_ptr<API::Workspace> &workspace) override; - /// Overridden addOrReplace member to attach the name to the workspace when a - /// workspace object is added to the service + /// Overridden addOrReplace member to attach the name to the workspace when + /// a workspace object is added to the service void addOrReplace(const std::string &name, const boost::shared_ptr<API::Workspace> &workspace) override; @@ -128,26 +127,27 @@ public: /** Retrieve a workspace and cast it to the given WSTYPE * * @param name :: name of the workspace - * @tparam WSTYPE :: type of workspace to cast to. Should sub-class Workspace + * @tparam WSTYPE :: type of workspace to cast to. Should sub-class + * Workspace * @return a shared pointer of WSTYPE */ template <typename WSTYPE> boost::shared_ptr<WSTYPE> retrieveWS(const std::string &name) const { // Get as a bare workspace try { - boost::shared_ptr<Mantid::API::Workspace> workspace = - Kernel::DataService<API::Workspace>::retrieve(name); // Cast to the desired type and return that. - return boost::dynamic_pointer_cast<WSTYPE>(workspace); + return boost::dynamic_pointer_cast<WSTYPE>( + Kernel::DataService<API::Workspace>::retrieve(name)); } catch (Kernel::Exception::NotFoundError &) { - throw Kernel::Exception::NotFoundError( - "Unable to find workspace type with name '" + name + - "': data service ", - name); + throw; } } + std::vector<Workspace_sptr> + retrieveWorkspaces(const std::vector<std::string> &names, + bool unrollGroups = false) const; + /** @name Methods to work with workspace groups */ //@{ void sortGroupByName(const std::string &groupName); diff --git a/Framework/API/src/AnalysisDataService.cpp b/Framework/API/src/AnalysisDataService.cpp index 529ee2dbc80b8fd6abbd1b1bc5830610c72d9c99..6b0d242cedca100eef6adec679de26dd1ca58468 100644 --- a/Framework/API/src/AnalysisDataService.cpp +++ b/Framework/API/src/AnalysisDataService.cpp @@ -1,5 +1,6 @@ #include "MantidAPI/AnalysisDataService.h" #include "MantidAPI/WorkspaceGroup.h" +#include <iterator> #include <sstream> namespace Mantid { @@ -164,6 +165,46 @@ void AnalysisDataServiceImpl::remove(const std::string &name) { } } +/** + * @brief Given a list of names retrieve the corresponding workspace handles + * @param names A list of names of workspaces, if any does not exist then + * a Kernel::Exception::NotFoundError is thrown. + * @param unrollGroups If true flatten groups into the list of members. + * @return A vector of pointers to Workspaces + * @throws std::invalid_argument if no names are provided + * @throws Mantid::Kernel::Exception::NotFoundError if a workspace does not + * exist within the ADS + */ +std::vector<Workspace_sptr> AnalysisDataServiceImpl::retrieveWorkspaces( + const std::vector<std::string> &names, bool unrollGroups) const { + using WorkspacesVector = std::vector<Workspace_sptr>; + WorkspacesVector workspaces; + workspaces.reserve(names.size()); + std::transform( + std::begin(names), std::end(names), std::back_inserter(workspaces), + [this](const std::string &name) { return this->retrieve(name); }); + assert(names.size() == workspaces.size()); + if (unrollGroups) { + using IteratorDifference = + std::iterator_traits<WorkspacesVector::iterator>::difference_type; + for (size_t i = 0; i < workspaces.size(); ++i) { + if (auto group = + boost::dynamic_pointer_cast<WorkspaceGroup>(workspaces.at(i))) { + const auto groupLength(group->size()); + workspaces.erase(std::next(std::begin(workspaces), + static_cast<IteratorDifference>(i))); + for (size_t j = 0; j < groupLength; ++j) { + workspaces.insert(std::next(std::begin(workspaces), + static_cast<IteratorDifference>(i + j)), + group->getItem(j)); + } + i += groupLength; + } + } + } + return workspaces; +} + /** * Sort members by Workspace name. The group must be in the ADS. * @param groupName :: A group name. diff --git a/Framework/API/test/AnalysisDataServiceTest.h b/Framework/API/test/AnalysisDataServiceTest.h index e3e7dbcf66afc2758229448b0fc018420e2a1289..cbb2cbfce64e98b527d2e5372a7408560cefbf0e 100644 --- a/Framework/API/test/AnalysisDataServiceTest.h +++ b/Framework/API/test/AnalysisDataServiceTest.h @@ -5,6 +5,7 @@ #include "MantidAPI/AnalysisDataService.h" #include "MantidAPI/WorkspaceGroup.h" +#include <boost/make_shared.hpp> using namespace Mantid::Kernel; using namespace Mantid::API; @@ -75,6 +76,61 @@ public: TS_ASSERT_THROWS(ads.retrieve("z"), Exception::NotFoundError); } + void test_retrieveWorkspaces_with_empty_list_returns_empty_list() { + std::vector<Workspace_sptr> empty; + TS_ASSERT_EQUALS(empty, ads.retrieveWorkspaces({})); + } + + void test_retrieveWorkspaces_with_all_missing_items_throws_exception() { + TS_ASSERT_THROWS(ads.retrieveWorkspaces({"a"}), Exception::NotFoundError); + TS_ASSERT_THROWS(ads.retrieveWorkspaces({"a", "b"}), + Exception::NotFoundError); + } + + void test_retrieveWorkspaces_with_some_missing_items_throws_exception() { + const std::string name("test_some_missing_items"); + addToADS(name); + TS_ASSERT_THROWS(ads.retrieveWorkspaces({"a", "b"}), + Exception::NotFoundError); + ads.remove(name); + } + + void test_retrieveWorkspaces_with_all_items_present_and_no_group_unrolling() { + const std::vector<std::string> names{"test_all_items_present_1", + "test_all_items_present_2"}; + std::vector<Workspace_sptr> expected; + for (const auto &name : names) { + expected.push_back(addToADS(name)); + } + std::vector<Workspace_sptr> items; + TS_ASSERT_THROWS_NOTHING(items = ads.retrieveWorkspaces(names)); + TS_ASSERT_EQUALS(expected, expected); + + for (const auto &name : names) { + ads.remove(name); + } + } + + void test_retrieveWorkspaces_with_group_unrolling() { + const std::vector<std::string> names{"test_all_items_present_unroll_1", + "test_all_items_present_unroll_2"}; + std::vector<Workspace_sptr> expected; + expected.push_back(addToADS(names[0])); + const size_t nitems{4u}; + WorkspaceGroup_sptr groupWS{addGroupToADS(names[1], nitems)}; + for (auto i = 0u; i < nitems; ++i) { + expected.push_back(groupWS->getItem(i)); + } + std::vector<Workspace_sptr> items; + TS_ASSERT_THROWS_NOTHING(items = ads.retrieveWorkspaces(names, true)); + TS_ASSERT_EQUALS(expected.size(), items.size()); + TS_ASSERT_EQUALS(expected, items); + + for (const auto &name : names) { + ads.remove(name); + } + } + void test_Add_With_Name_That_Has_No_Special_Chars_Is_Accpeted() { const std::string name = "MySpace"; TS_ASSERT_THROWS_NOTHING(addToADS(name)); @@ -479,11 +535,13 @@ private: return space; } - /// Add a group with 2 simple workspaces to the ADS - WorkspaceGroup_sptr addGroupToADS(const std::string &name) { - WorkspaceGroup_sptr group(new WorkspaceGroup); - group->addWorkspace(MockWorkspace_sptr(new MockWorkspace)); - group->addWorkspace(MockWorkspace_sptr(new MockWorkspace)); + /// Add a group with N simple workspaces to the ADS + WorkspaceGroup_sptr addGroupToADS(const std::string &name, + const size_t nitems = 2) { + auto group(boost::make_shared<WorkspaceGroup>()); + for (auto i = 0u; i < nitems; ++i) { + group->addWorkspace(boost::make_shared<MockWorkspace>()); + } ads.add(name, group); return group; } diff --git a/Framework/PythonInterface/inc/MantidPythonInterface/kernel/Converters/ToPyList.h b/Framework/PythonInterface/inc/MantidPythonInterface/kernel/Converters/ToPyList.h new file mode 100644 index 0000000000000000000000000000000000000000..b5e0b0e8e7e010d4468debd54770ca25b3b3d8df --- /dev/null +++ b/Framework/PythonInterface/inc/MantidPythonInterface/kernel/Converters/ToPyList.h @@ -0,0 +1,57 @@ +#ifndef MANTID_PYTHONINTERFACE_TOPYLIST_H_ +#define MANTID_PYTHONINTERFACE_TOPYLIST_H_ +/** + Copyright © 2012 ISIS Rutherford Appleton Laboratory, NScD Oak Ridge + National Laboratory & European Spallation Source + + This file is part of Mantid. + + Mantid is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 3 of the License, or + (at your option) any later version. + + Mantid is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. + + File change history is stored at: <https://github.com/mantidproject/mantid> + Code Documentation is available at: <http://doxygen.mantidproject.org> + */ +#include <boost/python/list.hpp> +#include <vector> + +namespace Mantid { +namespace PythonInterface { +namespace Converters { +//----------------------------------------------------------------------- +// Converter implementation +//----------------------------------------------------------------------- +/** + * Converter that takes a std::vector and converts it into a python list. + * It is able to convert anything for which a converter is already registered + */ +template <typename ElementType> struct ToPyList { + /** + * Converts a cvector to a numpy array + * @param cdata :: A const reference to a vector + * @returns A new python list object + */ + inline boost::python::list + operator()(const std::vector<ElementType> &cdata) const { + boost::python::list result; + for (const auto &item : cdata) { + result.append(item); + } + return result; + } +}; +} // namespace Converters +} // namespace PythonInterface +} // namespace Mantid + +#endif /* MANTID_PYTHONINTERFACE_TOPYLIST_H_ */ diff --git a/Framework/PythonInterface/inc/MantidPythonInterface/kernel/Converters/VectorToNDArray.h b/Framework/PythonInterface/inc/MantidPythonInterface/kernel/Converters/VectorToNDArray.h index d8a5d366dd5fedc926d4319e774d96b2b513006d..43b3d4089d31d78011536c46b722b77060a0a57b 100644 --- a/Framework/PythonInterface/inc/MantidPythonInterface/kernel/Converters/VectorToNDArray.h +++ b/Framework/PythonInterface/inc/MantidPythonInterface/kernel/Converters/VectorToNDArray.h @@ -37,7 +37,7 @@ namespace Converters { * Converter that takes a std::vector and converts it into a flat numpy array. * * The type of conversion is specified by another struct/class that - * contains a static member create. + * contains a static member create1D. */ template <typename ElementType, typename ConversionPolicy> struct VectorToNDArray { diff --git a/Framework/PythonInterface/mantid/api/src/Exports/AnalysisDataService.cpp b/Framework/PythonInterface/mantid/api/src/Exports/AnalysisDataService.cpp index cd0d53f6d1aeb541d0b2be56be91c9f6597699a0..008c2544fa19b17b54e46fe7c85b672ab6847682 100644 --- a/Framework/PythonInterface/mantid/api/src/Exports/AnalysisDataService.cpp +++ b/Framework/PythonInterface/mantid/api/src/Exports/AnalysisDataService.cpp @@ -1,7 +1,14 @@ +#include "MantidAPI/AnalysisDataService.h" +#include "MantidKernel/WarningSuppressions.h" +#include "MantidPythonInterface/kernel/Converters/PySequenceToVector.h" +#include "MantidPythonInterface/kernel/Converters/ToPyList.h" #include "MantidPythonInterface/kernel/DataServiceExporter.h" #include "MantidPythonInterface/kernel/GetPointer.h" -#include "MantidAPI/AnalysisDataService.h" +#include <boost/python/enum.hpp> +#include <boost/python/list.hpp> +#include <boost/python/overloads.hpp> +#include <boost/python/return_value_policy.hpp> using namespace Mantid::API; using namespace Mantid::Kernel; @@ -10,6 +17,23 @@ using namespace boost::python; GET_POINTER_SPECIALIZATION(AnalysisDataServiceImpl) +namespace { +list retrieveWorkspaces(AnalysisDataServiceImpl &self, const list &names, + bool unrollGroups = false) { + return Converters::ToPyList<Workspace_sptr>()(self.retrieveWorkspaces( + Converters::PySequenceToVector<std::string>(names)(), unrollGroups)); +} + +GNU_DIAG_OFF("unused-local-typedef") +// Ignore -Wconversion warnings coming from boost::python +// Seen with GCC 7.1.1 and Boost 1.63.0 +GNU_DIAG_OFF("conversion") +BOOST_PYTHON_FUNCTION_OVERLOADS(AdsRetrieveWorkspacesOverloads, + retrieveWorkspaces, 2, 3) +GNU_DIAG_ON("conversion") +GNU_DIAG_ON("unused-local-typedef") +} // namespace + void export_AnalysisDataService() { using ADSExporter = DataServiceExporter<AnalysisDataServiceImpl, Workspace_sptr>; @@ -18,5 +42,9 @@ void export_AnalysisDataService() { .def("Instance", &AnalysisDataService::Instance, return_value_policy<reference_existing_object>(), "Return a reference to the singleton instance") - .staticmethod("Instance"); + .staticmethod("Instance") + .def("retrieveWorkspaces", retrieveWorkspaces, + AdsRetrieveWorkspacesOverloads( + "Retrieve a list of workspaces by name", + (arg("self"), arg("names"), arg("unrollGroups") = false))); } diff --git a/Framework/PythonInterface/mantid/api/src/Exports/Workspace.cpp b/Framework/PythonInterface/mantid/api/src/Exports/Workspace.cpp index 71dad6096533d0adfef3ae75303586093dd490fc..bf77888bd387422a9b709c37ee88891c46169bfa 100644 --- a/Framework/PythonInterface/mantid/api/src/Exports/Workspace.cpp +++ b/Framework/PythonInterface/mantid/api/src/Exports/Workspace.cpp @@ -31,10 +31,8 @@ GNU_DIAG_ON("unused-local-typedef") ///@endcond } // namespace -//-------------------------------------------------------------------------------------- -// Deprecated function -//-------------------------------------------------------------------------------------- /** + * DEPRECATED. Use DataItem.name() * @param self Reference to the calling object * @return name of the workspace. */ diff --git a/Framework/PythonInterface/test/cpp/CMakeLists.txt b/Framework/PythonInterface/test/cpp/CMakeLists.txt index 4217fafedc9fa796a7aa94e911e71c7ffec481ef..86c0bb1b595b3b63fa4b9a4a89f7a7c64885da7e 100644 --- a/Framework/PythonInterface/test/cpp/CMakeLists.txt +++ b/Framework/PythonInterface/test/cpp/CMakeLists.txt @@ -8,6 +8,7 @@ set ( TEST_FILES PythonAlgorithmInstantiatorTest.h PySequenceToVectorTest.h RunPythonScriptTest.h + ToPyListTest.h ) if ( CXXTEST_FOUND ) diff --git a/Framework/PythonInterface/test/cpp/PySequenceToVectorTest.h b/Framework/PythonInterface/test/cpp/PySequenceToVectorTest.h index 960d64327e1cd9c8064f13ab3c92065c4ccac304..f342d8d5b46222d6e366b8d354c1cd7a225ac69d 100644 --- a/Framework/PythonInterface/test/cpp/PySequenceToVectorTest.h +++ b/Framework/PythonInterface/test/cpp/PySequenceToVectorTest.h @@ -20,6 +20,8 @@ private: using PySequenceToVectorDouble = PySequenceToVector<double>; public: + void tearDown() override { PyErr_Clear(); } + void test_construction_succeeds_with_a_valid_sequence_type() { boost::python::list testList; TS_ASSERT_THROWS_NOTHING(PySequenceToVectorDouble converter(testList)); diff --git a/Framework/PythonInterface/test/cpp/ToPyListTest.h b/Framework/PythonInterface/test/cpp/ToPyListTest.h new file mode 100644 index 0000000000000000000000000000000000000000..bbf707a0c677b7388529c12fca70d4e4f4073a05 --- /dev/null +++ b/Framework/PythonInterface/test/cpp/ToPyListTest.h @@ -0,0 +1,34 @@ +#ifndef MANTID_PYTHONINTERFACE_TOPYLISTTEST_H +#define MANTID_PYTHONINTERFACE_TOPYLISTTEST_H + +#include "MantidPythonInterface/kernel/Converters/ToPyList.h" +#include <boost/python/errors.hpp> +#include <cxxtest/TestSuite.h> + +using namespace Mantid::PythonInterface::Converters; + +class ToPyListTest : public CxxTest::TestSuite { +public: + static ToPyListTest *createSuite() { return new ToPyListTest(); } + static void destroySuite(ToPyListTest *suite) { delete suite; } + + using ToPyListVectorDouble = ToPyList<double>; + + void test_empty_vector_returns_empty_list() { + std::vector<double> empty; + boost::python::list result; + TS_ASSERT_THROWS_NOTHING(result = ToPyListVectorDouble()(empty)); + TS_ASSERT_EQUALS(0, boost::python::len(result)); + } + + void test_unregistered_element_type_throws_runtime_error() { + std::vector<UnregisteredType> unknownElements{UnregisteredType()}; + TS_ASSERT_THROWS(ToPyList<UnregisteredType>()(unknownElements), + boost::python::error_already_set); + } + +private: + struct UnregisteredType {}; +}; + +#endif // MANTID_PYTHONINTERFACE_TOPYLISTTEST_H diff --git a/Framework/PythonInterface/test/python/mantid/api/AnalysisDataServiceTest.py b/Framework/PythonInterface/test/python/mantid/api/AnalysisDataServiceTest.py index 5512b949bd0e3c61a6d1e51fa46f2360e8ec5d2d..a53af1a6bf4f2f3f02c24726048109c52c0f750c 100644 --- a/Framework/PythonInterface/test/python/mantid/api/AnalysisDataServiceTest.py +++ b/Framework/PythonInterface/test/python/mantid/api/AnalysisDataServiceTest.py @@ -7,6 +7,9 @@ from mantid import mtd class AnalysisDataServiceTest(unittest.TestCase): + def tearDown(self): + AnalysisDataService.Instance().clear() + def test_len_returns_correct_value(self): self.assertEquals(len(AnalysisDataService), 0) @@ -34,8 +37,6 @@ class AnalysisDataServiceTest(unittest.TestCase): current_len = len(AnalysisDataService) self._run_createws(wsname) self.assertEquals(len(AnalysisDataService), current_len + 1) - # Remove to clean the test up - AnalysisDataService.remove(wsname) def test_len_decreases_when_item_removed(self): wsname = 'ADSTest_test_len_decreases_when_item_removed' @@ -52,7 +53,6 @@ class AnalysisDataServiceTest(unittest.TestCase): ws = alg.getProperty("OutputWorkspace").value AnalysisDataService.addOrReplace(name, ws) self.assertRaises(RuntimeError, AnalysisDataService.add, name, ws) - AnalysisDataService.remove(name) def test_addOrReplace_replaces_workspace_with_existing_name(self): data = [1.0,2.0,3.0] @@ -64,7 +64,6 @@ class AnalysisDataServiceTest(unittest.TestCase): AnalysisDataService.addOrReplace(name, ws) len_after = len(AnalysisDataService) self.assertEquals(len_after, len_before) - AnalysisDataService.remove(name) def do_check_for_matrix_workspace_type(self, workspace): self.assertTrue(isinstance(workspace, MatrixWorkspace)) @@ -72,12 +71,10 @@ class AnalysisDataServiceTest(unittest.TestCase): self.assertTrue(hasattr(workspace, 'getNumberHistograms')) self.assertTrue(hasattr(workspace, 'getMemorySize')) - def test_retrieve_gives_back_derived_type_not_DataItem(self): wsname = 'ADSTest_test_retrieve_gives_back_derived_type_not_DataItem' self._run_createws(wsname) self.do_check_for_matrix_workspace_type(AnalysisDataService.retrieve(wsname)) - AnalysisDataService.remove(wsname) def test_key_operator_does_same_as_retrieve(self): wsname = 'ADSTest_test_key_operator_does_same_as_retrieve' @@ -91,8 +88,25 @@ class AnalysisDataServiceTest(unittest.TestCase): self.assertEquals(ws_from_op.name(), ws_from_method.name()) self.assertEquals(ws_from_op.getMemorySize(), ws_from_method.getMemorySize()) - # Remove to clean the test up - AnalysisDataService.remove(wsname) + def test_retrieve_workspaces_respects_default_not_unrolling_groups(self): + ws_names = ["test_retrieve_workspaces_1", "test_retrieve_workspaces_2"] + for name in ws_names: + self._run_createws(name) + workspaces = AnalysisDataService.retrieveWorkspaces(ws_names) + self.assertEquals(2, len(workspaces)) + + def test_retrieve_workspaces_accepts_unrolling_groups_argument(self): + ws_names = ["test_retrieve_workspaces_1", "test_retrieve_workspaces_2"] + for name in ws_names: + self._run_createws(name) + group_name = 'group1' + alg = run_algorithm('GroupWorkspaces', InputWorkspaces=ws_names, + OutputWorkspace=group_name) + + workspaces = AnalysisDataService.retrieveWorkspaces([group_name], True) + self.assertEquals(2, len(workspaces)) + self.assertTrue(isinstance(workspaces[0], MatrixWorkspace)) + self.assertTrue(isinstance(workspaces[1], MatrixWorkspace)) def test_removing_item_invalidates_extracted_handles(self): # If a reference to a DataItem has been extracted from the ADS diff --git a/Framework/PythonInterface/test/python/mantid/plots/plotfunctionsTest.py b/Framework/PythonInterface/test/python/mantid/plots/plotfunctionsTest.py index 75bbf1d5e059e2c2e0dad1c0f265e337187e91ea..cff911f8752e3f43034ef312318d4f41fb25f5da 100644 --- a/Framework/PythonInterface/test/python/mantid/plots/plotfunctionsTest.py +++ b/Framework/PythonInterface/test/python/mantid/plots/plotfunctionsTest.py @@ -9,8 +9,8 @@ import unittest import mantid.api import mantid.plots.plotfunctions as funcs from mantid.kernel import config -from mantid.simpleapi import CreateWorkspace, DeleteWorkspace, CreateMDHistoWorkspace,\ - ConjoinWorkspaces, AddTimeSeriesLog +from mantid.simpleapi import (CreateWorkspace, CreateEmptyTableWorkspace, DeleteWorkspace, + CreateMDHistoWorkspace, ConjoinWorkspaces, AddTimeSeriesLog) @@ -116,6 +116,19 @@ class PlotFunctionsTest(unittest.TestCase): funcs.pcolormesh(ax, self.ws_MD_2d) funcs.pcolorfast(ax, self.ws2d_point_uneven, vmin=-1) + def test_1d_plots_with_unplottable_type_raises_attributeerror(self): + table = CreateEmptyTableWorkspace() + _, ax = plt.subplots() + self.assertRaises(AttributeError, funcs.plot, ax, table, wkspIndex=0) + self.assertRaises(AttributeError, funcs.errorbar, ax, table, wkspIndex=0) + + def test_2d_plots_with_unplottable_type_raises_attributeerror(self): + table = CreateEmptyTableWorkspace() + _, ax = plt.subplots() + self.assertRaises(AttributeError, funcs.pcolor, ax, table) + self.assertRaises(AttributeError, funcs.pcolormesh, ax, table) + self.assertRaises(AttributeError, funcs.pcolorfast, ax, table) + if __name__ == '__main__': unittest.main() diff --git a/MantidPlot/src/Mantid/MantidUI.cpp b/MantidPlot/src/Mantid/MantidUI.cpp index 120e93633bed2ba2ced2a0eb66e115f5698881ff..c23e6a1074c23fb102533da74820cdf2759b1aa5 100644 --- a/MantidPlot/src/Mantid/MantidUI.cpp +++ b/MantidPlot/src/Mantid/MantidUI.cpp @@ -1563,16 +1563,10 @@ bool MantidUI::canAcceptDrop(QDragEnterEvent *e) { bool MantidUI::drop(QDropEvent *e) { QString name = e->mimeData()->objectName(); if (name == "MantidWorkspace") { - QString text = e->mimeData()->text(); - int endIndex = 0; - QStringList wsNames; - while (text.indexOf("[\"", endIndex) > -1) { - int startIndex = text.indexOf("[\"", endIndex) + 2; - endIndex = text.indexOf("\"]", startIndex); - wsNames.append(text.mid(startIndex, endIndex - startIndex)); + QStringList wsNames = e->mimeData()->text().split("\n"); + for (const auto &wsName : wsNames) { + importWorkspace(wsName, false); } - - foreach (const auto &wsName, wsNames) { importWorkspace(wsName, false); } return true; } else if (e->mimeData()->hasUrls()) { const auto pyFiles = DropEventHelper::extractPythonFiles(e); diff --git a/qt/applications/workbench/CMakeLists.txt b/qt/applications/workbench/CMakeLists.txt index 22e85645b70d41c8c1e61efd4cdc898f59caade4..ed59fab1bbabd5f22daf7d412cac8d5d50dcdecb 100644 --- a/qt/applications/workbench/CMakeLists.txt +++ b/qt/applications/workbench/CMakeLists.txt @@ -20,6 +20,8 @@ set ( TEST_FILES workbench/config/test/test_user.py workbench/test/test_import.py + workbench/plotting/test/test_functions.py + workbench/widgets/plotselector/test/test_plotselector_model.py workbench/widgets/plotselector/test/test_plotselector_presenter.py workbench/widgets/plotselector/test/test_plotselector_view.py diff --git a/qt/applications/workbench/workbench/app/mainwindow.py b/qt/applications/workbench/workbench/app/mainwindow.py index 24e2a836c4c908fe7ccc5c17bc05117d8221ad9b..c27da1a4b9f54ece7fe79809bc2e9bcd61b2ad15 100644 --- a/qt/applications/workbench/workbench/app/mainwindow.py +++ b/qt/applications/workbench/workbench/app/mainwindow.py @@ -305,7 +305,8 @@ class MainWindow(QMainWindow): # flatten list widgets = [item for column in widgets_layout for row in column for item in row] # show everything - map(lambda w: w.toggle_view(True), widgets) + for w in widgets: + w.toggle_view(True) # split everything on the horizontal for i in range(len(widgets) - 1): first, second = widgets[i], widgets[i+1] diff --git a/qt/applications/workbench/workbench/plotting/figuremanager.py b/qt/applications/workbench/workbench/plotting/figuremanager.py index 21e58e85ea8947692a0da3c17021846c19b104d1..434a19e30b9a8149fdaa65d8824c305c370751c5 100644 --- a/qt/applications/workbench/workbench/plotting/figuremanager.py +++ b/qt/applications/workbench/workbench/plotting/figuremanager.py @@ -21,37 +21,15 @@ import matplotlib from matplotlib.backend_bases import FigureManagerBase from matplotlib.backends.backend_qt5agg import (FigureCanvasQTAgg, backend_version, draw_if_interactive, show) # noqa from matplotlib._pylab_helpers import Gcf -from qtpy.QtCore import Qt, QEvent, QObject, Signal -from qtpy.QtWidgets import QApplication, QLabel, QMainWindow +from qtpy.QtCore import Qt, QObject +from qtpy.QtWidgets import QApplication, QLabel from six import text_type # local imports -from .propertiesdialog import LabelEditor, XAxisEditor, YAxisEditor -from .toolbar import WorkbenchNavigationToolbar -from .qappthreadcall import QAppThreadCall - - -class MainWindow(QMainWindow): - activated = Signal() - closing = Signal() - visibility_changed = Signal() - - def event(self, event): - if event.type() == QEvent.WindowActivate: - self.activated.emit() - return QMainWindow.event(self, event) - - def closeEvent(self, event): - self.closing.emit() - QMainWindow.closeEvent(self, event) - - def hideEvent(self, event): - self.visibility_changed.emit() - QMainWindow.hideEvent(self, event) - - def showEvent(self, event): - self.visibility_changed.emit() - QMainWindow.showEvent(self, event) +from workbench.plotting.figurewindow import FigureWindow +from workbench.plotting.propertiesdialog import LabelEditor, XAxisEditor, YAxisEditor +from workbench.plotting.toolbar import WorkbenchNavigationToolbar +from workbench.plotting.qappthreadcall import QAppThreadCall class FigureManagerWorkbench(FigureManagerBase, QObject): @@ -86,15 +64,14 @@ class FigureManagerWorkbench(FigureManagerBase, QObject): self.fig_visibility_changed_orig = self.fig_visibility_changed self.fig_visibility_changed = QAppThreadCall(self.fig_visibility_changed_orig) - self.canvas = canvas - self.window = MainWindow() + self.window = FigureWindow(canvas) self.window.activated.connect(self._window_activated) self.window.closing.connect(canvas.close_event) self.window.closing.connect(self._widgetclosed) self.window.visibility_changed.connect(self.fig_visibility_changed) self.window.setWindowTitle("Figure %d" % num) - self.canvas.figure.set_label("Figure %d" % num) + canvas.figure.set_label("Figure %d" % num) # Give the keyboard focus to the figure instead of the # manager; StrongFocus accepts both tab and click to focus and @@ -103,8 +80,8 @@ class FigureManagerWorkbench(FigureManagerBase, QObject): # clicked # on. http://qt-project.org/doc/qt-4.8/qt.html#FocusPolicy-enum or # http://doc.qt.digia.com/qt/qt.html#FocusPolicy-enum - self.canvas.setFocusPolicy(Qt.StrongFocus) - self.canvas.setFocus() + canvas.setFocusPolicy(Qt.StrongFocus) + canvas.setFocus() self.window._destroying = False @@ -112,7 +89,7 @@ class FigureManagerWorkbench(FigureManagerBase, QObject): self.statusbar_label = QLabel() self.window.statusBar().addWidget(self.statusbar_label) - self.toolbar = self._get_toolbar(self.canvas, self.window) + self.toolbar = self._get_toolbar(canvas, self.window) if self.toolbar is not None: self.window.addToolBar(self.toolbar) self.toolbar.message.connect(self.statusbar_label.setText) @@ -129,17 +106,17 @@ class FigureManagerWorkbench(FigureManagerBase, QObject): height = cs.height() + self._status_and_tool_height self.window.resize(cs.width(), height) - self.window.setCentralWidget(self.canvas) + self.window.setCentralWidget(canvas) if matplotlib.is_interactive(): self.window.show() - self.canvas.draw_idle() + canvas.draw_idle() def notify_axes_change(fig): # This will be called whenever the current axes is changed if self.toolbar is not None: self.toolbar.update() - self.canvas.figure.add_axobserver(notify_axes_change) + canvas.figure.add_axobserver(notify_axes_change) # Register canvas observers self._cids = [] @@ -160,7 +137,8 @@ class FigureManagerWorkbench(FigureManagerBase, QObject): if self.window._destroying: return self.window._destroying = True - map(self.canvas.mpl_disconnect, self._cids) + for id in self._cids: + self.canvas.mpl_disconnect(id) try: Gcf.destroy(self.num) except AttributeError: diff --git a/qt/applications/workbench/workbench/plotting/figuretype.py b/qt/applications/workbench/workbench/plotting/figuretype.py index ce0d801adacdf67abd3fc424d525c7e571e191b2..24f74e1cd143bb91c3d02c17ce42c6817371747d 100644 --- a/qt/applications/workbench/workbench/plotting/figuretype.py +++ b/qt/applications/workbench/workbench/plotting/figuretype.py @@ -19,7 +19,9 @@ Provides facilities to check plot types """ from __future__ import absolute_import +# third party from mantidqt.py3compat import Enum +from matplotlib.container import ErrorbarContainer class FigureType(Enum): @@ -32,7 +34,7 @@ class FigureType(Enum): Line = 1 # Line plot with error bars Errorbar = 2 - # An image plot from imshow, pcolor + # An image plot from imshow, pcolor, pcolormesh Image = 3 # Any other type of plot Other = 100 @@ -51,25 +53,25 @@ def axes_type(ax): axtype = FigureType.Other if nrows == 1 and ncols == 1: - if len(ax.lines) > 0: + # an errorbar plot also has len(lines) > 0 + if len(ax.containers) > 0 and isinstance(ax.containers[0], ErrorbarContainer): + axtype = FigureType.Errorbar + elif len(ax.lines) > 0: axtype = FigureType.Line - elif len(ax.images) > 0: + elif len(ax.images) > 0 or len(ax.collections) > 0: axtype = FigureType.Image return axtype def figure_type(fig): - """Return the type of the figure + """Return the type of the figure. It inspects the axes + return by fig.gca() :param fig: A matplotlib figure instance :return: An enumeration defining the plot type """ - all_axes = fig.axes - all_axes_length = len(all_axes) - if all_axes_length == 0: + if len(fig.get_axes()) == 0: return FigureType.Empty - elif all_axes_length == 1: - return axes_type(all_axes[0]) else: - return FigureType.Other + return axes_type(fig.gca()) diff --git a/qt/applications/workbench/workbench/plotting/figurewindow.py b/qt/applications/workbench/workbench/plotting/figurewindow.py new file mode 100644 index 0000000000000000000000000000000000000000..11952fb67465bd91c0f9603de5eb5a57bc0a928e --- /dev/null +++ b/qt/applications/workbench/workbench/plotting/figurewindow.py @@ -0,0 +1,106 @@ +# This file is part of the mantid workbench. +# +# Copyright (C) 2018 mantidproject +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +"""Provides the QMainWindow subclass for a plotting window""" +from __future__ import absolute_import + +# std imports +import weakref + +# 3rdparty imports +from qtpy.QtCore import QEvent, Signal +from qtpy.QtWidgets import QMainWindow + +# local imports +from .figuretype import figure_type, FigureType + + +class FigureWindow(QMainWindow): + """A MainWindow that will hold plots""" + activated = Signal() + closing = Signal() + visibility_changed = Signal() + + def __init__(self, canvas, parent=None): + QMainWindow.__init__(self, parent=parent) + # attributes + self._canvas = weakref.proxy(canvas) + + self.setAcceptDrops(True) + + def event(self, event): + if event.type() == QEvent.WindowActivate: + self.activated.emit() + return QMainWindow.event(self, event) + + def closeEvent(self, event): + self.closing.emit() + QMainWindow.closeEvent(self, event) + + def hideEvent(self, event): + self.visibility_changed.emit() + QMainWindow.hideEvent(self, event) + + def showEvent(self, event): + self.visibility_changed.emit() + QMainWindow.showEvent(self, event) + + def dragEnterEvent(self, event): + """ + Accepts drag events if the event contains text and no + urls. + + :param event: A QDragEnterEvent instance for the most + recent drag + """ + data = event.mimeData() + if data.hasText() and not data.hasUrls(): + event.acceptProposedAction() + + def dropEvent(self, event): + """ + If the event data contains a workspace reference then + request a plot of the workspace. + + :param event: A QDropEvent containing the MIME + data of the action + """ + self._plot_on_here(event.mimeData().text().split('\n')) + QMainWindow.dropEvent(self, event) + + # private api + + def _plot_on_here(self, names): + """ + Assume the list of strings refer to workspace names and they are to be plotted + on this figure. If the current figure contains an image plot then + a new image plot will replace the current image. If the current figure + contains a line plot then the user will be asked what should be plotted and this + will overplot onto the figure. If the first line of the plot + :param names: A list of workspace names + """ + if len(names) == 0: + return + # local import to avoid circular import with FigureManager + from workbench.plotting.functions import pcolormesh_from_names, plot_from_names + + fig = self._canvas.figure + fig_type = figure_type(fig) + if fig_type == FigureType.Image: + pcolormesh_from_names(names, fig=fig) + else: + plot_from_names(names, errors=(fig_type == FigureType.Errorbar), + overplot=True, fig=fig) diff --git a/qt/applications/workbench/workbench/plotting/functions.py b/qt/applications/workbench/workbench/plotting/functions.py index 9795fc0cd9737abcf20854de94033afd3900cd2d..b6e113a9c9d61b67b1fb1c9b9f83391498e1a8bf 100644 --- a/qt/applications/workbench/workbench/plotting/functions.py +++ b/qt/applications/workbench/workbench/plotting/functions.py @@ -23,51 +23,54 @@ import collections import math # 3rd party imports -from mantid.api import MatrixWorkspace +from mantid.api import AnalysisDataService, MatrixWorkspace +from mantid.kernel import Logger import matplotlib.pyplot as plt +try: + from matplotlib.cm import viridis as DEFAULT_CMAP +except ImportError: + from matplotlib.cm import jet as DEFAULT_CMAP from mantidqt.py3compat import is_text_string +from mantidqt.dialogs.spectraselectordialog import get_spectra_selection +from matplotlib.gridspec import GridSpec +import numpy as np # local imports +from .figuretype import figure_type, FigureType # ----------------------------------------------------------------------------- # Constants # ----------------------------------------------------------------------------- PROJECTION = 'mantid' -DEFAULT_COLORMAP = 'viridis' # See https://matplotlib.org/api/_as_gen/matplotlib.figure.SubplotParams.html#matplotlib.figure.SubplotParams SUBPLOT_WSPACE = 0.5 SUBPLOT_HSPACE = 0.5 +LOGGER = Logger("workspace.plotting.functions") # ----------------------------------------------------------------------------- -# Functions +# 'Public' Functions # ----------------------------------------------------------------------------- -def raise_if_not_sequence(seq, seq_name): - accepted_types = [list, tuple] - if type(seq) not in accepted_types: - raise ValueError("{} should be a list or tuple".format(seq_name)) - - -def _validate_plot_inputs(workspaces, spectrum_nums, wksp_indices): - """Raises a ValueError if any arguments have the incorrect types""" - if spectrum_nums is not None and wksp_indices is not None: - raise ValueError("Both spectrum_nums and wksp_indices supplied. " - "Please supply only 1.") - - if not isinstance(workspaces, MatrixWorkspace): - raise_if_not_sequence(workspaces, 'Workspaces') - - if spectrum_nums is not None: - raise_if_not_sequence(spectrum_nums, 'spectrum_nums') - if wksp_indices is not None: - raise_if_not_sequence(wksp_indices, 'wksp_indices') +def can_overplot(): + """ + Checks if overplotting on the current figure can proceed + with the given options + :return: A 2-tuple of boolean indicating compatability and + a string containing an error message if the current figure is not + compatible. + """ + compatible = False + msg = "Unable to overplot on currently active plot type.\n" \ + "Please select another plot." + fig = current_figure_or_none() + if fig is not None: + figtype = figure_type(fig) + if figtype is FigureType.Line or figtype is FigureType.Errorbar: + compatible, msg = True, None -def _validate_pcolormesh_inputs(workspaces): - """Raises a ValueError if any arguments have the incorrect types""" - if not isinstance(workspaces, MatrixWorkspace): - raise_if_not_sequence(workspaces, 'Workspaces') + return compatible, msg def current_figure_or_none(): @@ -103,8 +106,35 @@ def figure_title(workspaces, fig_num): return wsname(first) + '-' + str(fig_num) +def plot_from_names(names, errors, overplot, fig=None): + """ + Given a list of names of workspaces, raise a dialog asking for the + a selection of what to plot and then plot it. + + :param names: A list of workspace names + :param errors: If true then error bars will be plotted on the points + :param overplot: If true then the add to the current figure if one + exists and it is a compatible figure + :param fig: If not None then use this figure object to plot + :return: The figure containing the plot or None if selection was cancelled + """ + workspaces = AnalysisDataService.Instance().retrieveWorkspaces(names, unrollGroups=True) + try: + selection = get_spectra_selection(workspaces) + except Exception as exc: + LOGGER.warning(format(str(exc))) + selection = None + + if selection is None: + return None + + return plot(selection.workspaces, spectrum_nums=selection.spectra, + wksp_indices=selection.wksp_indices, + errors=errors, overplot=overplot, fig=fig) + + def plot(workspaces, spectrum_nums=None, wksp_indices=None, errors=False, - overplot=False): + overplot=False, fig=None): """ Create a figure with a single subplot and for each workspace/index add a line plot to the new axes. show() is called before returning the figure instance. A legend @@ -115,7 +145,8 @@ def plot(workspaces, spectrum_nums=None, wksp_indices=None, errors=False, :param wksp_indices: A list of workspace indexes (starts from 0) :param errors: If true then error bars are added for each plot :param overplot: If true then overplot over the current figure if one exists - :returns: The figure containing the plots + :param fig: If not None then use this Figure object to plot + :return: The figure containing the plots """ # check inputs _validate_plot_inputs(workspaces, spectrum_nums, wksp_indices) @@ -124,13 +155,16 @@ def plot(workspaces, spectrum_nums=None, wksp_indices=None, errors=False, else: kw, nums = 'wkspIndex', wksp_indices - # get/create the axes to hold the plot - if overplot: - ax = plt.gca(projection=PROJECTION) - fig = ax.figure + if fig is None: + # get/create the axes to hold the plot + if overplot: + ax = plt.gca(projection=PROJECTION) + fig = ax.figure + else: + fig = plt.figure() + ax = fig.add_subplot(111, projection=PROJECTION) else: - fig = plt.figure() - ax = fig.add_subplot(111, projection=PROJECTION) + ax = fig.gca() # do the plotting plot_fn = ax.errorbar if errors else ax.plot @@ -148,14 +182,28 @@ def plot(workspaces, spectrum_nums=None, wksp_indices=None, errors=False, return fig -def pcolormesh(workspaces): +def pcolormesh_from_names(names, fig=None): + """ + Create a figure containing pcolor subplots + + :param names: A list of workspace names + :param fig: An optional figure to contain the new plots. Its current contents will be cleared + :returns: The figure containing the plots + """ + try: + return pcolormesh(AnalysisDataService.retrieveWorkspaces(names, unrollGroups=True), + fig=fig) + except Exception as exc: + LOGGER.warning(format(str(exc))) + return None + + +def pcolormesh(workspaces, fig=None): """ - Create a figure containing subplots + Create a figure containing pcolor subplots :param workspaces: A list of workspace handles - :param spectrum_nums: A list of spectrum number identifiers (general start from 1) - :param wksp_indices: A list of workspace indexes (starts from 0) - :param errors: If true then error bars are added for each plot + :param fig: An optional figure to contain the new plots. Its current contents will be cleared :returns: The figure containing the plots """ # check inputs @@ -164,26 +212,17 @@ def pcolormesh(workspaces): # create a subplot of the appropriate number of dimensions # extend in number of columns if the number of plottables is not a square number workspaces_len = len(workspaces) - square_side_len = int(math.ceil(math.sqrt(workspaces_len))) - nrows, ncols = square_side_len, square_side_len - if square_side_len*square_side_len != workspaces_len: - # not a square number - square_side_len x square_side_len - # will be large enough but we could end up with an empty - # row so chop that off - if workspaces_len <= (nrows-1)*ncols: - nrows -= 1 + fig, axes, nrows, ncols = _create_subplots(workspaces_len, fig=fig) - fig, axes = plt.subplots(nrows, ncols, squeeze=False, - subplot_kw=dict(projection=PROJECTION)) row_idx, col_idx = 0, 0 for subplot_idx in range(nrows*ncols): ax = axes[row_idx][col_idx] if subplot_idx < workspaces_len: ws = workspaces[subplot_idx] ax.set_title(ws.name()) - pcm = ax.pcolormesh(ws, cmap=DEFAULT_COLORMAP) - xticks = ax.get_xticklabels() - map(lambda lbl: lbl.set_rotation(45), xticks) + pcm = ax.pcolormesh(ws, cmap=DEFAULT_CMAP) + for lbl in ax.get_xticklabels(): + lbl.set_rotation(45) if col_idx < ncols - 1: col_idx += 1 else: @@ -201,8 +240,9 @@ def pcolormesh(workspaces): fig.show() return fig +# ----------------- Compatability functions --------------------- + -# Compatibility function for existing MantidPlot functionality def plotSpectrum(workspaces, indices, distribution=None, error_bars=False, type=None, window=None, clearWindow=None, waterfall=False): @@ -229,3 +269,83 @@ def plotSpectrum(workspaces, indices, distribution=None, error_bars=False, return plot(workspaces, wksp_indices=indices, errors=error_bars, fmt=fmt) + + +# ----------------------------------------------------------------------------- +# 'Private' Functions +# ----------------------------------------------------------------------------- +def _raise_if_not_sequence(value, seq_name, element_type=None): + """ + Raise a ValueError if the given object is not a sequence + + :param value: The value object to validate + :param seq_name: The variable name of the sequence for the error message + :param element_type: An optional type to provide to check that each element + is an instance of this type + :raises ValueError: if the conditions are not met + """ + accepted_types = (list, tuple) + if type(value) not in accepted_types: + raise ValueError("{} should be a list or tuple".format(seq_name)) + if element_type is not None: + def raise_if_not_type(x): + if not isinstance(x, element_type): + raise ValueError("Unexpected type: '{}'".format(x.__class__.__name__)) + + map(raise_if_not_type, value) + + +def _validate_plot_inputs(workspaces, spectrum_nums, wksp_indices): + """Raises a ValueError if any arguments have the incorrect types""" + if spectrum_nums is not None and wksp_indices is not None: + raise ValueError("Both spectrum_nums and wksp_indices supplied. " + "Please supply only 1.") + + _raise_if_not_sequence(workspaces, 'workspaces', MatrixWorkspace) + + if spectrum_nums is not None: + _raise_if_not_sequence(spectrum_nums, 'spectrum_nums') + + if wksp_indices is not None: + _raise_if_not_sequence(wksp_indices, 'wksp_indices') + + +def _validate_pcolormesh_inputs(workspaces): + """Raises a ValueError if any arguments have the incorrect types""" + _raise_if_not_sequence(workspaces, 'workspaces', MatrixWorkspace) + + +def _create_subplots(nplots, fig=None): + """ + Create a set of subplots suitable for a given number of plots. A stripped down + version of plt.subplots that can accept an existing figure instance. + + :param nplots: The number of plots required + :param fig: An optional figure. It is cleared before plotting the new contents + :return: A 2-tuple of (fig, axes) + """ + square_side_len = int(math.ceil(math.sqrt(nplots))) + nrows, ncols = square_side_len, square_side_len + if square_side_len*square_side_len != nplots: + # not a square number - square_side_len x square_side_len + # will be large enough but we could end up with an empty + # row so chop that off + if nplots <= (nrows-1)*ncols: + nrows -= 1 + + if fig is None: + fig = plt.figure() + else: + fig.clf() + # annoyling this repl + nplots = nrows*ncols + gs = GridSpec(nrows, ncols) + axes = np.empty(nplots, dtype=object) + ax0 = fig.add_subplot(gs[0, 0], projection=PROJECTION) + axes[0] = ax0 + for i in range(1, nplots): + axes[i] = fig.add_subplot(gs[i // ncols, i % ncols], + projection=PROJECTION) + axes = axes.reshape(nrows, ncols) + + return fig, axes, nrows, ncols diff --git a/qt/applications/workbench/workbench/plotting/test/test_figuretype.py b/qt/applications/workbench/workbench/plotting/test/test_figuretype.py index 1d23fbbfd6b34c8b6535783de3bf6ac37532bda4..6e691d3783fd5b574eab2e1330f6f93a6c692a0d 100644 --- a/qt/applications/workbench/workbench/plotting/test/test_figuretype.py +++ b/qt/applications/workbench/workbench/plotting/test/test_figuretype.py @@ -42,6 +42,11 @@ class FigureTypeTest(TestCase): ax.plot([1]) self.assertEqual(FigureType.Line, figure_type(ax.figure)) + def test_error_plot_returns_error(self): + ax = plt.subplot(111) + ax.errorbar([1], [1], yerr=[0.01]) + self.assertEqual(FigureType.Errorbar, figure_type(ax.figure)) + def test_image_plot_returns_image(self): ax = plt.subplot(111) ax.imshow([[1],[1]]) diff --git a/qt/applications/workbench/workbench/plotting/test/test_functions.py b/qt/applications/workbench/workbench/plotting/test/test_functions.py index 54f17df00f093aa5529f5a3e8a98fdff96f4c190..a1c11d316eaf0a2860315f7c827485ec75821ce0 100644 --- a/qt/applications/workbench/workbench/plotting/test/test_functions.py +++ b/qt/applications/workbench/workbench/plotting/test/test_functions.py @@ -18,14 +18,22 @@ from __future__ import absolute_import # std imports from unittest import TestCase, main +try: + from unittest import mock +except ImportError: + import mock -# thirdparty imports +# third party imports +from mantid.api import AnalysisDataService, WorkspaceFactory import matplotlib matplotlib.use('AGG') # noqa import matplotlib.pyplot as plt +from mantidqt.dialogs.spectraselectordialog import SpectraSelection +import numpy as np # local imports -from workbench.plotting.functions import current_figure_or_none, figure_title +from workbench.plotting.functions import (can_overplot, current_figure_or_none, figure_title, + plot, plot_from_names, pcolormesh_from_names) # Avoid importing the whole of mantid for a single mock of the workspace class @@ -40,8 +48,30 @@ class FakeWorkspace(object): class FunctionsTest(TestCase): - def test_current_figure_or_none_returns_none_if_no_figures_exist(self): + _test_ws = None + + def setUp(self): + if self._test_ws is None: + self.__class__._test_ws = WorkspaceFactory.Instance().create("Workspace2D", NVectors=2, YLength=5, XLength=5) + + def tearDown(self): + AnalysisDataService.Instance().clear() plt.close('all') + + def test_can_overplot_returns_false_with_no_active_plots(self): + self.assertFalse(can_overplot()[0]) + + def test_can_overplot_returns_true_for_active_line_plot(self): + plt.plot([1, 2]) + self.assertTrue(can_overplot()[0]) + + def test_can_overplot_returns_false_for_active_patch_plot(self): + plt.pcolormesh(np.arange(9.).reshape(3,3)) + allowed, msg = can_overplot() + self.assertFalse(allowed) + self.assertTrue(len(msg) > 0) + + def test_current_figure_or_none_returns_none_if_no_figures_exist(self): self.assertTrue(current_figure_or_none() is None) def test_figure_title_with_single_string(self): @@ -61,6 +91,107 @@ class FunctionsTest(TestCase): with self.assertRaises(AssertionError): figure_title([], 5) + @mock.patch('workbench.plotting.functions.get_spectra_selection') + @mock.patch('workbench.plotting.functions.plot') + def test_plot_from_names_calls_plot(self, get_spectra_selection_mock, plot_mock): + ws_name = 'test_plot_from_names_calls_plot-1' + AnalysisDataService.Instance().addOrReplace(ws_name, self._test_ws) + selection = SpectraSelection([self._test_ws]) + selection.wksp_indices = [0] + get_spectra_selection_mock.return_value = selection + plot_from_names([ws_name], errors=False, overplot=False) + + self.assertEqual(1, plot_mock.call_count) + + @mock.patch('workbench.plotting.functions.get_spectra_selection') + def test_plot_from_names_produces_single_line_plot_for_valid_name(self, get_spectra_selection_mock): + self._do_plot_from_names_test(get_spectra_selection_mock, expected_labels=["spec 1"], wksp_indices=[0], + errors=False, overplot=False) + + @mock.patch('workbench.plotting.functions.get_spectra_selection') + def test_plot_from_names_produces_single_error_plot_for_valid_name(self, get_spectra_selection_mock): + fig = self._do_plot_from_names_test(get_spectra_selection_mock, + # matplotlib does not set labels on the lines for error plots + expected_labels=[None, None, None], + wksp_indices=[0], errors=True, overplot=False) + self.assertEqual(1, len(fig.gca().containers)) + + @mock.patch('workbench.plotting.functions.get_spectra_selection') + def test_plot_from_names_produces_overplot_for_valid_name(self, get_spectra_selection_mock): + # make first plot + plot([self._test_ws], wksp_indices=[0]) + self._do_plot_from_names_test(get_spectra_selection_mock, expected_labels=["spec 1", "spec 2"], + wksp_indices=[1], errors=False, overplot=True) + + @mock.patch('workbench.plotting.functions.get_spectra_selection') + def test_plot_from_names_within_existing_figure(self, get_spectra_selection_mock): + # make existing plot + fig = plot([self._test_ws], wksp_indices=[0]) + self._do_plot_from_names_test(get_spectra_selection_mock, expected_labels=["spec 1", "spec 2"], + wksp_indices=[1], errors=False, overplot=False, + target_fig=fig) + + @mock.patch('workbench.plotting.functions.pcolormesh') + def test_pcolormesh_from_names_calls_pcolormesh(self, pcolormesh_mock): + ws_name = 'test_pcolormesh_from_names_calls_pcolormesh-1' + AnalysisDataService.Instance().addOrReplace(ws_name, self._test_ws) + pcolormesh_from_names([ws_name]) + + self.assertEqual(1, pcolormesh_mock.call_count) + + def test_pcolormesh_from_names(self): + ws_name = 'test_pcolormesh_from_names-1' + AnalysisDataService.Instance().addOrReplace(ws_name, self._test_ws) + fig = pcolormesh_from_names([ws_name]) + + self.assertEqual(1, len(fig.gca().collections)) + + def test_pcolormesh_from_names_using_existing_figure(self): + ws_name = 'test_pcolormesh_from_names-1' + AnalysisDataService.Instance().addOrReplace(ws_name, self._test_ws) + target_fig = plt.figure() + fig = pcolormesh_from_names([ws_name], fig=target_fig) + + self.assertEqual(fig, target_fig) + self.assertEqual(1, len(fig.gca().collections)) + + # ------------- Failure tests ------------- + + def test_plot_from_names_with_non_plottable_workspaces_returns_None(self): + table = WorkspaceFactory.Instance().createTable() + table_name = 'test_plot_from_names_with_non_plottable_workspaces_returns_None' + AnalysisDataService.Instance().addOrReplace(table_name, table) + result = plot_from_names([table_name], errors=False, overplot=False) + self.assertTrue(result is None) + + def test_pcolormesh_from_names_with_non_plottable_workspaces_returns_None(self): + table = WorkspaceFactory.Instance().createTable() + table_name = 'test_pcolormesh_from_names_with_non_plottable_workspaces_returns_None' + AnalysisDataService.Instance().addOrReplace(table_name, table) + result = pcolormesh_from_names([table_name]) + self.assertTrue(result is None) + + # ------------- Private ------------------- + def _do_plot_from_names_test(self, get_spectra_selection_mock, expected_labels, + wksp_indices, errors, overplot, target_fig=None): + ws_name = 'test_plot_from_names-1' + AnalysisDataService.Instance().addOrReplace(ws_name, self._test_ws) + + selection = SpectraSelection([self._test_ws]) + selection.wksp_indices = wksp_indices + get_spectra_selection_mock.return_value = selection + fig = plot_from_names([ws_name], errors, overplot, target_fig) + if target_fig is not None: + self.assertEqual(target_fig, fig) + + plotted_lines = fig.gca().get_lines() + self.assertEqual(len(expected_labels), len(plotted_lines)) + for label_part, line in zip(expected_labels, plotted_lines): + if label_part is not None: + self.assertTrue(label_part in line.get_label(), + msg="Label fragment '{}' not found in line label".format(label_part)) + return fig + if __name__ == '__main__': main() diff --git a/qt/applications/workbench/workbench/plugins/workspacewidget.py b/qt/applications/workbench/workbench/plugins/workspacewidget.py index 2ffcc1c21be7aec370d6cf2f516fe73c82553943..c8c7698117ebd3be48430e689a097e49cddfb220 100644 --- a/qt/applications/workbench/workbench/plugins/workspacewidget.py +++ b/qt/applications/workbench/workbench/plugins/workspacewidget.py @@ -20,48 +20,13 @@ from __future__ import (absolute_import, unicode_literals) import functools # third-party library imports -from mantid.api import AnalysisDataService, MatrixWorkspace, WorkspaceGroup -from mantid.kernel import Logger -from mantidqt.dialogs.spectraselectordialog import get_spectra_selection +from mantid.api import AnalysisDataService from mantidqt.widgets.workspacewidget.workspacetreewidget import WorkspaceTreeWidget from qtpy.QtWidgets import QMessageBox, QVBoxLayout # local package imports from workbench.plugins.base import PluginWidget -from workbench.plotting.figuretype import figure_type, FigureType -from workbench.plotting.functions import current_figure_or_none, pcolormesh, plot - -LOGGER = Logger(b"workspacewidget") - - -def _workspaces_from_names(names): - """ - Convert a list of workspace names to a list of MatrixWorkspace handles. Any WorkspaceGroup - encountered is flattened and its members inserted into the list at this point - - Flattens any workspace groups with the list, preserving the order of the remaining elements - :param names: A list of workspace names - :return: A list where each element is a single MatrixWorkspace - """ - ads = AnalysisDataService.Instance() - flat = [] - for name in names: - try: - ws = ads[name.encode('utf-8')] - except KeyError: - LOGGER.warning("Skipping {} as it does not exist".format(name)) - continue - - if isinstance(ws, MatrixWorkspace): - flat.append(ws) - elif isinstance(ws, WorkspaceGroup): - group_len = len(ws) - for i in range(group_len): - flat.append(ws[i]) - else: - LOGGER.warning("{} it is not a MatrixWorkspace or WorkspaceGroup.".format(name)) - - return flat +from workbench.plotting.functions import can_overplot, pcolormesh, plot_from_names class WorkspaceWidget(PluginWidget): @@ -70,6 +35,8 @@ class WorkspaceWidget(PluginWidget): def __init__(self, parent): super(WorkspaceWidget, self).__init__(parent) + self._ads = AnalysisDataService.Instance() + # layout self.workspacewidget = WorkspaceTreeWidget() layout = QVBoxLayout() @@ -110,20 +77,12 @@ class WorkspaceWidget(PluginWidget): exists and it is a compatible figure """ if overplot: - compatible, error_msg = self._can_overplot() + compatible, error_msg = can_overplot() if not compatible: QMessageBox.warning(self, "", error_msg) return - try: - selection = get_spectra_selection(_workspaces_from_names(names), self) - if selection is not None: - plot(selection.workspaces, spectrum_nums=selection.spectra, - wksp_indices=selection.wksp_indices, - errors=errors, overplot=overplot) - except BaseException: - import traceback - traceback.print_exc() + plot_from_names(names, errors, overplot) def _do_plot_colorfill(self, names): """ @@ -132,26 +91,7 @@ class WorkspaceWidget(PluginWidget): :param names: A list of workspace names """ try: - pcolormesh(_workspaces_from_names(names)) + pcolormesh(self._ads.retrieveWorkspaces(names, unrollGroups=True)) except BaseException: import traceback traceback.print_exc() - - def _can_overplot(self): - """ - Checks if overplotting can proceed with the given options - - :return: A 2-tuple of boolean indicating compatability and - a string containing an error message if the current figure is not - compatible. - """ - compatible = False - msg = "Unable to overplot on currently active plot type.\n" \ - "Please select another plot." - fig = current_figure_or_none() - if fig is not None: - figtype = figure_type(fig) - if figtype is FigureType.Line or figtype is FigureType.Errorbar: - compatible, msg = True, None - - return compatible, msg diff --git a/qt/paraview_ext/VatesSimpleGui/ViewWidgets/src/MdViewerWidget.cpp b/qt/paraview_ext/VatesSimpleGui/ViewWidgets/src/MdViewerWidget.cpp index e4fdad8166d15c8f887accc64074d3376d8ba8ef..c99fc62342b593f723b5ec88f64cc2d1a2a06a12 100644 --- a/qt/paraview_ext/VatesSimpleGui/ViewWidgets/src/MdViewerWidget.cpp +++ b/qt/paraview_ext/VatesSimpleGui/ViewWidgets/src/MdViewerWidget.cpp @@ -1651,11 +1651,8 @@ bool otherWorkspacePresent() { void MdViewerWidget::handleDragAndDropPeaksWorkspaces(QEvent *e, const QString &text, QStringList &wsNames) { - int endIndex = 0; - while (text.indexOf("[\"", endIndex) > -1) { - int startIndex = text.indexOf("[\"", endIndex) + 2; - endIndex = text.indexOf("\"]", startIndex); - QString candidate = text.mid(startIndex, endIndex - startIndex); + const QStringList selectedWorkspaces = text.split("\n"); + for (const auto &candidate : selectedWorkspaces) { // Only append the candidate if SplattorPlotView is selected and an // MDWorkspace is loaded. if (currentView->getViewType() == ModeControlWidget::Views::SPLATTERPLOT && diff --git a/qt/python/mantidqt/dialogs/spectraselectordialog.py b/qt/python/mantidqt/dialogs/spectraselectordialog.py index 0bde6df74688864fa3641688a2c72d390d7902a1..4e81ccef4593d96e064d43d64283a93d36f27dc2 100644 --- a/qt/python/mantidqt/dialogs/spectraselectordialog.py +++ b/qt/python/mantidqt/dialogs/spectraselectordialog.py @@ -19,6 +19,7 @@ from __future__ import (absolute_import, unicode_literals) # std imports # 3rd party imports +from mantid.api import MatrixWorkspace import qtawesome as qta from qtpy.QtWidgets import QDialogButtonBox @@ -54,9 +55,17 @@ class SpectraSelection(object): class SpectraSelectionDialog(SpectraSelectionDialogUIBase): + @staticmethod + def raise_error_if_workspaces_not_compatible(workspaces): + for ws in workspaces: + if not isinstance(ws, MatrixWorkspace): + raise ValueError("Expected MatrixWorkspace, found {}.".format(ws.__class__.__name__)) + def __init__(self, workspaces, parent=None): super(SpectraSelectionDialog, self).__init__(parent) + self.raise_error_if_workspaces_not_compatible(workspaces) + # attributes self._workspaces = workspaces self.spec_min, self.spec_max = None, None @@ -78,6 +87,7 @@ class SpectraSelectionDialog(SpectraSelectionDialogUIBase): self.accept() # ------------------- Private ------------------------- + def _init_ui(self): ui = SpectraSelectionDialogUI() ui.setupUi(self) @@ -161,17 +171,19 @@ class SpectraSelectionDialog(SpectraSelectionDialogUIBase): return self.selection is not None -def get_spectra_selection(workspaces, parent_widget): +def get_spectra_selection(workspaces, parent_widget=None): """Decides whether it is necessary to request user input when asked to plot a list of workspaces. The input dialog will only be shown in the case where all workspaces have more than 1 spectrum :param workspaces: A list of MatrixWorkspaces that will be plotted - :param parent_widget: A parent_widget to use for the input selection dialog + :param parent_widget: An optional parent_widget to use for the input selection dialog :returns: Either a SpectraSelection object containing the details of workspaces to plot or None indicating the request was cancelled + :raises ValueError: if the workspaces are not of type MatrixWorkspace """ + SpectraSelectionDialog.raise_error_if_workspaces_not_compatible(workspaces) single_spectra_ws = [wksp.getNumberHistograms() for wksp in workspaces if wksp.getNumberHistograms() == 1] if len(single_spectra_ws) > 0: # At least 1 workspace contains only a single spectrum so this is all diff --git a/qt/python/mantidqt/dialogs/test/test_spectraselectiondialog.py b/qt/python/mantidqt/dialogs/test/test_spectraselectiondialog.py index fe899f3841ba35e7edbc38524b1b3a592c91734d..578ae5f66bd8fc7d60fa21468fff0f07348159c4 100644 --- a/qt/python/mantidqt/dialogs/test/test_spectraselectiondialog.py +++ b/qt/python/mantidqt/dialogs/test/test_spectraselectiondialog.py @@ -16,51 +16,63 @@ # along with this program. If not, see <http://www.gnu.org/licenses/>. # std imports -import time import unittest +try: + from unittest import mock +except ImportError: + import mock # 3rdparty imports -from mantid.simpleapi import CreateSampleWorkspace, CropWorkspace -from qtpy.QtWidgets import QDialogButtonBox +from mantid.api import WorkspaceFactory +from qtpy.QtWidgets import QDialog, QDialogButtonBox # local imports from mantidqt.utils.qt.test import requires_qapp -from mantidqt.dialogs.spectraselectordialog import parse_selection_str, SpectraSelectionDialog +from mantidqt.dialogs.spectraselectordialog import (get_spectra_selection, parse_selection_str, + SpectraSelectionDialog) @requires_qapp class SpectraSelectionDialogTest(unittest.TestCase): + _single_spec_ws = None + _multi_spec_ws = None + + def setUp(self): + if self._single_spec_ws is None: + self.__class__._single_spec_ws = WorkspaceFactory.Instance().create("Workspace2D", NVectors=1, + XLength=1, YLength=1) + self.__class__._multi_spec_ws = WorkspaceFactory.Instance().create("Workspace2D", NVectors=200, + XLength=1, YLength=1) + def test_initial_dialog_setup(self): - workspaces = [CreateSampleWorkspace(OutputWorkspace='ws', StoreInADS=False)] + workspaces = [self._multi_spec_ws] dlg = SpectraSelectionDialog(workspaces) self.assertFalse(dlg._ui.buttonBox.button(QDialogButtonBox.Ok).isEnabled()) def test_filling_workspace_details_single_workspace(self): - workspaces = [CreateSampleWorkspace(OutputWorkspace='ws', StoreInADS=False)] + workspaces = [self._multi_spec_ws] dlg = SpectraSelectionDialog(workspaces) self.assertEqual("valid range: 1-200", dlg._ui.specNums.placeholderText()) self.assertEqual("valid range: 0-199", dlg._ui.wkspIndices.placeholderText()) def test_filling_workspace_details_multiple_workspaces_of_same_size(self): - workspaces = [CreateSampleWorkspace(OutputWorkspace='ws', StoreInADS=False), - CreateSampleWorkspace(OutputWorkspace='ws2', StoreInADS=False)] + workspaces = [self._multi_spec_ws, + self._multi_spec_ws] dlg = SpectraSelectionDialog(workspaces) self.assertEqual("valid range: 1-200", dlg._ui.specNums.placeholderText()) self.assertEqual("valid range: 0-199", dlg._ui.wkspIndices.placeholderText()) def test_filling_workspace_details_multiple_workspaces_of_different_sizes(self): - ws1 = CreateSampleWorkspace(OutputWorkspace='ws', NumBanks=1, StoreInADS=False) - ws1 = CropWorkspace(ws1, StartWorkspaceIndex=50) - ws2 = CreateSampleWorkspace(OutputWorkspace='ws', StoreInADS=False) - - dlg = SpectraSelectionDialog([ws1, ws2]) + cropped_ws = WorkspaceFactory.Instance().create("Workspace2D", NVectors=50, XLength=1, YLength=1) + for i in range(cropped_ws.getNumberHistograms()): + cropped_ws.getSpectrum(i).setSpectrumNo(51 + i) + dlg = SpectraSelectionDialog([cropped_ws, self._multi_spec_ws]) self.assertEqual("valid range: 51-100", dlg._ui.specNums.placeholderText()) self.assertEqual("valid range: 0-49", dlg._ui.wkspIndices.placeholderText()) def test_valid_text_in_boxes_activates_ok(self): - workspaces = [CreateSampleWorkspace(OutputWorkspace='ws', StoreInADS=False)] - dlg = SpectraSelectionDialog(workspaces) + dlg = SpectraSelectionDialog([self._multi_spec_ws]) def do_test(input_box): input_box.setText("1") @@ -72,13 +84,51 @@ class SpectraSelectionDialogTest(unittest.TestCase): do_test(dlg._ui.specNums) def test_plot_all_gives_only_workspaces_indices(self): - ws = CreateSampleWorkspace(OutputWorkspace='ws', StoreInADS=False) - dlg = SpectraSelectionDialog([ws]) + dlg = SpectraSelectionDialog([self._multi_spec_ws]) dlg._ui.buttonBox.button(QDialogButtonBox.YesToAll).click() self.assertTrue(dlg.selection is not None) self.assertTrue(dlg.selection.spectra is None) self.assertEqual(range(200), dlg.selection.wksp_indices) + def test_entered_workspace_indices_gives_correct_selection_back(self): + dlg = SpectraSelectionDialog([self._multi_spec_ws]) + dlg._ui.wkspIndices.setText("1-3,5") + dlg._ui.buttonBox.button(QDialogButtonBox.Ok).click() + + self.assertTrue(dlg.selection is not None) + self.assertTrue(dlg.selection.spectra is None) + self.assertEqual([1, 2, 3, 5], dlg.selection.wksp_indices) + + def test_entered_spectra_gives_correct_selection_back(self): + dlg = SpectraSelectionDialog([self._multi_spec_ws]) + dlg._ui.wkspIndices.setText("50-60") + dlg._ui.buttonBox.button(QDialogButtonBox.Ok).click() + + self.assertTrue(dlg.selection is not None) + self.assertTrue(dlg.selection.spectra is None) + self.assertEqual(list(range(50, 61)), dlg.selection.wksp_indices) + + @mock.patch('mantidqt.dialogs.spectraselectordialog.SpectraSelectionDialog', autospec=True) + def test_get_spectra_selection_cancelled_returns_None(self, dialog_mock): + # a new instance of the mock created inside get_spectra_selection will return + # dialog_mock + dialog_mock.return_value = dialog_mock + dialog_mock.Rejected = QDialog.Rejected + dialog_mock.exec_.return_value = dialog_mock.Rejected + + selection = get_spectra_selection([self._multi_spec_ws]) + + dialog_mock.exec_.assert_called_once_with() + self.assertTrue(selection is None) + + @mock.patch('mantidqt.dialogs.spectraselectordialog.SpectraSelectionDialog') + def test_get_spectra_selection_does_not_use_dialog_for_single_spectrum(self, dialog_mock): + selection = get_spectra_selection([self._single_spec_ws]) + + dialog_mock.assert_not_called() + self.assertEqual([0], selection.wksp_indices) + self.assertEqual([self._single_spec_ws], selection.workspaces) + def test_parse_selection_str_single_number(self): s = '1' self.assertEqual([1], parse_selection_str(s, 1, 3)) @@ -108,8 +158,15 @@ class SpectraSelectionDialogTest(unittest.TestCase): self.assertEqual([1, 2, 3, 5, 8, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19], parse_selection_str(s, 1, 20)) + # --------------- failure tests ----------- + def test_construction_with_non_MatrixWorkspace_type_raises_exception(self): + table = WorkspaceFactory.Instance().createTable() + self.assertRaises(ValueError, SpectraSelectionDialog, [self._single_spec_ws, table]) + + def test_get_spectra_selection_raises_error_with_wrong_workspace_type(self): + table = WorkspaceFactory.Instance().createTable() + self.assertRaises(ValueError, get_spectra_selection, [self._single_spec_ws, table]) + if __name__ == '__main__': unittest.main() - # investigate why this is needed to avoid a segfault on Linux - time.sleep(0.5) diff --git a/qt/python/mantidqt/utils/qt/test/__init__.py b/qt/python/mantidqt/utils/qt/test/__init__.py index 19a8885cfe81a5134675f81103b61d5810b03e8e..9380bdb1feb93d780de7b546281198e887b8fafe 100644 --- a/qt/python/mantidqt/utils/qt/test/__init__.py +++ b/qt/python/mantidqt/utils/qt/test/__init__.py @@ -63,7 +63,7 @@ def requires_qapp(cls): qapp = QApplication.instance() if qapp is None: setup_library_paths() - self._qapp = QApplication(['']) + cls._qapp = QApplication([cls.__name__]) else: self._qapp = qapp orig_setUp(self) diff --git a/qt/python/mantidqt/widgets/codeeditor/test/test_execution.py b/qt/python/mantidqt/widgets/codeeditor/test/test_execution.py index e10529105be6e44c524b5b603224007499ea462c..9c202e8f99e65d258fec3ce06dd71e0074acd025 100644 --- a/qt/python/mantidqt/widgets/codeeditor/test/test_execution.py +++ b/qt/python/mantidqt/widgets/codeeditor/test/test_execution.py @@ -75,9 +75,9 @@ class PythonCodeExecutionTest(unittest.TestCase): def test_execute_places_output_in_globals(self): code = "_local=100" user_globals = self._verify_serial_execution_successful(code) - self.assertEquals(100, user_globals['_local']) + self.assertEqual(100, user_globals['_local']) user_globals = self._verify_async_execution_successful(code) - self.assertEquals(100, user_globals['_local']) + self.assertEqual(100, user_globals['_local']) def test_execute_async_calls_success_signal_on_completion(self): code = "x=1+2" diff --git a/qt/widgets/common/inc/MantidQtWidgets/Common/ScriptEditor.h b/qt/widgets/common/inc/MantidQtWidgets/Common/ScriptEditor.h index 4c4287ccf83a04c74b0a3d6d7f04a540328d8136..6228f18b215739ed50f8e12c7bb7511d190e7931 100644 --- a/qt/widgets/common/inc/MantidQtWidgets/Common/ScriptEditor.h +++ b/qt/widgets/common/inc/MantidQtWidgets/Common/ScriptEditor.h @@ -17,6 +17,7 @@ class FindReplaceDialog; class QAction; class QMenu; +class QMimeData; class QKeyEvent; class QMouseEvent; class QsciAPIs; @@ -158,6 +159,8 @@ protected: void dropEvent(QDropEvent *de) override; void dragMoveEvent(QDragMoveEvent *de) override; void dragEnterEvent(QDragEnterEvent *de) override; + QByteArray fromMimeData(const QMimeData *source, + bool &rectangular) const override; private slots: diff --git a/qt/widgets/common/src/MantidTreeWidget.cpp b/qt/widgets/common/src/MantidTreeWidget.cpp index 0b18d9debfcdf2eaffad098a3bfeee8aa6d4d5b8..f46559d17962d8cfc266e70cc61a00966c68b3f5 100644 --- a/qt/widgets/common/src/MantidTreeWidget.cpp +++ b/qt/widgets/common/src/MantidTreeWidget.cpp @@ -105,30 +105,18 @@ void MantidTreeWidget::mouseMoveEvent(QMouseEvent *e) { QApplication::startDragDistance()) return; - // Start dragging - QDrag *drag = new QDrag(this); - QMimeData *mimeData = new QMimeData; - - QStringList wsnames = getSelectedWorkspaceNames(); - if (wsnames.size() == 0) + QStringList wsNames = getSelectedWorkspaceNames(); + if (wsNames.isEmpty()) return; - QString importStatement = ""; - foreach (const QString wsname, wsnames) { - QString prefix = ""; - if (wsname[0].isDigit()) - prefix = "ws"; - if (importStatement.size() > 0) - importStatement += "\n"; - importStatement += prefix + wsname + " = mtd[\"" + wsname + "\"]"; - } - - mimeData->setText(importStatement); - mimeData->setObjectName("MantidWorkspace"); + // Start dragging - Qt docs say not to delete the QDrag object + // manually + QDrag *drag = new QDrag(this); + QMimeData *mimeData = new QMimeData; drag->setMimeData(mimeData); - - Qt::DropAction dropAction = drag->exec(Qt::CopyAction | Qt::MoveAction); - (void)dropAction; + mimeData->setObjectName("MantidWorkspace"); + mimeData->setText(wsNames.join("\n")); + drag->exec(Qt::CopyAction | Qt::MoveAction); } void MantidTreeWidget::mouseDoubleClickEvent(QMouseEvent *e) { diff --git a/qt/widgets/common/src/ScriptEditor.cpp b/qt/widgets/common/src/ScriptEditor.cpp index 4f27612803ffa5740c9766d5d9d855e5c8104a4e..ccac7fda8b66c5650c0884d8b20671e74a327f80 100644 --- a/qt/widgets/common/src/ScriptEditor.cpp +++ b/qt/widgets/common/src/ScriptEditor.cpp @@ -413,9 +413,22 @@ void ScriptEditor::dragMoveEvent(QDragMoveEvent *de) { * @param de :: The drag enter event */ void ScriptEditor::dragEnterEvent(QDragEnterEvent *de) { - if (!de->mimeData()->hasUrls()) - // pass to base class - This handles text appropriately + if (!de->mimeData()->hasUrls()) { QsciScintilla::dragEnterEvent(de); + } +} + +/** + * If the QMimeData object holds workspaces names then extract text from a + * QMimeData object and add the necessary wrapping text to import mantid. + * @param source An existing QMimeData object + * @param rectangular On return rectangular is set if the text corresponds to a + * rectangular selection. + * @return The text + */ +QByteArray ScriptEditor::fromMimeData(const QMimeData *source, + bool &rectangular) const { + return QsciScintilla::fromMimeData(source, rectangular); } /** @@ -423,11 +436,10 @@ void ScriptEditor::dragEnterEvent(QDragEnterEvent *de) { * @param de :: The drag drop event */ void ScriptEditor::dropEvent(QDropEvent *de) { - QStringList filenames; - const QMimeData *mimeData = de->mimeData(); - if (!mimeData->hasUrls()) { + if (!de->mimeData()->hasUrls()) { + QDropEvent localDrop(*de); // pass to base class - This handles text appropriately - QsciScintilla::dropEvent(de); + QsciScintilla::dropEvent(&localDrop); } } diff --git a/qt/widgets/common/src/WorkspacePresenter/WorkspaceTreeWidgetSimple.cpp b/qt/widgets/common/src/WorkspacePresenter/WorkspaceTreeWidgetSimple.cpp index f0f02481484e227a7277b24d879bbbbac2bbecc1..f73e0022f362e469bef361265c62cb5fa5a4309a 100644 --- a/qt/widgets/common/src/WorkspacePresenter/WorkspaceTreeWidgetSimple.cpp +++ b/qt/widgets/common/src/WorkspacePresenter/WorkspaceTreeWidgetSimple.cpp @@ -58,17 +58,26 @@ void WorkspaceTreeWidgetSimple::popupContextMenu() { menu = new QMenu(this); menu->setObjectName("WorkspaceContextMenu"); - // plot submenu first - QMenu *plotSubMenu(new QMenu("Plot", menu)); - plotSubMenu->addAction(m_plotSpectrum); - plotSubMenu->addAction(m_overplotSpectrum); - plotSubMenu->addAction(m_plotSpectrumWithErrs); - plotSubMenu->addAction(m_overplotSpectrumWithErrs); - plotSubMenu->addSeparator(); - plotSubMenu->addAction(m_plotColorfill); - menu->addMenu(plotSubMenu); - - menu->addSeparator(); + // plot submenu first for MatrixWorkspace. + // Check is defensive just in case the workspace has disappeared + Workspace_sptr workspace; + try { + workspace = AnalysisDataService::Instance().retrieve( + selectedWsName.toStdString()); + } catch (Exception::NotFoundError &) { + return; + } + if (boost::dynamic_pointer_cast<MatrixWorkspace>(workspace)) { + QMenu *plotSubMenu(new QMenu("Plot", menu)); + plotSubMenu->addAction(m_plotSpectrum); + plotSubMenu->addAction(m_overplotSpectrum); + plotSubMenu->addAction(m_plotSpectrumWithErrs); + plotSubMenu->addAction(m_overplotSpectrumWithErrs); + plotSubMenu->addSeparator(); + plotSubMenu->addAction(m_plotColorfill); + menu->addMenu(plotSubMenu); + menu->addSeparator(); + } menu->addAction(m_rename); menu->addAction(m_saveNexus);