diff --git a/Testing/Data/UnitTest/LARMOR00002260.nxs.md5 b/Testing/Data/UnitTest/LARMOR00002260.nxs.md5 new file mode 100644 index 0000000000000000000000000000000000000000..d49f44b50a2d6cd916a5940a9375dd471541ab71 --- /dev/null +++ b/Testing/Data/UnitTest/LARMOR00002260.nxs.md5 @@ -0,0 +1 @@ +ecedafb102297a4db6b28d98c7269911 diff --git a/Testing/Data/UnitTest/LOQ74044.nxs.md5 b/Testing/Data/UnitTest/LOQ74044.nxs.md5 new file mode 100644 index 0000000000000000000000000000000000000000..8d152470cffe576ab959f68c1363f713a32d8190 --- /dev/null +++ b/Testing/Data/UnitTest/LOQ74044.nxs.md5 @@ -0,0 +1 @@ +79020b3973e727f535dd90295773b589 diff --git a/Testing/Data/UnitTest/SANS2D00022024.nxs.md5 b/Testing/Data/UnitTest/SANS2D00022024.nxs.md5 new file mode 100644 index 0000000000000000000000000000000000000000..750d54bfc729cbd04462c301426a745e9823f5b4 --- /dev/null +++ b/Testing/Data/UnitTest/SANS2D00022024.nxs.md5 @@ -0,0 +1 @@ +d8df0c8d545bb4462e88e501f1677170 diff --git a/Testing/Data/UnitTest/SANS2D00022048.nxs.md5 b/Testing/Data/UnitTest/SANS2D00022048.nxs.md5 new file mode 100644 index 0000000000000000000000000000000000000000..1f3de952d451eef3f09710a6f23ef5e1d4fea3e3 --- /dev/null +++ b/Testing/Data/UnitTest/SANS2D00022048.nxs.md5 @@ -0,0 +1 @@ +bb63c083b4fb7373f1a2a33d1b8946bb diff --git a/scripts/CMakeLists.txt b/scripts/CMakeLists.txt index bd69f64d171f3a5574ea2477b53aa41870d9ff61..08739451aa9db276091c7e5864d2fdb85e7384b8 100644 --- a/scripts/CMakeLists.txt +++ b/scripts/CMakeLists.txt @@ -57,7 +57,9 @@ set ( TEST_PY_FILES test/ReductionSettingsTest.py ) - +# Addition tests for SANS components +add_subdirectory(test/SANS) + # python unit tests if (PYUNITTEST_FOUND) pyunittest_add_test ( ${CMAKE_CURRENT_SOURCE_DIR}/test PythonScriptsTest ${TEST_PY_FILES} ) diff --git a/scripts/SANS/sans/README.md b/scripts/SANS/sans/README.md new file mode 100644 index 0000000000000000000000000000000000000000..8b7a21d65a7ae82c7c26fd2999589358d779a71c --- /dev/null +++ b/scripts/SANS/sans/README.md @@ -0,0 +1,12 @@ +# `sans` + +The `sans` package contains the elements of the second version of the ISIS SANS reduction, except for Python algorithms +which can be found in the `WorkflowAlgorithm` section of Mantid's `PythonInterface`. + +## `common` + +The elements in the common package include widely used general purpose functions, constants and SANS-wide type definitions. + +## `state` + +The elements in the `state` package contain the definition of the reduction configuration and the corresponding builders. \ No newline at end of file diff --git a/scripts/SANS/sans/__init__.py b/scripts/SANS/sans/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scripts/SANS/sans/common/__init__.py b/scripts/SANS/sans/common/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scripts/SANS/sans/common/configurations.py b/scripts/SANS/sans/common/configurations.py new file mode 100644 index 0000000000000000000000000000000000000000..1586f21f1bb8860dd07b4b61a703fdc3703442a7 --- /dev/null +++ b/scripts/SANS/sans/common/configurations.py @@ -0,0 +1,24 @@ +""" The SANSConfigurations class holds instrument-specific configs to centralize instrument-specific magic numbers""" +# pylint: disable=too-few-public-methods + + +class Configurations(object): + + class LARMOR(object): + # The full wavelength range of the instrument + wavelength_full_range_low = 0.5 + wavelength_full_range_high = 13.5 + + class SANS2D(object): + # The full wavelength range of the instrument + wavelength_full_range_low = 2.0 + wavelength_full_range_high = 14.0 + + class LOQ(object): + # The full wavelength range of the instrument + wavelength_full_range_low = 2.2 + wavelength_full_range_high = 10.0 + + # The default prompt peak range for LOQ + prompt_peak_correction_min = 19000.0 + prompt_peak_correction_max = 20500.0 diff --git a/scripts/SANS/sans/common/constants.py b/scripts/SANS/sans/common/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..015990ba94ff23437b167a88b6a7910bd03af8e2 --- /dev/null +++ b/scripts/SANS/sans/common/constants.py @@ -0,0 +1,42 @@ +""" These constants are used in the SANS reducer framework. We want a central place for them.""" + +# pylint: disable=too-few-public-methods + +# ---------------------------------------- +# Proeprty names for Algorithms +# --------------------------------------- +MONITOR_SUFFIX = "_monitors" +INPUT_WORKSPACE = "InputWorkspace" + +FILE_NAME = "Filename" + +OUTPUT_WORKSPACE = "OutputWorkspace" +OUTPUT_WORKSPACE_GROUP = OUTPUT_WORKSPACE + "_" + +OUTPUT_MONITOR_WORKSPACE = "MonitorWorkspace" +OUTPUT_MONITOR_WORKSPACE_GROUP = OUTPUT_MONITOR_WORKSPACE + "_" + +WORKSPACE = "Workspace" +EMPTY_NAME = "dummy" + + +# ---------------------------------------- +# Other +# --------------------------------------- +SANS_SUFFIX = "sans" +TRANS_SUFFIX = "trans" + +high_angle_bank = "HAB" +low_angle_bank = "LAB" + +SANS2D = "SANS2D" +LARMOR = "LARMOR" +LOQ = "LOQ" + +REDUCED_WORKSPACE_NAME_IN_LOGS = "reduced_workspace_name" +SANS_FILE_TAG = "sans_file_tag" +REDUCED_CAN_TAG = "reduced_can_hash" + +ALL_PERIODS = 0 + +CALIBRATION_WORKSPACE_TAG = "sans_applied_calibration_file" diff --git a/scripts/SANS/sans/common/enums.py b/scripts/SANS/sans/common/enums.py new file mode 100644 index 0000000000000000000000000000000000000000..e4374c8634c8ae5134da5bbb02415dd91107e97c --- /dev/null +++ b/scripts/SANS/sans/common/enums.py @@ -0,0 +1,297 @@ +""" The elements of this module define typed enums which are used in the SANS reduction framework.""" + +# pylint: disable=too-few-public-methods, invalid-name + +from inspect import isclass +from functools import partial + + +# ---------------------------------------------------------------------------------------------------------------------- +# Serializable Enum decorator +# ---------------------------------------------------------------------------------------------------------------------- +def serializable_enum(*inner_classes): + """ + Class decorator which changes the name of an inner class to include the name of the outer class. The inner class + gets a method to determine the name of the outer class. This information is needed for serialization at the + algorithm input boundary. + """ + def inner_class_builder(cls): + # Add each inner class to the outer class + for inner_class in inner_classes: + new_class = type(inner_class, (cls, ), {"outer_class_name": cls.__name__}) + # We set the module of the inner class to the module of the outer class. We have to do this since we + # are dynamically adding the inner class which gets its module name from the module where it was added, + # but not where the outer class lives. + module_of_outer_class = getattr(cls, "__module__") + setattr(new_class, "__module__", module_of_outer_class) + # Add the inner class to the outer class + setattr(cls, inner_class, new_class) + return cls + return inner_class_builder + + +# ---------------------------------------------------------------------------------------------------------------------- +# String conversion decorator +# ---------------------------------------------------------------------------------------------------------------------- +def string_convertible(cls): + """ + Class decorator to make the enum/sub-class entries string convertible. + + We do this by creating a static from_string and to_string method on the class. + IMPORTANT: It is important that the enum values are added to the class before applying this decorator. In general + the order has to be: + @string_convertible + @serializable_enum + class MyClass(object): + ... + @param cls: a reference to the class + @return: the class + """ + def to_string(elements, convert_to_string): + for key, value in elements.items(): + if convert_to_string is value: + return key + raise RuntimeError("Could not convert {0} to string. Unknown value.".format(convert_to_string)) + + def from_string(elements, convert_from_string): + for key, value in elements.items(): + if convert_from_string == key: + return value + raise RuntimeError("Could not convert {0} from string. Unknown value.".format(convert_from_string)) + + # First get all enum/sub-class elements + convertible_elements = {} + for attribute_name, attribute_value in cls.__dict__.items(): + if isclass(attribute_value) and issubclass(attribute_value, cls): + convertible_elements.update({attribute_name: attribute_value}) + + # Add the new static methods to the class + partial_to_string = partial(to_string, convertible_elements) + partial_from_string = partial(from_string, convertible_elements) + setattr(cls, "to_string", staticmethod(partial_to_string)) + setattr(cls, "from_string", staticmethod(partial_from_string)) + return cls + + +# -------------------------------- +# Instrument and facility types +# -------------------------------- +@string_convertible +@serializable_enum("LOQ", "LARMOR", "SANS2D", "NoInstrument") +class SANSInstrument(object): + pass + + +@serializable_enum("ISIS", "NoFacility") +class SANSFacility(object): + pass + + +# ------------------------------------ +# Data Types +# ------------------------------------ +@string_convertible +@serializable_enum("SampleScatter", "SampleTransmission", "SampleDirect", "CanScatter", "CanTransmission", "CanDirect", + "Calibration") +class SANSDataType(object): + """ + Defines the different data types which are required for the reduction. Besides the fundamental data of the + sample and the can, we can also specify a calibration. + """ + pass + + +# --------------------------- +# Coordinate Definitions (3D) +# -------------------------- +class Coordinates(object): + pass + + +@serializable_enum("X", "Y", "Z") +class CanonicalCoordinates(Coordinates): + pass + + +# -------------------------- +# ReductionMode +# -------------------------- +@serializable_enum("Merged", "All") +class ReductionMode(object): + """ + Defines the reduction modes which should be common to all implementations, namely All and Merged. + """ + pass + + +@string_convertible +@serializable_enum("HAB", "LAB") +class ISISReductionMode(ReductionMode): + """ + Defines the different reduction modes. This can be the high-angle bank, the low-angle bank + """ + pass + + +# -------------------------- +# Reduction dimensionality +# -------------------------- +@serializable_enum("OneDim", "TwoDim") +class ReductionDimensionality(object): + """ + Defines the dimensionality for reduction. This can be either 1D or 2D + """ + pass + + +# -------------------------- +# Reduction data +# -------------------------- +@serializable_enum("Scatter", "Transmission", "Direct") +class ReductionData(object): + """ + Defines the workspace type of the reduction data. For all known instances this can be scatter, transmission + or direct + """ + pass + + +# -------------------------- +# Type of data +# -------------------------- +@string_convertible +@serializable_enum("Sample", "Can") +class DataType(object): + """ + Defines the type of reduction data. This can either the sample or only the can. + """ + pass + + +# --------------------------------- +# Partial reduction output setting +# --------------------------------- +@serializable_enum("Count", "Norm") +class OutputParts(object): + """ + Defines the partial outputs of a reduction. They are the numerator (Count) and denominator (Norm) of a division. + """ + pass + + +# ----------------------------------------------------- +# The fit type during merge of HAB and LAB reductions +# ----------------------------------------------------- +@string_convertible +@serializable_enum("Both", "NoFit", "ShiftOnly", "ScaleOnly") +class FitModeForMerge(object): + """ + Defines which fit operation to use during the merge of two reductions. + """ + pass + + +# -------------------------- +# Detectors +# -------------------------- +@serializable_enum("Horizontal", "Vertical", "Rotated") +class DetectorOrientation(object): + """ + Defines the detector orientation. + """ + pass + + +# -------------------------- +# Detector Type +# -------------------------- +@string_convertible +@serializable_enum("HAB", "LAB") +class DetectorType(object): + """ + Defines the detector type + """ + pass + + +# -------------------------- +# Ranges +# -------------------------- +@string_convertible +@serializable_enum("Lin", "Log") +class RangeStepType(object): + """ + Defines the step type of a range + """ + pass + + +# -------------------------- +# Rebin +# -------------------------- +@string_convertible +@serializable_enum("Rebin", "InterpolatingRebin") +class RebinType(object): + """ + Defines the rebin types available + """ + pass + + +# -------------------------- +# SaveType +# -------------------------- +@string_convertible +@serializable_enum("Nexus", "NistQxy", "CanSAS", "RKH", "CSV", "NXcanSAS") +class SaveType(object): + """ + Defines the save types available + """ + pass + + +# ------------------------------------------ +# Fit type for the transmission calculation +# ------------------------------------------ +@string_convertible +@serializable_enum("Linear", "Log", "Polynomial", "NoFit") +class FitType(object): + """ + Defines possible fit types + """ + pass + + +# -------------------------- +# SampleShape +# -------------------------- +@serializable_enum("CylinderAxisUp", "Cuboid", "CylinderAxisAlong") +class SampleShape(object): + """ + Defines the sample shape types + """ + pass + + +def convert_int_to_shape(shape_int): + """ + Note that we convert the sample shape to an integer here. This is required for the workspace, hence we don't + use the string_convertible decorator. + """ + if shape_int == 1: + as_type = SampleShape.CylinderAxisUp + elif shape_int == 2: + as_type = SampleShape.Cuboid + elif shape_int == 3: + as_type = SampleShape.CylinderAxisAlong + else: + raise ValueError("SampleShape: Cannot convert unknown sample shape integer: {0}".format(shape_int)) + return as_type + + +# --------------------------- +# FileTypes +# --------------------------- +@serializable_enum("ISISNexus", "ISISNexusAdded", "ISISRaw", "NoFileType") +class FileType(object): + pass diff --git a/scripts/SANS/sans/common/file_information.py b/scripts/SANS/sans/common/file_information.py new file mode 100644 index 0000000000000000000000000000000000000000..68c33bdca0ea704b9c6e632398cc61fbf856c119 --- /dev/null +++ b/scripts/SANS/sans/common/file_information.py @@ -0,0 +1,523 @@ +""" The elements of this module coordinate file access and information extraction from files.""" + +# pylint: disable=too-few-public-methods, invalid-name + +import os +import h5py as h5 +from abc import (ABCMeta, abstractmethod) + +from mantid.api import FileFinder +from mantid.kernel import (DateAndTime, ConfigService) +from mantid.api import (AlgorithmManager, ExperimentInfo) +from sans.common.enums import (SANSInstrument, FileType) + + +# ----------------------------------- +# Free Functions +# ----------------------------------- +def find_full_file_path(file_name): + """ + Gets the full path of a file name if it is available on the Mantid paths. + + :param file_name: the name of the file. + :return: the full file path. + """ + return FileFinder.getFullPath(file_name) + + +def find_sans_file(file_name): + """ + Finds a SANS file. + The file can be specified as: + 1. file.ext or path1 path2 file.ext + 2. run number + :param file_name: a file name or a run number. + :return: the full path. + """ + full_path = find_full_file_path(file_name) + if not full_path: + runs = FileFinder.findRuns(file_name) + if runs: + full_path = runs[0] + if not full_path: + raise RuntimeError("Trying to find the SANS file {0}, but cannot find it. Make sure that " + "the relevant paths are added.".format(file_name)) + return full_path + + +def get_extension_for_file_type(file_info): + """ + Get the extension for a specific file type. + + :param file_info: a SANSFileInformation object. + :return: the extension a stirng. This can be either nxs or raw. + """ + if file_info.get_type() is FileType.ISISNexus or file_info.get_type() is FileType.ISISNexusAdded: + extension = "nxs" + elif file_info.get_type() is FileType.ISISRaw: + extension = "raw" + else: + raise RuntimeError("The file extension type for a file of type {0} is unknown" + "".format(str(file_info.get_type()))) + return extension + + +def get_number_of_periods(func, file_name): + """ + Get the number of periods of the data in a file. + + :param func: a function handle which extracts the relevant information. + :param file_name: the file name to the relevant file. + :return: the number of periods if it is applicable else 0. + """ + is_file_type, number_of_periods = func(file_name) + return number_of_periods if is_file_type else 0 + + +def is_single_period(func, file_name): + """ + Checks if a file contains only single period data. + + :param func: a function handle which extracts the number of periods. + :param file_name: the name of the file. + :return: true if the number of periods is 1 else false. + """ + is_file_type, number_of_periods = func(file_name) + return is_file_type and number_of_periods == 1 + + +def is_multi_period(func, file_name): + """ + Checks if a file contains multi-period data. + + :param func: a function handle which extracts the number of periods. + :param file_name: the name of the file. + :return: true if the number of periods is larger than one else false. + """ + is_file_type, number_of_periods = func(file_name) + return is_file_type and number_of_periods >= 1 + + +def get_instrument_paths_for_sans_file(file_name): + """ + Gets the Instrument Definition File (IDF) path and the Instrument Parameter Path (IPF) path associated with a file. + + :param file_name: the file name is a name fo a SANS data file, e.g. SANS2D0001234 + :return: the IDF path and the IPF path + """ + def get_file_location(path): + return os.path.dirname(path) + + def get_ipf_equivalent_name(path): + # If XXX_Definition_Yyy.xml is the IDF name, then the equivalent IPF name is: XXX_Parameters_Yyy.xml + base_file_name = os.path.basename(path) + return base_file_name.replace("Definition", "Parameters") + + def get_ipf_standard_name(path): + # If XXX_Definition_Yyy.xml is the IDF name, then the standard IPF name is: XXX_Parameters.xml + base_file_name = os.path.basename(path) + elements = base_file_name.split("_") + return elements[0] + "_Parameters.xml" + + def check_for_files(directory, path): + # Check if XXX_Parameters_Yyy.xml exists in the same folder + ipf_equivalent_name = get_ipf_equivalent_name(path) + ipf_equivalent = os.path.join(directory, ipf_equivalent_name) + if os.path.exists(ipf_equivalent): + return ipf_equivalent + + # Check if XXX_Parameters.xml exists in the same folder + ipf_standard_name = get_ipf_standard_name(path) + ipf_standard = os.path.join(directory, ipf_standard_name) + if os.path.exists(ipf_standard): + return ipf_standard + # Does not seem to be in the folder + return None + + def get_ipf_for_rule_1(path): + # Check if can be found in the same folder + directory = get_file_location(path) + return check_for_files(directory, path) + + def get_ipf_for_rule_2(path): + # Check if can be found in the instrument folder + directory = ConfigService.getInstrumentDirectory() + return check_for_files(directory, path) + + # Get the measurement date + file_information_factory = SANSFileInformationFactory() + file_information = file_information_factory.create_sans_file_information(file_name) + measurement_time = file_information.get_date() + # For some odd reason the __str__ method of DateAndTime adds a space which we need to strip here. It seems + # to be on purpose though since the export method is called IS08601StringPlusSpace --> hence we need to strip it + # ourselves + measurement_time_as_string = str(measurement_time).strip() + + # Get the instrument + instrument = file_information.get_instrument() + instrument_as_string = SANSInstrument.to_string(instrument) + + # Get the idf file path + idf_path = ExperimentInfo.getInstrumentFilename(instrument_as_string, measurement_time_as_string) + idf_path = os.path.normpath(idf_path) + + if not os.path.exists(idf_path): + raise RuntimeError("SANSFileInformation: The instrument definition file {0} does not seem to " + "exist.".format(str(idf_path))) + + # Get the ipf path. This is slightly more complicated. See the Mantid documentation for the naming rules. Currently + # they are: + # 1. If the IDF is not in the instrument folder and there is another X_Parameters.xml in the same folder, + # this one in the same folder will be used instead of any parameter file in the instrument folder. + # 2. If you want one parameter file for your IDF file, name your IDF file X_Definition_Yyy.xml and the parameter + # file X_Parameters_Yyy.xml , where Yyy is any combination a characters you find appropriate. If your IDF + # file is not in the instrument folder, the parameter file can be in either the same folder or in the instrument + # folder, but it can only be in the instrument folder, if the same folder has no X_Parameters.xml or + # X_Parameters_Yyy.xml file. If there is no X_Parameters_Yyy.xml file, X_Parameters.xml would be used. + ipf_rule1 = get_ipf_for_rule_1(idf_path) + if ipf_rule1: + return idf_path, ipf_rule1 + + ipf_rule2 = get_ipf_for_rule_2(idf_path) + if ipf_rule2: + return idf_path, ipf_rule2 + + raise RuntimeError("SANSFileInformation: There does not seem to be a corresponding instrument parameter file " + "available for {0}".format(str(idf_path))) + + +# ---------------------------------------------- +# Methods for ISIS Nexus +# --------------------------------------------- +def get_isis_nexus_info(file_name): + """ + Get information if is ISIS Nexus and the number of periods. + + :param file_name: the full file path. + :return: if the file was a Nexus file and the number of periods. + """ + try: + with h5.File(file_name) as h5_file: + keys = h5_file.keys() + is_isis_nexus = u"raw_data_1" in keys + first_entry = h5_file["raw_data_1"] + period_group = first_entry["periods"] + proton_charge_data_set = period_group["proton_charge"] + number_of_periods = len(proton_charge_data_set) + except IOError: + is_isis_nexus = False + number_of_periods = -1 + return is_isis_nexus, number_of_periods + + +def is_isis_nexus_single_period(file_name): + return is_single_period(get_isis_nexus_info, file_name) + + +def is_isis_nexus_multi_period(file_name): + return is_multi_period(get_isis_nexus_info, file_name) + + +def get_number_of_periods_for_isis_nexus(file_name): + return get_number_of_periods(get_isis_nexus_info, file_name) + + +def get_instrument_name_for_isis_nexus(file_name): + """ + Instrument inforamtion is + file| + |--mantid_workspace_1/raw_data_1| + |--instrument| + |--name + """ + with h5.File(file_name) as h5_file: + # Open first entry + keys = h5_file.keys() + first_entry = h5_file[keys[0]] + # Open instrument group + instrument_group = first_entry["instrument"] + # Open name data set + name_data_set = instrument_group["name"] + # Read value + instrument_name = name_data_set[0] + return instrument_name + + +def get_top_level_nexus_entry(file_name, entry_name): + """ + Gets the first entry in a Nexus file. + + :param file_name: The file name + :param entry_name: the entry name + :return: + """ + with h5.File(file_name) as h5_file: + # Open first entry + keys = h5_file.keys() + top_level = h5_file[keys[0]] + entry = top_level[entry_name] + value = entry[0] + return value + + +def get_date_for_isis_nexus(file_name): + value = get_top_level_nexus_entry(file_name, "start_time") + return DateAndTime(value) + + +def get_run_number_for_isis_nexus(file_name): + return int(get_top_level_nexus_entry(file_name, "run_number")) + + +def get_event_mode_information(file_name): + """ + Event mode files have a class with a "NXevent_data" type + Structure: + |--mantid_workspace_1/raw_data_1| + |--some_group| + |--Attribute: NX_class = NXevent_data + """ + with h5.File(file_name) as h5_file: + # Open first entry + keys = h5_file.keys() + first_entry = h5_file[keys[0]] + # Open instrument group + is_event_mode = False + for value in first_entry.values(): + if "NX_class" in value.attrs and "NXevent_data" == value.attrs["NX_class"]: + is_event_mode = True + break + return is_event_mode + + +# --------- +# ISIS Raw +# --------- +def get_raw_info(file_name): + try: + alg_info = AlgorithmManager.createUnmanaged("RawFileInfo") + alg_info.initialize() + alg_info.setChild(True) + alg_info.setProperty("Filename", file_name) + alg_info.setProperty("GetRunParameters", True) + alg_info.execute() + + periods = alg_info.getProperty("PeriodCount").value + is_raw = True + number_of_periods = periods + except IOError: + is_raw = False + number_of_periods = -1 + + return is_raw, number_of_periods + + +def is_raw_single_period(file_name): + return is_single_period(get_raw_info, file_name) + + +def is_raw_multi_period(file_name): + return is_multi_period(get_raw_info, file_name) + + +def get_from_raw_header(file_name, index): + alg_info = AlgorithmManager.createUnmanaged("RawFileInfo") + alg_info.initialize() + alg_info.setChild(True) + alg_info.setProperty("Filename", file_name) + alg_info.setProperty("GetRunParameters", True) + alg_info.execute() + + header = alg_info.getProperty("RunHeader").value + element = header.split()[index] + return element + + +def instrument_name_correction(instrument_name): + return "SANS2D" if instrument_name == "SAN" else instrument_name + + +def get_instrument_name_for_raw(file_name): + instrument_name = get_from_raw_header(file_name, 0) + return instrument_name_correction(instrument_name) + + +def get_run_number_for_raw(file_name): + return int(get_from_raw_header(file_name, 1)) + + +def get_number_of_periods_for_raw(file_name): + return get_number_of_periods(get_raw_info, file_name) + + +def get_date_for_raw(file_name): + def get_month(month_string): + month_conversion = {"JAN": "01", "FEB": "02", "MAR": "03", "APR": "04", + "MAY": "05", "JUN": "06", "JUL": "07", "AUG": "08", + "SEP": "09", "OCT": "10", "NOV": "11", "DEC": "12"} + month_upper = month_string.upper() + if month_upper in month_conversion: + return month_conversion[month_upper] + else: + raise RuntimeError("Cannot get measurement time. Invalid month in Raw file: " + month_upper) + + def get_raw_measurement_time(date_input, time_input): + year = date_input[7:(7 + 4)] + day = date_input[0:2] + month_string = date_input[3:6] + month = get_month(month_string) + + date_and_time_string = year + "-" + month + "-" + day + "T" + time_input + return DateAndTime(date_and_time_string) + + alg_info = AlgorithmManager.createUnmanaged("RawFileInfo") + alg_info.initialize() + alg_info.setChild(True) + alg_info.setProperty("Filename", file_name) + alg_info.setProperty("GetRunParameters", True) + alg_info.execute() + + run_parameters = alg_info.getProperty("RunParameterTable").value + + keys = run_parameters.getColumnNames() + + time_id = "r_endtime" + date_id = "r_enddate" + + time = run_parameters.column(keys.index(time_id)) + date = run_parameters.column(keys.index(date_id)) + time = time[0] + date = date[0] + return get_raw_measurement_time(date, time) + + +# ----------------------------------------------- +# SANS file Information +# ----------------------------------------------- +class SANSFileInformation(object): + __metaclass__ = ABCMeta + + def __init__(self, file_name): + self._file_name = file_name + + @abstractmethod + def get_file_name(self): + pass + + @abstractmethod + def get_instrument(self): + pass + + @abstractmethod + def get_date(self): + pass + + @abstractmethod + def get_number_of_periods(self): + pass + + @abstractmethod + def get_type(self): + pass + + @abstractmethod + def get_run_number(self): + pass + + @staticmethod + def get_full_file_name(file_name): + return find_sans_file(file_name) + + +class SANSFileInformationISISNexus(SANSFileInformation): + def __init__(self, file_name): + super(SANSFileInformationISISNexus, self).__init__(file_name) + # Setup instrument name + self._full_file_name = SANSFileInformation.get_full_file_name(self._file_name) + instrument_name = get_instrument_name_for_isis_nexus(self._full_file_name) + self._instrument = SANSInstrument.from_string(instrument_name) + + # Setup date + self._date = get_date_for_isis_nexus(self._full_file_name) + + # Setup number of periods + self._number_of_periods = get_number_of_periods_for_isis_nexus(self._full_file_name) + + # Setup run number + self._run_number = get_run_number_for_isis_nexus(self._full_file_name) + + # Setup event mode check + self._is_event_mode = get_event_mode_information(self._full_file_name) + + def get_file_name(self): + return self._full_file_name + + def get_instrument(self): + return self._instrument + + def get_date(self): + return self._date + + def get_number_of_periods(self): + return self._number_of_periods + + def get_run_number(self): + return self._run_number + + def get_type(self): + return FileType.ISISNexus + + def is_event_mode(self): + return self._is_event_mode + + +class SANSFileInformationRaw(SANSFileInformation): + def __init__(self, file_name): + super(SANSFileInformationRaw, self).__init__(file_name) + # Setup instrument name + self._full_file_name = SANSFileInformation.get_full_file_name(self._file_name) + instrument_name = get_instrument_name_for_raw(self._full_file_name) + self._instrument = SANSInstrument.from_string(instrument_name) + + # Setup date + self._date = get_date_for_raw(self._full_file_name) + + # Setup number of periods + self._number_of_periods = get_number_of_periods_for_raw(self._full_file_name) + + # Setup run number + self._run_number = get_run_number_for_raw(self._full_file_name) + + def get_file_name(self): + return self._full_file_name + + def get_instrument(self): + return self._instrument + + def get_date(self): + return self._date + + def get_number_of_periods(self): + return self._number_of_periods + + def get_run_number(self): + return self._run_number + + def get_type(self): + return FileType.ISISRaw + + +class SANSFileInformationFactory(object): + def __init__(self): + super(SANSFileInformationFactory, self).__init__() + + def create_sans_file_information(self, file_name): + full_file_name = find_sans_file(file_name) + if is_isis_nexus_single_period(full_file_name) or is_isis_nexus_multi_period(full_file_name): + file_information = SANSFileInformationISISNexus(full_file_name) + elif is_raw_single_period(full_file_name) or is_raw_multi_period(full_file_name): + file_information = SANSFileInformationRaw(full_file_name) + # TODO: ADD added nexus files here + else: + raise NotImplementedError("The file type you have provided is not implemented yet.") + return file_information diff --git a/scripts/SANS/sans/common/general_functions.py b/scripts/SANS/sans/common/general_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..e6776e6a018fb1d1572c4aff8a03418e1773295c --- /dev/null +++ b/scripts/SANS/sans/common/general_functions.py @@ -0,0 +1,181 @@ +""" The elements of this module contain various general-purpose functions for the SANS reduction framework.""" + +# pylint: disable=invalid-name + +from math import (acos, sqrt, degrees) +from mantid.api import AlgorithmManager, AnalysisDataService +from mantid.kernel import (DateAndTime) +from sans.common.constants import SANS_FILE_TAG +from sans.common.log_tagger import (get_tag, has_tag, set_tag) + + +# ------------------------------------------- +# Free functions +# ------------------------------------------- +def get_log_value(run, log_name, log_type): + """ + Find a log value. + + There are two options here. Either the log is a scalar or a vector. In the case of a scalar there is not much + left to do. In the case of a vector we select the first element whose time_stamp is after the start time of the run + @param run: a Run object. + @param log_name: the name of the log entry + @param log_type: the expected type fo the log entry + @return: the log entry + """ + try: + # Scalar case + output = log_type(run.getLogData(log_name).value) + except TypeError: + # We must be dealing with a vectorized case, ie a time series + log_property = run.getLogData(log_name) + number_of_entries = len(log_property.value) + + # If we have only one item, then there is nothing left to do + output = None + if number_of_entries == 1: + output = log_type(run.getLogData(log_name).value[0]) + else: + # Get the first entry which is past the start time log + start_time = DateAndTime(run.getLogData('run_start').value) + times = log_property.times + values = log_property.value + + has_found_value = False + for index in range(0, number_of_entries): + if times[index] > start_time: + # Not fully clear why we take index - 1, but this follows the old implementation rules + value_index = index if index == 0 else index - 1 + has_found_value = True + output = log_type(values[value_index]) + break + + # Not fully clear why we take index - 1, but this follows the old implementation rules + if not has_found_value: + output = float(values[number_of_entries - 1]) + return output + + +def get_single_valued_logs_from_workspace(workspace, log_names, log_types, convert_from_millimeter_to_meter=False): + """ + Gets non-array valued entries from the sample logs. + + :param workspace: the workspace with the sample log. + :param log_names: the log names which are to be extracted. + :param log_types: the types of log entries, ie strings or numeric + :param convert_from_millimeter_to_meter: + :return: the log results + """ + assert len(log_names) == len(log_types) + # Find the desired log names. + run = workspace.getRun() + log_results = {} + for log_name, log_type in zip(log_names, log_types): + log_value = get_log_value(run, log_name, log_type) + log_results.update({log_name: log_value}) + if convert_from_millimeter_to_meter: + for key in log_results.keys(): + log_results[key] /= 1000. + return log_results + + +def create_unmanaged_algorithm(name, **kwargs): + """ + Creates an unmanaged child algorithm and initializes it. + + :param name: the name of the algorithm + :param kwargs: settings for the algorithm + :return: an initialized algorithm instance. + """ + alg = AlgorithmManager.createUnmanaged(name) + alg.initialize() + alg.setChild(True) + for key, value in kwargs.items(): + alg.setProperty(key, value) + return alg + + +def quaternion_to_angle_and_axis(quaternion): + """ + Converts a quaternion to an angle + an axis + + The conversion from a quaternion to an angle + axis is explained here: + http://www.euclideanspace.com/maths/geometry/rotations/conversions/quaternionToAngle/ + """ + angle = 2*acos(quaternion[0]) + s_parameter = sqrt(1 - quaternion[0]*quaternion[0]) + + axis = [] + # If the the angle is zero, then it does not make sense to have an axis + if s_parameter < 1e-8: + axis.append(quaternion[1]) + axis.append(quaternion[2]) + axis.append(quaternion[3]) + else: + axis.append(quaternion[1]/s_parameter) + axis.append(quaternion[2]/s_parameter) + axis.append(quaternion[3]/s_parameter) + return degrees(angle), axis + + +def get_charge_and_time(workspace): + """ + Gets the total charge and time from a workspace + + :param workspace: the workspace from which we extract the charge and time. + :return: the charge, the time + """ + run = workspace.getRun() + charges = run.getLogData('proton_charge') + total_charge = sum(charges.value) + time_passed = (charges.times[-1] - charges.times[0]).total_microseconds() + time_passed /= 1e6 + return total_charge, time_passed + + +def add_to_sample_log(workspace, log_name, log_value, log_type): + """ + Adds a sample log to the workspace + + :param workspace: the workspace to whcih the sample log is added + :param log_name: the name of the log + :param log_value: the value of the log in string format + :param log_type: the log value type which can be String, Number, Number Series + """ + if log_type not in ["String", "Number", "Number Series"]: + raise ValueError("Tryint go add {0} to the sample logs but it was passed " + "as an unknown type of {1}".format(log_value, log_type)) + if not isinstance(log_value, str): + raise TypeError("The value which is added to the sample logs needs to be passed as a string," + " but it is passed as {0}".format(type(log_value))) + + add_log_name = "AddSampleLog" + add_log_options = {"Workspace": workspace, + "LogName": log_name, + "LogText": log_value, + "LogType": log_type} + add_log_alg = create_unmanaged_algorithm(add_log_name, **add_log_options) + add_log_alg.execute() + + +def append_to_sans_file_tag(workspace, to_append): + """ + Appends a string to the existing sans file tag. + + :param workspace: the workspace which contains the sample logs with the sans file tag. + :param to_append: the additional tag + """ + if has_tag(SANS_FILE_TAG, workspace): + value = get_tag(SANS_FILE_TAG, workspace) + value += to_append + set_tag(SANS_FILE_TAG, value, workspace) + + +def get_ads_workspace_references(): + """ + Gets a list of handles of available workspaces on the ADS + + @return: the workspaces on the ADS. + """ + for workspace_name in AnalysisDataService.getObjectNames(): + yield AnalysisDataService.retrieve(workspace_name) diff --git a/scripts/SANS/sans/common/log_tagger.py b/scripts/SANS/sans/common/log_tagger.py new file mode 100644 index 0000000000000000000000000000000000000000..fccd2a4682dd69d7a35fd6e81195a39a220e85cf --- /dev/null +++ b/scripts/SANS/sans/common/log_tagger.py @@ -0,0 +1,101 @@ +""" The elements of this module manage and add specific entries in the sample log.""" + +# pylint: disable=invalid-name + +from hashlib import sha224 +from mantid.api import MatrixWorkspace + + +def get_hash_value(value): + """ + Converts a value into a hash + + :param value: a hashable value + :return: a hashed value. + """ + hash_value = sha224(str(value).encode("utf8")).hexdigest() + if not hash_value or hash_value is None: + raise RuntimeError("SANSLogTagger: Something went wrong when trying to get the hash" + " for {0}.".format(str(value))) + return hash_value + + +def check_if_valid_tag_and_workspace(tag, workspace): + """ + Checks if a tag and a workspace a valid for tags + :param tag: the tag + :param workspace: the workspace + """ + if not(isinstance(tag, basestring) and isinstance(workspace, MatrixWorkspace)): + raise RuntimeError("SANSLogTagger: Either tag {0} or workspace are invalid. The tag has to be a string" + " and the workspace {1} has to be a MatrixWorkspace".format(str(tag), str(workspace))) + + +def set_tag(tag, value, workspace): + """ + Adds/ sets a tag on a workspace + + :param tag: the tag name + :param value: the tag value + :param workspace: the workspace to which the tag will be added. + """ + check_if_valid_tag_and_workspace(tag, workspace) + run = workspace.getRun() + run.addProperty(tag, value, True) + + +def get_tag(tag, workspace): + """ + Extracts a tag from the workspace. + + :param tag: the tag name + :param workspace: the workspace with the tag. + :return: the tag value if it exists else None + """ + check_if_valid_tag_and_workspace(tag, workspace) + run = workspace.getRun() + return run[tag].value if tag in run else None + + +def has_tag(tag, workspace): + """ + Checks if workspace has a certain tag. + + :param tag: the tag name. + :param workspace: the workspace to check. + :return: true if the tag exists else false. + """ + check_if_valid_tag_and_workspace(tag, workspace) + run = workspace.getRun() + return tag in run + + +def set_hash(tag, value, workspace): + """ + Sets a value as a hashed tag on a workspace. + + :param tag: the tag name + :param value: the tag value (which is converted to a hash) + :param workspace: the workspace + """ + check_if_valid_tag_and_workspace(tag, workspace) + hash_value = get_hash_value(str(value)) + set_tag(tag, hash_value, workspace) + + +def has_hash(tag, value, workspace): + """ + Checks if a certain hash exists on a workspace. + + :param tag: the tag as a hash. + :param value: the value which is converted to a hash and checked. + :param workspace: the workspace. + :return: true if the hash exists on the workspace else false. + """ + check_if_valid_tag_and_workspace(tag, workspace) + same_hash = False + if has_tag(tag, workspace): + saved_hash = get_tag(tag, workspace) + to_check_hash = get_hash_value(str(value)) + same_hash = True if saved_hash == to_check_hash else False + return same_hash diff --git a/scripts/SANS/sans/common/xml_parsing.py b/scripts/SANS/sans/common/xml_parsing.py new file mode 100644 index 0000000000000000000000000000000000000000..cf4aad14425f73411f5054037f7bd711152c83a8 --- /dev/null +++ b/scripts/SANS/sans/common/xml_parsing.py @@ -0,0 +1,76 @@ +""" The elements of this module are used to extract information from IDF and IPF files.""" + +# pylint: disable=invalid-name + +try: + import xml.etree.cElementTree as eTree +except ImportError: + import xml.etree.ElementTree as eTree + + +def get_named_elements_from_ipf_file(ipf_file, names_to_search, value_type): + """ + Gets a named element from the IPF + + This is useful for detector names etc. + :param ipf_file: the path to the IPF + :param names_to_search: the names we want to search for on the XML file. + :param value_type: the type we expect for the names. + :return: a ElementName vs Value map + """ + """ + Args: + ipf_file: The path to the IPF + names_to_search: A list of search names + value_type: the type of an item + Returns: A map of the search names and the found information + """ + output = {} + number_of_elements_to_search = len(names_to_search) + for _, element in eTree.iterparse(ipf_file): + if element.tag == "parameter" and "name" in element.keys(): + if element.get("name") in names_to_search: + sub_element = element.find("value") + value = sub_element.get("val") + output.update({element.get("name"): value_type(value)}) + element.clear() + if number_of_elements_to_search == len(output): + break + return output + + +def get_monitor_names_from_idf_file(idf_file): + """ + Gets the monitor names from the IDF + + :param idf_file: the path to the IDF + :return: a NumberAsString vs Monitor Name map + """ + def get_tag(tag_in): + return "{http://www.mantidproject.org/IDF/1.0}" + tag_in + output = {} + tag = "idlist" + idname = "idname" + id_tag = "id" + for _, element in eTree.iterparse(idf_file): + if element.tag == get_tag(tag) and idname in element.keys(): + name = element.get(idname) + if "monitor" in name: + sub_element = element.find(get_tag(id_tag)) + # We can have two situations here: + # 1. either monitors are separate, e.g. <idlist idname="monitor1"> <id val="1" /> </idlist>, .. + # 2. or in a range, e.g. <idlist idname="monitors"> <id start="1" end="8" /> </idlist> + val = sub_element.get("val") + start = sub_element.get("start") + end = sub_element.get("end") + if val: + output.update({val: name}) + element.clear() + elif start and end: + for index in range(int(start), int(end) + 1): + monitor_id = "monitor" + str(index) + output.update({str(index): monitor_id}) + element.clear() + else: + continue + return output diff --git a/scripts/SANS/sans/state/__init__.py b/scripts/SANS/sans/state/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scripts/SANS/sans/state/adjustment.py b/scripts/SANS/sans/state/adjustment.py new file mode 100644 index 0000000000000000000000000000000000000000..3efcc340c75640f6a200f6eab210e949e31daa87 --- /dev/null +++ b/scripts/SANS/sans/state/adjustment.py @@ -0,0 +1,89 @@ +# pylint: disable=too-few-public-methods + +"""State describing the adjustment workspace creation of the SANS reduction.""" + +import json +import copy +from sans.state.state_base import (StateBase, TypedParameter, rename_descriptor_names, BoolParameter, + validator_sub_state) +from sans.state.calculate_transmission import StateCalculateTransmission +from sans.state.normalize_to_monitor import StateNormalizeToMonitor +from sans.state.wavelength_and_pixel_adjustment import StateWavelengthAndPixelAdjustment +from sans.state.automatic_setters import (automatic_setters) +from sans.common.enums import SANSInstrument + + +# ---------------------------------------------------------------------------------------------------------------------- +# State +# ---------------------------------------------------------------------------------------------------------------------- +@rename_descriptor_names +class StateAdjustment(StateBase): + calculate_transmission = TypedParameter(StateCalculateTransmission, validator_sub_state) + normalize_to_monitor = TypedParameter(StateNormalizeToMonitor, validator_sub_state) + wavelength_and_pixel_adjustment = TypedParameter(StateWavelengthAndPixelAdjustment, validator_sub_state) + wide_angle_correction = BoolParameter() + + def __init__(self): + super(StateAdjustment, self).__init__() + self.wide_angle_correction = False + + def validate(self): + is_invalid = {} + + # Calculate transmission + if not self.calculate_transmission: + is_invalid.update({"StateAdjustment": "The StateCalculateTransmission object is missing."}) + if self.calculate_transmission: + try: + self.calculate_transmission.validate() + except ValueError as e: + is_invalid.update({"StateAdjustment": "The sub-CalculateTransmission state is invalid," + " see here {0}".format(str(e))}) + + # Normalize to monitor + if not self.normalize_to_monitor: + is_invalid.update({"StateAdjustment": "The StateNormalizeToMonitor object is missing."}) + if self.normalize_to_monitor: + try: + self.normalize_to_monitor.validate() + except ValueError as e: + is_invalid.update({"StateAdjustment": "The sub-NormalizeToMonitor state is invalid," + " see here {0}".format(str(e))}) + + # Wavelength and pixel adjustment + if not self.wavelength_and_pixel_adjustment: + is_invalid.update({"StateAdjustment": "The StateWavelengthAndPixelAdjustment object is missing."}) + if self.wavelength_and_pixel_adjustment: + try: + self.wavelength_and_pixel_adjustment.validate() + except ValueError as e: + is_invalid.update({"StateAdjustment": "The sub-WavelengthAndPixelAdjustment state is invalid," + " see here {0}".format(str(e))}) + if is_invalid: + raise ValueError("StateAdjustment: The provided inputs are illegal. " + "Please see: {0}".format(json.dumps(is_invalid))) + + +# ---------------------------------------------------------------------------------------------------------------------- +# Builder +# ---------------------------------------------------------------------------------------------------------------------- +class StateAdjustmentBuilder(object): + @automatic_setters(StateAdjustment) + def __init__(self): + super(StateAdjustmentBuilder, self).__init__() + self.state = StateAdjustment() + + def build(self): + self.state.validate() + return copy.copy(self.state) + + +def get_adjustment_builder(data_info): + # The data state has most of the information that we require to define the move. For the factory method, only + # the instrument is of relevance. + instrument = data_info.instrument + if instrument is SANSInstrument.LARMOR or instrument is SANSInstrument.LOQ or instrument is SANSInstrument.SANS2D: + return StateAdjustmentBuilder() + else: + raise NotImplementedError("StateAdjustmentBuilder: Could not find any valid adjustment builder for the " + "specified StateData object {0}".format(str(data_info))) diff --git a/scripts/SANS/sans/state/automatic_setters.py b/scripts/SANS/sans/state/automatic_setters.py new file mode 100644 index 0000000000000000000000000000000000000000..427678c919cb23d91ae4e70f5ad8f8264abc794e --- /dev/null +++ b/scripts/SANS/sans/state/automatic_setters.py @@ -0,0 +1,146 @@ +from functools import (partial, wraps) +import inspect + +from sans.state.state_base import (TypedParameter, DictParameter) +# ------------------------------------------------------------------------------------------------------------- +# Automatic Setter functionality +# This creates setters on a builder/director instance for parameters of a state object. +# The setter name generation is fairly simple. There are two scenarios +# 1. Standard parameter: state -> parameter results in setter name: "set_" + name of parameter +# 2. Parameter which is buried in dictionary: state->dict->parameter +# This results in "set_" + dictionary key + name of parameter, e.g. set_HAB_x_translation_correction, where +# HAB is a key of a dictionary and x_translation_correction is the parameter name +# +# The resulting decorator automatic_setters takes the type of class (essentially which state) to operate on and +# an exclusion list. Elements in the exclusion list will not have an automatic setter generated for them. This is +# desirable for parameters which are set internally during the initialization phase of the builder. +# ------------------------------------------------------------------------------------------------------------- + + +def forwarding_setter(value, builder_instance, attribute_name_list): + # The first element of the attribute list is the state object itself + instance = getattr(builder_instance, attribute_name_list[0]) + + # We need to exclude the first element, since we have already used. We also need to exclude the last + # element since we don't want to get it but rather set it. We need to treat the instance differently if it is + # a dictionary. + for index in range(1, len(attribute_name_list) - 1): + if isinstance(instance, dict): + instance = instance[attribute_name_list[index]] + else: + instance = getattr(instance, attribute_name_list[index]) + # Finally, the last attribute name is used to set the value + setattr(instance, attribute_name_list[-1], value) + + +def update_the_method(builder_instance, new_methods, setter_name, attribute_name, attribute_name_list): + setter_name_copy = list(setter_name) + setter_name_copy.append(attribute_name) + method_name = "_".join(setter_name_copy) + + attribute_name_list_copy = list(attribute_name_list) + attribute_name_list_copy.append(attribute_name) + + new_method = partial(forwarding_setter, builder_instance=builder_instance, + attribute_name_list=attribute_name_list_copy) + new_methods.update({method_name: new_method}) + + +def get_all_typed_parameter_descriptors(instance): + descriptor_types = {} + for descriptor_name, descriptor_object in inspect.getmembers(type(instance)): + if inspect.isdatadescriptor(descriptor_object) and isinstance(descriptor_object, TypedParameter): + descriptor_types.update({descriptor_name: descriptor_object}) + return descriptor_types + + +def create_automatic_setters_for_state(attribute_value, builder_instance, attribute_name_list, + setter_name, exclusions, new_methods): + # Find all typed parameter descriptors which are on the instance, i.e. on attribute_value. + all_descriptors = get_all_typed_parameter_descriptors(attribute_value) + + # Go through each descriptor and create a setter for it. + for name, value in all_descriptors.items(): + # If the name is in the exception list, then we don't want to create a setter for this attribute + if name in exclusions: + continue + + # There are two scenarios. The attribute can be: + # 1. A dictionary which is empty or None-> install a setter + # 2. A dictionary containing elements -> for each element apply a recursion + # 3. A regular attribute -> install the setter + if isinstance(value, DictParameter): + dict_parameter_value = getattr(attribute_value, name) + if dict_parameter_value is None or len(dict_parameter_value) == 0: + update_the_method(builder_instance, new_methods, setter_name, name, attribute_name_list) + else: + for dict_key, dict_value in dict_parameter_value.items(): + setter_name_copy = list(setter_name) + setter_name_copy.append(dict_key) + + # We need to add the dict name to the attribute list and the key we are looking at now + attribute_name_list_copy = list(attribute_name_list) + attribute_name_list_copy.append(name) + attribute_name_list_copy.append(dict_key) + create_automatic_setters_for_state(dict_value, builder_instance, attribute_name_list_copy, + setter_name_copy, exclusions, new_methods) + else: + update_the_method(builder_instance, new_methods, setter_name, name, attribute_name_list) + + +def create_automatic_setters(builder_instance, state_class, exclusions): + # Get the name of the state object + new_methods = {} + for attribute_name, attribute_value in builder_instance.__dict__.items(): + if isinstance(attribute_value, state_class): + attribute_name_list = [attribute_name] + setter_name = ["set"] + create_automatic_setters_for_state(attribute_value, builder_instance, attribute_name_list, + setter_name, exclusions, new_methods) + + # Apply the methods + for method_name, method in new_methods.items(): + builder_instance.__dict__.update({method_name: method}) + + +def automatic_setters(state_class, exclusions=None): + if exclusions is None: + exclusions = [] + + def automatic_setters_decorator(func): + @wraps(func) + def func_wrapper(self, *args, **kwargs): + func(self, *args, **kwargs) + create_automatic_setters(self, state_class, exclusions) + return func_wrapper + return automatic_setters_decorator + + +# ---------------------------------------------------------------------------------------------------------------------- +# Automatic setter for director object +# Note that this is not a decorator, but just a free function for monkey patching. +# ---------------------------------------------------------------------------------------------------------------------- +def forwarding_setter_for_director(value, builder, method_name): + method = getattr(builder, method_name) + method(value) + + +def set_up_setter_forwarding_from_director_to_builder(director, builder_name): + """ + This function creates setter forwarding from a director object to builder objects. + + The method will look for any set_XXX method in the builder and add an equivalent method set_builder_XXX which is + forwarded to set_XXX. + @param director: a director object + @param builder_name: the name of the builder on the director + """ + set_tag = "set" + builder_instance = getattr(director, builder_name) + new_methods = {} + for method in dir(builder_instance): + if method.startswith(set_tag): + method_name_director = set_tag + builder_name + "_" + method[4:] + new_method = partial(forwarding_setter_for_director, builder=builder_instance, + method_name=method) + new_methods.update({method_name_director: new_method}) + director.__dict__.update(new_methods) diff --git a/scripts/SANS/sans/state/calculate_transmission.py b/scripts/SANS/sans/state/calculate_transmission.py new file mode 100644 index 0000000000000000000000000000000000000000..1388a6978831de4536f725b63c0939aa017cfdc7 --- /dev/null +++ b/scripts/SANS/sans/state/calculate_transmission.py @@ -0,0 +1,388 @@ +# pylint: disable=too-few-public-methods + +"""State describing the calculation of the transmission for SANS reduction.""" + +import json +import copy +from sans.state.state_base import (StateBase, rename_descriptor_names, PositiveIntegerParameter, BoolParameter, + PositiveFloatParameter, ClassTypeParameter, FloatParameter, DictParameter, + StringListParameter, PositiveFloatWithNoneParameter) +from sans.common.enums import (RebinType, RangeStepType, FitType, DataType, SANSInstrument) +from sans.common.configurations import Configurations +from sans.state.state_functions import (is_pure_none_or_not_none, validation_message, + is_not_none_and_first_larger_than_second, one_is_none) +from sans.state.automatic_setters import (automatic_setters) +from sans.common.file_information import (get_instrument_paths_for_sans_file) +from sans.common.xml_parsing import get_named_elements_from_ipf_file + + +# ---------------------------------------------------------------------------------------------------------------------- +# State +# ---------------------------------------------------------------------------------------------------------------------- +@rename_descriptor_names +class StateTransmissionFit(StateBase): + fit_type = ClassTypeParameter(FitType) + polynomial_order = PositiveIntegerParameter() + wavelength_low = PositiveFloatWithNoneParameter() + wavelength_high = PositiveFloatWithNoneParameter() + + def __init__(self): + super(StateTransmissionFit, self).__init__() + self.fit_type = FitType.Linear + self.polynomial_order = 0 + + def validate(self): # noqa + is_invalid = {} + if self.fit_type is not FitType.Polynomial and self.polynomial_order != 0: + entry = validation_message("You can only set a polynomial order of you selected polynomial fitting.", + "Make sure that you select polynomial fitting.", + {"fit_type": self.fit_type, + "polynomial_order": self.polynomial_order}) + is_invalid.update(entry) + + if not is_pure_none_or_not_none([self.wavelength_low, self.wavelength_high]): + entry = validation_message("Inconsistent wavelength setting.", + "Make sure that you have specified both wavelength bounds (or none).", + {"wavelength_low": self.wavelength_low, + "wavelength_high": self.wavelength_high}) + is_invalid.update(entry) + + if is_not_none_and_first_larger_than_second([self.wavelength_low, self.wavelength_high]): + entry = validation_message("Incorrect wavelength bounds.", + "Make sure that lower wavelength bound is smaller then upper bound.", + {"wavelength_low": self.wavelength_low, + "wavelength_high": self.wavelength_high}) + is_invalid.update(entry) + if is_invalid: + raise ValueError("StateTransmissionFit: The provided inputs are illegal. " + "Please see: {0}".format(json.dumps(is_invalid))) + + +@rename_descriptor_names +class StateCalculateTransmission(StateBase): + # ----------------------- + # Transmission + # ----------------------- + transmission_radius_on_detector = PositiveFloatParameter() + transmission_roi_files = StringListParameter() + transmission_mask_files = StringListParameter() + + default_transmission_monitor = PositiveIntegerParameter() + transmission_monitor = PositiveIntegerParameter() + + default_incident_monitor = PositiveIntegerParameter() + incident_monitor = PositiveIntegerParameter() + + # ---------------------- + # Prompt peak correction + # ---------------------- + prompt_peak_correction_min = PositiveFloatParameter() + prompt_peak_correction_max = PositiveFloatParameter() + + # ---------------- + # Wavelength rebin + # ---------------- + rebin_type = ClassTypeParameter(RebinType) + wavelength_low = PositiveFloatParameter() + wavelength_high = PositiveFloatParameter() + wavelength_step = PositiveFloatParameter() + wavelength_step_type = ClassTypeParameter(RangeStepType) + + use_full_wavelength_range = BoolParameter() + wavelength_full_range_low = PositiveFloatParameter() + wavelength_full_range_high = PositiveFloatParameter() + + # ----------------------- + # Background correction + # ---------------------- + background_TOF_general_start = FloatParameter() + background_TOF_general_stop = FloatParameter() + background_TOF_monitor_start = DictParameter() + background_TOF_monitor_stop = DictParameter() + background_TOF_roi_start = FloatParameter() + background_TOF_roi_stop = FloatParameter() + + # ----------------------- + # Fit + # ---------------------- + fit = DictParameter() + + def __init__(self): + super(StateCalculateTransmission, self).__init__() + # The keys of this dictionaries are the spectrum number of the monitors (as a string) + self.background_TOF_monitor_start = {} + self.background_TOF_monitor_stop = {} + + self.fit = {DataType.to_string(DataType.Sample): StateTransmissionFit(), + DataType.to_string(DataType.Can): StateTransmissionFit()} + self.use_full_wavelength_range = False + + def validate(self): # noqa + is_invalid = {} + # ----------------- + # Incident monitor + # ----------------- + if self.incident_monitor is None and self.default_incident_monitor is None: + entry = validation_message("No incident monitor was specified.", + "Make sure that incident monitor has been specified.", + {"incident_monitor": self.incident_monitor, + "default_incident_monitor": self.default_incident_monitor}) + is_invalid.update(entry) + + # -------------- + # Transmission, either we need some ROI (ie radius, roi files /mask files) or a transmission monitor + # -------------- + has_no_transmission_monitor_setting = self.transmission_monitor is None and\ + self.default_transmission_monitor is None # noqa + has_no_transmission_roi_setting = self.transmission_radius_on_detector is None and\ + self.transmission_roi_files is None # noqa + if has_no_transmission_monitor_setting and has_no_transmission_roi_setting: + entry = validation_message("No transmission settings were specified.", + "Make sure that transmission settings are specified.", + {"transmission_monitor": self.transmission_monitor, + "default_transmission_monitor": self.default_transmission_monitor, + "transmission_radius_on_detector": self.transmission_radius_on_detector, + "transmission_roi_files": self.transmission_roi_files}) + is_invalid.update(entry) + + # ----------------- + # Prompt peak + # ----------------- + if not is_pure_none_or_not_none([self.prompt_peak_correction_min, self.prompt_peak_correction_max]): + entry = validation_message("Inconsistent prompt peak setting.", + "Make sure that you have specified both prompt peak bounds (or none).", + {"prompt_peak_correction_min": self.prompt_peak_correction_min, + "prompt_peak_correction_max": self.prompt_peak_correction_max}) + is_invalid.update(entry) + + if is_not_none_and_first_larger_than_second([self.prompt_peak_correction_min, self.prompt_peak_correction_max]): + entry = validation_message("Incorrect prompt peak bounds.", + "Make sure that lower prompt peak bound is smaller then upper bound.", + {"prompt_peak_correction_min": self.prompt_peak_correction_min, + "prompt_peak_correction_max": self.prompt_peak_correction_max}) + is_invalid.update(entry) + + # ----------------- + # Wavelength rebin + # ----------------- + if one_is_none([self.wavelength_low, self.wavelength_high, self.wavelength_step, self.wavelength_step_type, + self.rebin_type]): + entry = validation_message("A wavelength entry has not been set.", + "Make sure that all entries are set.", + {"wavelength_low": self.wavelength_low, + "wavelength_high": self.wavelength_high, + "wavelength_step": self.wavelength_step, + "wavelength_step_type": self.wavelength_step_type, + "rebin_type": self.rebin_type}) + is_invalid.update(entry) + + if is_not_none_and_first_larger_than_second([self.wavelength_low, self.wavelength_high]): + entry = validation_message("Incorrect wavelength bounds.", + "Make sure that lower wavelength bound is smaller then upper bound.", + {"wavelength_low": self.wavelength_low, + "wavelength_high": self.wavelength_high}) + is_invalid.update(entry) + + if self.use_full_wavelength_range: + if self.wavelength_full_range_low is None or self.wavelength_full_range_high is None: + entry = validation_message("Incorrect full wavelength settings.", + "Make sure that both full wavelength entries have been set.", + {"wavelength_full_range_low": self.wavelength_full_range_low, + "wavelength_full_range_high": self.wavelength_full_range_high}) + is_invalid.update(entry) + if is_not_none_and_first_larger_than_second([self.wavelength_full_range_low, + self.wavelength_full_range_high]): + entry = validation_message("Incorrect wavelength bounds.", + "Make sure that lower full wavelength bound is smaller then upper bound.", + {"wavelength_full_range_low": self.wavelength_full_range_low, + "wavelength_full_range_high": self.wavelength_full_range_high}) + is_invalid.update(entry) + + # ---------------------- + # Background correction + # ---------------------- + if not is_pure_none_or_not_none([self.background_TOF_general_start, self.background_TOF_general_stop]): + entry = validation_message("A general background TOF entry has not been set.", + "Make sure that either all general background TOF entries are set or none.", + {"background_TOF_general_start": self.background_TOF_general_start, + "background_TOF_general_stop": self.background_TOF_general_stop}) + is_invalid.update(entry) + if is_not_none_and_first_larger_than_second([self.background_TOF_general_start, + self.background_TOF_general_stop]): + entry = validation_message("Incorrect general background TOF bounds.", + "Make sure that lower general background TOF bound is smaller then upper bound.", + {"background_TOF_general_start": self.background_TOF_general_start, + "background_TOF_general_stop": self.background_TOF_general_stop}) + is_invalid.update(entry) + + if not is_pure_none_or_not_none([self.background_TOF_roi_start, self.background_TOF_roi_stop]): + entry = validation_message("A ROI background TOF entry has not been set.", + "Make sure that either all ROI background TOF entries are set or none.", + {"background_TOF_roi_start": self.background_TOF_roi_start, + "background_TOF_roi_stop": self.background_TOF_roi_stop}) + is_invalid.update(entry) + + if is_not_none_and_first_larger_than_second([self.background_TOF_roi_start, + self.background_TOF_roi_stop]): + entry = validation_message("Incorrect ROI background TOF bounds.", + "Make sure that lower ROI background TOF bound is smaller then upper bound.", + {"background_TOF_roi_start": self.background_TOF_roi_start, + "background_TOF_roi_stop": self.background_TOF_roi_stop}) + is_invalid.update(entry) + + if not is_pure_none_or_not_none([self.background_TOF_monitor_start, self.background_TOF_monitor_stop]): + entry = validation_message("A monitor background TOF entry has not been set.", + "Make sure that either all monitor background TOF entries are set or none.", + {"background_TOF_monitor_start": self.background_TOF_monitor_start, + "background_TOF_monitor_stop": self.background_TOF_monitor_stop}) + is_invalid.update(entry) + + if self.background_TOF_monitor_start is not None and self.background_TOF_monitor_stop is not None: + if len(self.background_TOF_monitor_start) != len(self.background_TOF_monitor_stop): + entry = validation_message("The monitor background TOF entries have a length mismatch.", + "Make sure that all monitor background TOF entries have the same length.", + {"background_TOF_monitor_start": self.background_TOF_monitor_start, + "background_TOF_monitor_stop": self.background_TOF_monitor_stop}) + is_invalid.update(entry) + for key_start, value_start in self.background_TOF_monitor_start.items(): + if key_start not in self.background_TOF_monitor_stop: + entry = validation_message("The monitor background TOF had spectrum number mismatch.", + "Make sure that all monitors have entries for start and stop.", + {"background_TOF_monitor_start": self.background_TOF_monitor_start, + "background_TOF_monitor_stop": self.background_TOF_monitor_stop}) + is_invalid.update(entry) + else: + value_stop = self.background_TOF_monitor_stop[key_start] + if value_start > value_stop: + entry = validation_message("Incorrect monitor background TOF bounds.", + "Make sure that lower monitor background TOF bound is" + " smaller then upper bound.", + {"background_TOF_monitor_start": self.background_TOF_monitor_start, + "background_TOF_monitor_stop": self.background_TOF_monitor_stop}) + is_invalid.update(entry) + + # ----- + # Fit + # ----- + self.fit[DataType.to_string(DataType.Sample)].validate() + self.fit[DataType.to_string(DataType.Can)].validate() + + if is_invalid: + raise ValueError("StateCalculateTransmission: The provided inputs are illegal. " + "Please see: {0}".format(json.dumps(is_invalid))) + + +class StateCalculateTransmissionLOQ(StateCalculateTransmission): + def __init__(self): + super(StateCalculateTransmissionLOQ, self).__init__() + # Set the LOQ full wavelength range + self.wavelength_full_range_low = Configurations.LOQ.wavelength_full_range_low + self.wavelength_full_range_high = Configurations.LOQ.wavelength_full_range_high + + # Set the LOQ default range for prompt peak correction + self.prompt_peak_correction_min = Configurations.LOQ.prompt_peak_correction_min + self.prompt_peak_correction_max = Configurations.LOQ.prompt_peak_correction_max + + def validate(self): + super(StateCalculateTransmissionLOQ, self).validate() + + +class StateCalculateTransmissionSANS2D(StateCalculateTransmission): + def __init__(self): + super(StateCalculateTransmissionSANS2D, self).__init__() + # Set the LOQ full wavelength range + self.wavelength_full_range_low = Configurations.SANS2D.wavelength_full_range_low + self.wavelength_full_range_high = Configurations.SANS2D.wavelength_full_range_high + + def validate(self): + super(StateCalculateTransmissionSANS2D, self).validate() + + +class StateCalculateTransmissionLARMOR(StateCalculateTransmission): + def __init__(self): + super(StateCalculateTransmissionLARMOR, self).__init__() + # Set the LOQ full wavelength range + self.wavelength_full_range_low = Configurations.LARMOR.wavelength_full_range_low + self.wavelength_full_range_high = Configurations.LARMOR.wavelength_full_range_high + + def validate(self): + super(StateCalculateTransmissionLARMOR, self).validate() + + +# ---------------------------------------------------------------------------------------------------------------------- +# Builder +# ---------------------------------------------------------------------------------------------------------------------- +def set_default_monitors(calculate_transmission_info, data_info): + """ + The default incident monitor is stored on the IPF. + :param calculate_transmission_info: a StateCalculateTransmission object on which we set the default value + :param data_info: a StateData object + """ + file_name = data_info.sample_scatter + _, ipf_path = get_instrument_paths_for_sans_file(file_name) + incident_tag = "default-incident-monitor-spectrum" + transmission_tag = "default-transmission-monitor-spectrum" + monitors_to_search = [incident_tag, transmission_tag] + found_monitor_spectrum = get_named_elements_from_ipf_file(ipf_path, monitors_to_search, int) + if incident_tag in found_monitor_spectrum: + calculate_transmission_info.default_incident_monitor = found_monitor_spectrum[incident_tag] + if transmission_tag in found_monitor_spectrum: + calculate_transmission_info.default_transmission_monitor = found_monitor_spectrum[transmission_tag] + + +# --------------------------------------- +# State builders +# --------------------------------------- +class StateCalculateTransmissionBuilderLOQ(object): + @automatic_setters(StateCalculateTransmissionLOQ) + def __init__(self, data_info): + super(StateCalculateTransmissionBuilderLOQ, self).__init__() + self._data = data_info + self.state = StateCalculateTransmissionLOQ() + set_default_monitors(self.state, self._data) + + def build(self): + self.state.validate() + return copy.copy(self.state) + + +class StateCalculateTransmissionBuilderSANS2D(object): + @automatic_setters(StateCalculateTransmissionSANS2D) + def __init__(self, data_info): + super(StateCalculateTransmissionBuilderSANS2D, self).__init__() + self._data = data_info + self.state = StateCalculateTransmissionSANS2D() + set_default_monitors(self.state, self._data) + + def build(self): + self.state.validate() + return copy.copy(self.state) + + +class StateCalculateTransmissionBuilderLARMOR(object): + @automatic_setters(StateCalculateTransmissionLARMOR) + def __init__(self, data_info): + super(StateCalculateTransmissionBuilderLARMOR, self).__init__() + self._data = data_info + self.state = StateCalculateTransmissionLARMOR() + set_default_monitors(self.state, self._data) + + def build(self): + self.state.validate() + return copy.copy(self.state) + + +# ------------------------------------------ +# Factory method for StateCalculateTransmissionBuilder +# ------------------------------------------ +def get_calculate_transmission_builder(data_info): + instrument = data_info.instrument + if instrument is SANSInstrument.LARMOR: + return StateCalculateTransmissionBuilderLARMOR(data_info) + elif instrument is SANSInstrument.SANS2D: + return StateCalculateTransmissionBuilderSANS2D(data_info) + elif instrument is SANSInstrument.LOQ: + return StateCalculateTransmissionBuilderLOQ(data_info) + else: + raise NotImplementedError("StateCalculateTransmissionBuilder: Could not find any valid transmission " + "builder for the specified StateData object {0}".format(str(data_info))) diff --git a/scripts/SANS/sans/state/convert_to_q.py b/scripts/SANS/sans/state/convert_to_q.py new file mode 100644 index 0000000000000000000000000000000000000000..a41a566f6624f7e9b8b7fb0caee746b5c6986587 --- /dev/null +++ b/scripts/SANS/sans/state/convert_to_q.py @@ -0,0 +1,171 @@ +# pylint: disable=too-few-public-methods + +"""State describing the conversion to momentum transfer""" + +import json +import copy +from sans.state.state_base import (StateBase, rename_descriptor_names, BoolParameter, PositiveFloatParameter, + ClassTypeParameter, StringParameter) +from sans.common.enums import (ReductionDimensionality, RangeStepType, SANSInstrument) +from sans.state.state_functions import (is_pure_none_or_not_none, is_not_none_and_first_larger_than_second, + validation_message) +from sans.state.automatic_setters import (automatic_setters) + + +# ---------------------------------------------------------------------------------------------------------------------- +# State +# ---------------------------------------------------------------------------------------------------------------------- +@rename_descriptor_names +class StateConvertToQ(StateBase): + reduction_dimensionality = ClassTypeParameter(ReductionDimensionality) + use_gravity = BoolParameter() + gravity_extra_length = PositiveFloatParameter() + radius_cutoff = PositiveFloatParameter() + wavelength_cutoff = PositiveFloatParameter() + + # 1D settings + # The complex binning instructions require a second step and a mid point, which produces: + # start -> step -> mid -> step2 -> stop + # The simple form is: + # start -> step -> stop + q_min = PositiveFloatParameter() + q_max = PositiveFloatParameter() + q_step = PositiveFloatParameter() + q_step_type = ClassTypeParameter(RangeStepType) + q_step2 = PositiveFloatParameter() + q_step_type2 = ClassTypeParameter(RangeStepType) + q_mid = PositiveFloatParameter() + + # 2D settings + q_xy_max = PositiveFloatParameter() + q_xy_step = PositiveFloatParameter() + q_xy_step_type = ClassTypeParameter(RangeStepType) + + # ----------------------- + # Q Resolution specific + # --------------------- + use_q_resolution = BoolParameter() + q_resolution_collimation_length = PositiveFloatParameter() + q_resolution_delta_r = PositiveFloatParameter() + moderator_file = StringParameter() + + # Circular aperture settings + q_resolution_a1 = PositiveFloatParameter() + q_resolution_a2 = PositiveFloatParameter() + + # Rectangular aperture settings + q_resolution_h1 = PositiveFloatParameter() + q_resolution_h2 = PositiveFloatParameter() + q_resolution_w1 = PositiveFloatParameter() + q_resolution_w2 = PositiveFloatParameter() + + def __init__(self): + super(StateConvertToQ, self).__init__() + self.reduction_dimensionality = ReductionDimensionality.OneDim + self.use_gravity = False + self.gravity_extra_length = 0.0 + self.use_q_resolution = False + self.radius_cutoff = 0.0 + self.wavelength_cutoff = 0.0 + + def validate(self): + is_invalid = {} + + # 1D Q settings + if not is_pure_none_or_not_none([self.q_min, self.q_max]): + entry = validation_message("The q boundaries for the 1D reduction are inconsistent.", + "Make sure that both q boundaries are set (or none).", + {"q_min": self.q_min, + "q_max": self.q_max}) + is_invalid.update(entry) + if is_not_none_and_first_larger_than_second([self.q_min, self.q_max]): + entry = validation_message("Incorrect q bounds for 1D reduction.", + "Make sure that the lower q bound is smaller than the upper q bound.", + {"q_min": self.q_min, + "q_max": self.q_max}) + is_invalid.update(entry) + + if self.reduction_dimensionality is ReductionDimensionality.OneDim: + if self.q_min is None or self.q_max is None: + entry = validation_message("Q bounds not set for 1D reduction.", + "Make sure to set the q boundaries when using a 1D reduction.", + {"q_min": self.q_min, + "q_max": self.q_max}) + is_invalid.update(entry) + + if self.reduction_dimensionality is ReductionDimensionality.TwoDim: + if self.q_xy_max is None or self.q_xy_step is None: + entry = validation_message("Q bounds not set for 2D reduction.", + "Make sure that the q_max value bound and the step for the 2D reduction.", + {"q_xy_max": self.q_xy_max, + "q_xy_step": self.q_xy_step}) + is_invalid.update(entry) + + # Q Resolution settings + if self.use_q_resolution: + if not is_pure_none_or_not_none([self.q_resolution_a1, self.q_resolution_a2]): + entry = validation_message("Inconsistent circular geometry.", + "Make sure that both diameters for the circular apertures are set.", + {"q_resolution_a1": self.q_resolution_a1, + "q_resolution_a2": self.q_resolution_a2}) + is_invalid.update(entry) + if not is_pure_none_or_not_none([self.q_resolution_h1, self.q_resolution_h2, self.q_resolution_w1, + self.q_resolution_w2]): + entry = validation_message("Inconsistent rectangular geometry.", + "Make sure that both diameters for the circular apertures are set.", + {"q_resolution_h1": self.q_resolution_h1, + "q_resolution_h2": self.q_resolution_h2, + "q_resolution_w1": self.q_resolution_w1, + "q_resolution_w2": self.q_resolution_w2}) + is_invalid.update(entry) + + if all(element is None for element in [self.q_resolution_a1, self.q_resolution_a2, self.q_resolution_w1, + self.q_resolution_w2, self.q_resolution_h1, self.q_resolution_h2]): + entry = validation_message("Aperture is undefined.", + "Make sure that you set the geometry for a circular or a " + "rectangular aperture.", + {"q_resolution_a1": self.q_resolution_a1, + "q_resolution_a2": self.q_resolution_a2, + "q_resolution_h1": self.q_resolution_h1, + "q_resolution_h2": self.q_resolution_h2, + "q_resolution_w1": self.q_resolution_w1, + "q_resolution_w2": self.q_resolution_w2}) + is_invalid.update(entry) + if self.moderator_file is None: + entry = validation_message("Missing moderator file.", + "Make sure to specify a moderator file when using q resolution.", + {"moderator_file": self.moderator_file}) + is_invalid.update(entry) + is_invalid.update({"moderator_file": "A moderator file is required for the q resolution calculation."}) + + if is_invalid: + raise ValueError("StateMoveDetectorISIS: The provided inputs are illegal. " + "Please see: {0}".format(json.dumps(is_invalid))) + + +# ---------------------------------------------------------------------------------------------------------------------- +# Builder +# ---------------------------------------------------------------------------------------------------------------------- +class StateConvertToQBuilder(object): + @automatic_setters(StateConvertToQ) + def __init__(self): + super(StateConvertToQBuilder, self).__init__() + self.state = StateConvertToQ() + + def build(self): + self.state.validate() + return copy.copy(self.state) + + +# ------------------------------------------ +# Factory method for StateConvertToQBuilder +# ------------------------------------------ +def get_convert_to_q_builder(data_info): + # The data state has most of the information that we require to define the move. For the factory method, only + # the instrument is of relevance. + instrument = data_info.instrument + if instrument is SANSInstrument.LARMOR or instrument is SANSInstrument.LOQ or instrument is SANSInstrument.SANS2D: + return StateConvertToQBuilder() + else: + raise NotImplementedError("StateConvertToQBuilder: Could not find any valid save builder for the " + "specified StateData object {0}".format(str(data_info))) diff --git a/scripts/SANS/sans/state/data.py b/scripts/SANS/sans/state/data.py new file mode 100644 index 0000000000000000000000000000000000000000..8042264a3168d9acf5b5e5aa81cf7b2a4356795f --- /dev/null +++ b/scripts/SANS/sans/state/data.py @@ -0,0 +1,140 @@ +# pylint: disable=too-few-public-methods + +"""State about the actual data which is to be reduced.""" + +import json +import copy + +from sans.state.state_base import (StateBase, StringParameter, PositiveIntegerParameter, + ClassTypeParameter, rename_descriptor_names) +from sans.common.enums import (SANSInstrument, SANSFacility) +from sans.common.constants import ALL_PERIODS +from sans.state.state_functions import (is_pure_none_or_not_none, validation_message) +from sans.common.file_information import SANSFileInformationFactory +from sans.state.automatic_setters import automatic_setters + + +# ---------------------------------------------------------------------------------------------------------------------- +# State +# ---------------------------------------------------------------------------------------------------------------------- +@rename_descriptor_names +class StateData(StateBase): + ALL_PERIODS = ALL_PERIODS + sample_scatter = StringParameter() + sample_scatter_period = PositiveIntegerParameter() + sample_transmission = StringParameter() + sample_transmission_period = PositiveIntegerParameter() + sample_direct = StringParameter() + sample_direct_period = PositiveIntegerParameter() + + can_scatter = StringParameter() + can_scatter_period = PositiveIntegerParameter() + can_transmission = StringParameter() + can_transmission_period = PositiveIntegerParameter() + can_direct = StringParameter() + can_direct_period = PositiveIntegerParameter() + + calibration = StringParameter() + + sample_scatter_run_number = PositiveIntegerParameter() + instrument = ClassTypeParameter(SANSInstrument) + + def __init__(self): + super(StateData, self).__init__() + + # Setup default values for periods + self.sample_scatter_period = StateData.ALL_PERIODS + self.sample_transmission_period = StateData.ALL_PERIODS + self.sample_direct_period = StateData.ALL_PERIODS + + self.can_scatter_period = StateData.ALL_PERIODS + self.can_transmission_period = StateData.ALL_PERIODS + self.can_direct_period = StateData.ALL_PERIODS + + # This should be reset by the builder. Setting this to NoInstrument ensure that we will trip early on, + # in case this is not set, for example by not using the builders. + self.instrument = SANSInstrument.NoInstrument + + def validate(self): + is_invalid = dict() + + # A sample scatter must be specified + if self.sample_scatter is None: + entry = validation_message("Sample scatter was not specified.", + "Make sure that the sample scatter file is specified.", + {"sample_scatter": self.sample_scatter}) + is_invalid.update(entry) + + # If the sample transmission/direct was specified, then a sample direct/transmission is required + if not is_pure_none_or_not_none([self.sample_transmission, self.sample_direct]): + entry = validation_message("If the sample transmission is specified then, the direct run needs to be " + "specified too.", + "Make sure that the transmission and direct runs are both specified (or none).", + {"sample_transmission": self.sample_transmission, + "sample_direct": self.sample_direct}) + is_invalid.update(entry) + + # If the can transmission/direct was specified, then this requires the can scatter + if (self.can_direct or self.can_transmission) and (not self.can_scatter): + entry = validation_message("If the can transmission is specified then the can scatter run needs to be " + "specified too.", + "Make sure that the can scatter file is set.", + {"can_scatter": self.can_scatter, + "can_transmission": self.can_transmission, + "can_direct": self.can_direct}) + is_invalid.update(entry) + + # If a can transmission/direct was specified, then the other can entries need to be specified as well. + if self.can_scatter and not is_pure_none_or_not_none([self.can_transmission, self.can_direct]): + entry = validation_message("Inconsistent can transmission setting.", + "Make sure that the can transmission and can direct runs are set (or none of" + " them).", + {"can_transmission": self.can_transmission, + "can_direct": self.can_direct}) + is_invalid.update(entry) + + if is_invalid: + raise ValueError("StateData: The provided inputs are illegal. " + "Please see: {0}".format(json.dumps(is_invalid))) + + +# ---------------------------------------------------------------------------------------------------------------------- +# Builder +# ---------------------------------------------------------------------------------------------------------------------- +def set_information_from_file(data_info): + file_name = data_info.sample_scatter + file_information_factory = SANSFileInformationFactory() + file_information = file_information_factory.create_sans_file_information(file_name) + instrument = file_information.get_instrument() + run_number = file_information.get_run_number() + data_info.instrument = instrument + data_info.sample_scatter_run_number = run_number + + +class StateDataBuilder(object): + @automatic_setters(StateData) + def __init__(self): + super(StateDataBuilder, self).__init__() + self.state = StateData() + + def build(self): + # Make sure that the product is in a valid state, ie not incomplete + self.state.validate() + + # There are some elements which need to be read from the file. This is currently: + # 1. instrument + # 2. sample_scatter_run_number + set_information_from_file(self.state) + + return copy.copy(self.state) + + +# ------------------------------------------ +# Factory method for StateDataBuilder +# ------------------------------------------ +def get_data_builder(facility): + if facility is SANSFacility.ISIS: + return StateDataBuilder() + else: + raise NotImplementedError("StateDataBuilder: The selected facility {0} does not seem" + " to exist".format(str(facility))) diff --git a/scripts/SANS/sans/state/mask.py b/scripts/SANS/sans/state/mask.py new file mode 100644 index 0000000000000000000000000000000000000000..8c96299d50518ff2bbfa470632ed288a793eddcc --- /dev/null +++ b/scripts/SANS/sans/state/mask.py @@ -0,0 +1,273 @@ +# pylint: disable=too-few-public-methods + +"""State describing the masking behaviour of the SANS reduction.""" + +import json +import copy +from sans.state.state_base import (StateBase, BoolParameter, StringListParameter, StringParameter, + PositiveFloatParameter, FloatParameter, FloatListParameter, + DictParameter, PositiveIntegerListParameter, rename_descriptor_names) +from sans.state.state_functions import (is_pure_none_or_not_none, validation_message, set_detector_names) +from sans.state.automatic_setters import (automatic_setters) +from sans.common.file_information import find_full_file_path +from sans.common.enums import (SANSInstrument, DetectorType) +from sans.common.file_information import (get_instrument_paths_for_sans_file) + + +# ---------------------------------------------------------------------------------------------------------------------- +# State +# ---------------------------------------------------------------------------------------------------------------------- +def range_check(start, stop, invalid_dict, start_name, stop_name, general_name=None): + """ + Checks a start container against a stop container + + :param start: The start container + :param stop: the stop container + :param invalid_dict: The invalid dict to which we write our error messages + :param start_name: The name of the start container + :param stop_name: The name of the stop container + :param general_name: The general name of this container family + :return: A (potentially) updated invalid_dict + """ + if not is_pure_none_or_not_none([start, stop]): + entry = validation_message("A range element has not been set.", + "Make sure that all entries are set.", + {start_name: start, + stop_name: stop}) + invalid_dict.update(entry) + + if start is not None and stop is not None: + # Start and stop need to have the same length + if len(start) != len(stop): + entry = validation_message("Start and stop ranges have different lengths.", + "Make sure that all entries for {0} he same length.".format(general_name), + {start_name: start, + stop_name: stop}) + invalid_dict.update(entry) + # Start values need to be smaller than the stop values + for a, b in zip(start, stop): + if a > b: + entry = validation_message("Incorrect start-stop bounds.", + "Make sure the lower bound is smaller than the upper bound for {0}." + "".format(general_name), + {start_name: start, + stop_name: stop}) + invalid_dict.update(entry) + return invalid_dict + + +# ------------------------------------------------ +# StateData +# ------------------------------------------------ +@rename_descriptor_names +class StateMaskDetector(StateBase): + # Vertical strip masks + single_vertical_strip_mask = PositiveIntegerListParameter() + range_vertical_strip_start = PositiveIntegerListParameter() + range_vertical_strip_stop = PositiveIntegerListParameter() + + # Horizontal strip masks + single_horizontal_strip_mask = PositiveIntegerListParameter() + range_horizontal_strip_start = PositiveIntegerListParameter() + range_horizontal_strip_stop = PositiveIntegerListParameter() + + # Spectrum Block + block_horizontal_start = PositiveIntegerListParameter() + block_horizontal_stop = PositiveIntegerListParameter() + block_vertical_start = PositiveIntegerListParameter() + block_vertical_stop = PositiveIntegerListParameter() + + # Spectrum block cross + block_cross_horizontal = PositiveIntegerListParameter() + block_cross_vertical = PositiveIntegerListParameter() + + # Time/Bin mask + bin_mask_start = FloatListParameter() + bin_mask_stop = FloatListParameter() + + # Name of the detector + detector_name = StringParameter() + detector_name_short = StringParameter() + + def __init__(self): + super(StateMaskDetector, self).__init__() + + def validate(self): + is_invalid = {} + # -------------------- + # Vertical strip mask + # -------------------- + range_check(self.range_vertical_strip_start, self.range_vertical_strip_stop, + is_invalid, "range_vertical_strip_start", "range_vertical_strip_stop", "range_vertical_strip") + + # -------------------- + # Horizontal strip mask + # -------------------- + range_check(self.range_horizontal_strip_start, self.range_horizontal_strip_stop, + is_invalid, "range_horizontal_strip_start", "range_horizontal_strip_stop", "range_horizontal_strip") + + # -------------------- + # Block mask + # -------------------- + range_check(self.block_horizontal_start, self.block_horizontal_stop, + is_invalid, "block_horizontal_start", "block_horizontal_stop", "block_horizontal") + range_check(self.block_vertical_start, self.block_vertical_stop, + is_invalid, "block_vertical_start", "block_vertical_stop", "block_vertical") + + # -------------------- + # Time/Bin mask + # -------------------- + range_check(self.bin_mask_start, self.bin_mask_stop, + is_invalid, "bin_mask_start", "bin_mask_stop", "bin_mask") + + if not self.detector_name: + entry = validation_message("Missing detector name.", + "Make sure that the detector names are set.", + {"detector_name": self.detector_name}) + is_invalid.update(entry) + if not self.detector_name_short: + entry = validation_message("Missing short detector name.", + "Make sure that the short detector names are set.", + {"detector_name_short": self.detector_name_short}) + is_invalid.update(entry) + if is_invalid: + raise ValueError("StateMoveDetectorISIS: The provided inputs are illegal. " + "Please see: {0}".format(json.dumps(is_invalid))) + + +@rename_descriptor_names +class StateMask(StateBase): + # Radius Mask + radius_min = FloatParameter() + radius_max = FloatParameter() + + # Bin mask + bin_mask_general_start = FloatListParameter() + bin_mask_general_stop = FloatListParameter() + + # Mask files + mask_files = StringListParameter() + + # Angle masking + phi_min = FloatParameter() + phi_max = FloatParameter() + use_mask_phi_mirror = BoolParameter() + + # Beam stop + beam_stop_arm_width = PositiveFloatParameter() + beam_stop_arm_angle = FloatParameter() + beam_stop_arm_pos1 = FloatParameter() + beam_stop_arm_pos2 = FloatParameter() + + # Clear commands + clear = BoolParameter() + clear_time = BoolParameter() + + # Single Spectra + single_spectra = PositiveIntegerListParameter() + + # Spectrum Range + spectrum_range_start = PositiveIntegerListParameter() + spectrum_range_stop = PositiveIntegerListParameter() + + # The detector dependent masks + detectors = DictParameter() + + # The idf path of the instrument + idf_path = StringParameter() + + def __init__(self): + super(StateMask, self).__init__() + # Setup the detectors + self.detectors = {DetectorType.to_string(DetectorType.LAB): StateMaskDetector(), + DetectorType.to_string(DetectorType.HAB): StateMaskDetector()} + + # IDF Path + self.idf_path = "" + + def validate(self): + is_invalid = dict() + + # -------------------- + # Radius Mask + # -------------------- + # Radius mask rule: the min radius must be less or equal to the max radius + if self.radius_max is not None and self.radius_min is not None and\ + self.radius_max != -1 and self.radius_min != -1: # noqa + if self.radius_min > 0 and self.radius_max > 0 and (self.radius_min > self.radius_max): + entry = validation_message("Incorrect radius bounds.", + "Makes sure that the lower radius bound is smaller than the" + " upper radius bound.", + {"radius_min": self.radius_min, + "radius_max": self.radius_max}) + is_invalid.update(entry) + + # -------------------- + # General bin mask + # -------------------- + range_check(self.bin_mask_general_start, self.bin_mask_general_stop, + is_invalid, "bin_mask_general_start", "bin_mask_general_stop", "bin_mask_general") + + # -------------------- + # Mask files + # -------------------- + if self.mask_files: + for mask_file in self.mask_files: + if not find_full_file_path(mask_file): + entry = validation_message("Mask file not found.", + "Makes sure that the mask file is in your path", + {"mask_file": self.mask_file}) + is_invalid.update(entry) + + # -------------------- + # Spectrum Range + # -------------------- + range_check(self.spectrum_range_start, self.spectrum_range_stop, + is_invalid, "spectrum_range_start", "spectrum_range_stop", "spectrum_range") + + # -------------------- + # Detectors + # -------------------- + for _, value in self.detectors.items(): + value.validate() + + if is_invalid: + raise ValueError("StateMask: The provided inputs are illegal. " + "Please see: {0}".format(json.dumps(is_invalid))) + + +# ---------------------------------------------------------------------------------------------------------------------- +# Builder +# ---------------------------------------------------------------------------------------------------------------------- +def setup_idf_and_ipf_content(move_info, data_info): + # Get the IDF and IPF path since they contain most of the import information + file_name = data_info.sample_scatter + idf_path, ipf_path = get_instrument_paths_for_sans_file(file_name) + # Set the detector names + set_detector_names(move_info, ipf_path) + # Set the idf path + move_info.idf_path = idf_path + + +class StateMaskBuilder(object): + @automatic_setters(StateMask) + def __init__(self, data_info): + super(StateMaskBuilder, self).__init__() + self._data = data_info + self.state = StateMask() + setup_idf_and_ipf_content(self.state, data_info) + + def build(self): + self.state.validate() + return copy.copy(self.state) + + +def get_mask_builder(data_info): + # The data state has most of the information that we require to define the move. For the factory method, only + # the instrument is of relevance. + instrument = data_info.instrument + if instrument is SANSInstrument.LARMOR or instrument is SANSInstrument.LOQ or instrument is SANSInstrument.SANS2D: + return StateMaskBuilder(data_info) + else: + raise NotImplementedError("StateMaskBuilder: Could not find any valid mask builder for the " + "specified StateData object {0}".format(str(data_info))) diff --git a/scripts/SANS/sans/state/move.py b/scripts/SANS/sans/state/move.py new file mode 100644 index 0000000000000000000000000000000000000000..9fff81134be4ec03c07db46aace66d8e4baab425 --- /dev/null +++ b/scripts/SANS/sans/state/move.py @@ -0,0 +1,266 @@ +# pylint: disable=too-few-public-methods, too-many-instance-attributes + +"""State for moving workspaces.""" + +import json +import copy + +from sans.state.state_base import (StateBase, FloatParameter, DictParameter, ClassTypeParameter, + StringParameter, rename_descriptor_names) +from sans.common.enums import (Coordinates, CanonicalCoordinates, SANSInstrument, DetectorType) +from sans.common.file_information import (get_instrument_paths_for_sans_file) +from sans.state.automatic_setters import automatic_setters +from sans.state.state_functions import (validation_message, set_detector_names, set_monitor_names) + + +# ---------------------------------------------------------------------------------------------------------------------- +# State +# ---------------------------------------------------------------------------------------------------------------------- +@rename_descriptor_names +class StateMoveDetector(StateBase): + x_translation_correction = FloatParameter() + y_translation_correction = FloatParameter() + z_translation_correction = FloatParameter() + + rotation_correction = FloatParameter() + side_correction = FloatParameter() + radius_correction = FloatParameter() + + x_tilt_correction = FloatParameter() + y_tilt_correction = FloatParameter() + z_tilt_correction = FloatParameter() + + sample_centre_pos1 = FloatParameter() + sample_centre_pos2 = FloatParameter() + + # Name of the detector + detector_name = StringParameter() + detector_name_short = StringParameter() + + def __init__(self): + super(StateMoveDetector, self).__init__() + # Translation correction + self.x_translation_correction = 0.0 + self.y_translation_correction = 0.0 + self.z_translation_correction = 0.0 + + self.rotation_correction = 0.0 + self.side_correction = 0.0 + self.radius_correction = 0.0 + + self.x_tilt_correction = 0.0 + self.y_tilt_correction = 0.0 + self.z_tilt_correction = 0.0 + + # Sample centre Pos 1 + Pos 2 + self.sample_centre_pos1 = 0.0 + self.sample_centre_pos2 = 0.0 + + def validate(self): + is_invalid = {} + if not self.detector_name: + entry = validation_message("Missing detector name", + "Make sure that a detector name was specified.", + {"detector_name": self.detector_name}) + is_invalid.update(entry) + if not self.detector_name_short: + entry = validation_message("Missing short detector name", + "Make sure that a short detector name was specified.", + {"detector_name_short": self.detector_name_short}) + is_invalid.update(entry) + if is_invalid: + raise ValueError("StateMoveDetectorISIS: The provided inputs are illegal. " + "Please see: {0}".format(json.dumps(is_invalid))) + + +@rename_descriptor_names +class StateMove(StateBase): + sample_offset = FloatParameter() + sample_offset_direction = ClassTypeParameter(Coordinates) + detectors = DictParameter() + + def __init__(self): + super(StateMove, self).__init__() + + # Setup the sample offset + self.sample_offset = 0.0 + + # The sample offset direction is Z for the ISIS instruments + self.sample_offset_direction = CanonicalCoordinates.Z + + # Setup the detectors + self.detectors = {DetectorType.to_string(DetectorType.LAB): StateMoveDetector(), + DetectorType.to_string(DetectorType.HAB): StateMoveDetector()} + + def validate(self): + # No validation of the descriptors on this level, let potential exceptions from detectors "bubble" up + for key in self.detectors: + self.detectors[key].validate() + + +@rename_descriptor_names +class StateMoveLOQ(StateMove): + monitor_names = DictParameter() + center_position = FloatParameter() + + def __init__(self): + super(StateMoveLOQ, self).__init__() + # Set the center_position in meter + self.center_position = 317.5 / 1000. + + # Set the monitor names + self.monitor_names = {} + + def validate(self): + # No validation of the descriptors on this level, let potential exceptions from detectors "bubble" up + super(StateMoveLOQ, self).validate() + + +@rename_descriptor_names +class StateMoveSANS2D(StateMove): + monitor_names = DictParameter() + + hab_detector_radius = FloatParameter() + hab_detector_default_sd_m = FloatParameter() + hab_detector_default_x_m = FloatParameter() + + lab_detector_default_sd_m = FloatParameter() + + hab_detector_x = FloatParameter() + hab_detector_z = FloatParameter() + + hab_detector_rotation = FloatParameter() + + lab_detector_x = FloatParameter() + lab_detector_z = FloatParameter() + + monitor_4_offset = FloatParameter() + + def __init__(self): + super(StateMoveSANS2D, self).__init__() + # Set the descriptors which corresponds to information which we gain through the IPF + self.hab_detector_radius = 306.0 / 1000. + self.hab_detector_default_sd_m = 4.0 + self.hab_detector_default_x_m = 1.1 + self.lab_detector_default_sd_m = 4.0 + + # The actual values are found on the workspace and should be used from there. This is only a fall back. + self.hab_detector_x = 0.0 + self.hab_detector_z = 0.0 + self.hab_detector_rotation = 0.0 + self.lab_detector_x = 0.0 + self.lab_detector_z = 0.0 + + # Set the monitor names + self.monitor_names = {} + + self.monitor_4_offset = 0.0 + + def validate(self): + super(StateMoveSANS2D, self).validate() + + +@rename_descriptor_names +class StateMoveLARMOR(StateMove): + monitor_names = DictParameter() + bench_rotation = FloatParameter() + + def __init__(self): + super(StateMoveLARMOR, self).__init__() + + # Set a default for the bench rotation + self.bench_rotation = 0.0 + + # Set the monitor names + self.monitor_names = {} + + def validate(self): + super(StateMoveLARMOR, self).validate() + + +# ---------------------------------------------------------------------------------------------------------------------- +# Builder +# ---------------------------------------------------------------------------------------------------------------------- +def setup_idf_and_ipf_content(move_info, data_info): + # Get the IDF and IPF path since they contain most of the import information + file_name = data_info.sample_scatter + idf_path, ipf_path = get_instrument_paths_for_sans_file(file_name) + # Set the detector names + set_detector_names(move_info, ipf_path) + # Set the monitor names + set_monitor_names(move_info, idf_path) + + +class StateMoveLOQBuilder(object): + @automatic_setters(StateMoveLOQ, exclusions=["detector_name", "detector_name_short", "monitor_names"]) + def __init__(self, data_info): + super(StateMoveLOQBuilder, self).__init__() + self.state = StateMoveLOQ() + setup_idf_and_ipf_content(self.state, data_info) + + def build(self): + self.state.validate() + return copy.copy(self.state) + + def convert_pos1(self, value): + return value / 1000. + + def convert_pos2(self, value): + return value / 1000. + + +class StateMoveSANS2DBuilder(object): + @automatic_setters(StateMoveSANS2D, exclusions=["detector_name", "detector_name_short", "monitor_names"]) + def __init__(self, data_info): + super(StateMoveSANS2DBuilder, self).__init__() + self.state = StateMoveSANS2D() + setup_idf_and_ipf_content(self.state, data_info) + + def build(self): + self.state.validate() + return copy.copy(self.state) + + def convert_pos1(self, value): + return value / 1000. + + def convert_pos2(self, value): + return value / 1000. + + +class StateMoveLARMORBuilder(object): + @automatic_setters(StateMoveLARMOR, exclusions=["detector_name", "detector_name_short", "monitor_names"]) + def __init__(self, data_info): + super(StateMoveLARMORBuilder, self).__init__() + self.state = StateMoveLARMOR() + setup_idf_and_ipf_content(self.state, data_info) + self.conversion_value = 1000. + self._set_conversion_value(data_info) + + def _set_conversion_value(self, data_info): + run_number = data_info.sample_scatter_run_number + self.conversion_value = 1000. if run_number >= 2217 else 1. + + def build(self): + self.state.validate() + return copy.copy(self.state) + + def convert_pos1(self, value): + return value / self.conversion_value + + def convert_pos2(self, value): + return value / 1000. + + +def get_move_builder(data_info): + # The data state has most of the information that we require to define the move. For the factory method, only + # the instrument is of relevance. + instrument = data_info.instrument + if instrument is SANSInstrument.LOQ: + return StateMoveLOQBuilder(data_info) + elif instrument is SANSInstrument.SANS2D: + return StateMoveSANS2DBuilder(data_info) + elif instrument is SANSInstrument.LARMOR: + return StateMoveLARMORBuilder(data_info) + else: + raise NotImplementedError("StateMoveBuilder: Could not find any valid move builder for the " + "specified StateData object {0}".format(str(data_info))) diff --git a/scripts/SANS/sans/state/normalize_to_monitor.py b/scripts/SANS/sans/state/normalize_to_monitor.py new file mode 100644 index 0000000000000000000000000000000000000000..3b037e399d237aacfdfca7a15b6864c466949906 --- /dev/null +++ b/scripts/SANS/sans/state/normalize_to_monitor.py @@ -0,0 +1,206 @@ +# pylint: disable=too-few-public-methods + +"""State describing the normalization to the incident monitor for SANS reduction.""" + +import json +import copy +from sans.state.state_base import (StateBase, rename_descriptor_names, PositiveIntegerParameter, + PositiveFloatParameter, FloatParameter, ClassTypeParameter, DictParameter, + PositiveFloatWithNoneParameter) +from sans.state.automatic_setters import (automatic_setters) +from sans.common.enums import (RebinType, RangeStepType, SANSInstrument) +from sans.state.state_functions import (is_pure_none_or_not_none, is_not_none_and_first_larger_than_second, + one_is_none, validation_message) +from sans.common.xml_parsing import get_named_elements_from_ipf_file +from sans.common.file_information import (get_instrument_paths_for_sans_file) + + +# ---------------------------------------------------------------------------------------------------------------------- +# State +# ---------------------------------------------------------------------------------------------------------------------- +@rename_descriptor_names +class StateNormalizeToMonitor(StateBase): + prompt_peak_correction_min = PositiveFloatWithNoneParameter() + prompt_peak_correction_max = PositiveFloatWithNoneParameter() + + rebin_type = ClassTypeParameter(RebinType) + wavelength_low = PositiveFloatParameter() + wavelength_high = PositiveFloatParameter() + wavelength_step = PositiveFloatParameter() + wavelength_step_type = ClassTypeParameter(RangeStepType) + + background_TOF_general_start = FloatParameter() + background_TOF_general_stop = FloatParameter() + background_TOF_monitor_start = DictParameter() + background_TOF_monitor_stop = DictParameter() + + incident_monitor = PositiveIntegerParameter() + + def __init__(self): + super(StateNormalizeToMonitor, self).__init__() + self.background_TOF_monitor_start = {} + self.background_TOF_monitor_stop = {} + + def validate(self): + is_invalid = {} + # ----------------- + # incident Monitor + # ----------------- + if self.incident_monitor is None: + is_invalid.update({"incident_monitor": "An incident monitor must be specified."}) + + # ----------------- + # Prompt peak + # ----------------- + if not is_pure_none_or_not_none([self.prompt_peak_correction_min, self.prompt_peak_correction_max]): + entry = validation_message("A prompt peak correction entry has not been set.", + "Make sure that either all prompt peak entries have been set or none.", + {"prompt_peak_correction_min": self.prompt_peak_correction_min, + "prompt_peak_correction_max": self.prompt_peak_correction_max}) + is_invalid.update(entry) + + if is_not_none_and_first_larger_than_second([self.prompt_peak_correction_min, self.prompt_peak_correction_max]): + entry = validation_message("Incorrect prompt peak correction bounds.", + "Make sure that lower prompt peak time bound is smaller then upper bound.", + {"prompt_peak_correction_min": self.prompt_peak_correction_min, + "prompt_peak_correction_max": self.prompt_peak_correction_max}) + is_invalid.update(entry) + + # ----------------- + # Wavelength rebin + # ----------------- + if one_is_none([self.wavelength_low, self.wavelength_high, self.wavelength_step, self.wavelength_step_type]): + entry = validation_message("A wavelength entry has not been set.", + "Make sure that all entries are set.", + {"wavelength_low": self.wavelength_low, + "wavelength_high": self.wavelength_high, + "wavelength_step": self.wavelength_step, + "wavelength_step_type": self.wavelength_step_type}) + is_invalid.update(entry) + + if is_not_none_and_first_larger_than_second([self.wavelength_low, self.wavelength_high]): + entry = validation_message("Incorrect wavelength bounds.", + "Make sure that lower wavelength bound is smaller then upper bound.", + {"wavelength_low": self.wavelength_low, + "wavelength_high": self.wavelength_high}) + is_invalid.update(entry) + + # ---------------------- + # Background correction + # ---------------------- + if not is_pure_none_or_not_none([self.background_TOF_general_start, self.background_TOF_general_stop]): + entry = validation_message("A general background TOF entry has not been set.", + "Make sure that either all general background TOF entries are set or none.", + {"background_TOF_general_start": self.background_TOF_general_start, + "background_TOF_general_stop": self.background_TOF_general_stop}) + is_invalid.update(entry) + + if is_not_none_and_first_larger_than_second([self.background_TOF_general_start, + self.background_TOF_general_stop]): + entry = validation_message("Incorrect general background TOF bounds.", + "Make sure that lower general background TOF bound is smaller then upper bound.", + {"background_TOF_general_start": self.background_TOF_general_start, + "background_TOF_general_stop": self.background_TOF_general_stop}) + is_invalid.update(entry) + + if not is_pure_none_or_not_none([self.background_TOF_monitor_start, self.background_TOF_monitor_stop]): + entry = validation_message("A monitor background TOF entry has not been set.", + "Make sure that either all monitor background TOF entries are set or none.", + {"background_TOF_monitor_start": self.background_TOF_monitor_start, + "background_TOF_monitor_stop": self.background_TOF_monitor_stop}) + is_invalid.update(entry) + + if self.background_TOF_monitor_start is not None and self.background_TOF_monitor_stop is not None: + if len(self.background_TOF_monitor_start) != len(self.background_TOF_monitor_stop): + entry = validation_message("The monitor background TOF entries have a length mismatch.", + "Make sure that all monitor background TOF entries have the same length.", + {"background_TOF_monitor_start": self.background_TOF_monitor_start, + "background_TOF_monitor_stop": self.background_TOF_monitor_stop}) + is_invalid.update(entry) + for key_start, value_start in self.background_TOF_monitor_start.items(): + if key_start not in self.background_TOF_monitor_stop: + entry = validation_message("The monitor background TOF had spectrum number mismatch.", + "Make sure that all monitors have entries for start and stop.", + {"background_TOF_monitor_start": self.background_TOF_monitor_start, + "background_TOF_monitor_stop": self.background_TOF_monitor_stop}) + is_invalid.update(entry) + else: + value_stop = self.background_TOF_monitor_stop[key_start] + if value_start > value_stop: + entry = validation_message("Incorrect monitor background TOF bounds.", + "Make sure that lower monitor background TOF bound is" + " smaller then upper bound.", + {"background_TOF_monitor_start": self.background_TOF_monitor_start, + "background_TOF_monitor_stop": self.background_TOF_monitor_stop}) + is_invalid.update(entry) + + if is_invalid: + raise ValueError("StateMoveDetector: The provided inputs are illegal. " + "Please see: {0}".format(json.dumps(is_invalid))) + + +@rename_descriptor_names +class StateNormalizeToMonitorLOQ(StateNormalizeToMonitor): + def __init__(self): + super(StateNormalizeToMonitorLOQ, self).__init__() + # Set the LOQ default range for prompt peak correction + self.prompt_peak_correction_min = 19000.0 + self.prompt_peak_correction_max = 20500.0 + + def validate(self): + super(StateNormalizeToMonitorLOQ, self).validate() + + +# ---------------------------------------------------------------------------------------------------------------------- +# Builder +# ---------------------------------------------------------------------------------------------------------------------- +def set_default_incident_monitor(normalize_monitor_info, data_info): + """ + The default incident monitor is stored on the IPF. + :param normalize_monitor_info: a StateNormalizeMonitor object on which we set the default value + :param data_info: a StateData object + """ + file_name = data_info.sample_scatter + _, ipf_path = get_instrument_paths_for_sans_file(file_name) + named_element = "default-incident-monitor-spectrum" + monitor_spectrum_tag_to_search = [named_element] + found_monitor_spectrum = get_named_elements_from_ipf_file(ipf_path, monitor_spectrum_tag_to_search, int) + if named_element in found_monitor_spectrum: + normalize_monitor_info.incident_monitor = found_monitor_spectrum[named_element] + + +class StateNormalizeToMonitorBuilder(object): + @automatic_setters(StateNormalizeToMonitor, exclusions=["default_incident_monitor"]) + def __init__(self, data_info): + super(StateNormalizeToMonitorBuilder, self).__init__() + self._data = data_info + self.state = StateNormalizeToMonitor() + set_default_incident_monitor(self.state, self._data) + + def build(self): + self.state.validate() + return copy.copy(self.state) + + +class StateNormalizeToMonitorBuilderLOQ(object): + @automatic_setters(StateNormalizeToMonitorLOQ) + def __init__(self, data_info): + super(StateNormalizeToMonitorBuilderLOQ, self).__init__() + self._data = data_info + self.state = StateNormalizeToMonitorLOQ() + set_default_incident_monitor(self.state, self._data) + + def build(self): + self.state.validate() + return copy.copy(self.state) + + +def get_normalize_to_monitor_builder(data_info): + instrument = data_info.instrument + if instrument is SANSInstrument.LARMOR or instrument is SANSInstrument.SANS2D: + return StateNormalizeToMonitorBuilder(data_info) + elif instrument is SANSInstrument.LOQ: + return StateNormalizeToMonitorBuilderLOQ(data_info) + else: + raise NotImplementedError("StateNormalizeToMonitorBuilder: Could not find any valid normalize to monitor " + "builder for the specified StateData object {0}".format(str(data_info))) diff --git a/scripts/SANS/sans/state/reduction_mode.py b/scripts/SANS/sans/state/reduction_mode.py new file mode 100644 index 0000000000000000000000000000000000000000..917973e9b7008996f5dca9b34f50917e7ce7adda --- /dev/null +++ b/scripts/SANS/sans/state/reduction_mode.py @@ -0,0 +1,131 @@ +# pylint: disable=too-few-public-methods + +""" Defines the state of the reduction.""" + +from abc import (ABCMeta, abstractmethod) +import copy + +from sans.state.state_base import (StateBase, ClassTypeParameter, FloatParameter, DictParameter, + FloatWithNoneParameter, rename_descriptor_names) +from sans.common.enums import (ReductionMode, ISISReductionMode, ReductionDimensionality, FitModeForMerge, + SANSInstrument, DetectorType) +from sans.common.file_information import (get_instrument_paths_for_sans_file) +from sans.common.xml_parsing import get_named_elements_from_ipf_file +from sans.state.automatic_setters import (automatic_setters) + + +# ---------------------------------------------------------------------------------------------------------------------- +# State +# ---------------------------------------------------------------------------------------------------------------------- +class StateReductionBase(object): + __metaclass__ = ABCMeta + + @abstractmethod + def get_merge_strategy(self): + pass + + @abstractmethod + def get_detector_name_for_reduction_mode(self, reduction_mode): + pass + + @abstractmethod + def get_all_reduction_modes(self): + pass + + +@rename_descriptor_names +class StateReductionMode(StateReductionBase, StateBase): + reduction_mode = ClassTypeParameter(ReductionMode) + reduction_dimensionality = ClassTypeParameter(ReductionDimensionality) + + # Fitting + merge_fit_mode = ClassTypeParameter(FitModeForMerge) + merge_shift = FloatParameter() + merge_scale = FloatParameter() + merge_range_min = FloatWithNoneParameter() + merge_range_max = FloatWithNoneParameter() + + # Map from detector type to detector name + detector_names = DictParameter() + + def __init__(self): + super(StateReductionMode, self).__init__() + self.reduction_mode = ISISReductionMode.LAB + self.reduction_dimensionality = ReductionDimensionality.OneDim + + # Set the shifts to defaults which essentially don't do anything. + self.merge_shift = 0.0 + self.merge_scale = 1.0 + self.merge_fit_mode = FitModeForMerge.NoFit + self.merge_range_min = None + self.merge_range_max = None + + # Set the detector names to empty strings + self.detector_names = {DetectorType.to_string(DetectorType.LAB): "", + DetectorType.to_string(DetectorType.HAB): ""} + + def get_merge_strategy(self): + return [ISISReductionMode.LAB, ISISReductionMode.HAB] + + def get_all_reduction_modes(self): + return [ISISReductionMode.LAB, ISISReductionMode.HAB] + + def get_detector_name_for_reduction_mode(self, reduction_mode): + if reduction_mode is ISISReductionMode.LAB: + bank_type = DetectorType.to_string(DetectorType.LAB) + elif reduction_mode is ISISReductionMode.HAB: + bank_type = DetectorType.to_string(DetectorType.HAB) + else: + raise RuntimeError("SANStateReductionISIS: There is no detector available for the" + " reduction mode {0}.".format(reduction_mode)) + return self.detector_names[bank_type] + + def validate(self): + pass + + +# ---------------------------------------------------------------------------------------------------------------------- +# Builder +# ---------------------------------------------------------------------------------------------------------------------- +def setup_detectors_from_ipf(reduction_info, data_info): + file_name = data_info.sample_scatter + _, ipf_path = get_instrument_paths_for_sans_file(file_name) + + detector_names = {DetectorType.to_string(DetectorType.LAB): "low-angle-detector-name", + DetectorType.to_string(DetectorType.HAB): "high-angle-detector-name"} + + names_to_search = [] + names_to_search.extend(detector_names.values()) + + found_detector_names = get_named_elements_from_ipf_file(ipf_path, names_to_search, str) + + for detector_type in reduction_info.detector_names.keys(): + try: + detector_name_tag = detector_names[detector_type] + detector_name = found_detector_names[detector_name_tag] + except KeyError: + continue + reduction_info.detector_names[detector_type] = detector_name + + +class StateReductionModeBuilder(object): + @automatic_setters(StateReductionMode, exclusions=["detector_names"]) + def __init__(self, data_info): + super(StateReductionModeBuilder, self).__init__() + self.state = StateReductionMode() + setup_detectors_from_ipf(self.state, data_info) + + def build(self): + self.state.validate() + return copy.copy(self.state) + + +def get_reduction_mode_builder(data_info): + # The data state has most of the information that we require to define the move. For the factory method, only + # the instrument is of relevance. + instrument = data_info.instrument + if instrument is SANSInstrument.LARMOR or instrument is SANSInstrument.LOQ or instrument is SANSInstrument.SANS2D: + return StateReductionModeBuilder(data_info) + else: + raise NotImplementedError("StateReductionBuilder: Could not find any valid reduction builder for the " + "specified StateData object {0}".format(str(data_info))) diff --git a/scripts/SANS/sans/state/save.py b/scripts/SANS/sans/state/save.py new file mode 100644 index 0000000000000000000000000000000000000000..cb1f03ef6f07de16ac109612c7050566a919d039 --- /dev/null +++ b/scripts/SANS/sans/state/save.py @@ -0,0 +1,50 @@ +# pylint: disable=too-few-public-methods + +""" Defines the state of saving.""" +import copy +from sans.state.state_base import (StateBase, BoolParameter, StringParameter, + ClassTypeListParameter, rename_descriptor_names) +from sans.common.enums import (SaveType, SANSInstrument) +from sans.state.automatic_setters import (automatic_setters) + + +# ---------------------------------------------------------------------------------------------------------------------- +# State +# ---------------------------------------------------------------------------------------------------------------------- +@rename_descriptor_names +class StateSave(StateBase): + file_name = StringParameter() + zero_free_correction = BoolParameter() + file_format = ClassTypeListParameter(SaveType) + + def __init__(self): + super(StateSave, self).__init__() + self.zero_free_correction = True + + def validate(self): + pass + + +# ---------------------------------------------------------------------------------------------------------------------- +# Builder +# ---------------------------------------------------------------------------------------------------------------------- +class StateSaveBuilder(object): + @automatic_setters(StateSave) + def __init__(self): + super(StateSaveBuilder, self).__init__() + self.state = StateSave() + + def build(self): + self.state.validate() + return copy.copy(self.state) + + +def get_save_builder(data_info): + # The data state has most of the information that we require to define the move. For the factory method, only + # the instrument is of relevance. + instrument = data_info.instrument + if instrument is SANSInstrument.LARMOR or instrument is SANSInstrument.LOQ or instrument is SANSInstrument.SANS2D: + return StateSaveBuilder() + else: + raise NotImplementedError("StateSaveBuilder: Could not find any valid save builder for the " + "specified StateData object {0}".format(str(data_info))) diff --git a/scripts/SANS/sans/state/scale.py b/scripts/SANS/sans/state/scale.py new file mode 100644 index 0000000000000000000000000000000000000000..cf888b6209011d0939e3df618648522f4fd264d4 --- /dev/null +++ b/scripts/SANS/sans/state/scale.py @@ -0,0 +1,51 @@ +""" Defines the state of the geometry and unit scaling.""" +import copy +from sans.state.state_base import (StateBase, rename_descriptor_names, PositiveFloatParameter, ClassTypeParameter) +from sans.common.enums import (SampleShape, SANSInstrument) +from sans.state.automatic_setters import (automatic_setters) + + +# ---------------------------------------------------------------------------------------------------------------------- +# State +# ---------------------------------------------------------------------------------------------------------------------- +@rename_descriptor_names +class StateScale(StateBase): + shape = ClassTypeParameter(SampleShape) + thickness = PositiveFloatParameter() + width = PositiveFloatParameter() + height = PositiveFloatParameter() + scale = PositiveFloatParameter() + + def __init__(self): + super(StateScale, self).__init__() + + def validate(self): + pass + + +# ---------------------------------------------------------------------------------------------------------------------- +# Builder +# ---------------------------------------------------------------------------------------------------------------------- +class StateScaleBuilder(object): + @automatic_setters(StateScale, exclusions=[]) + def __init__(self): + super(StateScaleBuilder, self).__init__() + self.state = StateScale() + + def build(self): + self.state.validate() + return copy.copy(self.state) + + +# ------------------------------------------ +# Factory method for SANStateScaleBuilder +# ------------------------------------------ +def get_scale_builder(data_info): + # The data state has most of the information that we require to define the move. For the factory method, only + # the instrument is of relevance. + instrument = data_info.instrument + if instrument is SANSInstrument.LARMOR or instrument is SANSInstrument.LOQ or instrument is SANSInstrument.SANS2D: + return StateScaleBuilder() + else: + raise NotImplementedError("StateScaleBuilder: Could not find any valid scale builder for the " + "specified StateData object {0}".format(str(data_info))) diff --git a/scripts/SANS/sans/state/slice_event.py b/scripts/SANS/sans/state/slice_event.py new file mode 100644 index 0000000000000000000000000000000000000000..744dd94e25523f2254acc74b80e2180d3cf6d7c7 --- /dev/null +++ b/scripts/SANS/sans/state/slice_event.py @@ -0,0 +1,110 @@ +""" Defines the state of the event slices which should be reduced.""" + +import json +import copy +from sans.state.state_base import (StateBase, rename_descriptor_names, FloatListParameter) +from sans.state.state_functions import (is_pure_none_or_not_none, validation_message) +from sans.common.enums import SANSInstrument +from sans.state.automatic_setters import (automatic_setters) + + +# ---------------------------------------------------------------------------------------------------------------------- +# State +# ---------------------------------------------------------------------------------------------------------------------- +@rename_descriptor_names +class StateSliceEvent(StateBase): + start_time = FloatListParameter() + end_time = FloatListParameter() + + def __init__(self): + super(StateSliceEvent, self).__init__() + + def validate(self): + is_invalid = dict() + + if not is_pure_none_or_not_none([self.start_time, self.end_time]): + entry = validation_message("Missing slice times", + "Makes sure that either both or none are set.", + {"start_time": self.start_time, + "end_time": self.end_time}) + is_invalid.update(entry) + + if self.start_time and self.end_time: + # The length of start_time and end_time needs to be identical + if len(self.start_time) != len(self.end_time): + entry = validation_message("Bad relation of start and end", + "Makes sure that the start time is smaller than the end time.", + {"start_time": self.start_time, + "end_time": self.end_time}) + is_invalid.update(entry) + + # Each entry in start_time and end_time must be a float + if len(self.start_time) == len(self.end_time) and len(self.start_time) > 0: + for item in range(0, len(self.start_time)): + for element1, element2 in zip(self.start_time, self.end_time): + if not isinstance(element1, float) or not isinstance(element2, float): + entry = validation_message("Bad relation of start and end time entries", + "The elements need to be floats.", + {"start_time": self.start_time, + "end_time": self.end_time}) + is_invalid.update(entry) + + # Check that the entries are monotonically increasing. We don't want 12, 24, 22 + if len(self.start_time) > 1 and not monotonically_increasing(self.start_time): + entry = validation_message("Not monotonically increasing start time list", + "Make sure that the start times increase monotonically.", + {"start_time": self.start_time}) + is_invalid.update(entry) + + if len(self.end_time) > 1 and not monotonically_increasing(self.end_time): + entry = validation_message("Not monotonically increasing end time list", + "Make sure that the end times increase monotonically.", + {"end_time": self.end_time}) + is_invalid.update(entry) + + # Check that end_time is not smaller than start_time + if not is_smaller(self.start_time, self.end_time): + entry = validation_message("Start time larger than end time.", + "Make sure that the start time is not smaller than the end time.", + {"start_time": self.start_time, + "end_time": self.end_time}) + is_invalid.update(entry) + + if is_invalid: + raise ValueError("StateSliceEvent: The provided inputs are illegal. " + "Please see: {}".format(json.dumps(is_invalid))) + + +def monotonically_increasing(to_check): + return all(x <= y for x, y in zip(to_check, to_check[1:])) + + +def is_smaller(smaller, larger): + return all(x <= y for x, y in zip(smaller, larger)) + + +# ---------------------------------------------------------------------------------------------------------------------- +# Builder +# ---------------------------------------------------------------------------------------------------------------------- +class StateSliceEventBuilder(object): + @automatic_setters(StateSliceEvent) + def __init__(self): + super(StateSliceEventBuilder, self).__init__() + self.state = StateSliceEvent() + + def build(self): + # Make sure that the product is in a valid state, ie not incomplete + self.state.validate() + return copy.copy(self.state) + + +# ------------------------------------------ +# Factory method for SANStateDataBuilder +# ------------------------------------------ +def get_slice_event_builder(data_info): + instrument = data_info.instrument + if instrument is SANSInstrument.LARMOR or instrument is SANSInstrument.LOQ or instrument is SANSInstrument.SANS2D: + return StateSliceEventBuilder() + else: + raise NotImplementedError("StateSliceEventBuilder: Could not find any valid slice builder for the " + "specified StateData object {0}".format(str(data_info))) diff --git a/scripts/SANS/sans/state/state.py b/scripts/SANS/sans/state/state.py new file mode 100644 index 0000000000000000000000000000000000000000..7c5344caf767237946c3bdc54eb8fb713a88ec4b --- /dev/null +++ b/scripts/SANS/sans/state/state.py @@ -0,0 +1,110 @@ +""" Defines the main State object.""" + +# pylint: disable=too-few-public-methods + +import json +import pickle +import inspect +import copy +from sans.common.enums import SANSInstrument +from sans.state.state_base import (StateBase, TypedParameter, + rename_descriptor_names, validator_sub_state) +from sans.state.data import StateData +from sans.state.move import StateMove +from sans.state.reduction_mode import StateReductionMode +from sans.state.slice_event import StateSliceEvent +from sans.state.mask import StateMask +from sans.state.wavelength import StateWavelength +from sans.state.save import StateSave +from sans.state.adjustment import StateAdjustment +from sans.state.scale import StateScale +from sans.state.convert_to_q import StateConvertToQ +from sans.state.automatic_setters import (automatic_setters) + + +# ---------------------------------------------------------------------------------------------------------------------- +# State +# ---------------------------------------------------------------------------------------------------------------------- +@rename_descriptor_names +class State(StateBase): + data = TypedParameter(StateData, validator_sub_state) + move = TypedParameter(StateMove, validator_sub_state) + reduction = TypedParameter(StateReductionMode, validator_sub_state) + slice = TypedParameter(StateSliceEvent, validator_sub_state) + mask = TypedParameter(StateMask, validator_sub_state) + wavelength = TypedParameter(StateWavelength, validator_sub_state) + save = TypedParameter(StateSave, validator_sub_state) + scale = TypedParameter(StateScale, validator_sub_state) + adjustment = TypedParameter(StateAdjustment, validator_sub_state) + convert_to_q = TypedParameter(StateConvertToQ, validator_sub_state) + + def __init__(self): + super(State, self).__init__() + + def validate(self): + is_invalid = dict() + + # Make sure that the substates are contained + if not self.data: + is_invalid.update("State: The state object needs to include a StateData object.") + if not self.move: + is_invalid.update("State: The state object needs to include a StateMove object.") + if not self.reduction: + is_invalid.update("State: The state object needs to include a StateReduction object.") + if not self.slice: + is_invalid.update("State: The state object needs to include a StateSliceEvent object.") + if not self.mask: + is_invalid.update("State: The state object needs to include a StateMask object.") + if not self.wavelength: + is_invalid.update("State: The state object needs to include a StateWavelength object.") + if not self.save: + is_invalid.update("State: The state object needs to include a StateSave object.") + if not self.scale: + is_invalid.update("State: The state object needs to include a StateScale object.") + if not self.adjustment: + is_invalid.update("State: The state object needs to include a StateAdjustment object.") + if not self.convert_to_q: + is_invalid.update("State: The state object needs to include a StateConvertToQ object.") + + if is_invalid: + raise ValueError("State: There is an issue with your in put. See: {0}".format(json.dumps(is_invalid))) + + # Check the attributes themselves + is_invalid = {} + for descriptor_name, descriptor_object in inspect.getmembers(type(self)): + if inspect.isdatadescriptor(descriptor_object) and isinstance(descriptor_object, TypedParameter): + try: + attr = getattr(self, descriptor_name) + attr.validate() + except ValueError as err: + is_invalid.update({descriptor_name: pickle.dumps(str(err))}) + + if is_invalid: + raise ValueError("State: There is an issue with your in put. See: {0}".format(json.dumps(is_invalid))) + + +# ---------------------------------------------------------------------------------------------------------------------- +# Builder +# ---------------------------------------------------------------------------------------------------------------------- +class StateBuilder(object): + @automatic_setters(State) + def __init__(self): + super(StateBuilder, self).__init__() + self.state = State() + + def build(self): + # Make sure that the product is in a valid state, ie not incomplete + self.state.validate() + return copy.copy(self.state) + + +# ------------------------------------------ +# Factory method for SANStateDataBuilder +# ------------------------------------------ +def get_state_builder(data_info): + instrument = data_info.instrument + if instrument is SANSInstrument.LARMOR or instrument is SANSInstrument.LOQ or instrument is SANSInstrument.SANS2D: + return StateBuilder() + else: + raise NotImplementedError("SANSStateBuilder: Could not find any valid state builder for the " + "specified SANSStateData object {0}".format(str(data_info))) diff --git a/scripts/SANS/sans/state/state_base.py b/scripts/SANS/sans/state/state_base.py new file mode 100644 index 0000000000000000000000000000000000000000..90063b75664af55b6153921db3f974c28c3eb301 --- /dev/null +++ b/scripts/SANS/sans/state/state_base.py @@ -0,0 +1,566 @@ + +# pylint: disable=too-few-public-methods, invalid-name + +""" Fundamental classes and Descriptors for the State mechanism.""" +from abc import (ABCMeta, abstractmethod) +import copy +import inspect +from functools import (partial) + +from mantid.kernel import (PropertyManager, std_vector_dbl, std_vector_str, std_vector_int) + + +# --------------------------------------------------------------- +# Validator functions +# --------------------------------------------------------------- +def is_not_none(value): + return value is not None + + +def is_positive(value): + return value >= 0 + + +def is_positive_or_none(value): + return value >= 0 or value is None + + +def all_list_elements_are_of_specific_type_and_not_empty(value, comparison_type, + additional_comparison=lambda x: True, type_check=isinstance): + """ + Ensures that all elements of a list are of a specific type and that the list is not empty + + @param value: the list to check + @param comparison_type: the expected type of the elements of the list. + @param additional_comparison: additional comparison lambda. + @param type_check: the method which performs type checking. + @return: True if the list is not empty and all types are as expected, else False. + """ + is_of_type = True + for element in value: + # Perform type check + if not type_check(element, comparison_type): + is_of_type = False + # Perform additional check + if not additional_comparison(element): + is_of_type = False + + if not value: + is_of_type = False + return is_of_type + + +def all_list_elements_are_of_instance_type_and_not_empty(value, comparison_type, additional_comparison=lambda x: True): + """ + Ensures that all elements of a list are of a certain INSTANCE type and that the list is not empty. + """ + return all_list_elements_are_of_specific_type_and_not_empty(value=value, comparison_type=comparison_type, + additional_comparison=additional_comparison, + type_check=isinstance) + + +def all_list_elements_are_of_class_type_and_not_empty(value, comparison_type, additional_comparison=lambda x: True): + """ + Ensures that all elements of a list are of a certain INSTANCE type and that the list is not empty. + """ + return all_list_elements_are_of_specific_type_and_not_empty(value=value, comparison_type=comparison_type, + additional_comparison=additional_comparison, + type_check=issubclass) + + +def all_list_elements_are_float_and_not_empty(value): + typed_comparison = partial(all_list_elements_are_of_instance_type_and_not_empty, comparison_type=float) + return typed_comparison(value) + + +def all_list_elements_are_string_and_not_empty(value): + typed_comparison = partial(all_list_elements_are_of_instance_type_and_not_empty, comparison_type=str) + return typed_comparison(value) + + +def all_list_elements_are_int_and_not_empty(value): + typed_comparison = partial(all_list_elements_are_of_instance_type_and_not_empty, comparison_type=int) + return typed_comparison(value) + + +def all_list_elements_are_int_and_positive_and_not_empty(value): + typed_comparison = partial(all_list_elements_are_of_instance_type_and_not_empty, comparison_type=int, + additional_comparison=lambda x: x >= 0) + return typed_comparison(value) + + +def validator_sub_state(sub_state): + is_valid = True + try: + sub_state.validate() + except ValueError: + is_valid = False + return is_valid + + +# ------------------------------------------------------- +# Parameters +# ------------------------------------------------------- +class TypedParameter(object): + """ + The TypedParameter descriptor allows the user to store/handle a type-checked value with an additional + validator option, e.g. one can restrict the held parameter to be only a positive value. + """ + __counter = 0 + + def __init__(self, parameter_type, validator=lambda x: True): + cls = self.__class__ + prefix = cls.__name__ + # pylint: disable=protected-access + index = cls.__counter + cls.__counter += 1 + # Name which is used to store value in the instance. This will be unique and not accessible via the standard + # attribute access, since the developer/user cannot apply the hash symbol in their code (it is valid though + # when writing into the __dict__). Note that the name which we generate here will be altered (via a + # class decorator) in the classes which actually use the TypedParameter descriptor, to make it more readable. + self.name = '_{0}#{1}'.format(prefix, index) + self.parameter_type = parameter_type + self.value = None + self.validator = validator + + def __get__(self, instance, owner): + if instance is None: + return self + else: + if hasattr(instance, self.name): + return getattr(instance, self.name) + else: + return None + + def __set__(self, instance, value): + # Perform a type check + self._type_check(value) + if self.validator(value): + # The descriptor should be holding onto its own data and return a deepcopy of the data. + copied_value = copy.deepcopy(value) + setattr(instance, self.name, copied_value) + else: + raise ValueError("Trying to set {0} with an invalid value of {1}".format(self.name, str(value))) + + def __delete__(self): + raise AttributeError("Cannot delete the attribute {0}".format(self.name)) + + def _type_check(self, value): + if not isinstance(value, self.parameter_type): + raise TypeError("Trying to set {0} which expects a value of type {1}." + " Got a value of {2} which is of type: {3}".format(self.name, str(self.parameter_type), + str(value), str(type(value)))) + + +# --------------------------------------------------- +# Various standard cases of the TypedParameter +# --------------------------------------------------- +class StringParameter(TypedParameter): + def __init__(self): + super(StringParameter, self).__init__(str, is_not_none) + + +class BoolParameter(TypedParameter): + def __init__(self): + super(BoolParameter, self).__init__(bool, is_not_none) + + +class FloatParameter(TypedParameter): + def __init__(self): + super(FloatParameter, self).__init__(float, is_not_none) + + +class PositiveFloatParameter(TypedParameter): + def __init__(self): + super(PositiveFloatParameter, self).__init__(float, is_positive) + + +class PositiveIntegerParameter(TypedParameter): + def __init__(self): + super(PositiveIntegerParameter, self).__init__(int, is_positive) + + +class DictParameter(TypedParameter): + def __init__(self): + super(DictParameter, self).__init__(dict, is_not_none) + + +class ClassTypeParameter(TypedParameter): + """ + This TypedParameter variant allows for storing a class type. + + This could be for example something from the SANSType module, e.g. CanonicalCoordinates.X + It is something that is used frequently with the main of moving away from using strings where types + should be used instead. + """ + def __init__(self, class_type): + super(ClassTypeParameter, self).__init__(class_type, is_not_none) + + def _type_check(self, value): + if not issubclass(value, self.parameter_type): + raise TypeError("Trying to set {0} which expects a value of type {1}." + " Got a value of {2} which is of type: {3}".format(self.name, str(self.parameter_type), + str(value), type(value))) + + +class FloatWithNoneParameter(TypedParameter): + def __init__(self): + super(FloatWithNoneParameter, self).__init__(float) + + def _type_check(self, value): + if not isinstance(value, self.parameter_type) and value is not None: + raise TypeError("Trying to set {0} which expects a value of type {1}." + " Got a value of {2} which is of type: {3}".format(self.name, str(self.parameter_type), + str(value), type(value))) + + +class PositiveFloatWithNoneParameter(TypedParameter): + def __init__(self): + super(PositiveFloatWithNoneParameter, self).__init__(float, is_positive_or_none) + + def _type_check(self, value): + if not isinstance(value, self.parameter_type) and value is not None: + raise TypeError("Trying to set {0} which expects a value of type {1}." + " Got a value of {2} which is of type: {3}".format(self.name, str(self.parameter_type), + str(value), type(value))) + + +class FloatListParameter(TypedParameter): + def __init__(self): + super(FloatListParameter, self).__init__(list) + + def _type_check(self, value): + if not isinstance(value, self.parameter_type) or not all_list_elements_are_float_and_not_empty(value): + raise TypeError("Trying to set {0} which expects a value of type {1}." + " Got a value of {2} which is of type: {3}".format(self.name, str(self.parameter_type), + str(value), type(value))) + + +class StringListParameter(TypedParameter): + def __init__(self): + super(StringListParameter, self).__init__(list, all_list_elements_are_string_and_not_empty) + + def _type_check(self, value): + if not isinstance(value, self.parameter_type) or not all_list_elements_are_string_and_not_empty(value): + raise TypeError("Trying to set {0} which expects a value of type {1}." + " Got a value of {2} which is of type: {3}".format(self.name, str(self.parameter_type), + str(value), type(value))) + + +class PositiveIntegerListParameter(TypedParameter): + def __init__(self): + super(PositiveIntegerListParameter, self).__init__(list, + all_list_elements_are_int_and_positive_and_not_empty) + + def _type_check(self, value): + if not isinstance(value, self.parameter_type) or not all_list_elements_are_int_and_not_empty(value): + raise TypeError("Trying to set {0} which expects a value of type {1}." + " Got a value of {2} which is of type: {3}".format(self.name, str(self.parameter_type), + str(value), type(value))) + + +class ClassTypeListParameter(TypedParameter): + def __init__(self, class_type): + typed_comparison = partial(all_list_elements_are_of_class_type_and_not_empty, comparison_type=class_type) + super(ClassTypeListParameter, self).__init__(list, typed_comparison) + + +# ------------------------------------------------ +# StateBase +# ------------------------------------------------ +class StateBase(object): + """ The fundamental base of the SANS State""" + __metaclass__ = ABCMeta + + @property + def property_manager(self): + return convert_state_to_dict(self) + + @property_manager.setter + def property_manager(self, value): + set_state_from_property_manager(self, value) + + @abstractmethod + def validate(self): + pass + + +def rename_descriptor_names(cls): + """ + Class decorator which changes the names of TypedParameters in a class instance in order to make it more readable. + + This is especially helpful for debugging. And also in order to find attributes in the dictionaries. + :param cls: The class with the TypedParameters + :return: The class with the TypedParameters + """ + for attribute_name, attribute_value in cls.__dict__.items(): + if isinstance(attribute_value, TypedParameter): + attribute_value.name = '_{0}#{1}'.format(type(attribute_value).__name__, attribute_name) + return cls + + +# ------------------------------------------------ +# Serialization of the State +# ------------------------------------------------ +# Serialization of the state object is currently done via generating a dict object. Reversely, we can generate a +# State object from a property manager object, not a dict object. This quirk results from the way Mantid +# treats property manager inputs and outputs (it reads in dicts and converts them to property manager objects). +# We might have to live with that for now. +# +# During serialization we place identifier tags into the serialized object, e.g. we add a specifier if the item +# is a State type at all and if so which state it is. + + +STATE_NAME = "state_name" +STATE_MODULE = "state_module" +SEPARATOR_SERIAL = "#" +class_type_parameter_id = "ClassTypeParameterID#" +MODULE = "__module__" + + +def is_state(property_manager): + return property_manager.existsProperty(STATE_NAME) and property_manager.existsProperty(STATE_MODULE) + + +def is_float_vector(value): + return isinstance(value, std_vector_dbl) + + +def is_string_vector(value): + return isinstance(value, std_vector_str) + + +def is_int_vector(value): + return isinstance(value, std_vector_int) + + +def get_module_and_class_name(instance): + if inspect.isclass(instance): + module_name, class_name = str(instance.__dict__[MODULE]), str(instance.__name__) + else: + module_name, class_name = str(type(instance).__dict__[MODULE]), str(type(instance).__name__) + return module_name, class_name + + +def provide_class_from_module_and_class_name(module_name, class_name): + # Importlib seems to be missing on RHEL6, hence we resort to __import__ + try: + from importlib import import_module + module = import_module(module_name) + except ImportError: + if "." in module_name: + _, mod_name = module_name.rsplit(".", 1) + else: + mod_name = None + if not mod_name: + module = __import__(module_name) + else: + module = __import__(module_name, fromlist=[mod_name]) + return getattr(module, class_name) + + +def provide_class(instance): + module_name = instance.getProperty(STATE_MODULE).value + class_name = instance.getProperty(STATE_NAME).value + return provide_class_from_module_and_class_name(module_name, class_name) + + +def is_class_type_parameter(value): + return isinstance(value, basestring) and class_type_parameter_id in value + + +def is_vector_with_class_type_parameter(value): + is_vector_with_class_type = True + contains_str = is_string_vector(value) + if contains_str: + for element in value: + if not is_class_type_parameter(element): + is_vector_with_class_type = False + else: + is_vector_with_class_type = False + return is_vector_with_class_type + + +def get_module_and_class_name_from_encoded_string(encoder, value): + without_encoder = value.replace(encoder, "") + return without_encoder.split(SEPARATOR_SERIAL) + + +def create_module_and_class_name_from_encoded_string(class_type_id, module_name, class_name): + return class_type_id + module_name + SEPARATOR_SERIAL + class_name + + +def create_sub_state(value): + # We are dealing with a sub state. We first have to create it and then populate it + sub_state_class = provide_class(value) + # Create the sub state, populate it and set it on the super state + sub_state = sub_state_class() + sub_state.property_manager = value + return sub_state + + +def get_descriptor_values(instance): + # Get all descriptor names which are TypedParameter of instance's type + descriptor_names = [] + descriptor_types = {} + for descriptor_name, descriptor_object in inspect.getmembers(type(instance)): + if inspect.isdatadescriptor(descriptor_object) and isinstance(descriptor_object, TypedParameter): + descriptor_names.append(descriptor_name) + descriptor_types.update({descriptor_name: descriptor_object}) + + # Get the descriptor values from the instance + descriptor_values = {} + for key in descriptor_names: + if hasattr(instance, key): + value = getattr(instance, key) + if value is not None: + descriptor_values.update({key: value}) + return descriptor_values, descriptor_types + + +def get_class_descriptor_types(instance): + # Get all descriptor names which are TypedParameter of instance's type + descriptors = {} + for descriptor_name, descriptor_object in inspect.getmembers(type(instance)): + if inspect.isdatadescriptor(descriptor_object) and isinstance(descriptor_object, TypedParameter): + descriptors.update({descriptor_name: type(descriptor_object)}) + return descriptors + + +def convert_state_to_dict(instance): + """ + Converts the state object to a dictionary. + + @param instance: the instance which is to be converted + @return: a serialized state object in the form of a dict + """ + descriptor_values, descriptor_types = get_descriptor_values(instance) + # Add the descriptors to a dict + state_dict = dict() + for key, value in descriptor_values.items(): + # If the value is a SANSBaseState then create a dict from it + # If the value is a dict, then we need to check what the sub types are + # If the value is a ClassTypeParameter, then we need to encode it + # If the value is a list of ClassTypeParameters, then we need to encode each element in the list + if isinstance(value, StateBase): + sub_state_dict = value.property_manager + value = sub_state_dict + elif isinstance(value, dict): + # If we have a dict, then we need to watch out since a value in the dict might be a State + sub_dictionary = {} + for key_sub, val_sub in value.items(): + if isinstance(val_sub, StateBase): + sub_dictionary_value = val_sub.property_manager + else: + sub_dictionary_value = val_sub + sub_dictionary.update({key_sub: sub_dictionary_value}) + value = sub_dictionary + elif isinstance(descriptor_types[key], ClassTypeParameter): + value = get_serialized_class_type_parameter(value) + elif isinstance(descriptor_types[key], ClassTypeListParameter): + if value: + # If there are entries in the list, then convert them individually and place them into a list. + # The list will contain a sequence of serialized ClassTypeParameters + serialized_value = [] + for element in value: + serialized_element = get_serialized_class_type_parameter(element) + serialized_value.append(serialized_element) + value = serialized_value + + state_dict.update({key: value}) + # Add information about the current state object, such as in which module it lives and what its name is + module_name, class_name = get_module_and_class_name(instance) + state_dict.update({STATE_MODULE: module_name}) + state_dict.update({STATE_NAME: class_name}) + return state_dict + + +def set_state_from_property_manager(instance, property_manager): + """ + Set the State object from the information stored on a property manager object. This is the deserialization step. + + @param instance: the instance which is to be set with a values of the propery manager + @param property_manager: the property manager withe the stored setting + """ + def _set_element(inst, k_element, v_element): + if k_element != STATE_NAME and k_element != STATE_MODULE: + setattr(inst, k_element, v_element) + + keys = property_manager.keys() + for key in keys: + value = property_manager.getProperty(key).value + # There are four scenarios that need to be considered + # 1. ParameterManager 1: This indicates (most often) that we are dealing with a new state -> create it and + # apply recursion + # 2. ParameterManager 2: In some cases the ParameterManager object is actually a map rather than a state -> + # populate the state + # 3. String with special meaning: Admittedly this is a hack, but we are limited by the input property types + # of Mantid algorithms, which can be string, int, float and containers of these + # types (and PropertyManagerProperties). We need a wider range of types, such + # as ClassTypeParameters. These are encoded (as good as possible) in a string + # 4. Vector of strings with special meaning: See point 3) + # 5. Vector for float: This needs to handle Mantid's float array + # 6. Vector for string: This needs to handle Mantid's string array + # 7. Vector for int: This needs to handle Mantid's integer array + # 8. Normal values: all is fine, just populate them + if type(value) is PropertyManager and is_state(value): + sub_state = create_sub_state(value) + setattr(instance, key, sub_state) + elif type(value) is PropertyManager: + # We must be dealing with an actual dict descriptor + sub_dict_keys = value.keys() + dict_element = {} + # We need to watch out if a value of the dictionary is a sub state + for sub_dict_key in sub_dict_keys: + sub_dict_value = value.getProperty(sub_dict_key).value + if type(sub_dict_value) == PropertyManager and is_state(sub_dict_value): + sub_state = create_sub_state(sub_dict_value) + sub_dict_value_to_insert = sub_state + else: + sub_dict_value_to_insert = sub_dict_value + dict_element.update({sub_dict_key: sub_dict_value_to_insert}) + setattr(instance, key, dict_element) + elif is_class_type_parameter(value): + class_type_parameter = get_deserialized_class_type_parameter(value) + _set_element(instance, key, class_type_parameter) + elif is_vector_with_class_type_parameter(value): + class_type_list = [] + for element in value: + class_type_parameter = get_deserialized_class_type_parameter(element) + class_type_list.append(class_type_parameter) + _set_element(instance, key, class_type_list) + elif is_float_vector(value): + float_list_value = list(value) + _set_element(instance, key, float_list_value) + elif is_string_vector(value): + string_list_value = list(value) + _set_element(instance, key, string_list_value) + elif is_int_vector(value): + int_list_value = list(value) + _set_element(instance, key, int_list_value) + else: + _set_element(instance, key, value) + + +def get_serialized_class_type_parameter(value): + # The module will only know about the outer class name, therefore we need + # 1. The module name + # 2. The name of the outer class + # 3. The name of the actual class + module_name, class_name = get_module_and_class_name(value) + outer_class_name = value.outer_class_name + class_name = outer_class_name + SEPARATOR_SERIAL + class_name + return create_module_and_class_name_from_encoded_string(class_type_parameter_id, module_name, class_name) + + +def get_deserialized_class_type_parameter(value): + # We need to first get the outer class from the module + module_name, outer_class_name, class_name = \ + get_module_and_class_name_from_encoded_string(class_type_parameter_id, value) + outer_class_type_parameter = provide_class_from_module_and_class_name(module_name, outer_class_name) + # From the outer class we can then retrieve the inner class which normally defines the users selection + return getattr(outer_class_type_parameter, class_name) + + +def create_deserialized_sans_state_from_property_manager(property_manager): + return create_sub_state(property_manager) diff --git a/scripts/SANS/sans/state/state_functions.py b/scripts/SANS/sans/state/state_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..25e1e0968efc98037c50e9db1234b7efdc09bbed --- /dev/null +++ b/scripts/SANS/sans/state/state_functions.py @@ -0,0 +1,274 @@ +"""Set of general purpose functions which are related to the SANSState approach.""" + +from copy import deepcopy +from sans.common.enums import (ReductionDimensionality, ISISReductionMode, OutputParts, DetectorType) +from sans.common.constants import (ALL_PERIODS, REDUCED_WORKSPACE_NAME_IN_LOGS, EMPTY_NAME, REDUCED_CAN_TAG) +from sans.common.general_functions import (add_to_sample_log, get_ads_workspace_references) +from sans.common.log_tagger import (has_hash, get_hash_value, set_hash) +from sans.common.xml_parsing import (get_monitor_names_from_idf_file, get_named_elements_from_ipf_file) + + +def add_workspace_name(workspace, state, reduction_mode): + """ + Adds the default reduced workspace name to the sample logs + + :param workspace: The output workspace + :param state: a SANSState object + :param reduction_mode: the reduction mode, i.e. LAB, HAB, MERGED + """ + reduced_workspace_name = get_output_workspace_name(state, reduction_mode) + add_to_sample_log(workspace, REDUCED_WORKSPACE_NAME_IN_LOGS, reduced_workspace_name, "String") + + +def get_output_workspace_name_from_workspace(workspace): + run = workspace.run() + if not run.hasProperty(REDUCED_WORKSPACE_NAME_IN_LOGS): + raise RuntimeError("The workspace does not seem to contain an entry for the output workspace name.") + return run.getProperty(REDUCED_WORKSPACE_NAME_IN_LOGS).value + + +def get_output_workspace_name(state, reduction_mode): + """ + Creates the name of the output workspace from a state object. + + The name of the output workspace is: + 1. The short run number + 2. If specific period is being reduced: 'p' + number + 3. Short detector name of the current reduction or "merged" + 4. The reduction dimensionality: "_" + dimensionality + 5. A wavelength range: wavelength_low + "_" + wavelength_high + 6. In case of a 1D reduction, then add phi limits + 7. If we are dealing with an actual slice limit, then specify it: "_tXX_TYY" Note that the time set to + two decimals + :param state: a SANSState object + :param reduction_mode: which reduction is being looked at + :return: the name of the reduced workspace + """ + # 1. Short run number + data = state.data + short_run_number = data.sample_scatter_run_number + short_run_number_as_string = str(short_run_number) + + # 2. Multiperiod + if state.data.sample_scatter_period != ALL_PERIODS: + period = data.sample_scatter_period + period_as_string = "p"+str(period) + else: + period_as_string = "" + + # 3. Detector name + move = state.move + detectors = move.detectors + if reduction_mode is ISISReductionMode.Merged: + detector_name_short = "merged" + elif reduction_mode is ISISReductionMode.HAB: + detector_name_short = detectors[DetectorType.to_string(DetectorType.HAB)].detector_name_short + elif reduction_mode is ISISReductionMode.LAB: + detector_name_short = detectors[DetectorType.to_string(DetectorType.LAB)].detector_name_short + else: + raise RuntimeError("SANSStateFunctions: Unknown reduction mode {0} cannot be used to " + "create an output name".format(reduction_mode)) + + # 4. Dimensionality + reduction = state.reduction + if reduction.reduction_dimensionality is ReductionDimensionality.OneDim: + dimensionality_as_string = "_1D" + else: + dimensionality_as_string = "_2D" + + # 5. Wavelength range + wavelength = state.wavelength + wavelength_range_string = str(wavelength.wavelength_low) + "_" + str(wavelength.wavelength_high) + + # 6. Phi Limits + mask = state.mask + if reduction.reduction_dimensionality is ReductionDimensionality.OneDim: + if mask.phi_min and mask.phi_max and (abs(mask.phi_max - mask.phi_min) != 180.0): + phi_limits_as_string = 'Phi' + str(mask.phi_min) + '_' + str(mask.phi_max) + else: + phi_limits_as_string = "" + else: + phi_limits_as_string = "" + + # 7. Slice limits + slice_state = state.slice + start_time = slice_state.start_time + end_time = slice_state.end_time + if start_time and end_time: + start_time_as_string = '_t%.2f' % start_time[0] + end_time_as_string = '_T%.2f' % end_time[0] + else: + start_time_as_string = "" + end_time_as_string = "" + + # Piece it all together + output_workspace_name = (short_run_number_as_string + period_as_string + detector_name_short + + dimensionality_as_string + wavelength_range_string + phi_limits_as_string + + start_time_as_string + end_time_as_string) + return output_workspace_name + + +def is_pure_none_or_not_none(elements_to_check): + """ + Checks a list of elements contains None entries and non-None entries + + @param elements_to_check: a list with entries to check + @return: True if the list contains either only None or only non-None elements, else False + """ + are_all_none_or_all_not_none = True + + if len(elements_to_check) == 0: + return are_all_none_or_all_not_none + return all(element is not None for element in elements_to_check) or \ + all(element is None for element in elements_to_check) # noqa + + +def is_not_none_and_first_larger_than_second(elements_to_check): + """ + This function checks if both are not none and then checks if the first element is smaller than the second element. + + @param elements_to_check: a list with two entries. The first is the lower bound and the second entry is the upper + bound + @return: False if at least one input is None or if both are not None and the first element is smaller than the + second else True + """ + is_invalid = True + if len(elements_to_check) != 2: + return is_invalid + if any(element is None for element in elements_to_check): + is_invalid = False + return is_invalid + if elements_to_check[0] < elements_to_check[1]: + is_invalid = False + return is_invalid + + +def one_is_none(elements_to_check): + return any(element is None for element in elements_to_check) + + +def validation_message(error_message, instruction, variables): + """ + Generates a validation message for the SANSState. + + @param error_message: A message describing the error. + @param instruction: A message describing what to do to fix the error + @param variables: A dictionary which contains the variable names and values which are involved in the error. + @return: a formatted validation message string. + """ + message = "" + for key, value in variables.items(): + message += "{0}: {1}\n".format(key, value) + message += instruction + return {error_message: message} + + +def get_state_hash_for_can_reduction(state, partial_type=None): + """ + Creates a hash for a (modified) state object. + + Note that we need to modify the state object to exclude elements which are not relevant for the can reduction. + This is primarily the setting of the sample workspaces. This is the only place where we directly alter the value + of a state object + @param state: a SANSState object. + @param partial_type: if it is a partial type, then it needs to be specified here. + @return: the hash of the state + """ + def remove_sample_related_information(full_state): + state_to_hash = deepcopy(full_state) + state_to_hash.data.sample_scatter = EMPTY_NAME + state_to_hash.data.sample_scatter_period = ALL_PERIODS + state_to_hash.data.sample_transmission = EMPTY_NAME + state_to_hash.data.sample_transmission_period = ALL_PERIODS + state_to_hash.data.sample_direct = EMPTY_NAME + state_to_hash.data.sample_direct_period = ALL_PERIODS + state_to_hash.data.sample_scatter_run_number = 1 + return state_to_hash + new_state = remove_sample_related_information(state) + new_state_serialized = new_state.property_manager + + # If we are dealing with a partial output workspace, then mark it as such + if partial_type is OutputParts.Count: + state_string = str(new_state_serialized) + "counts" + elif partial_type is OutputParts.Norm: + state_string = str(new_state_serialized) + "norm" + else: + state_string = str(new_state_serialized) + return str(get_hash_value(state_string)) + + +def get_workspace_from_ads_based_on_hash(hash_value): + for workspace in get_ads_workspace_references(): + if has_hash(REDUCED_CAN_TAG, hash_value, workspace): + return workspace + + +def get_reduced_can_workspace_from_ads(state, output_parts): + """ + Get the reduced can workspace from the ADS if it exists else nothing + + @param state: a SANSState object. + @param output_parts: if true then search also for the partial workspaces + @return: a reduced can object or None. + """ + # Get the standard reduced can workspace + hashed_state = get_state_hash_for_can_reduction(state) + reduced_can = get_workspace_from_ads_based_on_hash(hashed_state) + reduced_can_count = None + reduced_can_norm = None + if output_parts: + hashed_state_count = get_state_hash_for_can_reduction(state, OutputParts.Count) + reduced_can_count = get_workspace_from_ads_based_on_hash(hashed_state_count) + hashed_state_norm = get_state_hash_for_can_reduction(state, OutputParts.Norm) + reduced_can_norm = get_workspace_from_ads_based_on_hash(hashed_state_norm) + return reduced_can, reduced_can_count, reduced_can_norm + + +def write_hash_into_reduced_can_workspace(state, workspace, partial_type=None): + """ + Writes the state hash into a reduced can workspace. + + @param state: a SANSState object. + @param workspace: a reduced can workspace + @param partial_type: if it is a partial type, then it needs to be specified here. + """ + hashed_state = get_state_hash_for_can_reduction(state, partial_type=partial_type) + set_hash(REDUCED_CAN_TAG, hashed_state, workspace) + + +def set_detector_names(state, ipf_path): + """ + Sets the detectors names on a State object which has a `detector` map entry, e.g. StateMask + + @param state: the state object + @param ipf_path: the path to the Instrument Parameter File + """ + lab_keyword = DetectorType.to_string(DetectorType.LAB) + hab_keyword = DetectorType.to_string(DetectorType.HAB) + detector_names = {lab_keyword: "low-angle-detector-name", + hab_keyword: "high-angle-detector-name"} + detector_names_short = {lab_keyword: "low-angle-detector-short-name", + hab_keyword: "high-angle-detector-short-name"} + + names_to_search = [] + names_to_search.extend(detector_names.values()) + names_to_search.extend(detector_names_short.values()) + + found_detector_names = get_named_elements_from_ipf_file(ipf_path, names_to_search, str) + + for detector_type in state.detectors: + try: + detector_name_tag = detector_names[detector_type] + detector_name_short_tag = detector_names_short[detector_type] + detector_name = found_detector_names[detector_name_tag] + detector_name_short = found_detector_names[detector_name_short_tag] + except KeyError: + continue + + state.detectors[detector_type].detector_name = detector_name + state.detectors[detector_type].detector_name_short = detector_name_short + + +def set_monitor_names(state, idf_path): + monitor_names = get_monitor_names_from_idf_file(idf_path) + state.monitor_names = monitor_names diff --git a/scripts/SANS/sans/state/wavelength.py b/scripts/SANS/sans/state/wavelength.py new file mode 100644 index 0000000000000000000000000000000000000000..963e65262819daec6daabe6b02d59cb5dfbcd55e --- /dev/null +++ b/scripts/SANS/sans/state/wavelength.py @@ -0,0 +1,70 @@ +""" Defines the state of the event slices which should be reduced.""" + +import json +import copy + +from sans.state.state_base import (StateBase, PositiveFloatParameter, ClassTypeParameter, rename_descriptor_names) +from sans.common.enums import (RebinType, RangeStepType, SANSInstrument) +from sans.state.state_functions import (is_not_none_and_first_larger_than_second, one_is_none, validation_message) +from sans.state.automatic_setters import (automatic_setters) + + +# ---------------------------------------------------------------------------------------------------------------------- +# State +# ---------------------------------------------------------------------------------------------------------------------- +@rename_descriptor_names +class StateWavelength(StateBase): + rebin_type = ClassTypeParameter(RebinType) + wavelength_low = PositiveFloatParameter() + wavelength_high = PositiveFloatParameter() + wavelength_step = PositiveFloatParameter() + wavelength_step_type = ClassTypeParameter(RangeStepType) + + def __init__(self): + super(StateWavelength, self).__init__() + self.rebin_type = RebinType.Rebin + + def validate(self): + is_invalid = dict() + if one_is_none([self.wavelength_low, self.wavelength_high, self.wavelength_step]): + entry = validation_message("A wavelength entry has not been set.", + "Make sure that all entries for the wavelength are set.", + {"wavelength_low": self.wavelength_low, + "wavelength_high": self.wavelength_high, + "wavelength_step": self.wavelength_step}) + is_invalid.update(entry) + + if is_not_none_and_first_larger_than_second([self.wavelength_low, self.wavelength_high]): + entry = validation_message("Incorrect wavelength bounds.", + "Make sure that lower wavelength bound is smaller then upper bound.", + {"wavelength_low": self.wavelength_low, + "wavelength_high": self.wavelength_high}) + is_invalid.update(entry) + + if is_invalid: + raise ValueError("StateWavelength: The provided inputs are illegal. " + "Please see: {0}".format(json.dumps(is_invalid, indent=4))) + + +# ---------------------------------------------------------------------------------------------------------------------- +# Builder +# ---------------------------------------------------------------------------------------------------------------------- +class StateWavelengthBuilder(object): + @automatic_setters(StateWavelength) + def __init__(self): + super(StateWavelengthBuilder, self).__init__() + self.state = StateWavelength() + + def build(self): + # Make sure that the product is in a valid state, ie not incomplete + self.state.validate() + return copy.copy(self.state) + + +def get_wavelength_builder(data_info): + instrument = data_info.instrument + if instrument is SANSInstrument.LARMOR or instrument is SANSInstrument.LOQ or instrument is SANSInstrument.SANS2D: + return StateWavelengthBuilder() + else: + raise NotImplementedError("StateWavelengthBuilder: Could not find any valid wavelength builder for the " + "specified StateData object {0}".format(str(data_info))) diff --git a/scripts/SANS/sans/state/wavelength_and_pixel_adjustment.py b/scripts/SANS/sans/state/wavelength_and_pixel_adjustment.py new file mode 100644 index 0000000000000000000000000000000000000000..2a363bc43d9801609a99d19053db1c13dd20d279 --- /dev/null +++ b/scripts/SANS/sans/state/wavelength_and_pixel_adjustment.py @@ -0,0 +1,99 @@ +# pylint: disable=too-few-public-methods + +"""State describing the creation of pixel and wavelength adjustment workspaces for SANS reduction.""" + +import json +import copy +from sans.state.state_base import (StateBase, rename_descriptor_names, StringParameter, + ClassTypeParameter, PositiveFloatParameter, DictParameter) +from sans.state.state_functions import (is_not_none_and_first_larger_than_second, one_is_none, validation_message) +from sans.common.enums import (RangeStepType, DetectorType, SANSInstrument) +from sans.state.automatic_setters import (automatic_setters) + + +# ---------------------------------------------------------------------------------------------------------------------- +# State +# ---------------------------------------------------------------------------------------------------------------------- +@rename_descriptor_names +class StateAdjustmentFiles(StateBase): + pixel_adjustment_file = StringParameter() + wavelength_adjustment_file = StringParameter() + + def __init__(self): + super(StateAdjustmentFiles, self).__init__() + + def validate(self): + is_invalid = {} + # TODO if a file was specified then make sure that its existence is checked. + + if is_invalid: + raise ValueError("StateAdjustmentFiles: The provided inputs are illegal. " + "Please see: {0}".format(json.dumps(is_invalid))) + + +@rename_descriptor_names +class StateWavelengthAndPixelAdjustment(StateBase): + wavelength_low = PositiveFloatParameter() + wavelength_high = PositiveFloatParameter() + wavelength_step = PositiveFloatParameter() + wavelength_step_type = ClassTypeParameter(RangeStepType) + + adjustment_files = DictParameter() + + def __init__(self): + super(StateWavelengthAndPixelAdjustment, self).__init__() + self.adjustment_files = {DetectorType.to_string(DetectorType.LAB): StateAdjustmentFiles(), + DetectorType.to_string(DetectorType.HAB): StateAdjustmentFiles()} + + def validate(self): + is_invalid = {} + + if one_is_none([self.wavelength_low, self.wavelength_high, self.wavelength_step, self.wavelength_step_type]): + entry = validation_message("A wavelength entry has not been set.", + "Make sure that all entries are set.", + {"wavelength_low": self.wavelength_low, + "wavelength_high": self.wavelength_high, + "wavelength_step": self.wavelength_step, + "wavelength_step_type": self.wavelength_step_type}) + is_invalid.update(entry) + + if is_not_none_and_first_larger_than_second([self.wavelength_low, self.wavelength_high]): + entry = validation_message("Incorrect wavelength bounds.", + "Make sure that lower wavelength bound is smaller then upper bound.", + {"wavelength_low": self.wavelength_low, + "wavelength_high": self.wavelength_high}) + is_invalid.update(entry) + + try: + self.adjustment_files[DetectorType.to_string(DetectorType.LAB)].validate() + self.adjustment_files[DetectorType.to_string(DetectorType.HAB)].validate() + except ValueError as e: + is_invalid.update({"adjustment_files": str(e)}) + + if is_invalid: + raise ValueError("StateWavelengthAndPixelAdjustment: The provided inputs are illegal. " + "Please see: {0}".format(json.dumps(is_invalid))) + + +# ---------------------------------------------------------------------------------------------------------------------- +# Builder +# ---------------------------------------------------------------------------------------------------------------------- +class StateWavelengthAndPixelAdjustmentBuilder(object): + @automatic_setters(StateWavelengthAndPixelAdjustment) + def __init__(self): + super(StateWavelengthAndPixelAdjustmentBuilder, self).__init__() + self.state = StateWavelengthAndPixelAdjustment() + + def build(self): + self.state.validate() + return copy.copy(self.state) + + +def get_wavelength_and_pixel_adjustment_builder(data_info): + instrument = data_info.instrument + if instrument is SANSInstrument.LARMOR or instrument is SANSInstrument.SANS2D or instrument is SANSInstrument.LOQ: + return StateWavelengthAndPixelAdjustmentBuilder() + else: + raise NotImplementedError("StateWavelengthAndPixelAdjustmentBuilder: Could not find any valid " + "wavelength and pixel adjustment builder for the specified " + "StateData object {0}".format(str(data_info))) diff --git a/scripts/test/SANS/CMakeLists.txt b/scripts/test/SANS/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..54e6432619097844410a4a521a8e4c60d535ebf3 --- /dev/null +++ b/scripts/test/SANS/CMakeLists.txt @@ -0,0 +1,4 @@ +add_subdirectory(common) +add_subdirectory(state) + + diff --git a/scripts/test/SANS/common/CMakeLists.txt b/scripts/test/SANS/common/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..a701f70af35ec385310db87ef17a4aabfcf58896 --- /dev/null +++ b/scripts/test/SANS/common/CMakeLists.txt @@ -0,0 +1,17 @@ +## +## Tests for SANS +## + +set ( TEST_PY_FILES + enums_test.py + file_information_test.py + log_tagger_test.py + general_functions_test.py + xml_parsing_test.py +) + +check_tests_valid ( ${CMAKE_CURRENT_SOURCE_DIR} ${TEST_PY_FILES} ) + +# Prefix for test name=Python +pyunittest_add_test ( ${CMAKE_CURRENT_SOURCE_DIR} PythonSANS ${TEST_PY_FILES} ) + diff --git a/scripts/test/SANS/common/enums_test.py b/scripts/test/SANS/common/enums_test.py new file mode 100644 index 0000000000000000000000000000000000000000..0db5728abc48344a3aa7be702ea305cb64c1f90f --- /dev/null +++ b/scripts/test/SANS/common/enums_test.py @@ -0,0 +1,33 @@ +import unittest +import mantid + +from sans.common.enums import serializable_enum, string_convertible + + +# ----Create a test class +@string_convertible +@serializable_enum("TypeA", "TypeB", "TypeC") +class DummyClass(object): + pass + + +class SANSFileInformationTest(unittest.TestCase): + def test_that_can_create_enum_value_and_is_sub_class_of_base_type(self): + type_a = DummyClass.TypeA + self.assertTrue(issubclass(type_a, DummyClass)) + + def test_that_can_convert_to_string(self): + type_b = DummyClass.TypeB + self.assertTrue(DummyClass.to_string(type_b) == "TypeB") + + def test_that_raises_run_time_error_if_enum_value_is_not_known(self): + self.assertRaises(RuntimeError, DummyClass.to_string, DummyClass) + + def test_that_can_convert_from_string(self): + self.assertTrue(DummyClass.from_string("TypeC") is DummyClass.TypeC) + + def test_that_raises_run_time_error_if_string_is_not_known(self): + self.assertRaises(RuntimeError, DummyClass.from_string, "TypeD") + +if __name__ == '__main__': + unittest.main() diff --git a/scripts/test/SANS/common/file_information_test.py b/scripts/test/SANS/common/file_information_test.py new file mode 100644 index 0000000000000000000000000000000000000000..c4bc6d20115d76ee7eaec249b0bf5f8708c7e327 --- /dev/null +++ b/scripts/test/SANS/common/file_information_test.py @@ -0,0 +1,58 @@ +import unittest +import mantid + +from sans.common.file_information import (SANSFileInformationFactory, SANSFileInformation, FileType, + SANSInstrument, get_instrument_paths_for_sans_file) +from mantid.kernel import DateAndTime + + +class SANSFileInformationTest(unittest.TestCase): + def test_that_can_extract_information_from_file_for_SANS2D_single_period_and_ISISNexus(self): + # Arrange + # The file is a single period + file_name = "SANS2D00022024" + factory = SANSFileInformationFactory() + + # Act + file_information = factory.create_sans_file_information(file_name) + + # Assert + self.assertTrue(file_information.get_number_of_periods() == 1) + self.assertTrue(file_information.get_date() == DateAndTime("2013-10-25T14:21:19")) + self.assertTrue(file_information.get_instrument() == SANSInstrument.SANS2D) + self.assertTrue(file_information.get_type() == FileType.ISISNexus) + self.assertTrue(file_information.get_run_number() == 22024) + self.assertFalse(file_information.is_event_mode()) + + def test_that_can_extract_information_from_file_for_LOQ_single_period_and_raw_format(self): + # Arrange + # The file is a single period + file_name = "LOQ48094" + factory = SANSFileInformationFactory() + + # Act + file_information = factory.create_sans_file_information(file_name) + + # Assert + self.assertTrue(file_information.get_number_of_periods() == 1) + self.assertTrue(file_information.get_date() == DateAndTime("2008-12-18T11:20:58")) + self.assertTrue(file_information.get_instrument() == SANSInstrument.LOQ) + self.assertTrue(file_information.get_type() == FileType.ISISRaw) + self.assertTrue(file_information.get_run_number() == 48094) + + +class SANSFileInformationGeneralFunctionsTest(unittest.TestCase): + def test_that_finds_idf_and_ipf_paths(self): + # Arrange + file_name = "SANS2D00022024" + # Act + idf_path, ipf_path = get_instrument_paths_for_sans_file(file_name) + # Assert + self.assertTrue(idf_path is not None) + self.assertTrue(ipf_path is not None) + self.assertTrue("Definition" in idf_path) + self.assertTrue("Parameters" in ipf_path) + + +if __name__ == '__main__': + unittest.main() diff --git a/scripts/test/SANS/common/general_functions_test.py b/scripts/test/SANS/common/general_functions_test.py new file mode 100644 index 0000000000000000000000000000000000000000..bce83417d581cc553e727e60771b77a43b4d29de --- /dev/null +++ b/scripts/test/SANS/common/general_functions_test.py @@ -0,0 +1,98 @@ +import unittest +import mantid + +from mantid.kernel import (V3D, Quat) +from sans.common.general_functions import (quaternion_to_angle_and_axis, create_unmanaged_algorithm, add_to_sample_log) + + +class SANSFunctionsTest(unittest.TestCase): + @staticmethod + def _create_sample_workspace(): + sample_name = "CreateSampleWorkspace" + sample_options = {"OutputWorkspace": "dummy"} + sample_alg = create_unmanaged_algorithm(sample_name, **sample_options) + sample_alg.execute() + return sample_alg.getProperty("OutputWorkspace").value + + def _do_test_quaternion(self, angle, axis, expected_axis=None): + # Act + quaternion = Quat(angle, axis) + converted_angle, converted_axis = quaternion_to_angle_and_axis(quaternion) + + # Assert + if expected_axis is not None: + axis = expected_axis + self.assertAlmostEqual(angle, converted_angle) + self.assertAlmostEqual(axis[0], converted_axis[0]) + self.assertAlmostEqual(axis[1], converted_axis[1]) + self.assertAlmostEqual(axis[2], converted_axis[2]) + + def test_that_quaternion_can_be_converted_to_axis_and_angle_for_regular(self): + # Arrange + angle = 23.0 + axis = V3D(0.0, 1.0, 0.0) + self._do_test_quaternion(angle, axis) + + def test_that_quaternion_can_be_converted_to_axis_and_angle_for_0_degree(self): + # Arrange + angle = 0.0 + axis = V3D(1.0, 0.0, 0.0) + # There shouldn't be an axis for angle 0 + expected_axis = V3D(0.0, 0.0, 0.0) + self._do_test_quaternion(angle, axis, expected_axis) + + def test_that_quaternion_can_be_converted_to_axis_and_angle_for_180_degree(self): + # Arrange + angle = 180.0 + axis = V3D(0.0, 1.0, 0.0) + # There shouldn't be an axis for angle 0 + self._do_test_quaternion(angle, axis) + + def test_that_sample_log_is_added(self): + # Arrange + workspace = SANSFunctionsTest._create_sample_workspace() + log_name = "TestName" + log_value = "TestValue" + log_type = "String" + + # Act + add_to_sample_log(workspace, log_name, log_value, log_type) + + # Assert + run = workspace.run() + self.assertTrue(run.hasProperty(log_name)) + self.assertTrue(run.getProperty(log_name).value == log_value) + + def test_that_sample_log_raises_for_non_string_type_arguments(self): + # Arrange + workspace = SANSFunctionsTest._create_sample_workspace() + log_name = "TestName" + log_value = 123 + log_type = "String" + + # Act + Assert + try: + add_to_sample_log(workspace, log_name, log_value, log_type) + did_raise = False + except TypeError: + did_raise = True + self.assertTrue(did_raise) + + def test_that_sample_log_raises_for_wrong_type_selection(self): + # Arrange + workspace = SANSFunctionsTest._create_sample_workspace() + log_name = "TestName" + log_value = "test" + log_type = "sdfsdfsdf" + + # Act + Assert + try: + add_to_sample_log(workspace, log_name, log_value, log_type) + did_raise = False + except ValueError: + did_raise = True + self.assertTrue(did_raise) + + +if __name__ == '__main__': + unittest.main() diff --git a/scripts/test/SANS/common/log_tagger_test.py b/scripts/test/SANS/common/log_tagger_test.py new file mode 100644 index 0000000000000000000000000000000000000000..ab6610a4f5ca163c6910329eccb8ae2f8618ed2e --- /dev/null +++ b/scripts/test/SANS/common/log_tagger_test.py @@ -0,0 +1,44 @@ +import unittest +import mantid +from mantid.api import AlgorithmManager +from sans.common.log_tagger import (has_tag, set_tag, get_tag, has_hash, set_hash) + + +class SANSLogTaggerTest(unittest.TestCase): + @staticmethod + def _provide_sample_workspace(): + alg = AlgorithmManager.createUnmanaged("CreateSampleWorkspace") + alg.setChild(True) + alg.initialize() + alg.setProperty("OutputWorkspace", "dummy") + alg.execute() + return alg.getProperty("OutputWorkspace").value + + def test_that_can_read_and_write_tag_in_sample_logs(self): + # Arrange + ws1 = SANSLogTaggerTest._provide_sample_workspace() + tag1 = "test" + value1 = 123 + + # Act + Assert + self.assertFalse(has_tag(tag1, ws1)) + set_tag(tag1, value1, ws1) + self.assertTrue(has_tag(tag1, ws1)) + self.assertTrue(get_tag(tag1, ws1) == value1) + + def test_that_can_read_and_write_hash_in_sample_log(self): + # Arrange + ws1 = self._provide_sample_workspace() + tag1 = "test" + value1 = "tested" + value2 = "tested2" + + # Act + Assert + self.assertFalse(has_hash(tag1, value1, ws1)) + set_hash(tag1, value1, ws1) + self.assertTrue(has_hash(tag1, value1, ws1)) + self.assertFalse(has_hash(tag1, value2, ws1)) + + +if __name__ == '__main__': + unittest.main() diff --git a/scripts/test/SANS/common/xml_parsing_test.py b/scripts/test/SANS/common/xml_parsing_test.py new file mode 100644 index 0000000000000000000000000000000000000000..d71aaedb314dd4c3bdcd5004da0b9f6d30f391d7 --- /dev/null +++ b/scripts/test/SANS/common/xml_parsing_test.py @@ -0,0 +1,64 @@ +import unittest +import mantid + +from sans.common.file_information import (SANSFileInformationFactory, get_instrument_paths_for_sans_file) +from sans.common.xml_parsing import (get_named_elements_from_ipf_file, get_monitor_names_from_idf_file) + + +class XMLParsingTest(unittest.TestCase): + def test_that_named_entries_in_instrument_parameter_file_can_be_retrieved(self): + # Arrange + test_file = "LARMOR00003368" + file_information_factory = SANSFileInformationFactory() + file_information = file_information_factory.create_sans_file_information(test_file) + full_file_path = file_information.get_file_name() + + _, ipf = get_instrument_paths_for_sans_file(full_file_path) + to_search = ["low-angle-detector-name", "high-angle-detector-short-name"] + + # Act + results = get_named_elements_from_ipf_file(ipf, to_search, str) + + # Assert + self.assertTrue(len(results) == 2) + + self.assertTrue(results["low-angle-detector-name"] == "DetectorBench") + self.assertTrue(results["high-angle-detector-short-name"] == "front") + + def test_that_monitors_can_be_found(self): + # Arrange + test_file = "LARMOR00003368" + file_information_factory = SANSFileInformationFactory() + file_information = file_information_factory.create_sans_file_information(test_file) + full_file_path = file_information.get_file_name() + + idf, _ = get_instrument_paths_for_sans_file(full_file_path) + + # Act + results = get_monitor_names_from_idf_file(idf) + + # Assert + self.assertTrue(len(results) == 10) + for key, value in results.items(): + self.assertTrue(value == ("monitor"+str(key))) + + def test_that_monitors_can_be_found_v2(self): + # Arrange + test_file = "LOQ74044" + file_information_factory = SANSFileInformationFactory() + file_information = file_information_factory.create_sans_file_information(test_file) + full_file_path = file_information.get_file_name() + + idf, _ = get_instrument_paths_for_sans_file(full_file_path) + + # Act + results = get_monitor_names_from_idf_file(idf) + + # Assert + self.assertTrue(len(results) == 2) + for key, value in results.items(): + self.assertTrue(value == ("monitor"+str(key))) + + +if __name__ == '__main__': + unittest.main() diff --git a/scripts/test/SANS/state/CMakeLists.txt b/scripts/test/SANS/state/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..d16f314ce579ef44ceafc1807e494d732c08b7bb --- /dev/null +++ b/scripts/test/SANS/state/CMakeLists.txt @@ -0,0 +1,26 @@ +# +## Tests for SANSState +## + +set ( TEST_PY_FILES + adjustment_test.py + calculate_transmission_test.py + convert_to_q_test.py + data_test.py + mask_test.py + move_test.py + normalize_to_monitor_test.py + reduction_mode_test.py + save_test.py + scale_test.py + slice_event_test.py + state_base_test.py + state_functions_test.py + state_test.py + wavelength_and_pixel_adjustment_test.py + wavelength_test.py +) +check_tests_valid ( ${CMAKE_CURRENT_SOURCE_DIR} ${TEST_PY_FILES} ) + +# Prefix for test name=Python +pyunittest_add_test ( ${CMAKE_CURRENT_SOURCE_DIR} PythonSANS ${TEST_PY_FILES} ) diff --git a/scripts/test/SANS/state/adjustment_test.py b/scripts/test/SANS/state/adjustment_test.py new file mode 100644 index 0000000000000000000000000000000000000000..eaa0c7de8c0f8b5e6361db0c9b9f66697f9ec247 --- /dev/null +++ b/scripts/test/SANS/state/adjustment_test.py @@ -0,0 +1,98 @@ +import unittest +import mantid + +from sans.state.adjustment import (StateAdjustment, get_adjustment_builder) +from sans.state.data import (get_data_builder) +from sans.state.calculate_transmission import StateCalculateTransmission +from sans.state.normalize_to_monitor import StateNormalizeToMonitor +from sans.state.wavelength_and_pixel_adjustment import StateWavelengthAndPixelAdjustment +from sans.common.enums import (SANSFacility, SANSInstrument, FitType) + + +# ---------------------------------------------------------------------------------------------------------------------- +# State +# ---------------------------------------------------------------------------------------------------------------------- +class MockStateNormalizeToMonitor(StateNormalizeToMonitor): + def validate(self): + pass + + +class MockStateCalculateTransmission(StateCalculateTransmission): + def validate(self): + pass + + +class MockStateWavelengthAndPixelAdjustment(StateWavelengthAndPixelAdjustment): + def validate(self): + pass + + +class StateReductionTest(unittest.TestCase): + def test_that_raises_when_calculate_transmission_is_not_set(self): + state = StateAdjustment() + state.normalize_to_monitor = MockStateNormalizeToMonitor() + state.wavelength_and_pixel_adjustment = MockStateWavelengthAndPixelAdjustment() + self.assertRaises(ValueError, state.validate) + state.calculate_transmission = MockStateCalculateTransmission() + try: + state.validate() + except ValueError: + self.fail() + + def test_that_raises_when_normalize_to_monitor_is_not_set(self): + state = StateAdjustment() + state.calculate_transmission = MockStateCalculateTransmission() + state.wavelength_and_pixel_adjustment = MockStateWavelengthAndPixelAdjustment() + self.assertRaises(ValueError, state.validate) + state.normalize_to_monitor = MockStateNormalizeToMonitor() + try: + state.validate() + except ValueError: + self.fail() + + def test_that_raises_when_wavelength_and_pixel_adjustment_is_not_set(self): + state = StateAdjustment() + state.calculate_transmission = MockStateCalculateTransmission() + state.normalize_to_monitor = MockStateNormalizeToMonitor() + self.assertRaises(ValueError, state.validate) + state.wavelength_and_pixel_adjustment = MockStateWavelengthAndPixelAdjustment() + try: + state.validate() + except ValueError: + self.fail() + + +# ---------------------------------------------------------------------------------------------------------------------- +# Builder +# ---------------------------------------------------------------------------------------------------------------------- +class StateAdjustmentBuilderTest(unittest.TestCase): + def test_that_reduction_state_can_be_built(self): + # Arrange + facility = SANSFacility.ISIS + data_builder = get_data_builder(facility) + data_builder.set_sample_scatter("LOQ74044") + data_info = data_builder.build() + + # Act + builder = get_adjustment_builder(data_info) + self.assertTrue(builder) + + builder.set_calculate_transmission(MockStateCalculateTransmission()) + builder.set_normalize_to_monitor(MockStateNormalizeToMonitor()) + builder.set_wavelength_and_pixel_adjustment(MockStateWavelengthAndPixelAdjustment()) + builder.set_wide_angle_correction(False) + state = builder.build() + + # # Assert + self.assertTrue(not state.wide_angle_correction) + + try: + state.validate() + is_valid = True + except ValueError: + is_valid = False + self.assertTrue(is_valid) + + +if __name__ == '__main__': + unittest.main() diff --git a/scripts/test/SANS/state/calculate_transmission_test.py b/scripts/test/SANS/state/calculate_transmission_test.py new file mode 100644 index 0000000000000000000000000000000000000000..05bfc162aea4833a1e159ca12d7e97fd341eb45c --- /dev/null +++ b/scripts/test/SANS/state/calculate_transmission_test.py @@ -0,0 +1,283 @@ +import unittest +import mantid + +from sans.state.calculate_transmission import (StateCalculateTransmission, StateCalculateTransmissionLOQ, + get_calculate_transmission_builder) +from sans.state.data import get_data_builder +from sans.common.enums import (RebinType, RangeStepType, FitType, DataType, SANSFacility) +from state_test_helper import assert_validate_error, assert_raises_nothing + + +# ---------------------------------------------------------------------------------------------------------------------- +# State +# ---------------------------------------------------------------------------------------------------------------------- +class StateCalculateTransmissionTest(unittest.TestCase): + @staticmethod + def _set_fit(state, default_settings, custom_settings, fit_key): + fit = state.fit[fit_key] + for key, value in default_settings.items(): + if key in custom_settings: + value = custom_settings[key] + if value is not None: # If the value is None, then don't set it + setattr(fit, key, value) + state.fit[fit_key] = fit + + @staticmethod + def _get_calculate_transmission_state(trans_entries, fit_entries): + state = StateCalculateTransmission() + if trans_entries is None: + trans_entries = {} + trans_settings = {"transmission_radius_on_detector": 12., "transmission_roi_files": ["test.xml"], + "transmission_mask_files": ["test.xml"], "default_transmission_monitor": 3, + "transmission_monitor": 4, "default_incident_monitor": 1, "incident_monitor": 2, + "prompt_peak_correction_min": 123., "prompt_peak_correction_max": 1234., + "rebin_type": RebinType.Rebin, "wavelength_low": 1., "wavelength_high": 2.7, + "wavelength_step": 0.5, "wavelength_step_type": RangeStepType.Lin, + "use_full_wavelength_range": True, "wavelength_full_range_low": 12., + "wavelength_full_range_high": 434., "background_TOF_general_start": 1.4, + "background_TOF_general_stop": 24.5, "background_TOF_monitor_start": {"1": 123, "2": 123}, + "background_TOF_monitor_stop": {"1": 234, "2": 2323}, "background_TOF_roi_start": 12., + "background_TOF_roi_stop": 123.} + + for key, value in trans_settings.items(): + if key in trans_entries: + value = trans_entries[key] + if value is not None: # If the value is None, then don't set it + setattr(state, key, value) + + fit_settings = {"fit_type": FitType.Polynomial, "polynomial_order": 1, "wavelength_low": 12., + "wavelength_high": 232.} + if fit_entries is None: + fit_entries = {} + StateCalculateTransmissionTest._set_fit(state, fit_settings, fit_entries, + DataType.to_string(DataType.Sample)) + StateCalculateTransmissionTest._set_fit(state, fit_settings, fit_entries, + DataType.to_string(DataType.Can)) + return state + + @staticmethod + def _get_dict(entry_name, value): + output = {} + if value is not None: + output.update({entry_name: value}) + return output + + def check_bad_and_good_values(self, bad_trans=None, bad_fit=None, good_trans=None, good_fit=None): + # Bad values + state = self._get_calculate_transmission_state(bad_trans, bad_fit) + assert_validate_error(self, ValueError, state) + + # Good values + state = self._get_calculate_transmission_state(good_trans, good_fit) + assert_raises_nothing(self, state) + + def test_that_is_sans_state_data_object(self): + state = StateCalculateTransmissionLOQ() + self.assertTrue(isinstance(state, StateCalculateTransmission)) + + def test_that_raises_when_no_incident_monitor_is_available(self): + self.check_bad_and_good_values(bad_trans={"incident_monitor": None, "default_incident_monitor": None}, + good_trans={"incident_monitor": 1, "default_incident_monitor": None}) + self.check_bad_and_good_values(bad_trans={"incident_monitor": None, "default_incident_monitor": None}, + good_trans={"incident_monitor": None, "default_incident_monitor": 1}) + + def test_that_raises_when_no_transmission_is_specified(self): + self.check_bad_and_good_values(bad_trans={"transmission_monitor": None, "default_transmission_monitor": None, + "transmission_radius_on_detector": None, + "transmission_roi_files": None}, + good_trans={"transmission_monitor": 4, "default_transmission_monitor": None, + "transmission_radius_on_detector": None, + "transmission_roi_files": None}) + + def test_that_raises_for_inconsistent_prompt_peak(self): + self.check_bad_and_good_values(bad_trans={"prompt_peak_correction_min": 1., "prompt_peak_correction_max": None}, + good_trans={"prompt_peak_correction_min": None, + "prompt_peak_correction_max": None}) + self.check_bad_and_good_values(bad_trans={"prompt_peak_correction_min": 1., + "prompt_peak_correction_max": None}, + good_trans={"prompt_peak_correction_min": 1., "prompt_peak_correction_max": 2.}) + + def test_that_raises_for_lower_bound_larger_than_upper_bound_for_prompt_peak(self): + self.check_bad_and_good_values(bad_trans={"prompt_peak_correction_min": 2., "prompt_peak_correction_max": 1.}, + good_trans={"prompt_peak_correction_min": 1., "prompt_peak_correction_max": 2.}) + + def test_that_raises_when_not_all_elements_are_set_for_wavelength(self): + self.check_bad_and_good_values(bad_trans={"wavelength_low": 1., "wavelength_high": 2., + "wavelength_step": 0.5, "wavelength_step_type": None}, + good_trans={"wavelength_low": 1., "wavelength_high": 2., + "wavelength_step": 0.5, "wavelength_step_type": RangeStepType.Lin}) + + def test_that_raises_for_lower_bound_larger_than_upper_bound_for_wavelength(self): + self.check_bad_and_good_values(bad_trans={"wavelength_low": 2., "wavelength_high": 1., + "wavelength_step": 0.5, "wavelength_step_type": RangeStepType.Lin}, + good_trans={"wavelength_low": 1., "wavelength_high": 2., + "wavelength_step": 0.5, "wavelength_step_type": RangeStepType.Lin}) + + def test_that_raises_for_missing_full_wavelength_entry(self): + self.check_bad_and_good_values(bad_trans={"use_full_wavelength_range": True, "wavelength_full_range_low": None, + "wavelength_full_range_high": 12.}, + good_trans={"use_full_wavelength_range": True, "wavelength_full_range_low": 11., + "wavelength_full_range_high": 12.}) + + def test_that_raises_for_lower_bound_larger_than_upper_bound_for_full_wavelength(self): + self.check_bad_and_good_values(bad_trans={"use_full_wavelength_range": True, "wavelength_full_range_low": 2., + "wavelength_full_range_high": 1.}, + good_trans={"use_full_wavelength_range": True, "wavelength_full_range_low": 1., + "wavelength_full_range_high": 2.}) + + def test_that_raises_for_inconsistent_general_background(self): + self.check_bad_and_good_values(bad_trans={"background_TOF_general_start": 1., + "background_TOF_general_stop": None}, + good_trans={"background_TOF_general_start": None, + "background_TOF_general_stop": None}) + + def test_that_raises_for_lower_bound_larger_than_upper_bound_for_general_background(self): + self.check_bad_and_good_values(bad_trans={"background_TOF_general_start": 2., + "background_TOF_general_stop": 1.}, + good_trans={"background_TOF_general_start": 1., + "background_TOF_general_stop": 2.}) + + def test_that_raises_for_inconsistent_roi_background(self): + self.check_bad_and_good_values(bad_trans={"background_TOF_roi_start": 1., + "background_TOF_roi_stop": None}, + good_trans={"background_TOF_roi_start": None, + "background_TOF_roi_stop": None}) + + def test_that_raises_for_lower_bound_larger_than_upper_bound_for_roi_background(self): + self.check_bad_and_good_values(bad_trans={"background_TOF_roi_start": 2., + "background_TOF_roi_stop": 1.}, + good_trans={"background_TOF_roi_start": 1., + "background_TOF_roi_stop": 2.}) + + def test_that_raises_for_inconsistent_monitor_background(self): + self.check_bad_and_good_values(bad_trans={"background_TOF_monitor_start": {"1": 12., "2": 1.}, + "background_TOF_monitor_stop": None}, + good_trans={"background_TOF_monitor_start": None, + "background_TOF_monitor_stop": None}) + + def test_that_raises_when_lengths_of_monitor_backgrounds_are_different(self): + self.check_bad_and_good_values(bad_trans={"background_TOF_monitor_start": {"1": 1., "2": 1.}, + "background_TOF_monitor_stop": {"1": 2.}}, + good_trans={"background_TOF_monitor_start": {"1": 1., "2": 1.}, + "background_TOF_monitor_stop": {"1": 2., "2": 2.}}) + + def test_that_raises_when_monitor_name_mismatch_exists_for_monitor_backgrounds(self): + self.check_bad_and_good_values(bad_trans={"background_TOF_monitor_start": {"1": 1., "2": 1.}, + "background_TOF_monitor_stop": {"1": 2., "3": 2.}}, + good_trans={"background_TOF_monitor_start": {"1": 1., "2": 1.}, + "background_TOF_monitor_stop": {"1": 2., "2": 2.}}) + + def test_that_raises_lower_bound_larger_than_upper_bound_for_monitor_backgrounds(self): + self.check_bad_and_good_values(bad_trans={"background_TOF_monitor_start": {"1": 1., "2": 2.}, + "background_TOF_monitor_stop": {"1": 2., "2": 1.}}, + good_trans={"background_TOF_monitor_start": {"1": 1., "2": 1.}, + "background_TOF_monitor_stop": {"1": 2., "2": 2.}}) + + def test_that_polynomial_order_can_only_be_set_with_polynomial_setting(self): + self.check_bad_and_good_values(bad_fit={"fit_type": FitType.Log, "polynomial_order": 4}, + good_fit={"fit_type": FitType.Polynomial, "polynomial_order": 4}) + + def test_that_raises_for_inconsistent_wavelength_in_fit(self): + self.check_bad_and_good_values(bad_trans={"wavelength_low": None, "wavelength_high": 2.}, + good_trans={"wavelength_low": 1., "wavelength_high": 2.}) + + def test_that_raises_for_lower_bound_larger_than_upper_bound_for_wavelength_in_fit(self): + self.check_bad_and_good_values(bad_trans={"wavelength_low": 2., "wavelength_high": 1.}, + good_trans={"wavelength_low": 1., "wavelength_high": 2.}) + + +# ---------------------------------------------------------------------------------------------------------------------- +# Builder +# ---------------------------------------------------------------------------------------------------------------------- +class StateCalculateTransmissionBuilderTest(unittest.TestCase): + def test_that_reduction_state_can_be_built(self): + # Arrange + facility = SANSFacility.ISIS + data_builder = get_data_builder(facility) + data_builder.set_sample_scatter("LOQ74044") + data_info = data_builder.build() + + # Act + builder = get_calculate_transmission_builder(data_info) + self.assertTrue(builder) + + builder.set_prompt_peak_correction_min(12.0) + builder.set_prompt_peak_correction_max(17.0) + + builder.set_incident_monitor(1) + builder.set_default_incident_monitor(2) + builder.set_transmission_monitor(3) + builder.set_default_transmission_monitor(4) + builder.set_transmission_radius_on_detector(1.) + builder.set_transmission_roi_files(["sdfs", "sddfsdf"]) + builder.set_transmission_mask_files(["sdfs", "bbbbbb"]) + + builder.set_rebin_type(RebinType.Rebin) + builder.set_wavelength_low(1.5) + builder.set_wavelength_high(2.7) + builder.set_wavelength_step(0.5) + builder.set_wavelength_step_type(RangeStepType.Lin) + builder.set_use_full_wavelength_range(True) + builder.set_wavelength_full_range_low(12.) + builder.set_wavelength_full_range_high(24.) + + builder.set_background_TOF_general_start(1.4) + builder.set_background_TOF_general_stop(34.4) + builder.set_background_TOF_monitor_start({"1": 123, "2": 123}) + builder.set_background_TOF_monitor_stop({"1": 234, "2": 2323}) + builder.set_background_TOF_roi_start(1.4) + builder.set_background_TOF_roi_stop(34.4) + + builder.set_Sample_fit_type(FitType.Linear) + builder.set_Sample_polynomial_order(0) + builder.set_Sample_wavelength_low(10.0) + builder.set_Sample_wavelength_high(20.0) + + builder.set_Can_fit_type(FitType.Polynomial) + builder.set_Can_polynomial_order(3) + builder.set_Can_wavelength_low(10.0) + builder.set_Can_wavelength_high(20.0) + + state = builder.build() + + # Assert + self.assertTrue(state.prompt_peak_correction_min == 12.0) + self.assertTrue(state.prompt_peak_correction_max == 17.0) + + self.assertTrue(state.incident_monitor == 1) + self.assertTrue(state.default_incident_monitor == 2) + self.assertTrue(state.transmission_monitor == 3) + self.assertTrue(state.default_transmission_monitor == 4) + self.assertTrue(state.transmission_radius_on_detector == 1.) + self.assertTrue(state.transmission_roi_files == ["sdfs", "sddfsdf"]) + self.assertTrue(state.transmission_mask_files == ["sdfs", "bbbbbb"]) + + self.assertTrue(state.rebin_type is RebinType.Rebin) + self.assertTrue(state.wavelength_low == 1.5) + self.assertTrue(state.wavelength_high == 2.7) + self.assertTrue(state.wavelength_step == 0.5) + self.assertTrue(state.wavelength_step_type is RangeStepType.Lin) + self.assertTrue(state.use_full_wavelength_range is True) + self.assertTrue(state.wavelength_full_range_low == 12.) + self.assertTrue(state.wavelength_full_range_high == 24.) + + self.assertTrue(state.background_TOF_general_start == 1.4) + self.assertTrue(state.background_TOF_general_stop == 34.4) + self.assertTrue(len(set(state.background_TOF_monitor_start.items()) & set({"1": 123, "2": 123}.items())) == 2) + self.assertTrue(len(set(state.background_TOF_monitor_stop.items()) & set({"1": 234, "2": 2323}.items())) == 2) + self.assertTrue(state.background_TOF_roi_start == 1.4) + self.assertTrue(state.background_TOF_roi_stop == 34.4) + + self.assertTrue(state.fit[DataType.to_string(DataType.Sample)].fit_type is FitType.Linear) + self.assertTrue(state.fit[DataType.to_string(DataType.Sample)].polynomial_order == 0) + self.assertTrue(state.fit[DataType.to_string(DataType.Sample)].wavelength_low == 10.) + self.assertTrue(state.fit[DataType.to_string(DataType.Sample)].wavelength_high == 20.) + + self.assertTrue(state.fit[DataType.to_string(DataType.Can)].fit_type is + FitType.Polynomial) + self.assertTrue(state.fit[DataType.to_string(DataType.Can)].polynomial_order == 3) + self.assertTrue(state.fit[DataType.to_string(DataType.Can)].wavelength_low == 10.) + self.assertTrue(state.fit[DataType.to_string(DataType.Can)].wavelength_high == 20.) + +if __name__ == '__main__': + unittest.main() diff --git a/scripts/test/SANS/state/convert_to_q_test.py b/scripts/test/SANS/state/convert_to_q_test.py new file mode 100644 index 0000000000000000000000000000000000000000..67713b964472f12911abbe041c78791f0912d30c --- /dev/null +++ b/scripts/test/SANS/state/convert_to_q_test.py @@ -0,0 +1,117 @@ +import unittest +import mantid + +from sans.state.convert_to_q import (StateConvertToQ, get_convert_to_q_builder) +from sans.state.data import get_data_builder +from sans.common.enums import (RangeStepType, ReductionDimensionality, SANSFacility) +from state_test_helper import (assert_validate_error, assert_raises_nothing) + + +# ---------------------------------------------------------------------------------------------------------------------- +# State +# ---------------------------------------------------------------------------------------------------------------------- +class StateConvertToQTest(unittest.TestCase): + @staticmethod + def _get_convert_to_q_state(convert_to_q_entries): + state = StateConvertToQ() + default_entries = {"reduction_dimensionality": ReductionDimensionality.OneDim, "use_gravity": True, + "gravity_extra_length": 12., "radius_cutoff": 1.5, "wavelength_cutoff": 2.7, + "q_min": 0.5, "q_max": 1., "q_step": 1., "q_step_type": RangeStepType.Lin, + "q_step2": 1., "q_step_type2": RangeStepType.Lin, "q_mid": 1., + "q_xy_max": 1.4, "q_xy_step": 24.5, "q_xy_step_type": RangeStepType.Lin, + "use_q_resolution": True, "q_resolution_collimation_length": 12., + "q_resolution_delta_r": 12., "moderator_file": "test.txt", "q_resolution_a1": 1., + "q_resolution_a2": 2., "q_resolution_h1": 1., "q_resolution_h2": 2., "q_resolution_w1": 1., + "q_resolution_w2": 2.} + + for key, value in default_entries.items(): + if key in convert_to_q_entries: + value = convert_to_q_entries[key] + if value is not None: # If the value is None, then don't set it + setattr(state, key, value) + return state + + def check_bad_and_good_value(self, bad_convert_to_q, good_convert_to_q): + # Bad values + state = self._get_convert_to_q_state(bad_convert_to_q) + assert_validate_error(self, ValueError, state) + + # Good values + state = self._get_convert_to_q_state(good_convert_to_q) + assert_raises_nothing(self, state) + + def test_that_raises_with_inconsistent_1D_q_values(self): + self.check_bad_and_good_value({"q_min": None, "q_max": 2.}, {"q_min": 1., "q_max": 2.}) + + def test_that_raises_when_the_lower_bound_is_larger_than_the_upper_bound_for_q_1D(self): + self.check_bad_and_good_value({"q_min": 2., "q_max": 1.}, {"q_min": 1., "q_max": 2.}) + + def test_that_raises_when_no_q_bounds_are_set_for_explicit_1D_reduction(self): + self.check_bad_and_good_value({"q_min": None, "q_max": None, + "reduction_dimensionality": ReductionDimensionality.OneDim}, + {"q_min": 1., "q_max": 2., + "reduction_dimensionality": ReductionDimensionality.OneDim}) + + def test_that_raises_when_no_q_bounds_are_set_for_explicit_2D_reduction(self): + self.check_bad_and_good_value({"q_xy_max": None, "q_xy_step": None, + "reduction_dimensionality": ReductionDimensionality.TwoDim}, + {"q_xy_max": 1., "q_xy_step": 2., + "reduction_dimensionality": ReductionDimensionality.TwoDim}) + + def test_that_raises_when_inconsistent_circular_values_for_q_resolution_are_specified(self): + self.check_bad_and_good_value({"use_q_resolution": True, "q_resolution_a1": None, + "q_resolution_a2": 12.}, + {"use_q_resolution": True, "q_resolution_a1": 11., + "q_resolution_a2": 12.}) + + def test_that_raises_when_inconsistent_rectangular_values_for_q_resolution_are_specified(self): + self.check_bad_and_good_value({"use_q_resolution": True, "q_resolution_h1": None, + "q_resolution_h2": 12., "q_resolution_w1": 1., "q_resolution_w2": 2.}, + {"use_q_resolution": True, "q_resolution_h1": 1., + "q_resolution_h2": 12., "q_resolution_w1": 1., "q_resolution_w2": 2.}) + + def test_that_raises_when_no_geometry_for_q_resolution_was_specified(self): + self.check_bad_and_good_value({"use_q_resolution": True, "q_resolution_h1": None, "q_resolution_a1": None, + "q_resolution_a2": None, "q_resolution_h2": None, "q_resolution_w1": None, + "q_resolution_w2": None}, + {"use_q_resolution": True, "q_resolution_h1": 1., "q_resolution_a1": 1., + "q_resolution_a2": 2., "q_resolution_h2": 12., "q_resolution_w1": 1., + "q_resolution_w2": 2.}) + + def test_that_raises_when_moderator_file_has_not_been_set(self): + self.check_bad_and_good_value({"moderator_file": None}, {"moderator_file": "test"}) + + +# ---------------------------------------------------------------------------------------------------------------------- +# Builder +# ---------------------------------------------------------------------------------------------------------------------- +class StateConvertToQBuilderTest(unittest.TestCase): + def test_that_reduction_state_can_be_built(self): + # Arrange + facility = SANSFacility.ISIS + data_builder = get_data_builder(facility) + data_builder.set_sample_scatter("LOQ74044") + data_info = data_builder.build() + + # Act + builder = get_convert_to_q_builder(data_info) + self.assertTrue(builder) + + builder.set_q_min(12.0) + builder.set_q_max(17.0) + builder.set_q_step(1.) + builder.set_q_step_type(RangeStepType.Lin) + builder.set_reduction_dimensionality(ReductionDimensionality.OneDim) + + state = builder.build() + + # Assert + self.assertTrue(state.q_min == 12.0) + self.assertTrue(state.q_max == 17.0) + self.assertTrue(state.q_step == 1.) + self.assertTrue(state.q_step_type is RangeStepType.Lin) + self.assertTrue(state.reduction_dimensionality is ReductionDimensionality.OneDim) + + +if __name__ == '__main__': + unittest.main() diff --git a/scripts/test/SANS/state/data_test.py b/scripts/test/SANS/state/data_test.py new file mode 100644 index 0000000000000000000000000000000000000000..39ead90f26597895d96fc394cf5e1dc9c45e4e3a --- /dev/null +++ b/scripts/test/SANS/state/data_test.py @@ -0,0 +1,90 @@ +import unittest +import mantid + +from sans.state.data import (StateData, get_data_builder) +from state_test_helper import (assert_validate_error, assert_raises_nothing) +from sans.common.enums import (SANSFacility, SANSInstrument) + + +# ---------------------------------------------------------------------------------------------------------------------- +# State test +# ---------------------------------------------------------------------------------------------------------------------- +class StateDataTest(unittest.TestCase): + @staticmethod + def _get_data_state(**data_entries): + state = StateData() + data_settings = {"sample_scatter": "test", "sample_transmission": "test", + "sample_direct": "test", "can_scatter": "test", + "can_transmission": "test", "can_direct": "test"} + + for key, value in data_settings.items(): + if key in data_entries: + value = data_entries[key] + if value is not None: # If the value is None, then don't set it + setattr(state, key, value) + return state + + def assert_raises_for_bad_value_and_raises_nothing_for_good_value(self, data_entries_bad, + data_entries_good): + # Bad values + state = StateDataTest._get_data_state(**data_entries_bad) + assert_validate_error(self, ValueError, state) + + # Good values + state = StateDataTest._get_data_state(**data_entries_good) + assert_raises_nothing(self, state) + + def test_that_raises_when_sample_scatter_is_missing(self): + self.assert_raises_for_bad_value_and_raises_nothing_for_good_value({"sample_scatter": None}, + {"sample_scatter": "test"}) + + def test_that_raises_when_transmission_and_direct_are_inconsistently_specified_for_sample(self): + self.assert_raises_for_bad_value_and_raises_nothing_for_good_value({"sample_transmission": None, + "sample_direct": "test", + "can_transmission": None, + "can_direct": None}, + {"sample_transmission": "test", + "sample_direct": "test", + "can_transmission": None, + "can_direct": None}) + + def test_that_raises_when_transmission_and_direct_are_inconsistently_specified_for_can(self): + self.assert_raises_for_bad_value_and_raises_nothing_for_good_value({"can_transmission": "test", + "can_direct": None}, + {"can_transmission": "test", + "can_direct": "test"}) + + def test_that_raises_when_transmission_but_not_scatter_was_specified_for_can(self): + self.assert_raises_for_bad_value_and_raises_nothing_for_good_value({"can_scatter": None, + "can_transmission": "test", + "can_direct": "test"}, + {"can_scatter": "test", + "can_transmission": "test", + "can_direct": "test"}) + + +# ---------------------------------------------------------------------------------------------------------------------- +# Builder test +# ---------------------------------------------------------------------------------------------------------------------- +class StateDataBuilderTest(unittest.TestCase): + def test_that_data_state_can_be_built(self): + # Arrange + facility = SANSFacility.ISIS + + # Act + data_builder = get_data_builder(facility) + + data_builder.set_sample_scatter("LOQ74044") + data_builder.set_sample_scatter_period(3) + data_state = data_builder.build() + + # # Assert + self.assertTrue(data_state.sample_scatter == "LOQ74044") + self.assertTrue(data_state.sample_scatter_period == 3) + self.assertTrue(data_state.sample_direct_period == 0) + self.assertTrue(data_state.instrument is SANSInstrument.LOQ) + self.assertTrue(data_state.sample_scatter_run_number == 74044) + + +if __name__ == '__main__': + unittest.main() diff --git a/scripts/test/SANS/state/mask_test.py b/scripts/test/SANS/state/mask_test.py new file mode 100644 index 0000000000000000000000000000000000000000..40c5561893055186619ebdc277a9d8aedc2baee0 --- /dev/null +++ b/scripts/test/SANS/state/mask_test.py @@ -0,0 +1,224 @@ +import unittest +import mantid + +from sans.state.mask import (StateMask, get_mask_builder) +from sans.state.data import get_data_builder +from sans.common.enums import (SANSFacility, SANSInstrument, DetectorType) +from state_test_helper import (assert_validate_error, assert_raises_nothing) + + +# ---------------------------------------------------------------------------------------------------------------------- +# State +# ---------------------------------------------------------------------------------------------------------------------- +EXPLICIT_NONE = "explict_none" + + +class StateMaskTest(unittest.TestCase): + @staticmethod + def _set_detector(state, default_settings, custom_settings, detector_key): + detector = state.detectors[detector_key] + for key, value in default_settings.items(): + if key in custom_settings: + value = custom_settings[key] + if value is not None: # If the value is None, then don't set it + setattr(detector, key, value) + state.detectors[detector_key] = detector + + @staticmethod + def _get_mask_state(general_entries, detector_entries): + state = StateMask() + # Setup the general mask settings + mask_settings = {"radius_min": 12., "radius_max": 17., + "bin_mask_general_start": [1., 2., 3.], "bin_mask_general_stop": [2., 3., 4.], + "mask_files": None, + "phi_min": 0.5, "phi_max": 1., "use_mask_phi_mirror": True, + "beam_stop_arm_width": 1., "beam_stop_arm_angle": 24.5, "beam_stop_arm_pos1": 12., + "beam_stop_arm_pos2": 34., + "clear": False, "clear_time": False, "single_spectra": [1, 4, 6], + "spectrum_range_start": [1, 5, 7], "spectrum_range_stop": [2, 6, 8], + "idf_path": ""} + + for key, value in mask_settings.items(): + if key in general_entries: + value = general_entries[key] + if value is not None: # If the value is None, then don't set it + setattr(state, key, value) + + # Now setup the detector-specific settings + detector_settings = {"single_vertical_strip_mask": [1, 2, 4], "range_vertical_strip_start": [1, 2, 4], + "range_vertical_strip_stop": [2, 3, 5], "single_horizontal_strip_mask": [1, 2, 4], + "range_horizontal_strip_start": [1, 2, 4], "range_horizontal_strip_stop": [2, 3, 5], + "block_horizontal_start": [1, 2, 4], "block_horizontal_stop": [2, 3, 5], + "block_vertical_start": [1, 2, 4], "block_vertical_stop": [2, 3, 5], + "block_cross_horizontal": [1, 2, 4], "block_cross_vertical": [2, 3, 5], + "bin_mask_start": [1., 2., 4.], "bin_mask_stop": [2., 3., 5.], + "detector_name": "name", "detector_name_short": "name_short"} + + StateMaskTest._set_detector(state, detector_settings, detector_entries, + DetectorType.to_string(DetectorType.LAB)) + StateMaskTest._set_detector(state, detector_settings, detector_entries, + DetectorType.to_string(DetectorType.HAB)) + + return state + + @staticmethod + def _get_dict(entry_name, value): + is_explicit_none = value == EXPLICIT_NONE + output = {} + if value is not None: + value = None if is_explicit_none else value + output.update({entry_name: value}) + return output + + def assert_raises_for_bad_value_and_raises_nothing_for_good_value(self, entry_name=None, bad_value_general=None, + bad_value_detector=None, good_value_general=None, + good_value_detector=None): + # Bad values + bad_value_general_dict = StateMaskTest._get_dict(entry_name, bad_value_general) + bad_value_detector_dict = StateMaskTest._get_dict(entry_name, bad_value_detector) + state = self._get_mask_state(bad_value_general_dict, bad_value_detector_dict) + assert_validate_error(self, ValueError, state) + + # Good values + good_value_general_dict = StateMaskTest._get_dict(entry_name, good_value_general) + good_value_detector_dict = StateMaskTest._get_dict(entry_name, good_value_detector) + state = self._get_mask_state(good_value_general_dict, good_value_detector_dict) + assert_raises_nothing(self, state) + + def test_that_raises_when_lower_radius_bound_larger_than_upper_bound(self): + self.assert_raises_for_bad_value_and_raises_nothing_for_good_value("radius_min", 500., None, 12., None) + + def test_that_raises_when_only_one_bin_mask_has_been_set(self): + self.assert_raises_for_bad_value_and_raises_nothing_for_good_value("bin_mask_general_start", EXPLICIT_NONE, + None, [1., 2., 3.], None) + + def test_that_raises_when_bin_mask_lengths_are_mismatched(self): + self.assert_raises_for_bad_value_and_raises_nothing_for_good_value("bin_mask_general_start", [1., 3.], + None, [1., 2., 3.], None) + + def test_that_raises_lower_bound_is_larger_than_upper_bound_for_bin_mask(self): + self.assert_raises_for_bad_value_and_raises_nothing_for_good_value("bin_mask_general_start", [1., 10., 3.], + None, [1., 2., 3.], None) + + def test_that_raises_when_only_one_spectrum_range_has_been_set(self): + self.assert_raises_for_bad_value_and_raises_nothing_for_good_value("spectrum_range_start", EXPLICIT_NONE, + None, [1, 5, 7], None) + + def test_that_raises_when_spectrum_range_lengths_are_mismatched(self): + self.assert_raises_for_bad_value_and_raises_nothing_for_good_value("spectrum_range_start", [1, 3], + None, [1, 5, 7], None) + + def test_that_raises_lower_bound_is_larger_than_upper_bound_for_spectrum_range(self): + self.assert_raises_for_bad_value_and_raises_nothing_for_good_value("spectrum_range_start", [1, 10, 3], + None, [1, 5, 7], None) + + def test_that_raises_when_only_one_vertical_strip_has_been_set(self): + self.assert_raises_for_bad_value_and_raises_nothing_for_good_value("range_vertical_strip_start", None, + EXPLICIT_NONE, None, [1, 2, 4]) + + def test_that_raises_when_vertical_strip_lengths_are_mismatched(self): + self.assert_raises_for_bad_value_and_raises_nothing_for_good_value("range_vertical_strip_start", None, [1, 2], + None, [1, 2, 4]) + + def test_that_raises_lower_bound_is_larger_than_upper_bound_for_vertical_strip(self): + self.assert_raises_for_bad_value_and_raises_nothing_for_good_value("range_vertical_strip_start", None, + [1, 10, 3], None, [1, 2, 4]) + + def test_that_raises_when_only_one_horizontal_strip_has_been_set(self): + self.assert_raises_for_bad_value_and_raises_nothing_for_good_value("range_horizontal_strip_start", None, + EXPLICIT_NONE, None, [1, 2, 4]) + + def test_that_raises_when_horizontal_strip_lengths_are_mismatched(self): + self.assert_raises_for_bad_value_and_raises_nothing_for_good_value("range_horizontal_strip_start", None, [1, 2], + None, [1, 2, 4]) + + def test_that_raises_lower_bound_is_larger_than_upper_bound_for_horizontal_strip(self): + self.assert_raises_for_bad_value_and_raises_nothing_for_good_value("range_horizontal_strip_start", None, + [1, 10, 3], None, [1, 2, 4]) + + def test_that_raises_when_only_one_horizontal_block_has_been_set(self): + self.assert_raises_for_bad_value_and_raises_nothing_for_good_value("block_horizontal_start", None, + EXPLICIT_NONE, None, [1, 2, 4]) + + def test_that_raises_when_horizontal_block_lengths_are_mismatched(self): + self.assert_raises_for_bad_value_and_raises_nothing_for_good_value("block_horizontal_start", None, [1, 2], + None, [1, 2, 4]) + + def test_that_raises_lower_bound_is_larger_than_upper_bound_for_horiztonal_block(self): + self.assert_raises_for_bad_value_and_raises_nothing_for_good_value("block_horizontal_start", None, + [1, 10, 3], None, [1, 2, 4]) + + def test_that_raises_when_only_one_vertical_block_has_been_set(self): + self.assert_raises_for_bad_value_and_raises_nothing_for_good_value("block_vertical_start", None, + EXPLICIT_NONE, None, [1, 2, 4]) + + def test_that_raises_when_vertical_block_lengths_are_mismatched(self): + self.assert_raises_for_bad_value_and_raises_nothing_for_good_value("block_vertical_start", None, + [1, 2], + None, [1, 2, 4]) + + def test_that_raises_lower_bound_is_larger_than_upper_bound_for_vertical_block(self): + self.assert_raises_for_bad_value_and_raises_nothing_for_good_value("block_vertical_start", None, + [1, 10, 3], None, [1, 2, 4]) + + def test_that_raises_when_only_one_time_mask_has_been_set(self): + self.assert_raises_for_bad_value_and_raises_nothing_for_good_value("bin_mask_start", None, + EXPLICIT_NONE, None, [1., 2., 4.]) + + def test_that_raises_when_time_mask_lengths_are_mismatched(self): + self.assert_raises_for_bad_value_and_raises_nothing_for_good_value("bin_mask_start", None, + [1., 2.], + None, [1., 2., 4.]) + + def test_that_raises_lower_bound_is_larger_than_upper_bound_for_time_mask(self): + self.assert_raises_for_bad_value_and_raises_nothing_for_good_value("bin_mask_start", None, + [1., 10., 3.], None, [1., 2., 4.]) + + def test_that_raises_if_detector_names_have_not_been_set(self): + self.assert_raises_for_bad_value_and_raises_nothing_for_good_value("detector_name", None, + EXPLICIT_NONE, None, "name") + + def test_that_raises_if_short_detector_names_have_not_been_set(self): + self.assert_raises_for_bad_value_and_raises_nothing_for_good_value("detector_name_short", None, + EXPLICIT_NONE, None, "name") + + +# ---------------------------------------------------------------------------------------------------------------------- +# Builder +# ---------------------------------------------------------------------------------------------------------------------- +class StateMaskBuilderTest(unittest.TestCase): + def test_that_mask_can_be_built(self): + # Arrange + facility = SANSFacility.ISIS + data_builder = get_data_builder(facility) + data_builder.set_sample_scatter("LOQ74044") + data_builder.set_sample_scatter_period(3) + data_info = data_builder.build() + + # Act + builder = get_mask_builder(data_info) + self.assertTrue(builder) + + start_time = [0.1, 1.3] + end_time = [0.2, 1.6] + builder.set_bin_mask_general_start(start_time) + builder.set_bin_mask_general_stop(end_time) + builder.set_LAB_single_vertical_strip_mask([1, 2, 3]) + + # Assert + state = builder.build() + self.assertTrue(len(state.bin_mask_general_start) == 2) + self.assertTrue(state.bin_mask_general_start[0] == start_time[0]) + self.assertTrue(state.bin_mask_general_start[1] == start_time[1]) + + self.assertTrue(len(state.bin_mask_general_stop) == 2) + self.assertTrue(state.bin_mask_general_stop[0] == end_time[0]) + self.assertTrue(state.bin_mask_general_stop[1] == end_time[1]) + + strip_mask = state.detectors[DetectorType.to_string(DetectorType.LAB)].single_vertical_strip_mask + self.assertTrue(len(strip_mask) == 3) + self.assertTrue(strip_mask[2] == 3) + + +if __name__ == '__main__': + unittest.main() diff --git a/scripts/test/SANS/state/move_test.py b/scripts/test/SANS/state/move_test.py new file mode 100644 index 0000000000000000000000000000000000000000..a993884455d1a5865aace873e6d62bb21be39f28 --- /dev/null +++ b/scripts/test/SANS/state/move_test.py @@ -0,0 +1,166 @@ +import unittest +import mantid + +from sans.state.move import (StateMoveLOQ, StateMoveSANS2D, StateMoveLARMOR, StateMove, get_move_builder) +from sans.state.data import get_data_builder +from sans.common.enums import (CanonicalCoordinates, SANSFacility, DetectorType) +from state_test_helper import assert_validate_error, assert_raises_nothing + + +# ---------------------------------------------------------------------------------------------------------------------- +# State +# ---------------------------------------------------------------------------------------------------------------------- +class StateMoveWorkspaceTest(unittest.TestCase): + def test_that_raises_if_the_detector_name_is_not_set_up(self): + state = StateMove() + state.detectors[DetectorType.to_string(DetectorType.LAB)].detector_name = "test" + state.detectors[DetectorType.to_string(DetectorType.HAB)].detector_name_short = "test" + state.detectors[DetectorType.to_string(DetectorType.LAB)].detector_name_short = "test" + assert_validate_error(self, ValueError, state) + state.detectors[DetectorType.to_string(DetectorType.HAB)].detector_name = "test" + assert_raises_nothing(self, state) + + def test_that_raises_if_the_short_detector_name_is_not_set_up(self): + state = StateMove() + state.detectors[DetectorType.to_string(DetectorType.HAB)].detector_name = "test" + state.detectors[DetectorType.to_string(DetectorType.LAB)].detector_name = "test" + state.detectors[DetectorType.to_string(DetectorType.HAB)].detector_name_short = "test" + assert_validate_error(self, ValueError, state) + state.detectors[DetectorType.to_string(DetectorType.LAB)].detector_name_short = "test" + assert_raises_nothing(self, state) + + def test_that_general_isis_default_values_are_set_up(self): + state = StateMove() + self.assertTrue(state.sample_offset == 0.0) + self.assertTrue(state.sample_offset_direction is CanonicalCoordinates.Z) + self.assertTrue(state.detectors[DetectorType.to_string(DetectorType.HAB)].x_translation_correction == 0.0) + self.assertTrue(state.detectors[DetectorType.to_string(DetectorType.HAB)].y_translation_correction == 0.0) + self.assertTrue(state.detectors[DetectorType.to_string(DetectorType.HAB)].z_translation_correction == 0.0) + self.assertTrue(state.detectors[DetectorType.to_string(DetectorType.HAB)].rotation_correction == 0.0) + self.assertTrue(state.detectors[DetectorType.to_string(DetectorType.HAB)].side_correction == 0.0) + self.assertTrue(state.detectors[DetectorType.to_string(DetectorType.HAB)].radius_correction == 0.0) + self.assertTrue(state.detectors[DetectorType.to_string(DetectorType.HAB)].x_tilt_correction == 0.0) + self.assertTrue(state.detectors[DetectorType.to_string(DetectorType.HAB)].y_tilt_correction == 0.0) + self.assertTrue(state.detectors[DetectorType.to_string(DetectorType.HAB)].z_tilt_correction == 0.0) + self.assertTrue(state.detectors[DetectorType.to_string(DetectorType.HAB)].sample_centre_pos1 == 0.0) + self.assertTrue(state.detectors[DetectorType.to_string(DetectorType.HAB)].sample_centre_pos2 == 0.0) + + +class StateMoveWorkspaceLOQTest(unittest.TestCase): + def test_that_is_sans_state_move_object(self): + state = StateMoveLOQ() + self.assertTrue(isinstance(state, StateMove)) + + def test_that_LOQ_has_centre_position_set_up(self): + state = StateMoveLOQ() + self.assertTrue(state.center_position == 317.5 / 1000.) + self.assertTrue(state.monitor_names == {}) + + +class StateMoveWorkspaceSANS2DTest(unittest.TestCase): + def test_that_is_sans_state_data_object(self): + state = StateMoveSANS2D() + self.assertTrue(isinstance(state, StateMove)) + + def test_that_sans2d_has_default_values_set_up(self): + # Arrange + state = StateMoveSANS2D() + self.assertTrue(state.hab_detector_radius == 306.0/1000.) + self.assertTrue(state.hab_detector_default_sd_m == 4.0) + self.assertTrue(state.hab_detector_default_x_m == 1.1) + self.assertTrue(state.lab_detector_default_sd_m == 4.0) + self.assertTrue(state.hab_detector_x == 0.0) + self.assertTrue(state.hab_detector_z == 0.0) + self.assertTrue(state.hab_detector_rotation == 0.0) + self.assertTrue(state.lab_detector_x == 0.0) + self.assertTrue(state.lab_detector_z == 0.0) + self.assertTrue(state.monitor_names == {}) + self.assertTrue(state.monitor_4_offset == 0.0) + + +class StateMoveWorkspaceLARMORTest(unittest.TestCase): + def test_that_is_sans_state_data_object(self): + state = StateMoveLARMOR() + self.assertTrue(isinstance(state, StateMove)) + + def test_that_can_set_and_get_values(self): + state = StateMoveLARMOR() + self.assertTrue(state.bench_rotation == 0.0) + self.assertTrue(state.monitor_names == {}) + + +# ---------------------------------------------------------------------------------------------------------------------- +# Builder +# ---------------------------------------------------------------------------------------------------------------------- +class StateMoveBuilderTest(unittest.TestCase): + def test_that_state_for_loq_can_be_built(self): + # Arrange + facility = SANSFacility.ISIS + data_builder = get_data_builder(facility) + data_builder.set_sample_scatter("LOQ74044") + data_builder.set_sample_scatter_period(3) + data_info = data_builder.build() + + # Act + builder = get_move_builder(data_info) + self.assertTrue(builder) + value = 324.2 + builder.set_center_position(value) + builder.set_HAB_x_translation_correction(value) + builder.set_LAB_sample_centre_pos1(value) + + # Assert + state = builder.build() + self.assertTrue(state.center_position == value) + self.assertTrue(state.detectors[DetectorType.to_string(DetectorType.HAB)].x_translation_correction == value) + self.assertTrue(state.detectors[DetectorType.to_string(DetectorType.HAB)].detector_name_short == "HAB") + self.assertTrue(state.detectors[DetectorType.to_string(DetectorType.LAB)].detector_name == "main-detector-bank") + self.assertTrue(state.monitor_names[str(2)] == "monitor2") + self.assertTrue(len(state.monitor_names) == 2) + self.assertTrue(state.detectors[DetectorType.to_string(DetectorType.LAB)].sample_centre_pos1 == value) + + def test_that_state_for_sans2d_can_be_built(self): + # Arrange + facility = SANSFacility.ISIS + data_builder = get_data_builder(facility) + data_builder.set_sample_scatter("SANS2D00022048") + data_info = data_builder.build() + + # Act + builder = get_move_builder(data_info) + self.assertTrue(builder) + value = 324.2 + builder.set_HAB_x_translation_correction(value) + + # Assert + state = builder.build() + self.assertTrue(state.detectors[DetectorType.to_string(DetectorType.HAB)].x_translation_correction == value) + self.assertTrue(state.detectors[DetectorType.to_string(DetectorType.HAB)].detector_name_short == "front") + self.assertTrue(state.detectors[DetectorType.to_string(DetectorType.LAB)].detector_name == "rear-detector") + self.assertTrue(state.monitor_names[str(7)] == "monitor7") + self.assertTrue(len(state.monitor_names) == 8) + + def test_that_state_for_larmor_can_be_built(self): + # Arrange + facility = SANSFacility.ISIS + data_builder = get_data_builder(facility) + data_builder.set_sample_scatter("LARMOR00002260") + data_info = data_builder.build() + + # Act + builder = get_move_builder(data_info) + self.assertTrue(builder) + value = 324.2 + builder.set_HAB_x_translation_correction(value) + + # Assert + state = builder.build() + self.assertTrue(state.detectors[DetectorType.to_string(DetectorType.HAB)].x_translation_correction == value) + self.assertTrue(state.detectors[DetectorType.to_string(DetectorType.HAB)].detector_name_short == "front") + self.assertTrue(state.detectors[DetectorType.to_string(DetectorType.LAB)].detector_name == "DetectorBench") + self.assertTrue(state.monitor_names[str(5)] == "monitor5") + self.assertTrue(len(state.monitor_names) == 10) + + +if __name__ == '__main__': + unittest.main() diff --git a/scripts/test/SANS/state/normalize_to_monitor_test.py b/scripts/test/SANS/state/normalize_to_monitor_test.py new file mode 100644 index 0000000000000000000000000000000000000000..f7c902bf8f57e979c68844f748824b1b29cadb32 --- /dev/null +++ b/scripts/test/SANS/state/normalize_to_monitor_test.py @@ -0,0 +1,128 @@ +import unittest +import mantid + +from sans.state.normalize_to_monitor import (StateNormalizeToMonitor, StateNormalizeToMonitorLOQ, + get_normalize_to_monitor_builder) +from sans.state.data import get_data_builder +from sans.common.enums import (RebinType, RangeStepType, SANSFacility) +from state_test_helper import assert_validate_error, assert_raises_nothing + + +# ---------------------------------------------------------------------------------------------------------------------- +# State +# ---------------------------------------------------------------------------------------------------------------------- +class StateNormalizeToMonitorTest(unittest.TestCase): + @staticmethod + def _get_normalize_to_monitor_state(**kwargs): + state = StateNormalizeToMonitor() + default_entries = {"prompt_peak_correction_min": 12., "prompt_peak_correction_max": 17., + "rebin_type": RebinType.Rebin, "wavelength_low": 1.5, "wavelength_high": 2.7, + "wavelength_step": 0.5, "incident_monitor": 1, "wavelength_step_type": RangeStepType.Lin, + "background_TOF_general_start": 1.4, "background_TOF_general_stop": 24.5, + "background_TOF_monitor_start": {"1": 123, "2": 123}, + "background_TOF_monitor_stop": {"1": 234, "2": 2323}} + + for key, value in default_entries.items(): + if key in kwargs: + value = kwargs[key] + if value is not None: # If the value is None, then don't set it + setattr(state, key, value) + return state + + def assert_raises_for_bad_value_and_raises_nothing_for_good_value(self, entry_name, bad_value, good_value): + kwargs = {entry_name: bad_value} + state = self._get_normalize_to_monitor_state(**kwargs) + assert_validate_error(self, ValueError, state) + setattr(state, entry_name, good_value) + assert_raises_nothing(self, state) + + def test_that_is_sans_state_normalize_to_monitor_object(self): + state = StateNormalizeToMonitorLOQ() + self.assertTrue(isinstance(state, StateNormalizeToMonitor)) + + def test_that_normalize_to_monitor_for_loq_has_default_prompt_peak(self): + state = StateNormalizeToMonitorLOQ() + self.assertTrue(state.prompt_peak_correction_max == 20500.) + self.assertTrue(state.prompt_peak_correction_min == 19000.) + + def test_that_raises_for_partially_set_prompt_peak(self): + self.assert_raises_for_bad_value_and_raises_nothing_for_good_value("prompt_peak_correction_min", None, 1.) + + def test_that_raises_for_inconsistent_prompt_peak(self): + self.assert_raises_for_bad_value_and_raises_nothing_for_good_value("prompt_peak_correction_max", 1., 30.) + + def test_that_raises_for_missing_incident_monitor(self): + self.assert_raises_for_bad_value_and_raises_nothing_for_good_value("incident_monitor", None, 1) + + def test_that_raises_for_partially_set_general_background_tof(self): + self.assert_raises_for_bad_value_and_raises_nothing_for_good_value("background_TOF_general_start", None, 1.) + + def test_that_raises_for_inconsistent_general_background_tof(self): + self.assert_raises_for_bad_value_and_raises_nothing_for_good_value("background_TOF_general_start", 100., 1.) + + def test_that_raises_for_partially_set_monitor_background_tof(self): + self.assert_raises_for_bad_value_and_raises_nothing_for_good_value("background_TOF_monitor_start", None, + {"1": 123, "2": 123}) + + def test_that_raises_for_monitor_background_tof_with_different_lengths(self): + self.assert_raises_for_bad_value_and_raises_nothing_for_good_value("background_TOF_monitor_start", {"1": 123}, + {"1": 123, "2": 123}) + + def test_that_raises_for_monitor_background_tof_with_differing_spectrum_numbers(self): + self.assert_raises_for_bad_value_and_raises_nothing_for_good_value("background_TOF_monitor_start", + {"1": 123, "5": 123}, + {"1": 123, "2": 123}) + + def test_that_raises_for_monitor_background_tof_with_inconsistent_bounds(self): + self.assert_raises_for_bad_value_and_raises_nothing_for_good_value("background_TOF_monitor_start", + {"1": 123, "2": 191123}, + {"1": 123, "2": 123}) + + +# ---------------------------------------------------------------------------------------------------------------------- +# Builder +# ---------------------------------------------------------------------------------------------------------------------- +class StateReductionBuilderTest(unittest.TestCase): + def test_that_reduction_state_can_be_built(self): + # Arrange + facility = SANSFacility.ISIS + data_builder = get_data_builder(facility) + data_builder.set_sample_scatter("LOQ74044") + data_info = data_builder.build() + + # Act + builder = get_normalize_to_monitor_builder(data_info) + self.assertTrue(builder) + + builder.set_prompt_peak_correction_min(12.0) + builder.set_prompt_peak_correction_max(17.0) + builder.set_rebin_type(RebinType.Rebin) + builder.set_wavelength_low(1.5) + builder.set_wavelength_high(2.7) + builder.set_wavelength_step(0.5) + builder.set_wavelength_step_type(RangeStepType.Lin) + builder.set_incident_monitor(1) + builder.set_background_TOF_general_start(1.4) + builder.set_background_TOF_general_stop(34.4) + builder.set_background_TOF_monitor_start({"1": 123, "2": 123}) + builder.set_background_TOF_monitor_stop({"1": 234, "2": 2323}) + + state = builder.build() + + # Assert + self.assertTrue(state.prompt_peak_correction_min == 12.0) + self.assertTrue(state.prompt_peak_correction_max == 17.0) + self.assertTrue(state.rebin_type is RebinType.Rebin) + self.assertTrue(state.wavelength_low == 1.5) + self.assertTrue(state.wavelength_high == 2.7) + self.assertTrue(state.wavelength_step == 0.5) + self.assertTrue(state.wavelength_step_type is RangeStepType.Lin) + self.assertTrue(state.background_TOF_general_start == 1.4) + self.assertTrue(state.background_TOF_general_stop == 34.4) + self.assertTrue(len(set(state.background_TOF_monitor_start.items()) & set({"1": 123, "2": 123}.items())) == 2) + self.assertTrue(len(set(state.background_TOF_monitor_stop.items()) & set({"1": 234, "2": 2323}.items())) == 2) + self.assertTrue(state.incident_monitor == 1) + + +if __name__ == '__main__': + unittest.main() diff --git a/scripts/test/SANS/state/reduction_mode_test.py b/scripts/test/SANS/state/reduction_mode_test.py new file mode 100644 index 0000000000000000000000000000000000000000..c4f87b9644cfa41625282a266297811e34a98b1b --- /dev/null +++ b/scripts/test/SANS/state/reduction_mode_test.py @@ -0,0 +1,85 @@ +import unittest +import mantid + +from sans.state.reduction_mode import (StateReductionMode, get_reduction_mode_builder) +from sans.state.data import get_data_builder +from sans.common.enums import (ISISReductionMode, ReductionDimensionality, FitModeForMerge, + SANSFacility, SANSInstrument, DetectorType) + + +# ---------------------------------------------------------------------------------------------------------------------- +# State +# ---------------------------------------------------------------------------------------------------------------------- +class StateReductionModeTest(unittest.TestCase): + def test_that_converter_methods_work(self): + # Arrange + state = StateReductionMode() + + state.reduction_mode = ISISReductionMode.Merged + state.dimensionality = ReductionDimensionality.TwoDim + state.merge_shift = 12.65 + state.merge_scale = 34.6 + state.merge_fit_mode = FitModeForMerge.ShiftOnly + + state.detector_names[DetectorType.to_string(DetectorType.LAB)] = "Test1" + state.detector_names[DetectorType.to_string(DetectorType.HAB)] = "Test2" + + # Assert + merge_strategy = state.get_merge_strategy() + self.assertTrue(merge_strategy[0] is ISISReductionMode.LAB) + self.assertTrue(merge_strategy[1] is ISISReductionMode.HAB) + + all_reductions = state.get_all_reduction_modes() + self.assertTrue(len(all_reductions) == 2) + self.assertTrue(all_reductions[0] is ISISReductionMode.LAB) + self.assertTrue(all_reductions[1] is ISISReductionMode.HAB) + + result_lab = state.get_detector_name_for_reduction_mode(ISISReductionMode.LAB) + self.assertTrue(result_lab == "Test1") + result_hab = state.get_detector_name_for_reduction_mode(ISISReductionMode.HAB) + self.assertTrue(result_hab == "Test2") + + self.assertRaises(RuntimeError, state.get_detector_name_for_reduction_mode, "non_sense") + + +# ---------------------------------------------------------------------------------------------------------------------- +# Builder +# ---------------------------------------------------------------------------------------------------------------------- +class StateReductionModeBuilderTest(unittest.TestCase): + def test_that_reduction_state_can_be_built(self): + # Arrange + facility = SANSFacility.ISIS + data_builder = get_data_builder(facility) + data_builder.set_sample_scatter("LOQ74044") + data_info = data_builder.build() + + # Act + builder = get_reduction_mode_builder(data_info) + self.assertTrue(builder) + + mode = ISISReductionMode.Merged + dim = ReductionDimensionality.OneDim + builder.set_reduction_mode(mode) + builder.set_reduction_dimensionality(dim) + + merge_shift = 324.2 + merge_scale = 3420.98 + fit_mode = FitModeForMerge.Both + builder.set_merge_fit_mode(fit_mode) + builder.set_merge_shift(merge_shift) + builder.set_merge_scale(merge_scale) + + state = builder.build() + + # Assert + self.assertTrue(state.reduction_mode is mode) + self.assertTrue(state.reduction_dimensionality is dim) + self.assertTrue(state.merge_fit_mode == fit_mode) + self.assertTrue(state.merge_shift == merge_shift) + self.assertTrue(state.merge_scale == merge_scale) + detector_names = state.detector_names + self.assertTrue(detector_names[DetectorType.to_string(DetectorType.LAB)] == "main-detector-bank") + + +if __name__ == '__main__': + unittest.main() diff --git a/scripts/test/SANS/state/save_test.py b/scripts/test/SANS/state/save_test.py new file mode 100644 index 0000000000000000000000000000000000000000..5cac8305027d05afcf15daae50b4d08b7e68001d --- /dev/null +++ b/scripts/test/SANS/state/save_test.py @@ -0,0 +1,46 @@ +import unittest +import mantid + +from sans.state.save import (get_save_builder) +from sans.state.data import (get_data_builder) +from sans.common.enums import (SANSFacility, SaveType) + + +# ---------------------------------------------------------------------------------------------------------------------- +# State +# No tests required +# ---------------------------------------------------------------------------------------------------------------------- + + +# ---------------------------------------------------------------------------------------------------------------------- +# Builder +# ---------------------------------------------------------------------------------------------------------------------- +class StateReductionBuilderTest(unittest.TestCase): + def test_that_reduction_state_can_be_built(self): + # Arrange + facility = SANSFacility.ISIS + data_builder = get_data_builder(facility) + data_builder.set_sample_scatter("LOQ74044") + data_info = data_builder.build() + + # Act + builder = get_save_builder(data_info) + self.assertTrue(builder) + + file_name = "test_file_name" + zero_free_correction = True + file_format = [SaveType.Nexus, SaveType.CanSAS] + + builder.set_file_name(file_name) + builder.set_zero_free_correction(zero_free_correction) + builder.set_file_format(file_format) + state = builder.build() + + # Assert + self.assertTrue(state.file_name == file_name) + self.assertTrue(state.zero_free_correction == zero_free_correction) + self.assertTrue(state.file_format == file_format) + + +if __name__ == '__main__': + unittest.main() diff --git a/scripts/test/SANS/state/scale_test.py b/scripts/test/SANS/state/scale_test.py new file mode 100644 index 0000000000000000000000000000000000000000..144ece7f2dc8192676b54359fa4247183ffa894c --- /dev/null +++ b/scripts/test/SANS/state/scale_test.py @@ -0,0 +1,45 @@ +import unittest +import mantid + +from sans.state.scale import get_scale_builder +from sans.state.data import get_data_builder +from sans.common.enums import (SANSFacility, SANSInstrument, SampleShape) + + +# ---------------------------------------------------------------------------------------------------------------------- +# State +# No tests required for the current states +# ---------------------------------------------------------------------------------------------------------------------- + +# ---------------------------------------------------------------------------------------------------------------------- +# Builder +# ---------------------------------------------------------------------------------------------------------------------- +class StateSliceEventBuilderTest(unittest.TestCase): + def test_that_slice_event_state_can_be_built(self): + # Arrange + facility = SANSFacility.ISIS + data_builder = get_data_builder(facility) + data_builder.set_sample_scatter("LOQ74044") + data_info = data_builder.build() + + # Act + builder = get_scale_builder(data_info) + self.assertTrue(builder) + + builder.set_scale(1.0) + builder.set_shape(SampleShape.Cuboid) + builder.set_thickness(3.6) + builder.set_width(3.7) + builder.set_height(5.8) + + # Assert + state = builder.build() + self.assertTrue(state.shape is SampleShape.Cuboid) + self.assertTrue(state.scale == 1.0) + self.assertTrue(state.thickness == 3.6) + self.assertTrue(state.width == 3.7) + self.assertTrue(state.height == 5.8) + + +if __name__ == '__main__': + unittest.main() diff --git a/scripts/test/SANS/state/slice_event_test.py b/scripts/test/SANS/state/slice_event_test.py new file mode 100644 index 0000000000000000000000000000000000000000..15f24eaf5fb328f0157853940551e571eb5ecd2a --- /dev/null +++ b/scripts/test/SANS/state/slice_event_test.py @@ -0,0 +1,80 @@ +import unittest +import mantid + +from sans.state.slice_event import (StateSliceEvent, get_slice_event_builder) +from sans.state.data import get_data_builder +from sans.common.enums import (SANSFacility, SANSInstrument) +from state_test_helper import (assert_validate_error) + + +# ---------------------------------------------------------------------------------------------------------------------- +# State +# ---------------------------------------------------------------------------------------------------------------------- +class StateSliceEventTest(unittest.TestCase): + def test_that_raises_when_only_one_time_is_set(self): + state = StateSliceEvent() + state.start_time = [1.0, 2.0] + assert_validate_error(self, ValueError, state) + state.end_time = [2.0, 3.0] + + def test_validate_method_raises_value_error_for_mismatching_start_and_end_time_length(self): + # Arrange + state = StateSliceEvent() + state.start_time = [1.0, 2.0] + state.end_time = [5.0] + + # Act + Assert + self.assertRaises(ValueError, state.validate) + + def test_validate_method_raises_value_error_for_non_increasing_time(self): + # Arrange + state = StateSliceEvent() + state.start_time = [1.0, 2.0, 1.5] + state.end_time = [1.1, 2.1, 2.5] + + # Act + Assert + self.assertRaises(ValueError, state.validate) + + def test_validate_method_raises_value_error_for_end_time_smaller_than_start_time(self): + # Arrange + state = StateSliceEvent() + state.start_time = [1.0, 2.0, 4.6] + state.end_time = [1.1, 2.1, 2.5] + + # Act + Assert + self.assertRaises(ValueError, state.validate) + + +# ---------------------------------------------------------------------------------------------------------------------- +# Builder +# ---------------------------------------------------------------------------------------------------------------------- +class StateSliceEventBuilderTest(unittest.TestCase): + def test_that_slice_event_state_can_be_built(self): + # Arrange + facility = SANSFacility.ISIS + data_builder = get_data_builder(facility) + data_builder.set_sample_scatter("LOQ74044") + data_info = data_builder.build() + + # Act + builder = get_slice_event_builder(data_info) + self.assertTrue(builder) + + start_time = [0.1, 1.3] + end_time = [0.2, 1.6] + builder.set_start_time(start_time) + builder.set_end_time(end_time) + + # Assert + state = builder.build() + self.assertTrue(len(state.start_time) == 2) + self.assertTrue(state.start_time[0] == start_time[0]) + self.assertTrue(state.start_time[1] == start_time[1]) + + self.assertTrue(len(state.end_time) == 2) + self.assertTrue(state.end_time[0] == end_time[0]) + self.assertTrue(state.end_time[1] == end_time[1]) + + +if __name__ == '__main__': + unittest.main() diff --git a/scripts/test/SANS/state/state_base_test.py b/scripts/test/SANS/state/state_base_test.py new file mode 100644 index 0000000000000000000000000000000000000000..4e7163d81d4790f91ce810a15ea6bfa710481e9e --- /dev/null +++ b/scripts/test/SANS/state/state_base_test.py @@ -0,0 +1,288 @@ +import unittest +import mantid + +from mantid.kernel import (PropertyManagerProperty, PropertyManager) +from mantid.api import Algorithm + +from sans.state.state_base import (StringParameter, BoolParameter, FloatParameter, PositiveFloatParameter, + PositiveIntegerParameter, DictParameter, ClassTypeParameter, + FloatWithNoneParameter, PositiveFloatWithNoneParameter, FloatListParameter, + StringListParameter, PositiveIntegerListParameter, ClassTypeListParameter, + StateBase, rename_descriptor_names, TypedParameter, validator_sub_state, + create_deserialized_sans_state_from_property_manager) +from sans.common.enums import serializable_enum + + +@serializable_enum("TypeA", "TypeB") +class TestType(object): + pass + + +# ---------------------------------------------------------------------------------------------------------------------- +# ---------------------------------------------------------------------------------------------------------------------- +# Test the typed parameters +# ---------------------------------------------------------------------------------------------------------------------- +# ---------------------------------------------------------------------------------------------------------------------- +@rename_descriptor_names +class StateBaseTestClass(StateBase): + string_parameter = StringParameter() + bool_parameter = BoolParameter() + float_parameter = FloatParameter() + positive_float_parameter = PositiveFloatParameter() + positive_integer_parameter = PositiveIntegerParameter() + dict_parameter = DictParameter() + float_with_none_parameter = FloatWithNoneParameter() + positive_float_with_none_parameter = PositiveFloatWithNoneParameter() + float_list_parameter = FloatListParameter() + string_list_parameter = StringListParameter() + positive_integer_list_parameter = PositiveIntegerListParameter() + class_type_parameter = ClassTypeParameter(TestType) + class_type_list_parameter = ClassTypeListParameter(TestType) + + def __init__(self): + super(StateBaseTestClass, self).__init__() + + def validate(self): + pass + + +class TypedParameterTest(unittest.TestCase): + def _check_that_raises(self, error_type, obj, descriptor_name, value): + try: + setattr(obj, descriptor_name, value) + self.fail() + except error_type: + pass + except: # noqa + self.fail() + + def test_that_can_set_to_valid_value_of_correct_type(self): + test_class = StateBaseTestClass() + try: + test_class.string_parameter = "Test" + test_class.bool_parameter = True + test_class.float_parameter = -23.5768 + test_class.positive_float_parameter = 234.5 + test_class.positive_integer_parameter = 12 + test_class.dict_parameter = {} + test_class.dict_parameter = {"test": 12, "test2": 13} + test_class.float_with_none_parameter = None + test_class.float_with_none_parameter = -123.67 + test_class.positive_float_with_none_parameter = None + test_class.positive_float_with_none_parameter = 123.67 + test_class.float_list_parameter = [12., -123., 2355.] + test_class.string_list_parameter = ["test", "test"] + test_class.positive_integer_list_parameter = [1, 2, 4] + test_class.class_type_parameter = TestType.TypeA + test_class.class_type_list_parameter = [TestType.TypeA, TestType.TypeB] + + except ValueError: + self.fail() + + def test_that_will_raise_type_error_if_set_with_wrong_type(self): + test_class = StateBaseTestClass() + self._check_that_raises(TypeError, test_class, "string_parameter", 1.) + self._check_that_raises(TypeError, test_class, "bool_parameter", 1.) + self._check_that_raises(TypeError, test_class, "float_parameter", "test") + self._check_that_raises(TypeError, test_class, "positive_float_parameter", "test") + self._check_that_raises(TypeError, test_class, "positive_integer_parameter", "test") + self._check_that_raises(TypeError, test_class, "dict_parameter", "test") + self._check_that_raises(TypeError, test_class, "float_with_none_parameter", "test") + self._check_that_raises(TypeError, test_class, "positive_float_with_none_parameter", "test") + self._check_that_raises(TypeError, test_class, "float_list_parameter", [1.23, "test"]) + self._check_that_raises(TypeError, test_class, "string_list_parameter", ["test", "test", 123.]) + self._check_that_raises(TypeError, test_class, "positive_integer_list_parameter", [1, "test"]) + self._check_that_raises(TypeError, test_class, "class_type_parameter", "test") + self._check_that_raises(TypeError, test_class, "class_type_list_parameter", ["test", TestType.TypeA]) + + def test_that_will_raise_if_set_with_wrong_value(self): + # Note that this check does not apply to all parameter, it checks the validator + test_class = StateBaseTestClass() + self._check_that_raises(ValueError, test_class, "positive_float_parameter", -1.2) + self._check_that_raises(ValueError, test_class, "positive_integer_parameter", -1) + self._check_that_raises(ValueError, test_class, "positive_float_with_none_parameter", -234.) + self._check_that_raises(ValueError, test_class, "positive_integer_list_parameter", [1, -2, 4]) + + +# ---------------------------------------------------------------------------------------------------------------------- +# ---------------------------------------------------------------------------------------------------------------------- +# Test the sans_parameters decorator +# ---------------------------------------------------------------------------------------------------------------------- +# ---------------------------------------------------------------------------------------------------------------------- + +class SANSParameterTest(unittest.TestCase): + @rename_descriptor_names + class SANSParameterTestClass(object): + my_string_parameter = StringParameter() + my_bool_parameter = BoolParameter() + + class SANSParameterTestClass2(object): + my_string_parameter = StringParameter() + my_bool_parameter = BoolParameter() + + def test_that_name_is_in_readable_format_in_instance_dictionary(self): + test_class = SANSParameterTest.SANSParameterTestClass() + test_class.my_string_parameter = "test" + test_class.my_bool_parameter = True + keys = test_class.__dict__.keys() + # We don't have a sensible name in the instance dictionary + self.assertTrue("_BoolParameter#my_bool_parameter" in keys) + self.assertTrue("_StringParameter#my_string_parameter" in keys) + + def test_that_name_cannot_be_found_in_instance_dictionary_when_sans_parameters_decorator_is_not_applied(self): + test_class = SANSParameterTest.SANSParameterTestClass2() + test_class.my_string_parameter = "test" + test_class.my_bool_parameter = True + keys = test_class.__dict__.keys() + # We don't have a sensible name in the instance dictionary. + # It will be rather stored as something like: _BoolParameter#2 etc. + self.assertTrue("_BoolParameter#my_bool_parameter" not in keys) + self.assertTrue("_StringParameter#my_string_parameter" not in keys) + + +# ---------------------------------------------------------------------------------------------------------------------- +# ---------------------------------------------------------------------------------------------------------------------- +# StateBase +# This will mainly test serialization +# ---------------------------------------------------------------------------------------------------------------------- +# ---------------------------------------------------------------------------------------------------------------------- + +@rename_descriptor_names +class VerySimpleState(StateBase): + string_parameter = StringParameter() + + def __init__(self): + super(VerySimpleState, self).__init__() + self.string_parameter = "test_in_very_simple" + + def validate(self): + pass + + +@rename_descriptor_names +class SimpleState(StateBase): + string_parameter = StringParameter() + bool_parameter = BoolParameter() + float_parameter = FloatParameter() + positive_float_parameter = PositiveFloatParameter() + positive_integer_parameter = PositiveIntegerParameter() + dict_parameter = DictParameter() + float_with_none_parameter = FloatWithNoneParameter() + positive_float_with_none_parameter = PositiveFloatWithNoneParameter() + float_list_parameter = FloatListParameter() + string_list_parameter = StringListParameter() + positive_integer_list_parameter = PositiveIntegerListParameter() + class_type_parameter = ClassTypeParameter(TestType) + class_type_list_parameter = ClassTypeListParameter(TestType) + + sub_state_very_simple = TypedParameter(VerySimpleState, validator_sub_state) + + def __init__(self): + super(SimpleState, self).__init__() + self.string_parameter = "String_in_SimpleState" + self.bool_parameter = False + # We explicitly leave out the float_parameter + self.positive_float_parameter = 1. + self.positive_integer_parameter = 6 + self.dict_parameter = {"1": 123, "2": "test"} + self.float_with_none_parameter = 325. + # We expliclty leave out the positive_float_with_none_parameter + self.float_list_parameter = [123., 234.] + self.string_list_parameter = ["test1", "test2"] + self.positive_integer_list_parameter = [1, 2, 3] + self.class_type_parameter = TestType.TypeA + self.class_type_list_parameter = [TestType.TypeA, TestType.TypeB] + self.sub_state_very_simple = VerySimpleState() + + def validate(self): + pass + + +@rename_descriptor_names +class ComplexState(StateBase): + float_parameter = FloatParameter() + positive_float_with_none_parameter = PositiveFloatWithNoneParameter() + sub_state_1 = TypedParameter(SimpleState, validator_sub_state) + dict_parameter = DictParameter() + + def __init__(self): + super(ComplexState, self).__init__() + self.float_parameter = 23. + self.positive_float_with_none_parameter = 234. + self.sub_state_1 = SimpleState() + self.dict_parameter = {"A": SimpleState(), "B": SimpleState()} + + def validate(self): + pass + + +class TestStateBase(unittest.TestCase): + def _assert_simple_state(self, state): + self.assertTrue(state.string_parameter == "String_in_SimpleState") + self.assertFalse(state.bool_parameter) + self.assertTrue(state.float_parameter is None) # We did not set it on the instance + self.assertTrue(state.positive_float_parameter == 1.) + self.assertTrue(state.positive_integer_parameter == 6) + self.assertTrue(state.dict_parameter["1"] == 123) + self.assertTrue(state.dict_parameter["2"] == "test") + self.assertTrue(state.float_with_none_parameter == 325.) + self.assertTrue(state.positive_float_with_none_parameter is None) + + self.assertTrue(len(state.float_list_parameter) == 2) + self.assertTrue(state.float_list_parameter[0] == 123.) + self.assertTrue(state.float_list_parameter[1] == 234.) + + self.assertTrue(len(state.string_list_parameter) == 2) + self.assertTrue(state.string_list_parameter[0] == "test1") + self.assertTrue(state.string_list_parameter[1] == "test2") + + self.assertTrue(len(state.positive_integer_list_parameter) == 3) + self.assertTrue(state.positive_integer_list_parameter[0] == 1) + self.assertTrue(state.positive_integer_list_parameter[1] == 2) + self.assertTrue(state.positive_integer_list_parameter[2] == 3) + + self.assertTrue(state.class_type_parameter is TestType.TypeA) + self.assertTrue(len(state.class_type_list_parameter) == 2) + self.assertTrue(state.class_type_list_parameter[0] == TestType.TypeA) + self.assertTrue(state.class_type_list_parameter[1] == TestType.TypeB) + + self.assertTrue(state.sub_state_very_simple.string_parameter == "test_in_very_simple") + + def test_that_sans_state_can_be_serialized_and_deserialized_when_going_through_an_algorithm(self): + class FakeAlgorithm(Algorithm): + def PyInit(self): + self.declareProperty(PropertyManagerProperty("Args")) + + def PyExec(self): + pass + + # Arrange + state = ComplexState() + + # Act + serialized = state.property_manager + fake = FakeAlgorithm() + fake.initialize() + fake.setProperty("Args", serialized) + property_manager = fake.getProperty("Args").value + + # Assert + self.assertTrue(type(serialized) == dict) + self.assertTrue(type(property_manager) == PropertyManager) + state_2 = create_deserialized_sans_state_from_property_manager(property_manager) + state_2.property_manager = property_manager + + # The direct sub state + self._assert_simple_state(state_2.sub_state_1) + + # The two states in the dictionary + self._assert_simple_state(state_2.dict_parameter["A"]) + self._assert_simple_state(state_2.dict_parameter["B"]) + + # The regular parameters + self.assertTrue(state_2.float_parameter == 23.) + self.assertTrue(state_2.positive_float_with_none_parameter == 234.) + + +if __name__ == '__main__': + unittest.main() diff --git a/scripts/test/SANS/state/state_functions_test.py b/scripts/test/SANS/state/state_functions_test.py new file mode 100644 index 0000000000000000000000000000000000000000..5547e721a49491a98b537b7d2eabfa218b26b224 --- /dev/null +++ b/scripts/test/SANS/state/state_functions_test.py @@ -0,0 +1,157 @@ +import unittest +import mantid + +from mantid.api import AnalysisDataService +from sans.state.state_functions import (get_output_workspace_name, is_pure_none_or_not_none, one_is_none, + validation_message, is_not_none_and_first_larger_than_second, + write_hash_into_reduced_can_workspace, get_reduced_can_workspace_from_ads) +from test_director import TestDirector +from sans.state.data import StateData +from sans.common.enums import (ReductionDimensionality, ISISReductionMode, OutputParts) +from sans.common.general_functions import create_unmanaged_algorithm + + +class StateFunctionsTest(unittest.TestCase): + @staticmethod + def _get_state(): + test_director = TestDirector() + state = test_director.construct() + + state.data.sample_scatter_run_number = 12345 + state.data.sample_scatter_period = StateData.ALL_PERIODS + + state.reduction.dimensionality = ReductionDimensionality.OneDim + + state.wavelength.wavelength_low = 12.0 + state.wavelength.wavelength_high = 34.0 + + state.mask.phi_min = 12.0 + state.mask.phi_max = 56.0 + + state.slice.start_time = [4.56778] + state.slice.end_time = [12.373938] + return state + + @staticmethod + def _prepare_workspaces(number_of_workspaces, tagged_workspace_names=None, state=None): + create_name = "CreateSampleWorkspace" + create_options = {"OutputWorkspace": "test", + "NumBanks": 1, + "BankPixelWidth": 2, + "XMin": 1, + "XMax": 10, + "BinWidth": 2} + create_alg = create_unmanaged_algorithm(create_name, **create_options) + + for index in range(number_of_workspaces): + create_alg.execute() + workspace = create_alg.getProperty("OutputWorkspace").value + workspace_name = "test" + "_" + str(index) + AnalysisDataService.addOrReplace(workspace_name, workspace) + + if tagged_workspace_names is not None: + for key, value in tagged_workspace_names.items(): + create_alg.execute() + workspace = create_alg.getProperty("OutputWorkspace").value + AnalysisDataService.addOrReplace(value, workspace) + write_hash_into_reduced_can_workspace(state, workspace, key) + + @staticmethod + def _remove_workspaces(): + for element in AnalysisDataService.getObjectNames(): + AnalysisDataService.remove(element) + + def test_that_unknown_reduction_mode_raises(self): + # Arrange + state = StateFunctionsTest._get_state() + + # Act + Assert + try: + get_output_workspace_name(state, ISISReductionMode.All) + did_raise = False + except RuntimeError: + did_raise = True + self.assertTrue(did_raise) + + def test_that_creates_correct_workspace_name_for_1D(self): + # Arrange + state = StateFunctionsTest._get_state() + # Act + output_workspace = get_output_workspace_name(state, ISISReductionMode.LAB) + # Assert + self.assertTrue("12345rear_1D12.0_34.0Phi12.0_56.0_t4.57_T12.37" == output_workspace) + + def test_that_detects_if_all_entries_are_none_or_not_none_as_true(self): + self.assertFalse(is_pure_none_or_not_none(["test", None, "test"])) + self.assertTrue(is_pure_none_or_not_none([None, None, None])) + self.assertTrue(is_pure_none_or_not_none(["test", "test", "test"])) + self.assertTrue(is_pure_none_or_not_none([])) + + def test_that_detects_if_one_is_none(self): + self.assertTrue(one_is_none(["test", None, "test"])) + self.assertFalse(one_is_none([])) + self.assertFalse(one_is_none(["test", "test", "test"])) + + def test_test_that_can_detect_when_first_is_larger_than_second(self): + self.assertTrue(is_not_none_and_first_larger_than_second([1, 2, 3])) + self.assertTrue(is_not_none_and_first_larger_than_second([2, 1])) + self.assertFalse(is_not_none_and_first_larger_than_second([1, 2])) + + def test_that_produces_correct_validation_message(self): + # Arrange + error_message = "test message." + instruction = "do this." + variables = {"var1": 12, + "var2": "test"} + # Act + val_message = validation_message(error_message, instruction, variables) + # Assert + expected_text = "var1: 12\n" \ + "var2: test\n" \ + "" + instruction + self.assertTrue(val_message.keys()[0] == error_message) + self.assertTrue(val_message[error_message] == expected_text) + + def test_that_can_find_can_reduction_if_it_exists(self): + # Arrange + test_director = TestDirector() + state = test_director.construct() + tagged_workspace_names = {None: "test_ws", + OutputParts.Count: "test_ws_count", + OutputParts.Norm: "test_ws_norm"} + StateFunctionsTest._prepare_workspaces(number_of_workspaces=4, + tagged_workspace_names=tagged_workspace_names, + state=state) + # Act + workspace, workspace_count, workspace_norm = get_reduced_can_workspace_from_ads(state, output_parts=True) + + # Assert + self.assertTrue(workspace is not None) + self.assertTrue(workspace.name() == AnalysisDataService.retrieve("test_ws").name()) + self.assertTrue(workspace_count is not None) + self.assertTrue(workspace_count.name() == AnalysisDataService.retrieve("test_ws_count").name()) + self.assertTrue(workspace_norm is not None) + self.assertTrue(workspace_norm.name() == AnalysisDataService.retrieve("test_ws_norm").name()) + + # Clean up + StateFunctionsTest._remove_workspaces() + + def test_that_returns_none_if_it_does_not_exist(self): + # Arrange + test_director = TestDirector() + state = test_director.construct() + StateFunctionsTest._prepare_workspaces(number_of_workspaces=4, tagged_workspace_names=None, state=state) + + # Act + workspace, workspace_count, workspace_norm = get_reduced_can_workspace_from_ads(state, output_parts=False) + + # Assert + self.assertTrue(workspace is None) + self.assertTrue(workspace_count is None) + self.assertTrue(workspace_norm is None) + + # Clean up + StateFunctionsTest._remove_workspaces() + +if __name__ == '__main__': + unittest.main() diff --git a/scripts/test/SANS/state/state_test.py b/scripts/test/SANS/state/state_test.py new file mode 100644 index 0000000000000000000000000000000000000000..7dea7471e46784143f7511df974db694f65581cb --- /dev/null +++ b/scripts/test/SANS/state/state_test.py @@ -0,0 +1,153 @@ +import unittest +import mantid + +from sans.state.state import (State) +from sans.state.data import (StateData) +from sans.state.move import (StateMove) +from sans.state.reduction_mode import (StateReductionMode) +from sans.state.slice_event import (StateSliceEvent) +from sans.state.mask import (StateMask) +from sans.state.wavelength import (StateWavelength) +from sans.state.save import (StateSave) +from sans.state.normalize_to_monitor import (StateNormalizeToMonitor) +from sans.state.scale import (StateScale) +from sans.state.calculate_transmission import (StateCalculateTransmission) +from sans.state.wavelength_and_pixel_adjustment import (StateWavelengthAndPixelAdjustment) +from sans.state.adjustment import (StateAdjustment) +from sans.state.convert_to_q import (StateConvertToQ) + +from state_test_helper import assert_validate_error, assert_raises_nothing + + +# ---------------------------------------------------------------------------------------------------------------------- +# State +# ---------------------------------------------------------------------------------------------------------------------- +class MockStateData(StateData): + def validate(self): + pass + + +class MockStateMove(StateMove): + def validate(self): + pass + + +class MockStateReduction(StateReductionMode): + def get_merge_strategy(self): + pass + + def get_detector_name_for_reduction_mode(self, reduction_mode): + pass + + def get_all_reduction_modes(self): + pass + + def validate(self): + pass + + +class MockStateSliceEvent(StateSliceEvent): + def validate(self): + pass + + +class MockStateMask(StateMask): + def validate(self): + pass + + +class MockStateWavelength(StateWavelength): + def validate(self): + pass + + +class MockStateSave(StateSave): + def validate(self): + pass + + +class MockStateNormalizeToMonitor(StateNormalizeToMonitor): + def validate(self): + pass + + +class MockStateScale(StateScale): + def validate(self): + pass + + +class MockStateCalculateTransmission(StateCalculateTransmission): + def validate(self): + pass + + +class MockStateWavelengthAndPixelAdjustment(StateWavelengthAndPixelAdjustment): + def validate(self): + pass + + +class MockStateAdjustment(StateAdjustment): + def validate(self): + pass + + +class MockStateConvertToQ(StateConvertToQ): + def validate(self): + pass + + +class StateTest(unittest.TestCase): + @staticmethod + def _get_state(entries): + state = State() + default_entries = {"data": MockStateData(), "move": MockStateMove(), "reduction": MockStateReduction(), + "slice": MockStateSliceEvent(), "mask": MockStateMask(), "wavelength": MockStateWavelength(), + "save": MockStateSave(), "scale": MockStateScale(), "adjustment": MockStateAdjustment(), + "convert_to_q": MockStateConvertToQ()} + + for key, value in default_entries.items(): + if key in entries: + value = entries[key] + if value is not None: # If the value is None, then don't set it + setattr(state, key, value) + return state + + def check_bad_and_good_values(self, bad_state, good_state): + # Bad values + state = self._get_state(bad_state) + assert_validate_error(self, ValueError, state) + + # Good values + state = self._get_state(good_state) + assert_raises_nothing(self, state) + + def test_that_raises_when_move_has_not_been_set(self): + self.check_bad_and_good_values({"move": None}, {"move": MockStateMove()}) + + def test_that_raises_when_reduction_has_not_been_set(self): + self.check_bad_and_good_values({"reduction": None}, {"reduction": MockStateReduction()}) + + def test_that_raises_when_slice_has_not_been_set(self): + self.check_bad_and_good_values({"slice": None}, {"slice": MockStateSliceEvent()}) + + def test_that_raises_when_mask_has_not_been_set(self): + self.check_bad_and_good_values({"mask": None}, {"mask": MockStateMask()}) + + def test_that_raises_when_wavelength_has_not_been_set(self): + self.check_bad_and_good_values({"wavelength": None}, {"wavelength": MockStateWavelength()}) + + def test_that_raises_when_save_has_not_been_set(self): + self.check_bad_and_good_values({"save": None}, {"save": MockStateSave()}) + + def test_that_raises_when_scale_has_not_been_set(self): + self.check_bad_and_good_values({"scale": None}, {"scale": MockStateScale()}) + + def test_that_raises_when_adjustment_has_not_been_set(self): + self.check_bad_and_good_values({"adjustment": None}, {"adjustment": MockStateAdjustment()}) + + def test_that_raises_when_convert_to_q_has_not_been_set(self): + self.check_bad_and_good_values({"convert_to_q": None}, {"convert_to_q": MockStateConvertToQ()}) + + +if __name__ == '__main__': + unittest.main() diff --git a/scripts/test/SANS/state/state_test_helper.py b/scripts/test/SANS/state/state_test_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..b8752e37bf56d0debb1f9b3f2439f5f450a06bb0 --- /dev/null +++ b/scripts/test/SANS/state/state_test_helper.py @@ -0,0 +1,17 @@ +def assert_validate_error(caller, error_type, obj): + try: + obj.validate() + raised_correct = False + except error_type: + raised_correct = True + except: # noqa + raised_correct = False + caller.assertTrue(raised_correct) + + +def assert_raises_nothing(caller, obj): + obj.validate() + try: + obj.validate() + except: # noqa + caller.fail() diff --git a/scripts/test/SANS/state/test_director.py b/scripts/test/SANS/state/test_director.py new file mode 100644 index 0000000000000000000000000000000000000000..8f17df8b1502a491055d62c256ce8de125931dcc --- /dev/null +++ b/scripts/test/SANS/state/test_director.py @@ -0,0 +1,196 @@ +""" A Test director """ +from sans.state.state import get_state_builder +from sans.state.data import get_data_builder +from sans.state.move import get_move_builder +from sans.state.reduction_mode import get_reduction_mode_builder +from sans.state.slice_event import get_slice_event_builder +from sans.state.mask import get_mask_builder +from sans.state.wavelength import get_wavelength_builder +from sans.state.save import get_save_builder +from sans.state.normalize_to_monitor import get_normalize_to_monitor_builder +from sans.state.scale import get_scale_builder +from sans.state.calculate_transmission import get_calculate_transmission_builder +from sans.state.wavelength_and_pixel_adjustment import get_wavelength_and_pixel_adjustment_builder +from sans.state.adjustment import get_adjustment_builder +from sans.state.convert_to_q import get_convert_to_q_builder + +from sans.common.enums import (SANSFacility, ISISReductionMode, ReductionDimensionality, + FitModeForMerge, RebinType, RangeStepType, SaveType, FitType, SampleShape) + + +class TestDirector(object): + """ The purpose of this builder is to create a valid state object for tests""" + def __init__(self): + super(TestDirector, self).__init__() + self.data_state = None + self.move_state = None + self.reduction_state = None + self.slice_state = None + self.mask_state = None + self.wavelength_state = None + self.save_state = None + self.scale_state = None + self.adjustment_state = None + self.convert_to_q_state = None + + def set_states(self, data_state=None, move_state=None, reduction_state=None, slice_state=None, + mask_state=None, wavelength_state=None, save_state=None, scale_state=None, adjustment_state=None, + convert_to_q_state=None): + self.data_state = data_state + self.data_state = data_state + self.move_state = move_state + self.reduction_state = reduction_state + self.slice_state = slice_state + self.mask_state = mask_state + self.wavelength_state = wavelength_state + self.save_state = save_state + self.scale_state = scale_state + self.adjustment_state = adjustment_state + self.convert_to_q_state = convert_to_q_state + + def construct(self): + facility = SANSFacility.ISIS + + # Build the SANSStateData + if self.data_state is None: + data_builder = get_data_builder(facility) + data_builder.set_sample_scatter("SANS2D00022024") + data_builder.set_can_scatter("SANS2D00022024") + self.data_state = data_builder.build() + + # Build the SANSStateMove + if self.move_state is None: + move_builder = get_move_builder(self.data_state) + move_builder.set_HAB_x_translation_correction(21.2) + move_builder.set_LAB_x_translation_correction(12.1) + self.move_state = move_builder.build() + + # Build the SANSStateReduction + if self.reduction_state is None: + reduction_builder = get_reduction_mode_builder(self.data_state) + reduction_builder.set_reduction_mode(ISISReductionMode.Merged) + reduction_builder.set_reduction_dimensionality(ReductionDimensionality.OneDim) + reduction_builder.set_merge_fit_mode(FitModeForMerge.Both) + reduction_builder.set_merge_shift(324.2) + reduction_builder.set_merge_scale(3420.98) + self.reduction_state = reduction_builder.build() + + # Build the SANSStateSliceEvent + if self.slice_state is None: + slice_builder = get_slice_event_builder(self.data_state) + slice_builder.set_start_time([0.1, 1.3]) + slice_builder.set_end_time([0.2, 1.6]) + self.slice_state = slice_builder.build() + + # Build the SANSStateMask + if self.mask_state is None: + mask_builder = get_mask_builder(self.data_state) + mask_builder.set_radius_min(10.0) + mask_builder.set_radius_max(20.0) + self.mask_state = mask_builder.build() + + # Build the SANSStateWavelength + if self.wavelength_state is None: + wavelength_builder = get_wavelength_builder(self.data_state) + wavelength_builder.set_wavelength_low(1.0) + wavelength_builder.set_wavelength_high(10.0) + wavelength_builder.set_wavelength_step(2.0) + wavelength_builder.set_wavelength_step_type(RangeStepType.Lin) + wavelength_builder.set_rebin_type(RebinType.Rebin) + self.wavelength_state = wavelength_builder.build() + + # Build the SANSStateSave + if self.save_state is None: + save_builder = get_save_builder(self.data_state) + save_builder.set_file_name("test_file_name") + save_builder.set_file_format([SaveType.Nexus]) + self.save_state = save_builder.build() + + # Build the SANSStateScale + if self.scale_state is None: + scale_builder = get_scale_builder(self.data_state) + scale_builder.set_shape(SampleShape.Cuboid) + scale_builder.set_width(1.0) + scale_builder.set_height(2.0) + scale_builder.set_thickness(3.0) + scale_builder.set_scale(4.0) + self.scale_state = scale_builder.build() + + # Build the SANSAdjustmentState + if self.adjustment_state is None: + # NormalizeToMonitor + normalize_to_monitor_builder = get_normalize_to_monitor_builder(self.data_state) + normalize_to_monitor_builder.set_wavelength_low(1.0) + normalize_to_monitor_builder.set_wavelength_high(10.0) + normalize_to_monitor_builder.set_wavelength_step(2.0) + normalize_to_monitor_builder.set_wavelength_step_type(RangeStepType.Lin) + normalize_to_monitor_builder.set_rebin_type(RebinType.Rebin) + normalize_to_monitor_builder.set_background_TOF_general_start(1000.) + normalize_to_monitor_builder.set_background_TOF_general_stop(2000.) + normalize_to_monitor_builder.set_incident_monitor(1) + normalize_to_monitor = normalize_to_monitor_builder.build() + + # CalculateTransmission + calculate_transmission_builder = get_calculate_transmission_builder(self.data_state) + calculate_transmission_builder.set_transmission_monitor(3) + calculate_transmission_builder.set_incident_monitor(2) + calculate_transmission_builder.set_wavelength_low(1.0) + calculate_transmission_builder.set_wavelength_high(10.0) + calculate_transmission_builder.set_wavelength_step(2.0) + calculate_transmission_builder.set_wavelength_step_type(RangeStepType.Lin) + calculate_transmission_builder.set_rebin_type(RebinType.Rebin) + calculate_transmission_builder.set_background_TOF_general_start(1000.) + calculate_transmission_builder.set_background_TOF_general_stop(2000.) + + calculate_transmission_builder.set_Sample_fit_type(FitType.Linear) + calculate_transmission_builder.set_Sample_polynomial_order(0) + calculate_transmission_builder.set_Sample_wavelength_low(1.0) + calculate_transmission_builder.set_Sample_wavelength_high(10.0) + calculate_transmission_builder.set_Can_fit_type(FitType.Polynomial) + calculate_transmission_builder.set_Can_polynomial_order(3) + calculate_transmission_builder.set_Can_wavelength_low(10.0) + calculate_transmission_builder.set_Can_wavelength_high(20.0) + calculate_transmission = calculate_transmission_builder.build() + + # Wavelength and pixel adjustment + wavelength_and_pixel_builder = get_wavelength_and_pixel_adjustment_builder(self.data_state) + wavelength_and_pixel_builder.set_wavelength_low(1.0) + wavelength_and_pixel_builder.set_wavelength_high(10.0) + wavelength_and_pixel_builder.set_wavelength_step(2.0) + wavelength_and_pixel_builder.set_wavelength_step_type(RangeStepType.Lin) + wavelength_and_pixel = wavelength_and_pixel_builder.build() + + # Adjustment + adjustment_builder = get_adjustment_builder(self.data_state) + adjustment_builder.set_normalize_to_monitor(normalize_to_monitor) + adjustment_builder.set_calculate_transmission(calculate_transmission) + adjustment_builder.set_wavelength_and_pixel_adjustment(wavelength_and_pixel) + self.adjustment_state = adjustment_builder.build() + + # SANSStateConvertToQ + if self.convert_to_q_state is None: + convert_to_q_builder = get_convert_to_q_builder(self.data_state) + convert_to_q_builder.set_reduction_dimensionality(ReductionDimensionality.OneDim) + convert_to_q_builder.set_use_gravity(False) + convert_to_q_builder.set_radius_cutoff(0.002) + convert_to_q_builder.set_wavelength_cutoff(12.) + convert_to_q_builder.set_q_min(0.1) + convert_to_q_builder.set_q_max(0.8) + convert_to_q_builder.set_q_step(0.01) + convert_to_q_builder.set_q_step_type(RangeStepType.Lin) + convert_to_q_builder.set_use_q_resolution(False) + self.convert_to_q_state = convert_to_q_builder.build() + + # Set the sub states on the SANSState + state_builder = get_state_builder(self.data_state) + state_builder.set_data(self.data_state) + state_builder.set_move(self.move_state) + state_builder.set_reduction(self.reduction_state) + state_builder.set_slice(self.slice_state) + state_builder.set_mask(self.mask_state) + state_builder.set_wavelength(self.wavelength_state) + state_builder.set_save(self.save_state) + state_builder.set_scale(self.scale_state) + state_builder.set_adjustment(self.adjustment_state) + state_builder.set_convert_to_q(self.convert_to_q_state) + return state_builder.build() diff --git a/scripts/test/SANS/state/wavelength_and_pixel_adjustment_test.py b/scripts/test/SANS/state/wavelength_and_pixel_adjustment_test.py new file mode 100644 index 0000000000000000000000000000000000000000..15b54ab808adb5c0714930cd1c94d7a9364cf10a --- /dev/null +++ b/scripts/test/SANS/state/wavelength_and_pixel_adjustment_test.py @@ -0,0 +1,73 @@ +import unittest +import mantid + +from sans.state.wavelength_and_pixel_adjustment import (StateWavelengthAndPixelAdjustment, + get_wavelength_and_pixel_adjustment_builder) +from sans.state.data import get_data_builder +from sans.common.enums import (RebinType, RangeStepType, DetectorType, SANSFacility, SANSInstrument) +from state_test_helper import assert_validate_error, assert_raises_nothing + + +# ---------------------------------------------------------------------------------------------------------------------- +# State +# ---------------------------------------------------------------------------------------------------------------------- +class StateWavelengthAndPixelAdjustmentTest(unittest.TestCase): + def test_that_raises_when_wavelength_entry_is_missing(self): + # Arrange + state = StateWavelengthAndPixelAdjustment() + assert_validate_error(self, ValueError, state) + state.wavelength_low = 1. + assert_validate_error(self, ValueError, state) + state.wavelength_high = 2. + assert_validate_error(self, ValueError, state) + state.wavelength_step = 2. + assert_validate_error(self, ValueError, state) + state.wavelength_step_type = RangeStepType.Lin + assert_raises_nothing(self, state) + + def test_that_raises_when_lower_wavelength_is_smaller_than_high_wavelength(self): + state = StateWavelengthAndPixelAdjustment() + state.wavelength_low = 2. + state.wavelength_high = 1. + state.wavelength_step = 2. + state.wavelength_step_type = RangeStepType.Lin + assert_validate_error(self, ValueError, state) + + +# ---------------------------------------------------------------------------------------------------------------------- +# Builder +# ---------------------------------------------------------------------------------------------------------------------- +class StateWavelengthAndPixelAdjustmentBuilderTest(unittest.TestCase): + def test_that_wavelength_and_pixel_adjustment_state_can_be_built(self): + # Arrange + facility = SANSFacility.ISIS + data_builder = get_data_builder(facility) + data_builder.set_sample_scatter("LOQ74044") + data_info = data_builder.build() + + # Act + builder = get_wavelength_and_pixel_adjustment_builder(data_info) + self.assertTrue(builder) + + builder.set_HAB_pixel_adjustment_file("test") + builder.set_HAB_wavelength_adjustment_file("test2") + builder.set_wavelength_low(1.5) + builder.set_wavelength_high(2.7) + builder.set_wavelength_step(0.5) + builder.set_wavelength_step_type(RangeStepType.Lin) + + state = builder.build() + + # Assert + self.assertTrue(state.adjustment_files[DetectorType.to_string( + DetectorType.HAB)].pixel_adjustment_file == "test") + self.assertTrue(state.adjustment_files[DetectorType.to_string( + DetectorType.HAB)].wavelength_adjustment_file == "test2") + self.assertTrue(state.wavelength_low == 1.5) + self.assertTrue(state.wavelength_high == 2.7) + self.assertTrue(state.wavelength_step == 0.5) + self.assertTrue(state.wavelength_step_type is RangeStepType.Lin) + + +if __name__ == '__main__': + unittest.main() diff --git a/scripts/test/SANS/state/wavelength_test.py b/scripts/test/SANS/state/wavelength_test.py new file mode 100644 index 0000000000000000000000000000000000000000..11dd84c6e3f49f6a4b241deeb8314db5253a8d6c --- /dev/null +++ b/scripts/test/SANS/state/wavelength_test.py @@ -0,0 +1,70 @@ +import unittest +import mantid + +from sans.state.wavelength import (StateWavelength, get_wavelength_builder) +from sans.state.data import get_data_builder +from sans.common.enums import (SANSFacility, SANSInstrument, RebinType, RangeStepType) +from state_test_helper import assert_validate_error, assert_raises_nothing + + +# ---------------------------------------------------------------------------------------------------------------------- +# State +# ---------------------------------------------------------------------------------------------------------------------- +class StateWavelengthTest(unittest.TestCase): + + def test_that_is_sans_state_data_object(self): + state = StateWavelength() + self.assertTrue(isinstance(state, StateWavelength)) + + def test_that_raises_when_wavelength_entry_is_missing(self): + # Arrange + state = StateWavelength() + assert_validate_error(self, ValueError, state) + state.wavelength_low = 1. + assert_validate_error(self, ValueError, state) + state.wavelength_high = 2. + assert_validate_error(self, ValueError, state) + state.wavelength_step = 2. + assert_raises_nothing(self, state) + + def test_that_raises_when_lower_wavelength_is_smaller_than_high_wavelength(self): + state = StateWavelength() + state.wavelength_low = 2. + state.wavelength_high = 1. + state.wavelength_step = 2. + assert_validate_error(self, ValueError, state) + + +# ---------------------------------------------------------------------------------------------------------------------- +# Builder +# ---------------------------------------------------------------------------------------------------------------------- + +class StateSliceEventBuilderTest(unittest.TestCase): + def test_that_slice_event_state_can_be_built(self): + # Arrange + facility = SANSFacility.ISIS + data_builder = get_data_builder(facility) + data_builder.set_sample_scatter("LOQ74044") + data_info = data_builder.build() + + # Act + builder = get_wavelength_builder(data_info) + self.assertTrue(builder) + + builder.set_wavelength_low(10.0) + builder.set_wavelength_high(20.0) + builder.set_wavelength_step(3.0) + builder.set_wavelength_step_type(RangeStepType.Lin) + builder.set_rebin_type(RebinType.Rebin) + + # Assert + state = builder.build() + + self.assertTrue(state.wavelength_low == 10.0) + self.assertTrue(state.wavelength_high == 20.0) + self.assertTrue(state.wavelength_step_type is RangeStepType.Lin) + self.assertTrue(state.rebin_type is RebinType.Rebin) + + +if __name__ == '__main__': + unittest.main()