Skip to content
Snippets Groups Projects
Commit 48d20882 authored by Samuel Jones's avatar Samuel Jones
Browse files

Re #24049 Further test incrementation

parent f4e69337
No related branches found
No related tags found
No related merge requests found
// Mantid Repository : https://github.com/mantidproject/mantid
//
// Copyright © 2018 ISIS Rutherford Appleton Laboratory UKRI,
// NScD Oak Ridge National Laboratory, European Spallation Source
// & Institut Laue - Langevin
// SPDX - License - Identifier: GPL - 3.0 +
#ifndef ANALYSISDATASERVICEOBSERVERTEST_H_
#define ANALYSISDATASERVICEOBSERVERTEST_H_
#include <cxxtest/TestSuite.h>
#import "MantidAPI/AnalysisDataServiceObserver.h"
class MockInheritingClass : public Mantid::API::AnalysisDataServiceObserver {
MockInheritingClass()
: m_anyChangeHandleCalled(false), m_addHandleCalled(false),
m_replaceHandleCalled(false), m_deleteHandleCalled(false),
m_clearHandleCalled(false), m_renameHandleCalled(false),
m_groupHandleCalled(false), m_unGroupHandleCalled(false),
m_groupUpdateHandleCalled(false) {
this->observeAll(false);
}
~MockInheritingClass() { this->observeAll(false); }
void anyChangeHandle() override { m_anyChangeHandleCalled = true; }
void addHandle(const std::string &wsName, const Workspace_sptr ws) override {
m_addHandleCalled = true;
}
void replaceHandle(const std::string &wsName,
const Workspace_sptr ws) override {
m_replaceHandleCalled = true;
}
void deleteHandle(const std::string &wsName,
const Workspace_sptr ws) override {
m_deleteHandleCalled = true;
}
void clearHandle() { m_clearHandleCalled = true; }
void renameHandle(const std::string &wsName, const std::string &newName) {
m_renameHandleCalled = true;
}
void groupHandle(const std::string &wsName, const Workspace_sptr ws) {
m_groupHandleCalled = true;
}
void unGroupHandle(const std::string &wsName, const Workspace_sptr ws) {
m_unGroupHandleCalled = true;
}
void groupUpdateHandle(const std::string &wsName, const Workspace_sptr ws) {
m_groupUpdateHandleCalled = true;
}
bool m_anyChangeHandleCalled, m_addHandleCalled, m_replaceHandleCalled,
m_deleteHandleCalled, m_clearHandleCalled, m_renameHandleCalled,
m_groupHandleCalled, m_unGroupHandleCalled, m_groupUpdateHandleCalled;
}
class AnalysisDataServiceObserverTest : public CxxTest::TestSuite {
private:
AnalysisDataServiceImpl &ads;
std::unique_ptr<MockInheritingClass> m_mockInheritingClass;
void setUp() {
ads.clear();
m_mockInheritingClass = std::make_unique<MockInheritingClass>()
}
void addWorkspaceToADS(std::string name = "dummy") {
CreateSampleWorkspace alg;
alg.setChild(true);
alg.initialize();
alg.setPropertyValue("OutputWorkspace", name);
alg.execute();
}
void test_anyChangeHandle_is_called_on_add() {
m_mockInheritingClass->observeAll();
addWorkspaceToADS();
TS_ASSERT(m_mockInheritingClass.m_anyChangeHandleCalled)
}
void test_addHandle_is_called_on_add() {
m_mockInheritingClass->observeAdd();
addWorkspaceToADS();
TS_ASSERT(m_mockInheritingClass.m_addHandleCalled)
}
void test_deleteHandle_is_called_on_delete() {
addWorkspaceToADS();
m_mockInheritingClass->observeDelete();
ads.remove("dummy");
TS_ASSERT(m_mockInheritingClass.m_deleteHandleCalled)
}
void test_replaceHandle_is_called_on_replace() {
addWorkspaceToADS();
m_mockInheritingClass->observeReplace();
addWorkspaceToADS();
TS_ASSERT(m_mockInheritingClass.m_replaceHandleCalled)
}
void test_clearHandle_is_called_on_clear() {
addWorkspaceToADS();
m_mockInheritingClass->observeClear();
ads.clear();
TS_ASSERT(m_mockInheritingClass.m_clearHandleCalled)
}
void test_renameHandle_is_called_on_rename() {
addWorkspaceToADS();
m_mockInheritingClass->observeRename();
Mantid::Algorithms::RenameWorkspace alg;
alg.initialize();
alg.setPropertyValue("InputWorkspace", "dummy");
alg.setPropertyValue("OutputWorkspace", "dummy2");
alg.execute();
TS_ASSERT(m_mockInheritingClass.m_renameHandleCalled)
}
void test_groupHandle_is_called_on_group_made() {
addWorkspaceToADS();
addWorkspaceToADS("dummy2");
m_mockInheritingClass->observeGroup();
Mantid::Algorithms::GroupWorkspaces alg;
alg.initialize();
alg.setPropertyValue("InputWorkspaces", "dummy,dummy2");
alg.setPropertyValue("OutputWorkspace", "newGroup");
alg.execute();
TS_ASSERT(m_mockInheritingClass.m_groupHandleCalled)
}
void test_unGroupHandle_is_called_on_un_grouping() {
addWorkspaceToADS();
addWorkspaceToADS("dummy2");
Mantid::Algorithms::GroupWorkspaces alg;
alg.initialize();
alg.setPropertyValue("InputWorkspaces", "dummy,dummy2");
alg.setPropertyValue("OutputWorkspace", "newGroup");
alg.execute();
m_mockInheritingClass->observeUnGroup();
Mantid::Algorithms::UnGroupWorkspaces alg2;
alg2.initialize();
alg2.setPropertyValue("InputWorkspace", "newGroup");
alg.exectute();
TS_ASSERT(m_mockInheritingClass.m_unGroupHandleCalled)
}
void test_groupUpdated_is_called_on_group_updated() {
addWorkspaceToADS();
addWorkspaceToADS("dummy2");
addWorkspaceToADS("dummy3");
Mantid::Algorithms::GroupWorkspaces alg;
alg.initialize();
alg.setPropertyValue("InputWorkspaces", "dummy,dummy2");
alg.setPropertyValue("OutputWorkspace", "newGroup");
alg.execute();
m_mockInheritingClass->observeGroup();
ads.addToGroup("newGroup", "dummy3");
TS_ASSERT(m_mockInheritingClass.m_groupUpdateHandleCalled)
}
};
#endif /* ANALYSISDATASERVICEOBSERVERTEST_H_ */
\ No newline at end of file
...@@ -6,15 +6,17 @@ ...@@ -6,15 +6,17 @@
# SPDX - License - Identifier: GPL - 3.0 + # SPDX - License - Identifier: GPL - 3.0 +
from __future__ import (absolute_import, division, print_function) from __future__ import (absolute_import, division, print_function)
import six
import unittest import unittest
from testhelpers import run_algorithm from testhelpers import run_algorithm
from mantid.api import AnalysisDataService, AnalysisDataServiceImpl, MatrixWorkspace, Workspace from mantid.api import AnalysisDataService, AnalysisDataServiceImpl, MatrixWorkspace, Workspace
from mantid import mtd from mantid import mtd
class AnalysisDataServiceTest(unittest.TestCase): class AnalysisDataServiceTest(unittest.TestCase):
def tearDown(self): def tearDown(self):
AnalysisDataService.Instance().clear() AnalysisDataService.Instance().clear()
def test_len_returns_correct_value(self): def test_len_returns_correct_value(self):
self.assertEquals(len(AnalysisDataService), 0) self.assertEquals(len(AnalysisDataService), 0)
...@@ -160,5 +162,34 @@ class AnalysisDataServiceTest(unittest.TestCase): ...@@ -160,5 +162,34 @@ class AnalysisDataServiceTest(unittest.TestCase):
for name in extra_names: for name in extra_names:
mtd.remove(name) mtd.remove(name)
def test_addToGroup_adds_workspace_to_group(self):
from mantid.simpleapi import CreateSampleWorkspace, GroupWorkspaces
CreateSampleWorkspace(OutputWorkspace="ws1")
CreateSampleWorkspace(OutputWorkspace="ws2")
GroupWorkspaces(InputWorkspaces="ws1,ws2", OutputWorkspace="NewGroup")
CreateSampleWorkspace(OutputWorkspace="ws3")
AnalysisDataService.addToGroup("NewGroup", "ws3")
group = mtd['NewGroup']
self.assertEquals(group.size(), 3)
six.assertCountEqual(self, group.getNames(), ["ws1", "ws2", "ws3"])
def test_removeFromGroup_removes_workspace_from_group(self):
from mantid.simpleapi import CreateSampleWorkspace, GroupWorkspaces
CreateSampleWorkspace(OutputWorkspace="ws1")
CreateSampleWorkspace(OutputWorkspace="ws2")
CreateSampleWorkspace(OutputWorkspace="ws3")
GroupWorkspaces(InputWorkspaces="ws1,ws2,ws3", OutputWorkspace="NewGroup")
AnalysisDataService.removeFromGroup("NewGroup", "ws3")
group = mtd['NewGroup']
self.assertEquals(group.size(), 2)
six.assertCountEqual(self, group.getNames(), ["ws1", "ws2"])
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -21,8 +21,8 @@ from mantidqt.project.projectsaver import ProjectSaver ...@@ -21,8 +21,8 @@ from mantidqt.project.projectsaver import ProjectSaver
class Project(AnalysisDataServiceObserver): class Project(AnalysisDataServiceObserver):
def __init__(self): def __init__(self):
super(Project, self).__init__() super(Project, self).__init__()
# Has the project been saved # Has the project been saved, to Access this call .saved
self.saved = True self.__saved = True
# Last save locations # Last save locations
self.last_project_location = None self.last_project_location = None
...@@ -42,7 +42,7 @@ class Project(AnalysisDataServiceObserver): ...@@ -42,7 +42,7 @@ class Project(AnalysisDataServiceObserver):
project_saver = ProjectSaver(self.project_file_ext) project_saver = ProjectSaver(self.project_file_ext)
project_saver.save_project(directory=self.last_project_location, workspace_to_save=workspaces_to_save, project_saver.save_project(directory=self.last_project_location, workspace_to_save=workspaces_to_save,
interfaces_to_save=None) interfaces_to_save=None)
self.saved = True self.__saved = True
def save_as(self): def save_as(self):
directory = self._get_directory_finder(accept_mode=QFileDialog.AcceptSave) directory = self._get_directory_finder(accept_mode=QFileDialog.AcceptSave)
...@@ -56,7 +56,7 @@ class Project(AnalysisDataServiceObserver): ...@@ -56,7 +56,7 @@ class Project(AnalysisDataServiceObserver):
workspaces_to_save = AnalysisDataService.getObjectNames() workspaces_to_save = AnalysisDataService.getObjectNames()
project_saver = ProjectSaver(self.project_file_ext) project_saver = ProjectSaver(self.project_file_ext)
project_saver.save_project(directory=directory, workspace_to_save=workspaces_to_save, interfaces_to_save=None) project_saver.save_project(directory=directory, workspace_to_save=workspaces_to_save, interfaces_to_save=None)
self.saved = True self.__saved = True
@staticmethod @staticmethod
def _get_directory_finder(accept_mode): def _get_directory_finder(accept_mode):
...@@ -90,7 +90,7 @@ class Project(AnalysisDataServiceObserver): ...@@ -90,7 +90,7 @@ class Project(AnalysisDataServiceObserver):
:return: Bool; Returns false if no save needed/save complete. Returns True if need to cancel closing. :return: Bool; Returns false if no save needed/save complete. Returns True if need to cancel closing.
""" """
# If the current project is saved then return and don't do anything # If the current project is saved then return and don't do anything
if self.saved: if self.__saved:
return return
result = self._offer_save_message_box(parent) result = self._offer_save_message_box(parent)
...@@ -109,13 +109,16 @@ class Project(AnalysisDataServiceObserver): ...@@ -109,13 +109,16 @@ class Project(AnalysisDataServiceObserver):
QMessageBox.Yes) QMessageBox.Yes)
def modified_project(self): def modified_project(self):
if not self.saved: self.__saved = False
return
self.saved = False
def anyChangeHandle(self): def anyChangeHandle(self):
self.modified_project() self.modified_project()
def __get_saved(self):
return self.__saved
saved = property(__get_saved)
@staticmethod @staticmethod
def _clear_unused_workspaces(path): def _clear_unused_workspaces(path):
files_to_remove = [] files_to_remove = []
......
...@@ -12,7 +12,6 @@ import unittest ...@@ -12,7 +12,6 @@ import unittest
import sys import sys
import tempfile import tempfile
import os import os
from time import sleep
from qtpy.QtWidgets import QMessageBox from qtpy.QtWidgets import QMessageBox
...@@ -27,28 +26,28 @@ else: ...@@ -27,28 +26,28 @@ else:
class ProjectTest(unittest.TestCase): class ProjectTest(unittest.TestCase):
def setUp(self):
self.project = Project()
def tearDown(self): def tearDown(self):
ADS.clear() ADS.clear()
def test_save_calls_save_as_when_last_location_is_not_none(self): def test_save_calls_save_as_when_last_location_is_not_none(self):
self.project.save_as = mock.MagicMock() project = Project()
self.project.save() project.save_as = mock.MagicMock()
self.assertEqual(self.project.save_as.call_count, 1) project.save()
self.assertEqual(project.save_as.call_count, 1)
def test_save_does_not_call_save_as_when_last_location_is_not_none(self): def test_save_does_not_call_save_as_when_last_location_is_not_none(self):
self.project.save_as = mock.MagicMock() project = Project()
self.project.last_project_location = "1" project.save_as = mock.MagicMock()
self.assertEqual(self.project.save_as.call_count, 0) project.last_project_location = "1"
self.assertEqual(project.save_as.call_count, 0)
def test_save_saves_project_successfully(self): def test_save_saves_project_successfully(self):
project = Project()
working_directory = tempfile.mkdtemp() working_directory = tempfile.mkdtemp()
self.project.last_project_location = working_directory project.last_project_location = working_directory
CreateSampleWorkspace(OutputWorkspace="ws1") CreateSampleWorkspace(OutputWorkspace="ws1")
self.project.save() project.save()
self.assertTrue(os.path.isdir(working_directory)) self.assertTrue(os.path.isdir(working_directory))
file_list = os.listdir(working_directory) file_list = os.listdir(working_directory)
...@@ -56,97 +55,87 @@ class ProjectTest(unittest.TestCase): ...@@ -56,97 +55,87 @@ class ProjectTest(unittest.TestCase):
self.assertTrue("ws1.nxs" in file_list) self.assertTrue("ws1.nxs" in file_list)
def test_save_as_saves_project_successfully(self): def test_save_as_saves_project_successfully(self):
project = Project()
working_directory = tempfile.mkdtemp() working_directory = tempfile.mkdtemp()
self.project._get_directory_finder = mock.MagicMock(return_value=working_directory) project._get_directory_finder = mock.MagicMock(return_value=working_directory)
CreateSampleWorkspace(OutputWorkspace="ws1") CreateSampleWorkspace(OutputWorkspace="ws1")
self.project.save_as() project.save_as()
self.assertEqual(self.project._get_directory_finder.call_count, 1) self.assertEqual(project._get_directory_finder.call_count, 1)
self.assertTrue(os.path.isdir(working_directory)) self.assertTrue(os.path.isdir(working_directory))
file_list = os.listdir(working_directory) file_list = os.listdir(working_directory)
self.assertTrue(os.path.basename(working_directory) + ".mtdproj" in file_list) self.assertTrue(os.path.basename(working_directory) + ".mtdproj" in file_list)
self.assertTrue("ws1.nxs" in file_list) self.assertTrue("ws1.nxs" in file_list)
def test_load_calls_loads_successfully(self): def test_load_calls_loads_successfully(self):
project = Project()
working_directory = tempfile.mkdtemp() working_directory = tempfile.mkdtemp()
self.project._get_directory_finder = mock.MagicMock(return_value=working_directory) project._get_directory_finder = mock.MagicMock(return_value=working_directory)
CreateSampleWorkspace(OutputWorkspace="ws1") CreateSampleWorkspace(OutputWorkspace="ws1")
self.project.save_as() project.save_as()
ADS.clear() ADS.clear()
self.project.load() project.load()
self.assertEqual(self.project._get_directory_finder.call_count, 2) self.assertEqual(project._get_directory_finder.call_count, 2)
self.assertEqual(["ws1"], ADS.getObjectNames()) self.assertEqual(["ws1"], ADS.getObjectNames())
def test_offer_save_does_nothing_if_saved_is_true(self): def test_offer_save_does_nothing_if_saved_is_true(self):
self.project.saved = True project = Project()
self.assertEqual(project.offer_save(None), None)
self.assertEqual(self.project.offer_save(None), None)
def test_offer_save_does_something_if_saved_is_false(self): def test_offer_save_does_something_if_saved_is_false(self):
self.project._offer_save_message_box = mock.MagicMock(return_value=QMessageBox.Yes) project = Project()
self.project.save = mock.MagicMock() project._offer_save_message_box = mock.MagicMock(return_value=QMessageBox.Yes)
self.project.saved = False project.save = mock.MagicMock()
# Add something to the ads so __saved is set to false
CreateSampleWorkspace(OutputWorkspace="ws1")
self.assertEqual(self.project.offer_save(None), False) self.assertEqual(project.offer_save(None), False)
self.assertEqual(self.project.save.call_count, 1) self.assertEqual(project.save.call_count, 1)
self.assertEqual(self.project._offer_save_message_box.call_count, 1) self.assertEqual(project._offer_save_message_box.call_count, 1)
def test_adding_to_ads_sets_saved_to_false(self): def test_adding_to_ads_sets_saved_to_false(self):
self.project.saved = True project = Project()
CreateSampleWorkspace(OutputWorkspace="ws1") CreateSampleWorkspace(OutputWorkspace="ws1")
# It takes some time for the notification to move sometimes so wait for it self.assertTrue(not project.saved)
sleep(0.05)
self.assertTrue(not self.project.saved)
def test_removing_from_ads_sets_saved_to_false(self): def test_removing_from_ads_sets_saved_to_false(self):
CreateSampleWorkspace(OutputWorkspace="ws1") CreateSampleWorkspace(OutputWorkspace="ws1")
self.project.saved = True project = Project()
DeleteWorkspace("ws1") ADS.remove("ws1")
# It takes some time for the notification to move sometimes so wait for it
sleep(0.05)
self.assertTrue(not self.project.saved) self.assertTrue(not project.saved)
def test_grouping_in_ads_sets_saved_to_false(self): def test_grouping_in_ads_sets_saved_to_false(self):
CreateSampleWorkspace(OutputWorkspace="ws1") CreateSampleWorkspace(OutputWorkspace="ws1")
CreateSampleWorkspace(OutputWorkspace="ws2") CreateSampleWorkspace(OutputWorkspace="ws2")
self.project.saved = True project = Project()
GroupWorkspaces(InputWorkspaces="ws1,ws2", OutputWorkspace="NewGroup") GroupWorkspaces(InputWorkspaces="ws1,ws2", OutputWorkspace="NewGroup")
# It takes some time for the notification to move sometimes so wait for it self.assertTrue(not project.saved)
sleep(0.05)
self.assertTrue(not self.project.saved)
def test_renaming_in_ads_sets_saved_to_false(self): def test_renaming_in_ads_sets_saved_to_false(self):
CreateSampleWorkspace(OutputWorkspace="ws1") CreateSampleWorkspace(OutputWorkspace="ws1")
self.project.saved = True
RenameWorkspace(InputWorkspace="ws1", OutputWorkspace="ws2")
# It takes some time for the notification to move sometimes so wait for it project = Project()
sleep(0.05) RenameWorkspace(InputWorkspace="ws1", OutputWorkspace="ws2")
self.assertTrue(not self.project.saved) self.assertTrue(not project.saved)
def test_ungrouping_in_ads_sets_saved_to_false(self): def test_ungrouping_in_ads_sets_saved_to_false(self):
CreateSampleWorkspace(OutputWorkspace="ws1") CreateSampleWorkspace(OutputWorkspace="ws1")
CreateSampleWorkspace(OutputWorkspace="ws2") CreateSampleWorkspace(OutputWorkspace="ws2")
GroupWorkspaces(InputWorkspaces="ws1,ws2", OutputWorkspace="NewGroup") GroupWorkspaces(InputWorkspaces="ws1,ws2", OutputWorkspace="NewGroup")
self.project.saved = True project = Project()
UnGroupWorkspace(InputWorkspace="NewGroup") UnGroupWorkspace(InputWorkspace="NewGroup")
# It takes some time for the notification to move sometimes so wait for it self.assertTrue(not project.saved)
sleep(0.05)
self.assertTrue(not self.project.saved)
def test_group_updated_in_ads_sets_saved_to_false(self): def test_group_updated_in_ads_sets_saved_to_false(self):
CreateSampleWorkspace(OutputWorkspace="ws1") CreateSampleWorkspace(OutputWorkspace="ws1")
...@@ -154,26 +143,28 @@ class ProjectTest(unittest.TestCase): ...@@ -154,26 +143,28 @@ class ProjectTest(unittest.TestCase):
GroupWorkspaces(InputWorkspaces="ws1,ws2", OutputWorkspace="NewGroup") GroupWorkspaces(InputWorkspaces="ws1,ws2", OutputWorkspace="NewGroup")
CreateSampleWorkspace(OutputWorkspace="ws3") CreateSampleWorkspace(OutputWorkspace="ws3")
self.project.saved = True project = Project()
ADS.addToGroup("NewGroup", "ws3") ADS.addToGroup("NewGroup", "ws3")
# It takes some time for the notification to move sometimes so wait for it self.assertTrue(not project.saved)
sleep(0.05)
self.assertTrue(not self.project.saved)
def test_removing_unused_workspaces_operates_as_expected_from_save(self): def test_removing_unused_workspaces_operates_as_expected_from_save(self):
project = Project()
working_directory = tempfile.mkdtemp() working_directory = tempfile.mkdtemp()
self.project.last_project_location = working_directory project.last_project_location = working_directory
CreateSampleWorkspace(OutputWorkspace="ws1") CreateSampleWorkspace(OutputWorkspace="ws1")
self.project.save() project.save()
ADS.clear() ADS.clear()
CreateSampleWorkspace(OutputWorkspace="ws2") CreateSampleWorkspace(OutputWorkspace="ws2")
self.project.save() project.save()
self.assertTrue(os.path.isdir(working_directory)) self.assertTrue(os.path.isdir(working_directory))
file_list = os.listdir(working_directory) file_list = os.listdir(working_directory)
self.assertTrue(os.path.basename(working_directory) + ".mtdproj" in file_list) self.assertTrue(os.path.basename(working_directory) + ".mtdproj" in file_list)
self.assertTrue("ws2.nxs" in file_list) self.assertTrue("ws2.nxs" in file_list)
self.assertTrue("ws1.nxs" not in file_list) self.assertTrue("ws1.nxs" not in file_list)
if __name__ == "__main__":
unittest.main()
...@@ -64,3 +64,7 @@ class ProjectReaderTest(unittest.TestCase): ...@@ -64,3 +64,7 @@ class ProjectReaderTest(unittest.TestCase):
project_reader.read_project(working_directory) project_reader.read_project(working_directory)
self.assertEqual(["ws1"], project_reader.workspace_names) self.assertEqual(["ws1"], project_reader.workspace_names)
self.assertEqual({}, project_reader.interfaces_dicts) self.assertEqual({}, project_reader.interfaces_dicts)
if __name__ == "__main__":
unittest.main()
...@@ -35,13 +35,17 @@ class ProjectSaverTest(unittest.TestCase): ...@@ -35,13 +35,17 @@ class ProjectSaverTest(unittest.TestCase):
ADS.addOrReplace(ws1_name, CreateSampleWorkspace(OutputWorkspace=ws1_name)) ADS.addOrReplace(ws1_name, CreateSampleWorkspace(OutputWorkspace=ws1_name))
project_saver = projectsaver.ProjectSaver(project_file_ext) project_saver = projectsaver.ProjectSaver(project_file_ext)
file_name = working_directory + "/" + os.path.basename(working_directory) + project_file_ext file_name = working_directory + "/" + os.path.basename(working_directory) + project_file_ext
saved_file = "{\"interfaces\": {}, \"workspaces\": [\"ws1\"]}"
workspaces_string = "\"workspaces\": [\"ws1\"]"
interfaces_string = "\"interfaces\": {}"
project_saver.save_project(workspace_to_save=[ws1_name], directory=working_directory) project_saver.save_project(workspace_to_save=[ws1_name], directory=working_directory)
# Check project file is saved correctly # Check project file is saved correctly
f = open(file_name, "r") f = open(file_name, "r")
self.assertEqual(f.read(), saved_file) file_string = f.read()
self.assertTrue(workspaces_string in file_string)
self.assertTrue(interfaces_string in file_string)
# Check workspace is saved # Check workspace is saved
list_of_files = os.listdir(working_directory) list_of_files = os.listdir(working_directory)
...@@ -62,14 +66,17 @@ class ProjectSaverTest(unittest.TestCase): ...@@ -62,14 +66,17 @@ class ProjectSaverTest(unittest.TestCase):
CreateSampleWorkspace(OutputWorkspace=ws5_name) CreateSampleWorkspace(OutputWorkspace=ws5_name)
project_saver = projectsaver.ProjectSaver(project_file_ext) project_saver = projectsaver.ProjectSaver(project_file_ext)
file_name = working_directory + "/" + os.path.basename(working_directory) + project_file_ext file_name = working_directory + "/" + os.path.basename(working_directory) + project_file_ext
saved_file = "{\"interfaces\": {}, \"workspaces\": [\"ws1\", \"ws2\", \"ws3\", \"ws4\", \"ws5\"]}"
workspaces_string = "\"workspaces\": [\"ws1\", \"ws2\", \"ws3\", \"ws4\", \"ws5\"]"
interfaces_string = "\"interfaces\": {}"
project_saver.save_project(workspace_to_save=[ws1_name, ws2_name, ws3_name, ws4_name, ws5_name], project_saver.save_project(workspace_to_save=[ws1_name, ws2_name, ws3_name, ws4_name, ws5_name],
directory=working_directory) directory=working_directory)
# Check project file is saved correctly # Check project file is saved correctly
f = open(file_name, "r") f = open(file_name, "r")
self.assertEqual(f.read(), saved_file) file_string = f.read()
self.assertTrue(workspaces_string in file_string)
self.assertTrue(interfaces_string in file_string)
# Check workspace is saved # Check workspace is saved
list_of_files = os.listdir(working_directory) list_of_files = os.listdir(working_directory)
...@@ -90,13 +97,15 @@ class ProjectSaverTest(unittest.TestCase): ...@@ -90,13 +97,15 @@ class ProjectSaverTest(unittest.TestCase):
CreateSampleWorkspace(OutputWorkspace=ws3_name) CreateSampleWorkspace(OutputWorkspace=ws3_name)
project_saver = projectsaver.ProjectSaver(project_file_ext) project_saver = projectsaver.ProjectSaver(project_file_ext)
file_name = working_directory + "/" + os.path.basename(working_directory) + project_file_ext file_name = working_directory + "/" + os.path.basename(working_directory) + project_file_ext
saved_file = "{\"interfaces\": {}, \"workspaces\": [\"ws1\"]}" workspaces_string = "\"workspaces\": [\"ws1\"]"
interfaces_string = "\"interfaces\": {}"
project_saver.save_project(workspace_to_save=[ws1_name], directory=working_directory) project_saver.save_project(workspace_to_save=[ws1_name], directory=working_directory)
# Check project file is saved correctly # Check project file is saved correctly
f = open(file_name, "r") f = open(file_name, "r")
self.assertEqual(f.read(), saved_file) file_string = f.read()
self.assertTrue(workspaces_string in file_string)
self.assertTrue(interfaces_string in file_string)
# Check workspace is saved # Check workspace is saved
list_of_files = os.listdir(working_directory) list_of_files = os.listdir(working_directory)
...@@ -128,39 +137,54 @@ class ProjectWriterTest(unittest.TestCase): ...@@ -128,39 +137,54 @@ class ProjectWriterTest(unittest.TestCase):
if os.path.isdir(working_directory): if os.path.isdir(working_directory):
rmtree(working_directory) rmtree(working_directory)
def test_write_out_on_just_dicts(self): def test_write_out_on_just_interfaces(self):
workspace_list = [] workspace_list = []
small_dict = {"interface1": {"value1": 2, "value2": 3}, "interface2": {"value3": 4, "value4": 5}} small_dict = {"interface1": [2, 3]}
project_writer = projectsaver.ProjectWriter(small_dict, working_directory, workspace_list, project_file_ext) project_writer = projectsaver.ProjectWriter(small_dict, working_directory, workspace_list, project_file_ext)
file_name = working_directory + "/" + os.path.basename(working_directory) + project_file_ext file_name = working_directory + "/" + os.path.basename(working_directory) + project_file_ext
saved_file = "{\"interfaces\": {\"interface1\": {\"value2\": 3, \"value1\": 2}, \"interface2\": {\"value4\"" \
": 5, \"value3\": 4}}, \"workspaces\": []}" workspaces_string = "\"workspaces\": []"
interfaces_string = "\"interfaces\": {\"interface1\": [2, 3]}"
project_writer.write_out() project_writer.write_out()
f = open(file_name, "r") f = open(file_name, "r")
self.assertEqual(f.read(), saved_file) file_string = f.read()
self.assertTrue(workspaces_string in file_string)
self.assertTrue(interfaces_string in file_string)
def test_write_out_on_just_workspaces(self): def test_write_out_on_just_workspaces(self):
workspace_list = ["ws1", "ws2", "ws3", "ws4"] workspace_list = ["ws1", "ws2", "ws3", "ws4"]
small_dict = {} small_dict = {}
project_writer = projectsaver.ProjectWriter(small_dict, working_directory, workspace_list, project_file_ext) project_writer = projectsaver.ProjectWriter(small_dict, working_directory, workspace_list, project_file_ext)
file_name = working_directory + "/" + os.path.basename(working_directory) + project_file_ext file_name = working_directory + "/" + os.path.basename(working_directory) + project_file_ext
saved_file = "{\"interfaces\": {}, \"workspaces\": [\"ws1\", \"ws2\", \"ws3\", \"ws4\"]}"
workspaces_string = "\"workspaces\": [\"ws1\", \"ws2\", \"ws3\", \"ws4\"]"
interfaces_string = "\"interfaces\": {}"
project_writer.write_out() project_writer.write_out()
f = open(file_name, "r") f = open(file_name, "r")
self.assertEqual(f.read(), saved_file) file_string = f.read()
self.assertTrue(workspaces_string in file_string)
self.assertTrue(interfaces_string in file_string)
def test_write_out_on_both_workspaces_and_dicts(self): def test_write_out_on_both_workspaces_and_dicts(self):
workspace_list = ["ws1", "ws2", "ws3", "ws4"] workspace_list = ["ws1", "ws2", "ws3", "ws4"]
small_dict = {"interface1": {"value1": 2, "value2": 3}, "interface2": {"value3": 4, "value4": 5}} small_dict = {"interface1": [2, 3]}
project_writer = projectsaver.ProjectWriter(small_dict, working_directory, workspace_list, project_file_ext) project_writer = projectsaver.ProjectWriter(small_dict, working_directory, workspace_list, project_file_ext)
file_name = working_directory + "/" + os.path.basename(working_directory) + project_file_ext file_name = working_directory + "/" + os.path.basename(working_directory) + project_file_ext
saved_file = "{\"interfaces\": {\"interface1\": {\"value2\": 3, \"value1\": 2}, \"interface2\": {\"value4\":" \
" 5, \"value3\": 4}}, \"workspaces\": [\"ws1\", \"ws2\", \"ws3\", \"ws4\"]}" workspaces_string = "\"workspaces\": [\"ws1\", \"ws2\", \"ws3\", \"ws4\"]"
interfaces_string = "\"interfaces\": {\"interface1\": [2, 3]}"
project_writer.write_out() project_writer.write_out()
f = open(file_name, "r") f = open(file_name, "r")
self.assertEqual(f.read(), saved_file) file_string = f.read()
self.assertTrue(workspaces_string in file_string)
self.assertTrue(interfaces_string in file_string)
if __name__ == "__main__":
unittest.main()
...@@ -36,3 +36,7 @@ class WorkspaceLoaderTest(unittest.TestCase): ...@@ -36,3 +36,7 @@ class WorkspaceLoaderTest(unittest.TestCase):
workspace_loader = workspaceloader.WorkspaceLoader() workspace_loader = workspaceloader.WorkspaceLoader()
workspace_loader.load_workspaces(self.working_directory, self.project_ext) workspace_loader.load_workspaces(self.working_directory, self.project_ext)
self.assertEqual(ADS.getObjectNames(), [self.ws1_name]) self.assertEqual(ADS.getObjectNames(), [self.ws1_name])
if __name__ == "__main__":
unittest.main()
...@@ -78,3 +78,7 @@ class WorkspaceSaverTest(unittest.TestCase): ...@@ -78,3 +78,7 @@ class WorkspaceSaverTest(unittest.TestCase):
ws = LoadMD(Filename=filename) ws = LoadMD(Filename=filename)
ws_is_a_mdworkspace = isinstance(ws, IMDEventWorkspace) or isinstance(ws, MDHistoWorkspace) ws_is_a_mdworkspace = isinstance(ws, IMDEventWorkspace) or isinstance(ws, MDHistoWorkspace)
self.assertEqual(ws_is_a_mdworkspace, True) self.assertEqual(ws_is_a_mdworkspace, True)
if __name__ == "__main__":
unittest.main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment