Skip to content
Snippets Groups Projects
Commit 98665aa3 authored by Matthew Andrew's avatar Matthew Andrew
Browse files

Added optional options to maxent Re #25546

parent 8ce0e290
No related merge requests found
......@@ -8,13 +8,15 @@ class PhaseTableContext(object):
self.phase_quad = []
def add_phase_table(self, name):
self.phase_tables.append(name)
if name not in self.phase_tables:
self.phase_tables.append(name)
def get_phase_table_list(self, instrument):
return [phase_table for phase_table in self.phase_tables if instrument in phase_table]
def add_phase_quad(self, name):
self.phase_quad.append(name)
if name not in self.phase_quad:
self.phase_quad.append(name)
def get_phase_quad(self, instrument, run):
return [phase_quad for phase_quad in self.phase_quad if instrument in phase_quad and run in phase_quad]
......
......@@ -123,6 +123,10 @@ def run_MuonMaxent(parameters_dict, alg):
alg.setAlwaysStoreInADS(False)
alg.setRethrows(True)
alg.setProperty("OutputWorkspace", "__NotUsed")
alg.setProperty("OutputPhaseTable", "__NotUsedPhase")
alg.setProperty("OutputDeadTimeTable", "__NotUsedDead")
alg.setProperty("ReconstructedSpectra", "__NotUsedRecon")
alg.setProperty("PhaseConvergenceTable", "__NotUsedConverge")
alg.setProperties(parameters_dict)
alg.execute()
return alg.getProperty("OutputWorkspace").value
......
......@@ -16,6 +16,7 @@ from Muon.GUI.Common.observer_pattern import Observer
from mantid.api import AnalysisDataService
from Muon.GUI.Common.thread_model_wrapper import ThreadModelWrapper
import functools
from mantid.api import WorkspaceGroup
raw_data = "_raw_data"
......@@ -27,6 +28,9 @@ class GenericObserver(Observer):
def update(self, observable, arg):
self.callback()
optional_output_suffixes = {'OutputPhaseTable': '_phase_table', 'OutputDeadTimeTable': '_dead_times',
'ReconstructedSpectra': '_reconstructed_spectra', 'PhaseConvergenceTable': '_phase_convergence'}
class MaxEntPresenter(object):
......@@ -108,7 +112,13 @@ class MaxEntPresenter(object):
maxent_workspace = run_MuonMaxent(maxent_parameters, alg)
self.add_maxent_workspace_to_ADS(maxent_parameters, maxent_workspace)
base_name, group = self.calculate_base_name_and_group(maxent_parameters['InputWorkspace'])
output_name = self.add_maxent_workspace_to_ADS(base_name, group, maxent_workspace)
maxent_output_options = self.get_maxent_output_options()
self.add_optional_outputs_to_ADS(alg, maxent_output_options, base_name, group)
def get_parameters_for_maxent_calculation(self):
inputs = {}
......@@ -149,13 +159,35 @@ class MaxEntPresenter(object):
self.view.update_phase_table_combo(phase_table_list)
def add_maxent_workspace_to_ADS(self, parameters, maxent_workspace):
run = re.search('[0-9]+', parameters['InputWorkspace']).group()
name = self.load.data_context._base_run_name(run) + '_MaxEnt'
def add_maxent_workspace_to_ADS(self, base_name, group, maxent_workspace):
AnalysisDataService.addOrReplace(base_name, maxent_workspace)
AnalysisDataService.addToGroup(group, base_name)
def get_maxent_output_options(self):
output_options = {}
output_options['OutputPhaseTable'] = self.view.output_phase_table
output_options['OutputDeadTimeTable'] = self.view.output_dead_times
output_options['ReconstructedSpectra'] = self.view.output_reconstructed_spectra
output_options['PhaseConvergenceTable'] = self.view.output_phase_convergence
AnalysisDataService.addOrReplace(name, maxent_workspace)
AnalysisDataService.addToGroup(self.load.data_context._base_run_name(run), name)
return output_options
def add_optional_outputs_to_ADS(self, alg, output_options, base_name, group):
for key in output_options:
if output_options[key]:
output = alg.getProperty(key).value
AnalysisDataService.addOrReplace(base_name + optional_output_suffixes[key], output)
AnalysisDataService.addToGroup(group, base_name + optional_output_suffixes[key])
def calculate_base_name_and_group(self, input_workspace):
run = re.search('[0-9]+', input_workspace).group()
base_name = self.load.data_context._base_run_name(run) + '_MaxEnt'
group = self.load.data_context._base_run_name(run) + ' MaxEnt Outputs'
if not AnalysisDataService.doesExist(group):
new_group = WorkspaceGroup()
AnalysisDataService.addOrReplace(group, new_group)
AnalysisDataService.addToGroup(self.load.data_context._base_run_name(run), group)
return base_name, group
......@@ -238,5 +238,5 @@ class MaxEntView(QtWidgets.QWidget):
return self.output_phase_evo_box.checkState() == QtCore.Qt.Checked
@property
def output_reconsturcted_spectra(self):
def output_reconstructed_spectra(self):
return self.output_data_box.checkState() == QtCore.Qt.Checked
......@@ -18,6 +18,8 @@ from Muon.GUI.Common.utilities import load_utils
from Muon.GUI.Common.muon_pair import MuonPair
from mantid.api import FileFinder
from Muon.GUI.Common.contexts.context_setup import setup_context_for_tests
from qtpy import QtCore
def retrieve_combobox_info(combo_box):
output_list = []
......@@ -41,7 +43,7 @@ class MaxEntPresenterTest(unittest.TestCase):
self.view = maxent_view_new.MaxEntView(self.obj)
self.presenter = maxent_presenter_new.MaxEntPresenter(self.view, self.model, self.context)
self.presenter = maxent_presenter_new.MaxEntPresenter(self.view, self.context)
file_path = FileFinder.findRuns('MUSR00022725.nxs')[0]
ws, run, filename = load_utils.load_workspace_from_filename(file_path)
......@@ -61,28 +63,6 @@ class MaxEntPresenterTest(unittest.TestCase):
self.assertEquals(retrieve_combobox_info(self.view.ws), ['MUSR22725_raw_data'])
self.assertEquals(retrieve_combobox_info(self.view.N_points), ['2048', '4096', '8192', '16384', '32768', '65536',
'131072', '262144', '524288', '1048576'])
# def test_get_phase_table_inputs_returns_correctly(self):
# self.presenter.getWorkspaceNames()
#
# self.assertEquals(self.presenter.get_phase_table_inputs(), {'DataFitted': 'fits', 'DetectorTable': 'PhaseTable',
# 'FirstGoodData': 0.1, 'InputWorkspace': 'MUSR22725_raw_data',
# 'LastGoodData': 15.0})
# def test_get_input_run_returns_correctly(self):
# self.presenter.getWorkspaceNames()
#
# self.assertEquals(self.presenter.get_input_run(), 'MUSR22725')
#
# def test_get_max_ent_inputs_return_correctly(self):
# self.presenter.getWorkspaceNames()
#
# self.assertEquals(self.presenter.getMaxEntInput(), {'DefaultLevel': 0.1, 'DoublePulse': False, 'Factor': 1.04,
# 'FirstGoodTime': 0.1, 'FitDeadTime': True, 'InnerIterations': 10,
# 'InputWorkspace': 'MUSR22725_raw_data', 'LastGoodTime': 15.0,
# 'MaxField': 1000.0, 'Npts': 2048, 'OuterIterations': 10,
# 'OutputWorkspace': 'MUSR22725_raw_data;FrequencyDomain;MaxEnt'})
def test_get_parameters_for_maxent_calculations(self):
self.presenter.getWorkspaceNames()
self.context.dead_time_table = mock.MagicMock(return_value='deadtime_table_name')
......@@ -123,6 +103,35 @@ class MaxEntPresenterTest(unittest.TestCase):
data_service_mock.addOrReplace.assert_called_once_with('MUSR22725_MaxEnt', maxent_workspace)
data_service_mock.addToGroup.assert_called_once_with('MUSR22725', 'MUSR22725_MaxEnt')
def test_get_output_options_defaults_returns_correctly(self):
self.presenter.getWorkspaceNames()
output_options = self.presenter.get_maxent_output_options()
self.assertEquals(output_options, {'OutputDeadTimeTable': False, 'PhaseConvergenceTable': False,
'OutputPhaseTable': False, 'ReconstructedSpectra': False})
def test_get_output_options_returns_correctly(self):
self.presenter.getWorkspaceNames()
self.view.output_dead_box.setCheckState(QtCore.Qt.Checked)
self.view.output_phase_box.setCheckState(QtCore.Qt.Checked)
self.view.output_phase_evo_box.setCheckState(QtCore.Qt.Checked)
self.view.output_data_box.setCheckState(QtCore.Qt.Checked)
output_options = self.presenter.get_maxent_output_options()
self.assertEquals(output_options, {'OutputDeadTimeTable': True, 'PhaseConvergenceTable': True,
'OutputPhaseTable': True, 'ReconstructedSpectra': True})
@mock.patch('Muon.GUI.FrequencyDomainAnalysis.MaxEnt.maxent_presenter_new.AnalysisDataService')
def test_calculate_base_name_and_group_returns_correctly(self, data_service_mock):
base_name, group = self.presenter.calculate_base_name_and_group('MUSR33333_something')
self.assertEquals(base_name, 'MUSR33333_MaxEnt')
self.assertEquals(group, 'MUSR33333_MaxEnt_Outputs')
if __name__ == '__main__':
unittest.main()
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment