Skip to content
Snippets Groups Projects
Commit 70547248 authored by Conor Finn's avatar Conor Finn
Browse files

RE #26849 Add cropping to calibration model

parent 9b0991ca
No related branches found
No related tags found
No related merge requests found
......@@ -15,6 +15,7 @@ from mantid.api import AnalysisDataService as Ads
from mantid.kernel import logger
from mantid.simpleapi import EnggCalibrate, DeleteWorkspace, CloneWorkspace, \
CreateWorkspace, AppendSpectra, CreateEmptyTableWorkspace
from mantidqt.plotting.functions import plot
from Engineering.EnggUtils import write_ENGINX_GSAS_iparam_file
from Engineering.gui.engineering_diffraction.tabs.common import vanadium_corrections
from Engineering.gui.engineering_diffraction.tabs.common import path_handling
......@@ -35,7 +36,9 @@ class CalibrationModel(object):
sample_path,
plot_output,
instrument,
rb_num=None):
rb_num=None,
bank=None,
spectrum_numbers=None):
"""
Create a new calibration from a vanadium run and sample run
:param vanadium_path: Path to vanadium data file.
......@@ -43,6 +46,8 @@ class CalibrationModel(object):
:param plot_output: Whether the output should be plotted.
:param instrument: The instrument the data relates to.
:param rb_num: The RB number for file creation.
:param bank: Optional parameter to crop by bank
:param spectrum_numbers: Optional parameter to crop using spectrum numbers.
"""
van_integration, van_curves = vanadium_corrections.fetch_correction_workspaces(
vanadium_path, instrument, rb_num=rb_num)
......@@ -51,31 +56,43 @@ class CalibrationModel(object):
path_handling.ENGINEERING_PREFIX, "full_calibration")
if full_calib_path is not None and path.exists(full_calib_path):
full_calib = path_handling.load_workspace(full_calib_path)
output = self.run_calibration(sample_workspace, van_integration, van_curves, full_calib_ws=full_calib)
output = self.run_calibration(sample_workspace, van_integration, van_curves, bank, spectrum_numbers, full_calib_ws=full_calib)
else:
output = self.run_calibration(sample_workspace, van_integration, van_curves)
output = self.run_calibration(sample_workspace, van_integration, van_curves, bank, spectrum_numbers)
if plot_output:
self._plot_vanadium_curves()
for i in range(2):
difc = [output[i].DIFC]
tzero = [output[i].TZERO]
self._generate_difc_tzero_workspace(difc, tzero, i + 1)
self._plot_difc_tzero()
difc = [output[0].DIFC, output[1].DIFC]
tzero = [output[0].TZERO, output[1].TZERO]
for i in range(len(output)):
if spectrum_numbers:
bank_name = "cropped"
elif bank is None:
bank_name = str(i + 1)
else:
bank_name = bank
difc = output[i].DIFC
tzero = output[i].TZERO
self._generate_difc_tzero_workspace(difc, tzero, bank_name)
if bank is None and spectrum_numbers is None:
self._plot_difc_tzero()
elif spectrum_numbers is None:
self._plot_difc_tzero_single_bank_or_custom(bank, False)
else:
self._plot_difc_tzero_single_bank_or_custom("", True)
difc = [i.DIFC for i in output]
tzero = [i.TZERO for i in output]
params_table = []
for i in range(2):
for i in range(len(difc)):
params_table.append([i, difc[i], 0.0, tzero[i]])
self.update_calibration_params_table(params_table)
calib_dir = path.join(path_handling.get_output_path(), "Calibration", "")
self.create_output_files(calib_dir, difc, tzero, sample_path, vanadium_path, instrument)
self.create_output_files(calib_dir, difc, tzero, sample_path, vanadium_path, instrument, bank, spectrum_numbers)
if rb_num:
user_calib_dir = path.join(path_handling.get_output_path(), "User", rb_num,
"Calibration", "")
self.create_output_files(user_calib_dir, difc, tzero, sample_path, vanadium_path,
instrument)
instrument, bank, spectrum_numbers)
def load_existing_gsas_parameters(self, file_path):
if not path.exists(file_path):
......@@ -139,14 +156,14 @@ class CalibrationModel(object):
@staticmethod
def _generate_difc_tzero_workspace(difc, tzero, bank):
bank_ws = Ads.retrieve(CalibrationModel._generate_table_workspace_name(bank - 1))
bank_ws = Ads.retrieve(CalibrationModel._generate_table_workspace_name(bank))
x_val = []
y_val = []
y2_val = []
difc_to_plot = difc[0]
tzero_to_plot = tzero[0]
difc_to_plot = difc
tzero_to_plot = tzero
for irow in range(0, bank_ws.rowCount()):
x_val.append(bank_ws.cell(irow, 0))
......@@ -185,35 +202,82 @@ class CalibrationModel(object):
ax.set_xlabel("Expected Peaks Centre(dSpacing, A)")
fig.show()
def run_calibration(self, sample_ws, van_integration, van_curves, full_calib_ws=None):
@staticmethod
def _plot_difc_tzero_single_bank_or_custom(bank, custom):
if not custom:
bank_ws = Ads.retrieve("engggui_difc_zero_peaks_bank_" + str(bank))
else:
bank_ws = Ads.retrieve("engggui_difc_zero_peaks_bank_cropped")
bank = "Cropped"
ax = plot([bank_ws], [0, 1],
plot_kwargs={
"linestyle": "--",
"marker": "o",
"markersize": "3"
}).gca()
ax.set_title("Engg Gui Difc Zero Peaks Bank " + str(bank))
ax.legend(("Peaks Fitted", "DifC/TZero Fitted Straight Line"))
ax.set_xlabel("Expected Peaks Centre(dSpacing, A)")
def run_calibration(self, sample_ws, van_integration, van_curves, bank, spectrum_numbers, full_calib_ws=None):
"""
Runs the main Engineering calibration algorithm.
:param sample_ws: The workspace with the sample data.
:param van_integration: The integration values from the vanadium corrections
:param van_curves: The curves from the vanadium corrections.
:param full_calib_ws: Full pixel calibration of the detector (optional)
:param bank: The bank to crop to, both if none.
:param spectrum_numbers: The spectrum numbers to crop to, no crop if none.
:return: The output of the algorithm.
"""
output = [None] * 2
for i in range(2):
table_name = self._generate_table_workspace_name(i)
def run_engg_calibrate(calib_bank):
table_name = self._generate_table_workspace_name(calib_bank)
if full_calib_ws is not None:
output[i] = EnggCalibrate(InputWorkspace=sample_ws,
return EnggCalibrate(InputWorkspace=sample_ws,
VanIntegrationWorkspace=van_integration,
VanCurvesWorkspace=van_curves,
Bank=calib_bank,
FittedPeaks=table_name,
DetectorPositions=full_calib_ws)
else:
return EnggCalibrate(InputWorkspace=sample_ws,
VanIntegrationWorkspace=van_integration,
VanCurvesWorkspace=van_curves,
Bank=calib_bank,
FittedPeaks=table_name)
if spectrum_numbers is None:
if bank is None:
output = [None] * 2
for i in range(len(output)):
output[i] = run_engg_calibrate(str(i + 1))
else:
output = [None]
output[0] = run_engg_calibrate(bank)
else:
if full_calib_ws is not None:
output = [None]
cropped_table_name = self._generate_table_workspace_name("cropped")
output[0] = EnggCalibrate(InputWorkspace=sample_ws,
VanIntegrationWorkspace=van_integration,
VanCurvesWorkspace=van_curves,
Bank=str(i + 1),
FittedPeaks=table_name)
SpectrumNumbers=spectrum_numbers,
FittedPeaks=cropped_table_name,
DetectorPositions=full_calib_ws)
else:
output[i] = EnggCalibrate(InputWorkspace=sample_ws,
output = [None]
cropped_table_name = self._generate_table_workspace_name("cropped")
output[0] = EnggCalibrate(InputWorkspace=sample_ws,
VanIntegrationWorkspace=van_integration,
VanCurvesWorkspace=van_curves,
Bank=str(i + 1),
FittedPeaks=table_name,
DetectorPositions=full_calib_ws)
SpectrumNumbers=spectrum_numbers,
FittedPeaks=cropped_table_name)
return output
def create_output_files(self, calibration_dir, difc, tzero, sample_path, vanadium_path,
instrument):
instrument, bank, spectrum_numbers):
"""
Create output files from the algorithms in the specified directory
:param calibration_dir: The directory to save the files into.
......@@ -222,36 +286,49 @@ class CalibrationModel(object):
:param sample_path: The path to the sample data file.
:param vanadium_path: The path to the vanadium data file.
:param instrument: The instrument (ENGINX or IMAT)
:param bank: Optional parameter to crop by bank
:param spectrum_numbers: Optional parameter to crop using spectrum numbers.
"""
def generate_both_banks_file(difc_list, tzero_list):
file_path = calibration_dir + self._generate_output_file_name(
vanadium_path, sample_path, instrument, bank="all")
write_ENGINX_GSAS_iparam_file(file_path,
difc_list,
tzero_list,
ceria_run=sample_path,
vanadium_run=vanadium_path)
def generate_north_or_custom_bank_file(difc_north, tzero_north, bank_name="north"):
file_path = calibration_dir + self._generate_output_file_name(
vanadium_path, sample_path, instrument, bank=bank_name)
write_ENGINX_GSAS_iparam_file(file_path, [difc_north], [tzero_north],
ceria_run=sample_path,
vanadium_run=vanadium_path,
template_file=NORTH_BANK_TEMPLATE_FILE,
bank_names=["North"])
def generate_south_bank_file(difc_south, tzero_south):
file_path = calibration_dir + self._generate_output_file_name(
vanadium_path, sample_path, instrument, bank="south")
write_ENGINX_GSAS_iparam_file(file_path, [difc_south], [tzero_south],
ceria_run=sample_path,
vanadium_run=vanadium_path,
template_file=SOUTH_BANK_TEMPLATE_FILE,
bank_names=["South"])
if not path.exists(calibration_dir):
makedirs(calibration_dir)
filename = self._generate_output_file_name(vanadium_path,
sample_path,
instrument,
bank="all")
# Both Banks
file_path = calibration_dir + filename
write_ENGINX_GSAS_iparam_file(file_path,
difc,
tzero,
ceria_run=sample_path,
vanadium_run=vanadium_path)
# North Bank
file_path = calibration_dir + self._generate_output_file_name(
vanadium_path, sample_path, instrument, bank="north")
write_ENGINX_GSAS_iparam_file(file_path, [difc[0]], [tzero[0]],
ceria_run=sample_path,
vanadium_run=vanadium_path,
template_file=NORTH_BANK_TEMPLATE_FILE,
bank_names=["North"])
# South Bank
file_path = calibration_dir + self._generate_output_file_name(
vanadium_path, sample_path, instrument, bank="south")
write_ENGINX_GSAS_iparam_file(file_path, [difc[1]], [tzero[1]],
ceria_run=sample_path,
vanadium_run=vanadium_path,
template_file=SOUTH_BANK_TEMPLATE_FILE,
bank_names=["South"])
if bank is None and spectrum_numbers is None:
generate_both_banks_file(difc, tzero)
generate_north_or_custom_bank_file(difc[0], tzero[0])
generate_south_bank_file(difc[1], tzero[1])
elif bank == "1":
generate_north_or_custom_bank_file(difc[0], tzero[0])
elif bank == "2":
generate_south_bank_file(difc[0], tzero[0])
elif bank is None:
generate_north_or_custom_bank_file(difc[0], tzero[0], "cropped")
@staticmethod
def get_info_from_file(file_path):
......@@ -284,7 +361,7 @@ class CalibrationModel(object):
@staticmethod
def _generate_table_workspace_name(bank_num):
return "engggui_calibration_bank_" + str(bank_num + 1)
return "engggui_calibration_bank_" + str(bank_num)
@staticmethod
def _generate_output_file_name(vanadium_path, sample_path, instrument, bank):
......@@ -305,6 +382,8 @@ class CalibrationModel(object):
filename = filename + "bank_North.prm"
elif bank == "south":
filename = filename + "bank_South.prm"
elif bank == "cropped":
filename = filename + "cropped.prm"
else:
raise ValueError("Invalid bank name entered")
return filename
......@@ -10,6 +10,7 @@ from __future__ import (absolute_import, division, print_function)
import unittest
from mantid.py3compat.mock import patch
from mantid.py3compat.mock import MagicMock
from Engineering.gui.engineering_diffraction.tabs.calibration.model import CalibrationModel
VANADIUM_NUMBER = "307521"
......@@ -67,7 +68,7 @@ class CalibrationModelTest(unittest.TestCase):
van_corr.return_value = ("mocked_integration", "mocked_curves")
load_workspace.return_value = "mocked_workspace"
self.model.create_new_calibration(VANADIUM_NUMBER, CERIUM_NUMBER, False, "ENGINX")
calibrate_alg.assert_called_with("mocked_workspace", "mocked_integration", "mocked_curves",
calibrate_alg.assert_called_with("mocked_workspace", "mocked_integration", "mocked_curves", None, None,
full_calib_ws="mocked_workspace")
@patch(class_path + '.update_calibration_params_table')
......@@ -80,6 +81,7 @@ class CalibrationModelTest(unittest.TestCase):
@patch(class_path + '.run_calibration')
def test_plotting_check(self, calib, plot_difc_zero, gen_difc, plot_van, van, sample,
output_files, update_table):
calib.return_value = [MagicMock(), MagicMock()]
van.return_value = ("A", "B")
self.model.create_new_calibration(VANADIUM_NUMBER, CERIUM_NUMBER, False, "ENGINX")
plot_van.assert_not_called()
......@@ -145,7 +147,7 @@ class CalibrationModelTest(unittest.TestCase):
output_name.return_value = filename
self.model.create_output_files("test/", [0, 0], [1, 1], sample_path, vanadium_path,
"ENGINX")
"ENGINX", bank=None, spectrum_numbers=None)
self.assertEqual(make_dirs.call_count, 1)
self.assertEqual(write_file.call_count, 3)
......@@ -157,7 +159,7 @@ class CalibrationModelTest(unittest.TestCase):
def test_generate_table_workspace_name(self):
self.assertEqual(self.model._generate_table_workspace_name(20),
"engggui_calibration_bank_21")
"engggui_calibration_bank_20")
def test_generate_output_file_name_for_north_bank(self):
filename = self.model._generate_output_file_name("test/20.raw", "test/10.raw", "ENGINX",
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment