diff --git a/CMakeLists.txt b/CMakeLists.txt index 0e6c192bb75157e4319215613ff70e2babe5d3d6..4f45d8c16ac91137b442bac6dc12c7c97762791b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -385,7 +385,7 @@ if(ENABLE_CPACK) "nexus >= 4.3.1,gsl,glibc,muParser,numpy,h5py >= 2.3.1,PyCifRW >= 4.2.1,tbb,librdkafka," "${CPACK_RPM_PACKAGE_REQUIRES},OCE-draw,OCE-foundation,OCE-modeling,OCE-ocaf,OCE-visualization," "poco-crypto,poco-data,poco-mysql,poco-sqlite,poco-odbc,poco-util,poco-xml,poco-zip,poco-net,poco-netssl,poco-foundation," - "sip >= 4.18," + "sip >= 4.18,python-enum34," "python-six,python-ipython >= 1.1.0,python-ipython-notebook,PyYAML," "python-requests," "scipy," diff --git a/Framework/API/inc/MantidAPI/IEventList.h b/Framework/API/inc/MantidAPI/IEventList.h index f58454a32197ea77cb6a9da449fc683024a2db45..084b067e443e429944663a7707adfe0c21fa88f8 100644 --- a/Framework/API/inc/MantidAPI/IEventList.h +++ b/Framework/API/inc/MantidAPI/IEventList.h @@ -75,6 +75,8 @@ public: virtual void addTof(const double offset) = 0; /// Add a value to the pulse time values virtual void addPulsetime(const double seconds) = 0; + /// Add a separate value to each of the pulse time values + virtual void addPulsetimes(const std::vector<double> &seconds) = 0; /// Mask a given TOF range virtual void maskTof(const double tofMin, const double tofMax) = 0; /// Mask the events by the condition vector diff --git a/Framework/API/src/Algorithm.cpp b/Framework/API/src/Algorithm.cpp index 5c8b770354daee24705989c8b74c1b9609222e9e..2908f919fbc538417de462b8c1159ef077d5f509 100644 --- a/Framework/API/src/Algorithm.cpp +++ b/Framework/API/src/Algorithm.cpp @@ -1429,7 +1429,21 @@ bool Algorithm::processGroups() { // Either: this is the single group // OR: all inputs are groups // ... so get then entry^th workspace in this group - ws = thisGroup[entry]; + if (entry < thisGroup.size()) { + ws = thisGroup[entry]; + } else { + // This can happen when one has more than one input group + // workspaces, having different sizes. For example one workspace + // group is the corrections which has N parts (e.g. weights for + // polarized measurement) while the other one is the actual input + // workspace group, where each item needs to be corrected together + // with all N inputs of the second group. In this case processGroup + // needs to be overridden, which is currently not possible in + // python. + throw std::runtime_error( + "Unable to process over groups; consider passing workspaces " + "one-by-one or override processGroup method of the algorithm."); + } } // Append the names together if (!outputBaseName.empty()) diff --git a/Framework/Algorithms/src/ConvertUnits.cpp b/Framework/Algorithms/src/ConvertUnits.cpp index 019203b910bbd3edd7e7637bf74c1b2eb7ca3691..d969d41817db7ea7a8aa46c6b31df4234899d264 100644 --- a/Framework/Algorithms/src/ConvertUnits.cpp +++ b/Framework/Algorithms/src/ConvertUnits.cpp @@ -49,7 +49,7 @@ void ConvertUnits::init() { // the Target property declareProperty("Target", "", boost::make_shared<StringListValidator>( - UnitFactory::Instance().getKeys()), + UnitFactory::Instance().getConvertibleUnits()), "The name of the units to convert to (must be one of those " "registered in\n" "the Unit Factory)"); diff --git a/Framework/Algorithms/src/ConvertUnitsUsingDetectorTable.cpp b/Framework/Algorithms/src/ConvertUnitsUsingDetectorTable.cpp index 7aae60984dce7c85bc84b3070e6409b672c4982c..0861c9b73c47da9f2964ae34321506b742fc5bea 100644 --- a/Framework/Algorithms/src/ConvertUnitsUsingDetectorTable.cpp +++ b/Framework/Algorithms/src/ConvertUnitsUsingDetectorTable.cpp @@ -70,7 +70,7 @@ void ConvertUnitsUsingDetectorTable::init() { "Name of the output workspace, can be the same as the input"); declareProperty("Target", "", boost::make_shared<StringListValidator>( - UnitFactory::Instance().getKeys()), + UnitFactory::Instance().getConvertibleUnits()), "The name of the units to convert to (must be one of those " "registered in\n" "the Unit Factory)"); diff --git a/Framework/Algorithms/src/CreateDetectorTable.cpp b/Framework/Algorithms/src/CreateDetectorTable.cpp index d0321409ab96455f3d456d325812e833dd332a88..c254aab6746ca088582031b502ce5e89af10d122 100644 --- a/Framework/Algorithms/src/CreateDetectorTable.cpp +++ b/Framework/Algorithms/src/CreateDetectorTable.cpp @@ -307,7 +307,7 @@ void populateTable(ITableWorkspace_sptr &t, const MatrixWorkspace_sptr &ws, colValues << phi // rtp << isMonitorDisplay; // monitor - } catch (std::exception) { + } catch (const std::exception &) { // spectrumNo=-1, detID=0 colValues << -1 << "0"; // Y/E diff --git a/Framework/Algorithms/src/GroupWorkspaces.cpp b/Framework/Algorithms/src/GroupWorkspaces.cpp index e05f8a09f6b882a50a917cf5fd7a280b884b8bcc..65e739af082f552b16b9424e1a4b76c06c1a4dfe 100644 --- a/Framework/Algorithms/src/GroupWorkspaces.cpp +++ b/Framework/Algorithms/src/GroupWorkspaces.cpp @@ -82,7 +82,7 @@ std::map<std::string, std::string> GroupWorkspaces::validateInputs() { for (auto it = globExpression.begin(); it < globExpression.end(); ++it) { if (*it == '\\') { - it = globExpression.erase(it, it + 2); + it = globExpression.erase(it, it + 1); } } diff --git a/Framework/Algorithms/src/HRPDSlabCanAbsorption.cpp b/Framework/Algorithms/src/HRPDSlabCanAbsorption.cpp index c8f5314a5a9e8e74dca104968a24555f2349ed3c..40d13f0701fde77e45977d375ed401ce71d52e29 100644 --- a/Framework/Algorithms/src/HRPDSlabCanAbsorption.cpp +++ b/Framework/Algorithms/src/HRPDSlabCanAbsorption.cpp @@ -42,14 +42,9 @@ void HRPDSlabCanAbsorption::init() { declareProperty("SampleNumberDensity", EMPTY_DBL(), mustBePositive, "The number density of the sample in number of atoms per " "cubic angstrom if not set with SetSampleMaterial"); - - std::vector<std::string> thicknesses(4); - thicknesses[0] = "0.2"; - thicknesses[1] = "0.5"; - thicknesses[2] = "1.0"; - thicknesses[3] = "1.5"; - declareProperty("Thickness", "0.2", - boost::make_shared<StringListValidator>(thicknesses)); + declareProperty("Thickness", 0.2, mustBePositive, + "The thickness of the sample in cm. Common values are 0.2, " + "0.5, 1.0, 1.5"); auto positiveInt = boost::make_shared<BoundedValidator<int64_t>>(); positiveInt->setLower(1); @@ -185,15 +180,14 @@ API::MatrixWorkspace_sptr HRPDSlabCanAbsorption::runFlatPlateAbsorption() { childAlg->setProperty<int64_t>("NumberOfWavelengthPoints", getProperty("NumberOfWavelengthPoints")); childAlg->setProperty<std::string>("ExpMethod", getProperty("ExpMethod")); + childAlg->setProperty<double>("ElementSize", getProperty("ElementSize")); // The height and width of the sample holder are standard for HRPD const double HRPDCanHeight = 2.3; const double HRPDCanWidth = 1.8; childAlg->setProperty("SampleHeight", HRPDCanHeight); childAlg->setProperty("SampleWidth", HRPDCanWidth); - // Valid values are 0.2,0.5,1.0 & 1.5 - would be nice to have a numeric list - // validator - const std::string thickness = getPropertyValue("Thickness"); - childAlg->setPropertyValue("SampleThickness", thickness); + const double thickness = getProperty("Thickness"); + childAlg->setProperty("SampleThickness", thickness); childAlg->executeAsChildAlg(); return childAlg->getProperty("OutputWorkspace"); } diff --git a/Framework/Algorithms/src/ReflectometryReductionOneAuto2.cpp b/Framework/Algorithms/src/ReflectometryReductionOneAuto2.cpp index 293713d47dd61df33288d66851117119fd8ea682..49bd47c340df7d34200a03b800245c9908d3040a 100644 --- a/Framework/Algorithms/src/ReflectometryReductionOneAuto2.cpp +++ b/Framework/Algorithms/src/ReflectometryReductionOneAuto2.cpp @@ -1016,8 +1016,9 @@ auto ReflectometryReductionOneAuto2::getOutputNamesForGroups( std::string informativeName = "TOF" + runNumber + "_"; WorkspaceNames outputNames; - if (equal(informativeName.begin(), informativeName.end(), - inputName.begin())) { + if ((informativeName.length() < inputName.length()) && + (equal(informativeName.begin(), informativeName.end(), + inputName.begin()))) { auto informativeTest = inputName.substr(informativeName.length()); outputNames.iVsQ = output.iVsQ + "_" + informativeTest; outputNames.iVsQBinned = output.iVsQBinned + "_" + informativeTest; diff --git a/Framework/Algorithms/test/CalculatePlaczekSelfScatteringTest.h b/Framework/Algorithms/test/CalculatePlaczekSelfScatteringTest.h index 2f6e471fe60bea813cdceafe8fbdd27424a2b6d8..6107115e1542584046fb884b751f9ce7c97a024f 100644 --- a/Framework/Algorithms/test/CalculatePlaczekSelfScatteringTest.h +++ b/Framework/Algorithms/test/CalculatePlaczekSelfScatteringTest.h @@ -156,7 +156,7 @@ public: alg->setProperty("IncidentSpecta", IncidentSpecta); alg->setProperty("InputWorkspace", InputWorkspace); alg->setProperty("OutputWorkspace", "correction_ws"); - TS_ASSERT_THROWS(alg->execute(), std::runtime_error) + TS_ASSERT_THROWS(alg->execute(), const std::runtime_error &) } void testCalculatePlaczekSelfScatteringDoesNotRunWithNoSample() { @@ -170,7 +170,7 @@ public: alg->setProperty("IncidentSpecta", IncidentSpecta); alg->setProperty("InputWorkspace", InputWorkspace); alg->setProperty("OutputWorkspace", "correction_ws"); - TS_ASSERT_THROWS(alg->execute(), std::runtime_error) + TS_ASSERT_THROWS(alg->execute(), const std::runtime_error &) } private: diff --git a/Framework/Algorithms/test/CreateDetectorTableTest.h b/Framework/Algorithms/test/CreateDetectorTableTest.h index c34302203fc6f5dc8fe5e9752102d250c304d0ed..89e75a01a67bc668d3c5ac54454e97d3f0388eb0 100644 --- a/Framework/Algorithms/test/CreateDetectorTableTest.h +++ b/Framework/Algorithms/test/CreateDetectorTableTest.h @@ -164,7 +164,7 @@ public: TS_ASSERT(alg.isInitialized()) TS_ASSERT_THROWS_NOTHING(alg.setProperty("InputWorkspace", inputWS)) - TS_ASSERT_THROWS(alg.executeAsChildAlg(), std::runtime_error); + TS_ASSERT_THROWS(alg.executeAsChildAlg(), const std::runtime_error &); } }; diff --git a/Framework/Algorithms/test/ReflectometryReductionOne2Test.h b/Framework/Algorithms/test/ReflectometryReductionOne2Test.h index 486600ca12d8a333cd8324baede18010ea3f6c37..2ab7a3aae98611417246ef4de6ca4cdae34fa837 100644 --- a/Framework/Algorithms/test/ReflectometryReductionOne2Test.h +++ b/Framework/Algorithms/test/ReflectometryReductionOne2Test.h @@ -465,7 +465,7 @@ public: alg.setProperty("ThetaIn", 25.0); MatrixWorkspace_sptr outLam = runAlgorithmLam(alg, 12); - TS_ASSERT_DELTA(outLam->x(0)[0], 0.934991, 1e-6); + TS_ASSERT_DELTA(outLam->x(0)[0], 0.934992, 1e-6); TS_ASSERT_DELTA(outLam->x(0)[3], 5.173599, 1e-6); TS_ASSERT_DELTA(outLam->x(0)[7], 10.825076, 1e-6); TS_ASSERT_DELTA(outLam->y(0)[0], 2.768185, 1e-6); @@ -490,7 +490,7 @@ public: TS_ASSERT_DELTA(outLam->x(0)[0], 0.825488, 1e-6); TS_ASSERT_DELTA(outLam->x(0)[3], 5.064095, 1e-6); TS_ASSERT_DELTA(outLam->x(0)[7], 10.715573, 1e-6); - TS_ASSERT_DELTA(outLam->y(0)[0], 3.141858, 1e-6); + TS_ASSERT_DELTA(outLam->y(0)[0], 3.141859, 1e-6); TS_ASSERT_DELTA(outLam->y(0)[3], 3.141885, 1e-6); TS_ASSERT_DELTA(outLam->y(0)[7], 3.141920, 1e-6); } @@ -522,11 +522,11 @@ public: alg.setProperty("ThetaIn", 25.0); MatrixWorkspace_sptr outLam = runAlgorithmLam(alg, 13); - TS_ASSERT_DELTA(outLam->x(0)[0], -0.748671, 1e-6); + TS_ASSERT_DELTA(outLam->x(0)[0], -0.748672, 1e-6); TS_ASSERT_DELTA(outLam->x(0)[5], 6.315674, 1e-6); TS_ASSERT_DELTA(outLam->x(0)[9], 11.967151, 1e-6); TS_ASSERT_DELTA(outLam->y(0)[0], 5.040302, 1e-6); - TS_ASSERT_DELTA(outLam->y(0)[5], 2.193649, 1e-6); + TS_ASSERT_DELTA(outLam->y(0)[5], 2.193650, 1e-6); TS_ASSERT_DELTA(outLam->y(0)[9], 2.255101, 1e-6); } @@ -541,7 +541,7 @@ public: alg.setProperty("ThetaIn", 25.0); MatrixWorkspace_sptr outLam = runAlgorithmLam(alg, 12); - TS_ASSERT_DELTA(outLam->x(0)[0], 0.934991, 1e-6); + TS_ASSERT_DELTA(outLam->x(0)[0], 0.934992, 1e-6); TS_ASSERT_DELTA(outLam->x(0)[3], 5.173599, 1e-6); TS_ASSERT_DELTA(outLam->x(0)[7], 10.825076, 1e-6); TS_ASSERT_DELTA(outLam->y(0)[0], 0.631775, 1e-6); @@ -566,8 +566,8 @@ public: TS_ASSERT_DELTA(outLam->x(0)[3], 5.159104, 1e-6); TS_ASSERT_DELTA(outLam->x(0)[7], 10.810581, 1e-6); TS_ASSERT_DELTA(outLam->y(0)[0], 16.351599, 1e-6); - TS_ASSERT_DELTA(outLam->y(0)[3], 23.963539, 1e-6); - TS_ASSERT_DELTA(outLam->y(0)[7], 39.756738, 1e-6); + TS_ASSERT_DELTA(outLam->y(0)[3], 23.963534, 1e-6); + TS_ASSERT_DELTA(outLam->y(0)[7], 39.756736, 1e-6); } void test_sum_in_q_IvsQ() { @@ -614,9 +614,9 @@ public: TS_ASSERT_DELTA(outQ->x(0)[3], 0.310524, 1e-6); TS_ASSERT_DELTA(outQ->x(0)[7], 0.363599, 1e-6); // Y counts - TS_ASSERT_DELTA(outQ->y(0)[0], 2.900305, 1e-6); - TS_ASSERT_DELTA(outQ->y(0)[3], 2.886947, 1e-6); - TS_ASSERT_DELTA(outQ->y(0)[7], 2.607359, 1e-6); + TS_ASSERT_DELTA(outQ->y(0)[0], 2.900303, 1e-6); + TS_ASSERT_DELTA(outQ->y(0)[3], 2.886945, 1e-6); + TS_ASSERT_DELTA(outQ->y(0)[7], 2.607357, 1e-6); } void test_sum_in_q_exclude_partial_bins() { diff --git a/Framework/Crystal/CMakeLists.txt b/Framework/Crystal/CMakeLists.txt index 12495d3b89a8779002232da865fb59878adbb39e..992738109e605c98cbab7c80d2b69f18a6006de1 100644 --- a/Framework/Crystal/CMakeLists.txt +++ b/Framework/Crystal/CMakeLists.txt @@ -39,6 +39,7 @@ set(SRC_FILES src/NormaliseVanadium.cpp src/OptimizeCrystalPlacement.cpp src/OptimizeLatticeForCellType.cpp + src/PeakAlgorithmHelpers.cpp src/PeakBackground.cpp src/PeakClusterProjection.cpp src/PeakHKLErrors.cpp @@ -115,6 +116,7 @@ set(INC_FILES inc/MantidCrystal/NormaliseVanadium.h inc/MantidCrystal/OptimizeCrystalPlacement.h inc/MantidCrystal/OptimizeLatticeForCellType.h + inc/MantidCrystal/PeakAlgorithmHelpers.h inc/MantidCrystal/PeakBackground.h inc/MantidCrystal/PeakClusterProjection.h inc/MantidCrystal/PeakHKLErrors.h diff --git a/Framework/Crystal/inc/MantidCrystal/PeakAlgorithmHelpers.h b/Framework/Crystal/inc/MantidCrystal/PeakAlgorithmHelpers.h new file mode 100644 index 0000000000000000000000000000000000000000..f6460157dfcbcacda8fe5147477821f4e9830103 --- /dev/null +++ b/Framework/Crystal/inc/MantidCrystal/PeakAlgorithmHelpers.h @@ -0,0 +1,54 @@ +// Mantid Repository : https://github.com/mantidproject/mantid +// +// Copyright © 2019 ISIS Rutherford Appleton Laboratory UKRI, +// NScD Oak Ridge National Laboratory, European Spallation Source +// & Institut Laue - Langevin +// SPDX - License - Identifier: GPL - 3.0 + +#ifndef MANTID_CRYSTAL_PEAKALGORITHMHELPERS_H +#define MANTID_CRYSTAL_PEAKALGORITHMHELPERS_H +#include "MantidAPI/IAlgorithm.h" + +namespace Mantid::Kernel { +class V3D; +} + +namespace Mantid::Crystal { + +/// Tie together a modulated peak number with its offset +using MNPOffset = std::tuple<double, double, double, Kernel::V3D>; + +/// Tie together the names of the properties for the modulation vectors +struct ModulationProperties { + inline const static std::string ModVector1{"ModVector1"}; + inline const static std::string ModVector2{"ModVector2"}; + inline const static std::string ModVector3{"ModVector3"}; + inline const static std::string MaxOrder{"MaxOrder"}; + inline const static std::string CrossTerms{"CrossTerms"}; + + static void appendTo(API::IAlgorithm *alg); + static ModulationProperties create(const API::IAlgorithm &alg); + + std::vector<MNPOffset> offsets; + int maxOrder; + bool crossTerms; + bool saveOnLattice; +}; + +/// Create a list of valid modulation vectors from the input +std::vector<Kernel::V3D> +validModulationVectors(const std::vector<double> &modVector1, + const std::vector<double> &modVector2, + const std::vector<double> &modVector3); + +/// Calculate a list of HKL offsets from the given modulation vectors. +std::vector<MNPOffset> +generateOffsetVectors(const std::vector<Kernel::V3D> &modVectors, + const int maxOrder, const bool crossTerms); +/// Calculate a list of HKL offsets from the given lists of offsets +std::vector<MNPOffset> +generateOffsetVectors(const std::vector<double> &hOffsets, + const std::vector<double> &kOffsets, + const std::vector<double> &lOffsets); +} // namespace Mantid::Crystal + +#endif // MANTID_CRYSTAL_PEAKALGORITHMHELPERS_H diff --git a/Framework/Crystal/inc/MantidCrystal/PredictFractionalPeaks.h b/Framework/Crystal/inc/MantidCrystal/PredictFractionalPeaks.h index d4e80ab20181540927af23e3a475511c96047613..7270bcb00163b3d45e67958d8cd1a613ea1ceef4 100644 --- a/Framework/Crystal/inc/MantidCrystal/PredictFractionalPeaks.h +++ b/Framework/Crystal/inc/MantidCrystal/PredictFractionalPeaks.h @@ -6,46 +6,41 @@ // SPDX - License - Identifier: GPL - 3.0 + #ifndef MANTID_CRYSTAL_PREDICTFRACTIONALPEAKS_H_ #define MANTID_CRYSTAL_PREDICTFRACTIONALPEAKS_H_ - #include "MantidAPI/Algorithm.h" -#include "MantidKernel/System.h" +#include "MantidCrystal/PeakAlgorithmHelpers.h" +#include <tuple> + +namespace Mantid::Kernel { +class V3D; +} -namespace Mantid { -namespace Crystal { +namespace Mantid::Crystal { /** - * + * Using a set of offset vectors, either provided as separate lists or as a set + * of vectors, predict whether */ class DLLExport PredictFractionalPeaks : public API::Algorithm { public: - /// Algorithm's name for identification const std::string name() const override { return "PredictFractionalPeaks"; } - /// Summary of algorithms purpose const std::string summary() const override { return "The offsets can be from hkl values in a range of hkl values or " "from peaks in the input PeaksWorkspace"; } - - /// Algorithm's version for identification - int version() const override { return 1; }; + int version() const override { return 1; } const std::vector<std::string> seeAlso() const override { return {"PredictPeaks"}; } - - /// Algorithm's category for identification const std::string category() const override { return "Crystal\\Peaks"; } - /// Return any errors in cross-property validation std::map<std::string, std::string> validateInputs() override; private: - /// Initialise the properties void init() override; - - /// Run the algorithm void exec() override; + + ModulationProperties getModulationInfo(); }; -} // namespace Crystal -} // namespace Mantid +} // namespace Mantid::Crystal #endif /* MANTID_CRYSTAL_PREDICTFRACTIONALPEAKS */ diff --git a/Framework/Crystal/src/IndexPeaks.cpp b/Framework/Crystal/src/IndexPeaks.cpp index 438cdf1ad0211d581e18ab3919d23379733a3e79..716cd1a5053703cb9a64dae92f76170063016262 100644 --- a/Framework/Crystal/src/IndexPeaks.cpp +++ b/Framework/Crystal/src/IndexPeaks.cpp @@ -6,6 +6,7 @@ // SPDX - License - Identifier: GPL - 3.0 + #include "MantidCrystal/IndexPeaks.h" #include "MantidAPI/Sample.h" +#include "MantidCrystal/PeakAlgorithmHelpers.h" #include "MantidDataObjects/PeaksWorkspace.h" #include "MantidGeometry/Crystal/IndexingUtils.h" #include "MantidGeometry/Crystal/OrientedLattice.h" @@ -40,11 +41,6 @@ const std::string TOLERANCE{"Tolerance"}; const std::string SATE_TOLERANCE{"ToleranceForSatellite"}; const std::string ROUNDHKLS{"RoundHKLs"}; const std::string COMMONUB{"CommonUBForAll"}; -const std::string MAXORDER{"MaxOrder"}; -const std::string MODVECTOR1{"ModVector1"}; -const std::string MODVECTOR2{"ModVector2"}; -const std::string MODVECTOR3{"ModVector3"}; -const std::string CROSSTERMS{"CrossTerms"}; const std::string SAVEMODINFO{"SaveModulationInfo"}; const std::string AVERAGE_ERR{"AverageError"}; const std::string NUM_INDEXED{"NumIndexed"}; @@ -63,13 +59,7 @@ struct SatelliteIndexingArgs { struct IndexPeaksArgs { static IndexPeaksArgs parse(const API::Algorithm &alg) { const PeaksWorkspace_sptr peaksWS = alg.getProperty(PEAKSWORKSPACE); - const int maxOrderFromAlg = alg.getProperty(Prop::MAXORDER); - - auto addIfNonZero = [](const auto &modVec, std::vector<V3D> &modVectors) { - if (std::fabs(modVec[0]) > 0 || std::fabs(modVec[1]) > 0 || - std::fabs(modVec[2]) > 0) - modVectors.emplace_back(V3D(modVec[0], modVec[1], modVec[2])); - }; + const int maxOrderFromAlg = alg.getProperty(ModulationProperties::MaxOrder); int maxOrderToUse{0}; std::vector<V3D> modVectorsToUse; @@ -78,21 +68,18 @@ struct IndexPeaksArgs { if (maxOrderFromAlg > 0) { // Use inputs from algorithm maxOrderToUse = maxOrderFromAlg; - crossTermToUse = alg.getProperty(Prop::CROSSTERMS); - std::vector<double> modVector = alg.getProperty(Prop::MODVECTOR1); - addIfNonZero(std::move(modVector), modVectorsToUse); - modVector = alg.getProperty(Prop::MODVECTOR2); - addIfNonZero(std::move(modVector), modVectorsToUse); - modVector = alg.getProperty(Prop::MODVECTOR3); - addIfNonZero(std::move(modVector), modVectorsToUse); + crossTermToUse = alg.getProperty(ModulationProperties::CrossTerms); + modVectorsToUse = validModulationVectors( + alg.getProperty(ModulationProperties::ModVector1), + alg.getProperty(ModulationProperties::ModVector2), + alg.getProperty(ModulationProperties::ModVector3)); } else { // Use lattice definitions if they exist const auto &lattice = peaksWS->sample().getOrientedLattice(); maxOrderToUse = lattice.getMaxOrder(); if (maxOrderToUse > 0) { - for (auto i = 0; i < 3; ++i) { - addIfNonZero(lattice.getModVec(i), modVectorsToUse); - } + modVectorsToUse = validModulationVectors( + lattice.getModVec(0), lattice.getModVec(1), lattice.getModVec(2)); } crossTermToUse = lattice.getCrossTerm(); } @@ -207,81 +194,6 @@ DblMatrix optimizeUBMatrix(const DblMatrix &ubOrig, return optimizedUB; } -/// Tie together a modulated peak number with its offset -using MNPOffset = std::tuple<double, double, double, V3D>; - -/** - * Calculate a list of HKL offsets from the given modulation vectors. - * @param maxOrder Integer specifying the multiples of the modulation vector. - * @param modVectors A list of modulation vectors form the user - * @param crossTerms If true then compute products of the modulation vectors - * @return A list of (m, n, p, V3D) were m,n,p specifies the modulation - * structure number and V3D specifies the offset to be tested - */ -std::vector<MNPOffset> -calculateOffsetsToTest(const int maxOrder, const std::vector<V3D> &modVectors, - const bool crossTerms) { - assert(modVectors.size() <= 3); - - std::vector<MNPOffset> offsets; - if (crossTerms && modVectors.size() > 1) { - const auto &modVector0{modVectors[0]}, modVector1{modVectors[1]}; - if (modVectors.size() == 2) { - // Calculate m*mod_vec1 + n*mod_vec2 for combinations of m, n in - // [-maxOrder,maxOrder] - offsets.reserve(2 * maxOrder); - for (auto m = -maxOrder; m <= maxOrder; ++m) { - for (auto n = -maxOrder; n <= maxOrder; ++n) { - if (m == 0 && n == 0) - continue; - offsets.emplace_back( - std::make_tuple(m, n, 0, modVector0 * m + modVector1 * n)); - } - } - } else { - // Calculate m*mod_vec1 + n*mod_vec2 + p*mod_vec3 for combinations of m, - // n, p in - // [-maxOrder,maxOrder] - const auto &modVector2{modVectors[2]}; - offsets.reserve(3 * maxOrder); - for (auto m = -maxOrder; m <= maxOrder; ++m) { - for (auto n = -maxOrder; n <= maxOrder; ++n) { - for (auto p = -maxOrder; p <= maxOrder; ++p) { - if (m == 0 && n == 0 && p == 0) - continue; - offsets.emplace_back(std::make_tuple( - m, n, p, modVector0 * m + modVector1 * n + modVector2 * p)); - } - } - } - } - } else { - // No cross terms: Compute coeff*mod_vec_i for each modulation vector - // separately for coeff in [-maxOrder, maxOrder] - for (auto i = 0u; i < modVectors.size(); ++i) { - const auto &modVector = modVectors[i]; - for (int order = -maxOrder; order <= maxOrder; ++order) { - if (order == 0) - continue; - V3D offset{modVector * order}; - switch (i) { - case 0: - offsets.emplace_back(std::make_tuple(order, 0, 0, std::move(offset))); - break; - case 1: - offsets.emplace_back(std::make_tuple(0, order, 0, std::move(offset))); - break; - case 2: - offsets.emplace_back(std::make_tuple(0, 0, order, std::move(offset))); - break; - } - } - } - } - - return offsets; -} - /// <IntHKL, IntMNP, error> using IndexedSatelliteInfo = std::tuple<V3D, V3D, double>; @@ -312,7 +224,7 @@ boost::optional<IndexedSatelliteInfo> indexSatellite(const V3D &mainHKL, const int maxOrder, const std::vector<V3D> &modVectors, const double tolerance, const bool crossTerms) { - const auto offsets = calculateOffsetsToTest(maxOrder, modVectors, crossTerms); + const auto offsets = generateOffsetVectors(modVectors, maxOrder, crossTerms); bool foundSatellite{false}; V3D indexedIntHKL, indexedMNP; for (const auto &mnpOffset : offsets) { @@ -457,28 +369,7 @@ void IndexPeaks::init() { "Round H, K and L values to integers"); this->declareProperty(Prop::COMMONUB, false, "Index all orientations with a common UB"); - auto mustBeLengthThree = boost::make_shared<ArrayLengthValidator<double>>(3); - this->declareProperty(std::make_unique<ArrayProperty<double>>( - Prop::MODVECTOR1, "0.0,0.0,0.0", mustBeLengthThree), - "Modulation Vector 1: dh, dk, dl"); - this->declareProperty(std::make_unique<Kernel::ArrayProperty<double>>( - Prop::MODVECTOR2, "0.0,0.0,0.0", mustBeLengthThree), - "Modulation Vector 2: dh, dk, dl"); - this->declareProperty(std::make_unique<Kernel::ArrayProperty<double>>( - Prop::MODVECTOR3, "0.0,0.0,0.0", mustBeLengthThree), - "Modulation Vector 3: dh, dk, dl"); - auto mustBePositiveOrZero = boost::make_shared<BoundedValidator<int>>(); - mustBePositiveOrZero->setLower(0); - this->declareProperty( - Prop::MAXORDER, 0, mustBePositiveOrZero, - "Maximum order to apply Modulation Vectors. Default = 0", - Direction::Input); - - this->declareProperty( - Prop::CROSSTERMS, false, - "Include combinations of modulation vectors in satellite search", - Direction::Input); - + ModulationProperties::appendTo(this); this->declareProperty( Prop::SAVEMODINFO, false, "If true, update the OrientedLattice with the maxOrder, " diff --git a/Framework/Crystal/src/PeakAlgorithmHelpers.cpp b/Framework/Crystal/src/PeakAlgorithmHelpers.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3081069b74f07996487b9297a5fe171369867118 --- /dev/null +++ b/Framework/Crystal/src/PeakAlgorithmHelpers.cpp @@ -0,0 +1,195 @@ +// Mantid Repository : https://github.com/mantidproject/mantid +// +// Copyright © 2019 ISIS Rutherford Appleton Laboratory UKRI, +// NScD Oak Ridge National Laboratory, European Spallation Source +// & Institut Laue - Langevin +// SPDX - License - Identifier: GPL - 3.0 + +#include "MantidCrystal/PeakAlgorithmHelpers.h" +#include "MantidKernel/ArrayLengthValidator.h" +#include "MantidKernel/ArrayProperty.h" +#include "MantidKernel/BoundedValidator.h" +#include "MantidKernel/V3D.h" + +using Mantid::Kernel::ArrayLengthValidator; +using Mantid::Kernel::ArrayProperty; +using Mantid::Kernel::BoundedValidator; +using Mantid::Kernel::Direction; +using Mantid::Kernel::V3D; + +namespace Mantid::Crystal { + +/** + * Append the common set of properties that relate to modulation vectors + * to the given algorithm + * @param alg A pointer to the algorithm that will receive the properties + */ +void ModulationProperties::appendTo(API::IAlgorithm *alg) { + auto mustBeLengthThree = boost::make_shared<ArrayLengthValidator<double>>(3); + alg->declareProperty( + std::make_unique<ArrayProperty<double>>(ModulationProperties::ModVector1, + "0.0,0.0,0.0", mustBeLengthThree), + "Modulation Vector 1: dh, dk, dl"); + alg->declareProperty( + std::make_unique<ArrayProperty<double>>(ModulationProperties::ModVector2, + "0.0,0.0,0.0", mustBeLengthThree), + "Modulation Vector 2: dh, dk, dl"); + alg->declareProperty( + std::make_unique<ArrayProperty<double>>(ModulationProperties::ModVector3, + "0.0,0.0,0.0", mustBeLengthThree), + "Modulation Vector 3: dh, dk, dl"); + auto mustBePositiveOrZero = boost::make_shared<BoundedValidator<int>>(); + mustBePositiveOrZero->setLower(0); + alg->declareProperty(ModulationProperties::MaxOrder, 0, mustBePositiveOrZero, + "Maximum order to apply Modulation Vectors. Default = 0", + Direction::Input); + alg->declareProperty( + ModulationProperties::CrossTerms, false, + "Include combinations of modulation vectors in satellite search", + Direction::Input); +} + +/** + * Create a ModulationProperties object from an algorithm + * @param alg An algorithm containing the user input + * @return A new ModulationProperties object + */ +ModulationProperties ModulationProperties::create(const API::IAlgorithm &alg) { + const int maxOrder{alg.getProperty(ModulationProperties::MaxOrder)}; + const bool crossTerms{alg.getProperty(ModulationProperties::CrossTerms)}; + auto offsets = generateOffsetVectors( + validModulationVectors(alg.getProperty(ModulationProperties::ModVector1), + alg.getProperty(ModulationProperties::ModVector2), + alg.getProperty(ModulationProperties::ModVector3)), + maxOrder, crossTerms); + const bool saveOnLattice{true}; + return {std::move(offsets), maxOrder, crossTerms, saveOnLattice}; +} + +/** + * Check each input is a valid modulation and add it to a list + * to return. + * @param modVector1 List of 3 doubles specifying an offset + * @param modVector2 List of 3 doubles specifying an offset + * @param modVector3 List of 3 doubles specifying an offset + * @return A list of valid modulation vectors + */ +std::vector<Kernel::V3D> +validModulationVectors(const std::vector<double> &modVector1, + const std::vector<double> &modVector2, + const std::vector<double> &modVector3) { + std::vector<V3D> modVectors; + auto addIfNonZero = [&modVectors](const auto &modVec) { + if (std::fabs(modVec[0]) > 0 || std::fabs(modVec[1]) > 0 || + std::fabs(modVec[2]) > 0) + modVectors.emplace_back(V3D(modVec[0], modVec[1], modVec[2])); + }; + addIfNonZero(modVector1); + addIfNonZero(modVector2); + addIfNonZero(modVector3); + return modVectors; +} + +/** + * @param maxOrder Integer specifying the multiples of the + * modulation vector. + * @param modVectors A list of modulation vectors form the user + * @param crossTerms If true then compute products of the + * modulation vectors + * @return A list of (m, n, p, V3D) were m,n,p specifies the + * modulation structure number and V3D specifies the offset to + * be tested + */ +std::vector<MNPOffset> +generateOffsetVectors(const std::vector<Kernel::V3D> &modVectors, + const int maxOrder, const bool crossTerms) { + assert(modVectors.size() <= 3); + + std::vector<MNPOffset> offsets; + if (crossTerms && modVectors.size() > 1) { + const auto &modVector0{modVectors[0]}, modVector1{modVectors[1]}; + if (modVectors.size() == 2) { + // Calculate m*mod_vec1 + n*mod_vec2 for combinations of + // m, n in + // [-maxOrder,maxOrder] + offsets.reserve(2 * maxOrder); + for (auto m = -maxOrder; m <= maxOrder; ++m) { + for (auto n = -maxOrder; n <= maxOrder; ++n) { + if (m == 0 && n == 0) + continue; + offsets.emplace_back( + std::make_tuple(m, n, 0, modVector0 * m + modVector1 * n)); + } + } + } else { + // Calculate m*mod_vec1 + n*mod_vec2 + p*mod_vec3 for + // combinations of m, n, p in [-maxOrder,maxOrder] + const auto &modVector2{modVectors[2]}; + offsets.reserve(3 * maxOrder); + for (auto m = -maxOrder; m <= maxOrder; ++m) { + for (auto n = -maxOrder; n <= maxOrder; ++n) { + for (auto p = -maxOrder; p <= maxOrder; ++p) { + if (m == 0 && n == 0 && p == 0) + continue; + offsets.emplace_back(std::make_tuple( + m, n, p, modVector0 * m + modVector1 * n + modVector2 * p)); + } + } + } + } + } else { + // No cross terms: Compute coeff*mod_vec_i for each + // modulation vector separately for coeff in [-maxOrder, + // maxOrder] + for (auto i = 0u; i < modVectors.size(); ++i) { + const auto &modVector = modVectors[i]; + for (int order = -maxOrder; order <= maxOrder; ++order) { + if (order == 0) + continue; + V3D offset{modVector * order}; + switch (i) { + case 0: + offsets.emplace_back(std::make_tuple(order, 0, 0, std::move(offset))); + break; + case 1: + offsets.emplace_back(std::make_tuple(0, order, 0, std::move(offset))); + break; + case 2: + offsets.emplace_back(std::make_tuple(0, 0, order, std::move(offset))); + break; + } + } + } + } + + return offsets; +} + +/** + * The final offset vector is computed as + * (hoffsets[i],koffsets[j],loffsets[k]) for each combination of + * i,j,k. All offsets are considered order 1 + * @param hOffsets A list of offsets in the h direction + * @param kOffsets A list of offsets in the k direction + * @param lOffsets A list of offsets in the l direction + * @return A list of (1, 1, 1, V3D) were m,n,p specifies the + * modulation structure number and V3D specifies the offset to + * be tested */ +std::vector<MNPOffset> +generateOffsetVectors(const std::vector<double> &hOffsets, + const std::vector<double> &kOffsets, + const std::vector<double> &lOffsets) { + std::vector<MNPOffset> offsets; + for (double hOffset : hOffsets) { + for (double kOffset : kOffsets) { + for (double lOffset : lOffsets) { + // mnp = 0, 0, 0 as + // it's not quite clear how to interpret them as mnp indices + offsets.emplace_back( + std::make_tuple(0, 0, 0, V3D(hOffset, kOffset, lOffset))); + } + } + } + return offsets; +} + +} // namespace Mantid::Crystal diff --git a/Framework/Crystal/src/PredictFractionalPeaks.cpp b/Framework/Crystal/src/PredictFractionalPeaks.cpp index 3b285837b0f01489901134ba00b7f8061a605f5a..43619cb7e91214f0d5056eb4162118d1bf6bf512 100644 --- a/Framework/Crystal/src/PredictFractionalPeaks.cpp +++ b/Framework/Crystal/src/PredictFractionalPeaks.cpp @@ -21,6 +21,7 @@ #include "MantidKernel/ListValidator.h" #include "MantidKernel/WarningSuppressions.h" +#include <boost/iterator/iterator_facade.hpp> #include <boost/math/special_functions/round.hpp> using Mantid::API::Algorithm; @@ -32,6 +33,7 @@ using Mantid::Geometry::HKLFilter; using Mantid::Geometry::HKLFilter_uptr; using Mantid::Geometry::HKLGenerator; using Mantid::Geometry::Instrument_const_sptr; +using Mantid::Geometry::OrientedLattice; using Mantid::Geometry::ReflectionCondition_sptr; using Mantid::Kernel::DblMatrix; using Mantid::Kernel::V3D; @@ -151,92 +153,134 @@ private: int m_currentPeak; }; +/** + * Create the output PeaksWorkspace + * @param inputPeaks The input workspace provided by the user + * @param modulationProps The set of modulation properties + * @return A new PeaksWorkspace + */ +boost::shared_ptr<Mantid::API::IPeaksWorkspace> createOutputWorkspace( + const PeaksWorkspace &inputPeaks, + const Mantid::Crystal::ModulationProperties &modulationProps) { + using Mantid::API::WorkspaceFactory; + auto outPeaks = WorkspaceFactory::Instance().createPeaks(); + outPeaks->setInstrument(inputPeaks.getInstrument()); + if (modulationProps.saveOnLattice) { + auto lattice = std::make_unique<Mantid::Geometry::OrientedLattice>(); + lattice->setMaxOrder(modulationProps.maxOrder); + lattice->setCrossTerm(modulationProps.crossTerms); + const auto &offsets = modulationProps.offsets; + // there should be a maximum of 3 modulation vectors. Store the + // order=(1,0,0),(0,1,0), (0,0,1) vectors + for (const auto &offset : offsets) { + const V3D &modVector{std::get<3>(offset)}; + const double m{std::get<0>(offset)}, n{std::get<1>(offset)}, + p{std::get<2>(offset)}; + int modNum(-1); + if (m == 1 && n == 0 && p == 0) + modNum = 1; + else if (m == 0 && n == 1 && p == 0) + modNum = 2; + else if (m == 0 && n == 0 && p == 1) + modNum = 3; + switch (modNum) { + case 1: + lattice->setModVec1(modVector); + break; + case 2: + lattice->setModVec2(modVector); + break; + case 3: + lattice->setModVec3(modVector); + break; + } + } + outPeaks->mutableSample().setOrientedLattice(std::move(lattice)); + } + return outPeaks; +} + /** * Predict fractional peaks in the range specified by [hklMin, hklMax] and * add them to a new PeaksWorkspace * @param alg The host algorithm pointer - * @param hOffsets Offsets to apply to HKL in H direction - * @param kOffsets Offsets to apply to HKL in K direction - * @param lOffsets Offsets to apply to HKL in L direction - * @param requirePeaksOnDetector If true the peaks is required to hit a detector + * @param requirePeaksOnDetector If true the peaks is required to hit a + * detector * @param inputPeaks A peaks workspace used to created new peaks. Defines the * instrument and metadata for the search + * @param modVectors The set of modulation vectors to use in the search * @param strategy An object defining were to start the search and how to * advance to the next HKL * @return A new PeaksWorkspace containing the predicted fractional peaks */ template <typename SearchStrategy> -IPeaksWorkspace_sptr -predictPeaks(Algorithm *const alg, const std::vector<double> &hOffsets, - const std::vector<double> &kOffsets, - const std::vector<double> &lOffsets, - const bool requirePeaksOnDetector, - const PeaksWorkspace &inputPeaks, SearchStrategy strategy) { - using Mantid::API::WorkspaceFactory; - auto outPeaks = WorkspaceFactory::Instance().createPeaks(); - const auto instrument = inputPeaks.getInstrument(); - outPeaks->setInstrument(instrument); - +IPeaksWorkspace_sptr predictFractionalPeaks( + Algorithm *const alg, const bool requirePeaksOnDetector, + const PeaksWorkspace &inputPeaks, + const Mantid::Crystal::ModulationProperties &modulationProps, + SearchStrategy searchStrategy) { using Mantid::Geometry::InstrumentRayTracer; - const InstrumentRayTracer tracer(instrument); + + const InstrumentRayTracer tracer(inputPeaks.getInstrument()); const auto &UB = inputPeaks.sample().getOrientedLattice().getUB(); + const auto &offsets = modulationProps.offsets; + auto outPeaks = createOutputWorkspace(inputPeaks, modulationProps); using PeakHash = std::array<int, 4>; std::vector<PeakHash> alreadyDonePeaks; V3D currentHKL; DblMatrix gonioMatrix; int runNumber{0}; - strategy.initialHKL(¤tHKL, &gonioMatrix, &runNumber); - auto progressReporter = strategy.createProgressReporter(alg); + searchStrategy.initialHKL(¤tHKL, &gonioMatrix, &runNumber); + auto progressReporter = searchStrategy.createProgressReporter(alg); while (true) { - for (double hOffset : hOffsets) { - for (double kOffset : kOffsets) { - for (double lOffset : lOffsets) { - const V3D candidateHKL(currentHKL[0] + hOffset, - currentHKL[1] + kOffset, - currentHKL[2] + lOffset); - const V3D qLab = (gonioMatrix * UB * candidateHKL) * 2 * M_PI; - if (qLab[2] <= 0) - continue; - - using Mantid::Geometry::IPeak; - std::unique_ptr<IPeak> peak; - try { - peak = inputPeaks.createPeak(qLab); - } catch (...) { - // If we can't create a valid peak we have no choice but to skip it - continue; - } - - peak->setGoniometerMatrix(gonioMatrix); - if (requirePeaksOnDetector && peak->getDetectorID() < 0) - continue; - GNU_DIAG_OFF("missing-braces") - PeakHash savedPeak{runNumber, - boost::math::iround(1000.0 * candidateHKL[0]), - boost::math::iround(1000.0 * candidateHKL[1]), - boost::math::iround(1000.0 * candidateHKL[2])}; - GNU_DIAG_ON("missing-braces") - auto it = - find(alreadyDonePeaks.begin(), alreadyDonePeaks.end(), savedPeak); - if (it == alreadyDonePeaks.end()) - alreadyDonePeaks.emplace_back(std::move(savedPeak)); - else - continue; - - peak->setHKL(candidateHKL); - peak->setRunNumber(runNumber); - outPeaks->addPeak(*peak); - } + for (const auto &mnpOffset : offsets) { + const V3D candidateHKL{currentHKL + std::get<3>(mnpOffset)}; + const V3D qLab = (gonioMatrix * UB * candidateHKL) * 2 * M_PI; + if (qLab[2] <= 0) + continue; + + using Mantid::Geometry::IPeak; + std::unique_ptr<IPeak> peak; + try { + peak = inputPeaks.createPeak(qLab); + } catch (...) { + // If we can't create a valid peak we have no choice but to skip + // it + continue; } + + peak->setGoniometerMatrix(gonioMatrix); + if (requirePeaksOnDetector && peak->getDetectorID() < 0) + continue; + GNU_DIAG_OFF("missing-braces") + PeakHash savedPeak{runNumber, + boost::math::iround(1000.0 * candidateHKL[0]), + boost::math::iround(1000.0 * candidateHKL[1]), + boost::math::iround(1000.0 * candidateHKL[2])}; + GNU_DIAG_ON("missing-braces") + auto it = + find(alreadyDonePeaks.begin(), alreadyDonePeaks.end(), savedPeak); + if (it == alreadyDonePeaks.end()) + alreadyDonePeaks.emplace_back(std::move(savedPeak)); + else + continue; + + peak->setHKL(candidateHKL); + const double m{std::get<0>(mnpOffset)}, n{std::get<1>(mnpOffset)}, + p{std::get<2>(mnpOffset)}; + if (fabs(m) > 0. || fabs(n) > 0. || fabs(p) > 0.) + peak->setIntMNP(V3D(m, n, p)); + peak->setRunNumber(runNumber); + outPeaks->addPeak(*peak); } progressReporter.report(); - if (!strategy.nextHKL(¤tHKL, &gonioMatrix, &runNumber)) + if (!searchStrategy.nextHKL(¤tHKL, &gonioMatrix, &runNumber)) break; } return outPeaks; -} // namespace +} } // namespace @@ -300,6 +344,7 @@ void PredictFractionalPeaks::init() { "If true then the predicted peaks are required to hit a " "detector pixel. Default=true", Direction::Input); + ModulationProperties::appendTo(this); // enable range properties if required using Kernel::EnabledWhenProperty; @@ -315,7 +360,34 @@ void PredictFractionalPeaks::init() { std::move(includeInRangeEqOne), std::move(reflConditionNotEmpty), Kernel::OR)); } - + // group offset/modulations options together + for (const auto &name : {PropertyNames::HOFFSET, PropertyNames::KOFFSET, + PropertyNames::LOFFSET}) { + setPropertyGroup(name, "Separate Offsets"); + } + for (const auto &name : + {ModulationProperties::ModVector1, ModulationProperties::ModVector2, + ModulationProperties::ModVector3, ModulationProperties::MaxOrder, + ModulationProperties::CrossTerms}) { + setPropertyGroup(name, "Modulation Vectors"); + } + // enable/disable offsets & modulation vectors appropriately + for (const auto &offsetName : {PropertyNames::HOFFSET, PropertyNames::KOFFSET, + PropertyNames::LOFFSET}) { + EnabledWhenProperty modVectorOneIsDefault{ModulationProperties::ModVector1, + Kernel::IS_DEFAULT}; + EnabledWhenProperty modVectorTwoIsDefault{ModulationProperties::ModVector2, + Kernel::IS_DEFAULT}; + EnabledWhenProperty modVectorThreeIsDefault{ + ModulationProperties::ModVector3, Kernel::IS_DEFAULT}; + EnabledWhenProperty modVectorOneAndTwoIsDefault{ + std::move(modVectorOneIsDefault), std::move(modVectorTwoIsDefault), + Kernel::AND}; + setPropertySettings(offsetName, + std::make_unique<Kernel::EnabledWhenProperty>( + std::move(modVectorOneAndTwoIsDefault), + std::move(modVectorThreeIsDefault), Kernel::AND)); + } // Outputs declareProperty( std::make_unique<WorkspaceProperty<PeaksWorkspace_sptr::element_type>>( @@ -323,6 +395,10 @@ void PredictFractionalPeaks::init() { "Workspace of Peaks with peaks with fractional h,k, and/or l values"); } +/** + * Validate the input once all values are set + * @return A map<string,string> containting an help messages for the user + */ std::map<std::string, std::string> PredictFractionalPeaks::validateInputs() { std::map<std::string, std::string> helpMessages; const PeaksWorkspace_sptr peaks = getProperty(PropertyNames::PEAKS); @@ -344,20 +420,23 @@ std::map<std::string, std::string> PredictFractionalPeaks::validateInputs() { validateRange(PropertyNames::KMIN, PropertyNames::KMAX); validateRange(PropertyNames::LMIN, PropertyNames::LMAX); + // If a modulation vector is provided then maxOrder is needed + const auto modVectors = + validModulationVectors(getProperty(ModulationProperties::ModVector1), + getProperty(ModulationProperties::ModVector2), + getProperty(ModulationProperties::ModVector3)); + const int maxOrder = getProperty(ModulationProperties::MaxOrder); + if (maxOrder == 0 && !modVectors.empty()) { + helpMessages[ModulationProperties::MaxOrder] = + "Maxorder required when specifying a modulation vector."; + } return helpMessages; } +/// Execute the algorithm void PredictFractionalPeaks::exec() { PeaksWorkspace_sptr inputPeaks = getProperty(PropertyNames::PEAKS); - std::vector<double> hOffsets = getProperty(PropertyNames::HOFFSET); - std::vector<double> kOffsets = getProperty(PropertyNames::KOFFSET); - std::vector<double> lOffsets = getProperty(PropertyNames::LOFFSET); - if (hOffsets.empty()) - hOffsets.emplace_back(0.0); - if (kOffsets.empty()) - kOffsets.emplace_back(0.0); - if (lOffsets.empty()) - lOffsets.emplace_back(0.0); + auto modulationInfo = getModulationInfo(); const bool includePeaksInRange = getProperty("IncludeAllPeaksInRange"); const V3D hklMin{getProperty(PropertyNames::HMIN), getProperty(PropertyNames::KMIN), @@ -386,17 +465,44 @@ void PredictFractionalPeaks::exec() { using Mantid::Geometry::HKLFilterNone; filter = std::make_unique<HKLFilterNone>(); } - outPeaks = predictPeaks( - this, hOffsets, kOffsets, lOffsets, requirePeakOnDetector, *inputPeaks, + outPeaks = predictFractionalPeaks( + this, requirePeakOnDetector, *inputPeaks, std::move(modulationInfo), PeaksInRangeStrategy(hklMin, hklMax, filter.get(), inputPeaks.get())); - } else { - outPeaks = - predictPeaks(this, hOffsets, kOffsets, lOffsets, requirePeakOnDetector, - *inputPeaks, PeaksFromIndexedStrategy(inputPeaks.get())); + outPeaks = predictFractionalPeaks( + this, requirePeakOnDetector, *inputPeaks, std::move(modulationInfo), + PeaksFromIndexedStrategy(inputPeaks.get())); } setProperty(PropertyNames::FRACPEAKS, outPeaks); } +/** + * @return The list of modulation vectors based on the user input. Anything + * specified by the modulation vector parameters takes precedence over the + * offsets + */ +ModulationProperties PredictFractionalPeaks::getModulationInfo() { + // Input validation ensures that we have either specified offests or + // modulation vectors + const int maxOrder = getProperty(ModulationProperties::MaxOrder); + + if (maxOrder == 0) { + std::vector<double> hOffsets = getProperty(PropertyNames::HOFFSET); + std::vector<double> kOffsets = getProperty(PropertyNames::KOFFSET); + std::vector<double> lOffsets = getProperty(PropertyNames::LOFFSET); + if (hOffsets.empty()) + hOffsets.emplace_back(0.0); + if (kOffsets.empty()) + kOffsets.emplace_back(0.0); + if (lOffsets.empty()) + lOffsets.emplace_back(0.0); + const bool crossTerms{false}, saveOnLattice{false}; + return {generateOffsetVectors(hOffsets, kOffsets, lOffsets), maxOrder, + crossTerms, saveOnLattice}; + } else { + return ModulationProperties::create(*this); + } +} + } // namespace Crystal } // namespace Mantid diff --git a/Framework/Crystal/test/IndexPeaksTest.h b/Framework/Crystal/test/IndexPeaksTest.h index b773976b1dcc83a084297eaecbc349d808f8b093..841ce5285798df834292e076775e256581813103 100644 --- a/Framework/Crystal/test/IndexPeaksTest.h +++ b/Framework/Crystal/test/IndexPeaksTest.h @@ -477,7 +477,8 @@ public: IndexPeaks alg; alg.initialize(); - TS_ASSERT_THROWS(alg.setProperty("MaxOrder", -1), std::invalid_argument) + TS_ASSERT_THROWS(alg.setProperty("MaxOrder", -1), + const std::invalid_argument &) } void test_modvector_with_list_length_not_three_throws() { @@ -485,10 +486,11 @@ public: alg.initialize(); for (const auto &propName : {"ModVector1", "ModVector2", "ModVector3"}) { - TS_ASSERT_THROWS(alg.setProperty(propName, "0"), std::invalid_argument) - TS_ASSERT_THROWS(alg.setProperty(propName, "0,0"), std::invalid_argument) + TS_ASSERT_THROWS(alg.setProperty(propName, "0"), std::invalid_argument &) + TS_ASSERT_THROWS(alg.setProperty(propName, "0,0"), + std::invalid_argument &) TS_ASSERT_THROWS(alg.setProperty(propName, "0,0,0,0"), - std::invalid_argument) + std::invalid_argument &) } } diff --git a/Framework/Crystal/test/PredictFractionalPeaksTest.h b/Framework/Crystal/test/PredictFractionalPeaksTest.h index c75544edab271dbae72f51bd30c18a49df90d01c..f0a179e4c1246e54450d1744c3dc9ef4e4d2b798 100644 --- a/Framework/Crystal/test/PredictFractionalPeaksTest.h +++ b/Framework/Crystal/test/PredictFractionalPeaksTest.h @@ -99,24 +99,31 @@ public: m_indexedPeaks, {{"HOffset", "-0.5,0,0.5"}, {"KOffset", "0.0"}, {"LOffset", "0.2"}}); - TS_ASSERT_EQUALS(117, fracPeaks->getNumberPeaks()) - const auto &peak0 = fracPeaks->getPeak(0); - TS_ASSERT_DELTA(peak0.getH(), -5.5, .0001) - TS_ASSERT_DELTA(peak0.getK(), 7.0, .0001) - TS_ASSERT_DELTA(peak0.getL(), -3.8, .0001) - TS_ASSERT_EQUALS(peak0.getDetectorID(), 1146353) - - const auto &peak3 = fracPeaks->getPeak(3); - TS_ASSERT_DELTA(peak3.getH(), -5.5, .0001) - TS_ASSERT_DELTA(peak3.getK(), 3.0, .0001) - TS_ASSERT_DELTA(peak3.getL(), -2.8, .0001) - TS_ASSERT_EQUALS(peak3.getDetectorID(), 1747163) - - const auto &peak6 = fracPeaks->getPeak(6); - TS_ASSERT_DELTA(peak6.getH(), -6.5, .0001) - TS_ASSERT_DELTA(peak6.getK(), 4.0, .0001) - TS_ASSERT_DELTA(peak6.getL(), -3.8, .0001) - TS_ASSERT_EQUALS(peak6.getDetectorID(), 1737894) + auto nPeaks = fracPeaks->getNumberPeaks(); + TS_ASSERT_EQUALS(117, nPeaks) + if (nPeaks > 0) { + const auto &peak0 = fracPeaks->getPeak(0); + TS_ASSERT_DELTA(peak0.getH(), -5.5, .0001) + TS_ASSERT_DELTA(peak0.getK(), 7.0, .0001) + TS_ASSERT_DELTA(peak0.getL(), -3.8, .0001) + TS_ASSERT_EQUALS(peak0.getDetectorID(), 1146353) + } + + if (nPeaks > 3) { + const auto &peak3 = fracPeaks->getPeak(3); + TS_ASSERT_DELTA(peak3.getH(), -5.5, .0001) + TS_ASSERT_DELTA(peak3.getK(), 3.0, .0001) + TS_ASSERT_DELTA(peak3.getL(), -2.8, .0001) + TS_ASSERT_EQUALS(peak3.getDetectorID(), 1747163) + } + + if (nPeaks > 6) { + const auto &peak6 = fracPeaks->getPeak(6); + TS_ASSERT_DELTA(peak6.getH(), -6.5, .0001) + TS_ASSERT_DELTA(peak6.getK(), 4.0, .0001) + TS_ASSERT_DELTA(peak6.getL(), -3.8, .0001) + TS_ASSERT_EQUALS(peak6.getDetectorID(), 1737894) + } } void test_exec_with_include_in_range_and_hit_detector() { @@ -229,6 +236,50 @@ public: TS_ASSERT_EQUALS(fracPeaks->getPeak(24).getDetectorID(), 3157981) } + void test_providing_modulation_vector_saves_properties_to_lattice() { + const auto fracPeaks = runPredictFractionalPeaks( + m_indexedPeaks, {{"ModVector1", "-0.5,0,0.5"}, + {"ModVector2", "0.0,0.5,0.5"}, + {"MaxOrder", "1"}, + {"CrossTerms", "0"}}); + + TS_ASSERT_EQUALS(124, fracPeaks->getNumberPeaks()) + + // check lattice + const auto &lattice = fracPeaks->sample().getOrientedLattice(); + TS_ASSERT_EQUALS(1, lattice.getMaxOrder()) + TS_ASSERT_EQUALS(false, lattice.getCrossTerm()) + const auto mod1 = lattice.getModVec(0); + TS_ASSERT_EQUALS(-0.5, mod1.X()) + TS_ASSERT_EQUALS(0.0, mod1.Y()) + TS_ASSERT_EQUALS(0.5, mod1.Z()) + const auto mod2 = lattice.getModVec(1); + TS_ASSERT_EQUALS(0.0, mod2.X()) + TS_ASSERT_EQUALS(0.5, mod2.Y()) + TS_ASSERT_EQUALS(0.5, mod2.Z()) + + // check a couple of peaks + const auto &peak0 = fracPeaks->getPeak(0); + TS_ASSERT_DELTA(peak0.getH(), -4.5, .0001) + TS_ASSERT_DELTA(peak0.getK(), 7.0, .0001) + TS_ASSERT_DELTA(peak0.getL(), -4.5, .0001) + TS_ASSERT_EQUALS(peak0.getDetectorID(), 1129591) + const auto mnp0 = peak0.getIntMNP(); + TS_ASSERT_DELTA(mnp0.X(), -1., 1e-08) + TS_ASSERT_DELTA(mnp0.Y(), 0., 1e-08) + TS_ASSERT_DELTA(mnp0.Z(), 0., 1e-08) + + const auto &peak34 = fracPeaks->getPeak(34); + TS_ASSERT_DELTA(peak34.getH(), -7, .0001) + TS_ASSERT_DELTA(peak34.getK(), 7.5, .0001) + TS_ASSERT_DELTA(peak34.getL(), -2.5, .0001) + TS_ASSERT_EQUALS(peak34.getDetectorID(), 1812163) + const auto mnp34 = peak34.getIntMNP(); + TS_ASSERT_DELTA(mnp34.X(), 0., 1e-08) + TS_ASSERT_DELTA(mnp34.Y(), 1., 1e-08) + TS_ASSERT_DELTA(mnp34.Z(), 0., 1e-08) + } + // ---------------- Failure tests ----------------------------- void test_empty_peaks_workspace_is_not_allowed() { PredictFractionalPeaks alg; @@ -253,6 +304,17 @@ public: doInvalidRangeTest("L"); } + void test_modulation_vector_requires_maxOrder_gt_0() { + PredictFractionalPeaks alg; + alg.initialize(); + alg.setProperty("Peaks", WorkspaceCreationHelper::createPeaksWorkspace(0)); + alg.setProperty("ModVector1", "0.5,0,0.5"); + + auto helpMsgs = alg.validateInputs(); + + TS_ASSERT(helpMsgs.find("MaxOrder") != helpMsgs.cend()) + } + private: void doInvalidRangeTest(const std::string &dimension) { PredictFractionalPeaks alg; diff --git a/Framework/Crystal/test/PredictPeaksTest.h b/Framework/Crystal/test/PredictPeaksTest.h index 1735df3fff00cc59a29529eadb1e6ce9c736458d..5db5d807c4b15307e07b87681ea9ad9594044890 100644 --- a/Framework/Crystal/test/PredictPeaksTest.h +++ b/Framework/Crystal/test/PredictPeaksTest.h @@ -244,15 +244,21 @@ public: void test_manual_U_and_gonio() { do_test_manual(22.5, 22.5); } void test_crystallography() { - Kernel::ConfigService::Instance().setString("Q.convention", - "Crystallography"); + using Kernel::ConfigService; + auto origQConv = ConfigService::Instance().getString("Q.convention"); + ConfigService::Instance().setString("Q.convention", "Crystallography"); do_test_exec("Primitive", 10, std::vector<V3D>(), -1); + ConfigService::Instance().setString("Q.convention", origQConv); } void test_edge() { do_test_exec("Primitive", 5, std::vector<V3D>(), 1, false, false, 10); } void test_exec_with_CalculateGoniometerForCW() { + using Kernel::ConfigService; + auto origQConv = ConfigService::Instance().getString("Q.convention"); + ConfigService::Instance().setString("Q.convention", "Crystallography"); + // Name of the output workspace. std::string outWSName("PredictPeaksTest_OutputWS"); @@ -296,6 +302,8 @@ public: // Remove workspace from the data service. AnalysisDataService::Instance().remove(outWSName); + + ConfigService::Instance().setString("Q.convention", origQConv); } }; diff --git a/Framework/Crystal/test/SCDCalibratePanelsTest.h b/Framework/Crystal/test/SCDCalibratePanelsTest.h index 872cd45c4209a77fa2b4ad54de11cbe003cef454..698224105715179985339ca28102f6e9e8a3f820 100644 --- a/Framework/Crystal/test/SCDCalibratePanelsTest.h +++ b/Framework/Crystal/test/SCDCalibratePanelsTest.h @@ -16,6 +16,7 @@ #include "MantidAPI/AnalysisDataService.h" #include "MantidCrystal/SCDCalibratePanels.h" +#include <boost/filesystem.hpp> #include <cxxtest/TestSuite.h> using namespace Mantid::API; @@ -57,7 +58,9 @@ public: alg->setProperty("alpha", 90.0); alg->setProperty("beta", 90.0); alg->setProperty("gamma", 120.0); - alg->setPropertyValue("DetCalFilename", "/tmp/topaz.detcal"); // deleteme + auto detCalTempPath = boost::filesystem::temp_directory_path(); + detCalTempPath /= "topaz.detcal"; + alg->setPropertyValue("DetCalFilename", detCalTempPath.string()); TS_ASSERT(alg->execute()); // verify the results @@ -82,6 +85,8 @@ public: AnalysisDataService::Instance().retrieveWS<ITableWorkspace>( "params_L1"); TS_ASSERT_DELTA(0.00529, resultsL1->cell<double>(2, 1), .01); + + remove(detCalTempPath.string().c_str()); } }; diff --git a/Framework/CurveFitting/test/Algorithms/VesuvioCalculateGammaBackgroundTest.h b/Framework/CurveFitting/test/Algorithms/VesuvioCalculateGammaBackgroundTest.h index 3a61f2724f55312b0c66be02b582a9299e9b2fc7..40ae80c6bcfcd7ca201a9060300faa4e0255fc0a 100644 --- a/Framework/CurveFitting/test/Algorithms/VesuvioCalculateGammaBackgroundTest.h +++ b/Framework/CurveFitting/test/Algorithms/VesuvioCalculateGammaBackgroundTest.h @@ -70,13 +70,13 @@ public: const auto &corrY(correctedWS->y(0)); TS_ASSERT_DELTA(corrY.front(), 0.0000012042, 1e-08); TS_ASSERT_DELTA(corrY[npts / 2], 0.1580361070, 1e-08); - TS_ASSERT_DELTA(corrY.back(), -0.0144493467, 1e-08); + TS_ASSERT_DELTA(corrY.back(), -0.0144492041, 1e-08); // Background Y values = 0.0 const auto &backY(backgroundWS->y(0)); TS_ASSERT_DELTA(backY.front(), -0.0000012042, 1e-08); TS_ASSERT_DELTA(backY[npts / 2], -0.0001317931, 1e-08); - TS_ASSERT_DELTA(backY.back(), 0.0144493467, 1e-08); + TS_ASSERT_DELTA(backY.back(), 0.0144492041, 1e-08); } void @@ -167,13 +167,13 @@ public: const auto &corrY(correctedWS->y(0)); TS_ASSERT_DELTA(corrY.front(), 0.0000012042, 1e-08); TS_ASSERT_DELTA(corrY[npts / 2], 0.1580361070, 1e-08); - TS_ASSERT_DELTA(corrY.back(), -0.0144493467, 1e-08); + TS_ASSERT_DELTA(corrY.back(), -0.0144492041, 1e-08); // Background Y values = 0.0 const auto &backY(backgroundWS->y(0)); TS_ASSERT_DELTA(backY.front(), -0.0000012042, 1e-08); TS_ASSERT_DELTA(backY[npts / 2], -0.0001317931, 1e-08); - TS_ASSERT_DELTA(backY.back(), 0.0144493467, 1e-08); + TS_ASSERT_DELTA(backY.back(), 0.0144492041, 1e-08); } //------------------------------------ Error cases diff --git a/Framework/DataHandling/src/SaveGSS.cpp b/Framework/DataHandling/src/SaveGSS.cpp index ff3121a8fa55bab15ac10a73554d25ba15ffded6..7cf51513c8b2956bb0170158c6fde10a4a4478fd 100644 --- a/Framework/DataHandling/src/SaveGSS.cpp +++ b/Framework/DataHandling/src/SaveGSS.cpp @@ -119,12 +119,13 @@ void writeBankHeader(std::stringstream &out, const std::string &bintype, //---------------------------------------------------------------------------------------------- // Initialise the algorithm void SaveGSS::init() { - declareProperty(std::make_unique<API::WorkspaceProperty<>>( + const std::vector<std::string> exts{".gsa", ".gss", ".gda", ".txt"}; + declareProperty(std::make_unique<API::WorkspaceProperty<MatrixWorkspace>>( "InputWorkspace", "", Kernel::Direction::Input), "The input workspace"); - declareProperty(std::make_unique<API::FileProperty>("Filename", "", - API::FileProperty::Save), + declareProperty(std::make_unique<API::FileProperty>( + "Filename", "", API::FileProperty::Save, exts), "The filename to use for the saved data"); declareProperty( @@ -715,7 +716,11 @@ std::map<std::string, std::string> SaveGSS::validateInputs() { std::map<std::string, std::string> result; API::MatrixWorkspace_const_sptr input_ws = getProperty("InputWorkspace"); - + if (!input_ws) { + result["InputWorkspace"] = + "The input workspace cannot be a GroupWorkspace."; + return result; + } // Check the number of histogram/spectra < 99 const auto nHist = static_cast<int>(input_ws->getNumberHistograms()); const bool split = getProperty("SplitFiles"); diff --git a/Framework/DataHandling/test/AppendGeometryToSNSNexusTest.h b/Framework/DataHandling/test/AppendGeometryToSNSNexusTest.h index fee0f869fedfaafbb8466993a1b4539c7982b7aa..83f9b5b216a99aedcc3b3410414b6d195d4e09a7 100644 --- a/Framework/DataHandling/test/AppendGeometryToSNSNexusTest.h +++ b/Framework/DataHandling/test/AppendGeometryToSNSNexusTest.h @@ -9,6 +9,8 @@ #include "MantidKernel/System.h" #include "MantidKernel/Timer.h" +#include <Poco/File.h> +#include <Poco/Path.h> #include <cxxtest/TestSuite.h> #include "MantidDataHandling/AppendGeometryToSNSNexus.h" @@ -17,6 +19,10 @@ using namespace Mantid; using namespace Mantid::DataHandling; using namespace Mantid::API; +namespace { +constexpr auto NXS_FILENAME = "HYS_11092_event.nxs"; +} + class AppendGeometryToSNSNexusTest : public CxxTest::TestSuite { public: // This pair of boilerplate methods prevent the suite being created statically @@ -43,12 +49,17 @@ public: TS_ASSERT(alg.isInitialized()) // TODO: Get a better test file. // Changed to use HYS_11088_event.nxs to test motors - TS_ASSERT_THROWS_NOTHING( - alg.setPropertyValue("Filename", "HYS_11092_event.nxs")); + TS_ASSERT_THROWS_NOTHING(alg.setPropertyValue("Filename", NXS_FILENAME)); TS_ASSERT_THROWS_NOTHING(alg.setProperty("MakeCopy", true)); TS_ASSERT_THROWS_NOTHING(alg.execute();); TS_ASSERT(alg.isExecuted()); + std::string fullpath(Poco::Path::temp() + NXS_FILENAME); + + if (Poco::File(fullpath).exists()) { + Poco::File(fullpath).remove(); + } + // Retrieve the workspace from data service. TODO: Change to your desired // type // Workspace_sptr ws; diff --git a/Framework/DataHandling/test/FindDetectorsParTest.h b/Framework/DataHandling/test/FindDetectorsParTest.h index 6104ecd863cf9b2204b468391a1442ac863f4526..1246d09dac6d8f2dfa8b4184abeefd95ac31287e 100644 --- a/Framework/DataHandling/test/FindDetectorsParTest.h +++ b/Framework/DataHandling/test/FindDetectorsParTest.h @@ -544,8 +544,8 @@ private: std::string azim_pattern("0,0,0,"); std::string pol_pattern("170.565,169.565,168.565,"); std::string sfp_pattern("1,1,1,"); - std::string polw_pattern("0.804071,0.804258,0.804442,"); - std::string azw_pattern("5.72472,5.72472,5.72472,"); + std::string polw_pattern("0.803981,0.804169,0.804354,"); + std::string azw_pattern("5.72481,5.72481,5.72481,"); std::array<std::stringstream, 5> bufs; for (int j = 0; j < 5; j++) { diff --git a/Framework/DataHandling/test/LoadNGEMTest.h b/Framework/DataHandling/test/LoadNGEMTest.h index c386efd37b39df7ae808d901c0e826c22f3b029f..b465a46c2e930a0654da439add019e17cce74373 100644 --- a/Framework/DataHandling/test/LoadNGEMTest.h +++ b/Framework/DataHandling/test/LoadNGEMTest.h @@ -136,7 +136,7 @@ public: TS_ASSERT_THROWS_NOTHING(alg.setProperty("MinEventsPerFrame", 20)); TS_ASSERT_THROWS_NOTHING(alg.setProperty("MaxEventsPerFrame", 10)); - TS_ASSERT_THROWS(alg.execute(), std::runtime_error); + TS_ASSERT_THROWS(alg.execute(), const std::runtime_error &); } void test_MinEventsPerFrame_removes_low_values() { diff --git a/Framework/DataHandling/test/LoadNexusProcessed2Test.h b/Framework/DataHandling/test/LoadNexusProcessed2Test.h index e061ed2fba0e39b2535363d0664b25bc539afc6f..50622570ef4419901644b8300f0b56d64c317b3d 100644 --- a/Framework/DataHandling/test/LoadNexusProcessed2Test.h +++ b/Framework/DataHandling/test/LoadNexusProcessed2Test.h @@ -178,7 +178,6 @@ public: using Mantid::SpectrumDefinition; using namespace Mantid::Indexing; FileResource fileInfo("test_spectra_miss_detectors.nxs"); - fileInfo.setDebugMode(true); auto instr = ComponentCreationHelper::createTestInstrumentRectangular( 2 /*numBanks*/, 10 /*numPixels*/); // 200 detectors in instrument diff --git a/Framework/DataHandling/test/SaveGSSTest.h b/Framework/DataHandling/test/SaveGSSTest.h index d555f6862480d7dbb91bc31d52d5a225264c75c7..33d0127fe126ae2090b0dc59ec28fec9260249f8 100644 --- a/Framework/DataHandling/test/SaveGSSTest.h +++ b/Framework/DataHandling/test/SaveGSSTest.h @@ -285,6 +285,9 @@ public: const std::string fileOnePath = outFilePath + "-0.gsas"; const std::string fileTwoPath = outFilePath + "-1.gsas"; + Poco::TemporaryFile::registerForDeletion(fileOnePath); + Poco::TemporaryFile::registerForDeletion(fileTwoPath); + TS_ASSERT(FileComparisonHelper::isEqualToReferenceFile( "SaveGSS-SplitRef-0.gsas", fileOnePath)); TS_ASSERT(FileComparisonHelper::isEqualToReferenceFile( diff --git a/Framework/DataHandling/test/SaveReflectometryAsciiTest.h b/Framework/DataHandling/test/SaveReflectometryAsciiTest.h index 3e3205f0a4ba031e427c6661579147ee873e261f..d0693246f208826abdbcc4977f484d162e210e54 100644 --- a/Framework/DataHandling/test/SaveReflectometryAsciiTest.h +++ b/Framework/DataHandling/test/SaveReflectometryAsciiTest.h @@ -90,6 +90,8 @@ public: TS_ASSERT_EQUALS(fullline, *(it++)); } TS_ASSERT(in.eof()) + in.close(); + Poco::File(filename).remove(); } void test_histogram_data() { @@ -127,6 +129,8 @@ public: TS_ASSERT_EQUALS(fullline, *(it++)) } TS_ASSERT(in.eof()) + in.close(); + Poco::File(filename).remove(); } void test_empty_workspace() { @@ -168,6 +172,8 @@ public: std::istreambuf_iterator<char>(), in.widen('\n')), 25) + in.close(); + Poco::File(filename).remove(); } void test_dx_values() { @@ -208,6 +214,8 @@ public: if (fullline.find(" : ") == std::string::npos) TS_ASSERT_EQUALS(fullline, *(it++)) } + in.close(); + Poco::File(filename).remove(); } void test_txt() { @@ -243,6 +251,9 @@ public: TS_ASSERT_EQUALS(fullline, *(it++)) } TS_ASSERT(in.eof()) + + in.close(); + Poco::File(filename).remove(); } void test_override_existing_file_txt() { @@ -287,6 +298,9 @@ public: TS_ASSERT_EQUALS(fullline, *(it++)); } TS_ASSERT(in.eof()) + + in.close(); + Poco::File(filename).remove(); } void test_more_than_nine_logs() { @@ -342,6 +356,9 @@ public: TS_ASSERT_EQUALS(line, "Number of file format : 40") std::getline(in, line); TS_ASSERT_EQUALS(line, "Number of data points : 2") + + in.close(); + Poco::File(filename).remove(); } void test_user_log() { @@ -403,6 +420,9 @@ public: TS_ASSERT_EQUALS(line, "Number of file format : 40") std::getline(in, line); TS_ASSERT_EQUALS(line, "Number of data points : 2") + + in.close(); + Poco::File(filename).remove(); } void test_user_log_overrides_fixed_log() { @@ -465,6 +485,9 @@ public: TS_ASSERT_EQUALS(line, "Number of file format : 40") std::getline(in, line); TS_ASSERT_EQUALS(line, "Number of data points : 2") + + in.close(); + Poco::File(filename).remove(); } void test_automatic_log_filling() { @@ -519,6 +542,9 @@ public: TS_ASSERT_EQUALS(line, "Number of file format : 40") std::getline(in, line); TS_ASSERT_EQUALS(line, "Number of data points : 2") + + in.close(); + Poco::File(filename).remove(); } void test_group_workspaces() { @@ -582,6 +608,11 @@ public: TS_ASSERT_EQUALS(fullline, *(it2++)); } TS_ASSERT(in2.eof()) + + in1.close(); + Poco::File(f1).remove(); + in2.close(); + Poco::File(f2).remove(); } void test_point_data_dat() { @@ -616,6 +647,9 @@ public: TS_ASSERT_EQUALS(fullline, *(it++)); } TS_ASSERT(in.eof()) + + in.close(); + Poco::File(filename).remove(); } void test_dx_values_with_header_custom() { @@ -658,6 +692,9 @@ public: if (fullline.find(" : ") == std::string::npos) TS_ASSERT_EQUALS(fullline, *(it++)) } + + in.close(); + Poco::File(filename).remove(); } void test_dx_values_no_header_custom() { @@ -694,6 +731,9 @@ public: while (std::getline(in, fullline)) { TS_ASSERT_EQUALS(fullline, *(it++)) } + + in.close(); + Poco::File(filename).remove(); } void test_no_header_no_resolution_separator_custom() { @@ -730,6 +770,8 @@ public: while (std::getline(in, fullline)) { TS_ASSERT_EQUALS(fullline, *(it++)) } + in.close(); + Poco::File(filename).remove(); } private: diff --git a/Framework/DataObjects/inc/MantidDataObjects/EventList.h b/Framework/DataObjects/inc/MantidDataObjects/EventList.h index 03646bc2326c50a0e18d9e15bd110d4407455a75..86ef07c9c0ca26bb201a91f0e4e080e024980890 100644 --- a/Framework/DataObjects/inc/MantidDataObjects/EventList.h +++ b/Framework/DataObjects/inc/MantidDataObjects/EventList.h @@ -243,6 +243,8 @@ public: void addPulsetime(const double seconds) override; + void addPulsetimes(const std::vector<double> &seconds) override; + void maskTof(const double tofMin, const double tofMax) override; void maskCondition(const std::vector<bool> &mask) override; @@ -454,6 +456,9 @@ private: template <class T> void addPulsetimeHelper(std::vector<T> &events, const double seconds); template <class T> + void addPulsetimesHelper(std::vector<T> &events, + const std::vector<double> &seconds); + template <class T> static std::size_t maskTofHelper(std::vector<T> &events, const double tofMin, const double tofMax); template <class T> diff --git a/Framework/DataObjects/src/EventList.cpp b/Framework/DataObjects/src/EventList.cpp index 5176fc224b1a0b320c2d3114e227481a2324a0ed..fa4a34c94d3ea7e4e7c9b2d87eea68ef7cc458e3 100644 --- a/Framework/DataObjects/src/EventList.cpp +++ b/Framework/DataObjects/src/EventList.cpp @@ -2632,6 +2632,23 @@ void EventList::addPulsetimeHelper(std::vector<T> &events, } } +/** Add an offset per event to the pulsetime (wall-clock time) of each event in + * the list. It is assumed that the vector sizes match. + * + * @param events :: reference to a vector of events to change. + * @param seconds :: The set of values to shift the pulsetime by, in seconds + */ +template <class T> +void EventList::addPulsetimesHelper(std::vector<T> &events, + const std::vector<double> &seconds) { + auto eventIterEnd{events.end()}; + auto secondsIter{seconds.cbegin()}; + for (auto eventIter = events.begin(); eventIter < eventIterEnd; + ++eventIter, ++secondsIter) { + eventIter->m_pulsetime += *secondsIter; + } +} + // -------------------------------------------------------------------------- /** Add an offset to the pulsetime (wall-clock time) of each event in the list. * @@ -2657,6 +2674,34 @@ void EventList::addPulsetime(const double seconds) { } } +// -------------------------------------------------------------------------- +/** Add an offset to the pulsetime (wall-clock time) of each event in the list. + * + * @param seconds :: A set of values to shift the pulsetime by, in seconds + */ +void EventList::addPulsetimes(const std::vector<double> &seconds) { + if (this->getNumberEvents() <= 0) + return; + if (this->getNumberEvents() != seconds.size()) { + throw std::runtime_error(""); + } + + // Convert the list + switch (eventType) { + case TOF: + this->addPulsetimesHelper(this->events, seconds); + break; + case WEIGHTED: + this->addPulsetimesHelper(this->weightedEvents, seconds); + break; + case WEIGHTED_NOTIME: + throw std::runtime_error("EventList::addPulsetime() called on an event " + "list with no pulse times. You must call this " + "algorithm BEFORE CompressEvents."); + break; + } +} + // -------------------------------------------------------------------------- /** Mask out events that have a tof between tofMin and tofMax (inclusively). * Events are removed from the list. diff --git a/Framework/DataObjects/test/EventListTest.h b/Framework/DataObjects/test/EventListTest.h index aea720f78c2cb72330b998cc0e6a668d874530fc..dc1876fb61b09746ae992531aef33dbd13840b79 100644 --- a/Framework/DataObjects/test/EventListTest.h +++ b/Framework/DataObjects/test/EventListTest.h @@ -1503,7 +1503,6 @@ public: } } - //----------------------------------------------------------------------------------------------- void test_addPulseTime_allTypes() { // Go through each possible EventType as the input for (int this_type = 0; this_type < 3; this_type++) { @@ -1531,6 +1530,45 @@ public: } } + void test_addPulseTimes_vector_throws_if_size_not_match_number_events() { + // Go through each possible EventType as the input + const std::vector<double> offsets = {1, 2, 3, 4, 5, 6}; + for (int this_type = 0; this_type < 3; this_type++) { + this->fake_uniform_time_data(); + el.switchTo(static_cast<EventType>(this_type)); + // Do convert + TS_ASSERT_THROWS(this->el.addPulsetimes(offsets), std::runtime_error); + } + } + + void test_addPulseTimes_vector_allTypes() { + // Go through each possible EventType as the input + for (int this_type = 0; this_type < 3; this_type++) { + this->fake_uniform_time_data(); + el.switchTo(static_cast<EventType>(this_type)); + const size_t old_num = this->el.getNumberEvents(); + std::vector<double> offsets(old_num, 123e-9); + // Do convert + if (static_cast<EventType>(this_type) == WEIGHTED_NOTIME) { + TS_ASSERT_THROWS_ANYTHING(this->el.addPulsetimes(offsets)) + } else { + this->el.addPulsetimes(offsets); + // Unchanged size + TS_ASSERT_EQUALS(old_num, this->el.getNumberEvents()); + // original times were 0, 1, etc. nansoeconds + TSM_ASSERT_EQUALS(this_type, + this->el.getEvent(0).pulseTime().totalNanoseconds(), + 123); + TSM_ASSERT_EQUALS(this_type, + this->el.getEvent(1).pulseTime().totalNanoseconds(), + 124); + TSM_ASSERT_EQUALS(this_type, + this->el.getEvent(2).pulseTime().totalNanoseconds(), + 125); + } + } + } + void test_sortByTimeAtSample_uniform_pulse_time() { // Go through each possible EventType (except the no-time one) as the input for (int this_type = 0; this_type < 3; this_type++) { diff --git a/Framework/Geometry/src/Surfaces/Plane.cpp b/Framework/Geometry/src/Surfaces/Plane.cpp index 8e65a5c6b7c009e7f513be88d1cd5dfa688760d5..2244fc064510d88f9b3b44ebb61226eab828d3b9 100644 --- a/Framework/Geometry/src/Surfaces/Plane.cpp +++ b/Framework/Geometry/src/Surfaces/Plane.cpp @@ -303,17 +303,23 @@ void Plane::write(std::ostream &OX) const { * @return The number of points of intersection */ int Plane::LineIntersectionWithPlane(V3D startpt, V3D endpt, V3D &output) { - double sprod = this->getNormal().scalar_prod(startpt - endpt); + double const sprod = this->getNormal().scalar_prod(startpt - endpt); if (sprod == 0) return 0; - double s1 = (NormV[0] * startpt[0] + NormV[1] * startpt[1] + - NormV[2] * startpt[2] - Dist) / - sprod; + double const projection = + NormV[0] * startpt[0] + NormV[1] * startpt[1] + NormV[2] * startpt[2]; + double s1 = (projection - Dist) / sprod; if (s1 < 0 || s1 > 1) return 0; - output[0] = startpt[0] + s1 * (endpt[0] - startpt[0]); - output[1] = startpt[1] + s1 * (endpt[1] - startpt[1]); - output[2] = startpt[2] + s1 * (endpt[2] - startpt[2]); + // The expressions below for resolving the point of intersection are + // resilient to the corner Dist << sprod. + double const ratio = projection / sprod; + output[0] = ratio * endpt[0] + (1 - ratio) * startpt[0] - + ((endpt[0] - startpt[0]) / sprod) * Dist; + output[1] = ratio * endpt[1] + (1 - ratio) * startpt[1] - + ((endpt[1] - startpt[1]) / sprod) * Dist; + output[2] = ratio * endpt[2] + (1 - ratio) * startpt[2] - + ((endpt[2] - startpt[2]) / sprod) * Dist; return 1; } diff --git a/Framework/Geometry/test/CSGObjectTest.h b/Framework/Geometry/test/CSGObjectTest.h index 85ea785aeeaeeb5a22da66b42d185b0915a43bf0..4cef5555940d1cbd401ab9af6ff5a5f0a174396b 100644 --- a/Framework/Geometry/test/CSGObjectTest.h +++ b/Framework/Geometry/test/CSGObjectTest.h @@ -439,7 +439,7 @@ public: dir.normalize(); Track track(V3D(-10, 0, 0), dir); - TS_ASSERT_THROWS(geom_obj->distance(track), std::runtime_error) + TS_ASSERT_THROWS(geom_obj->distance(track), const std::runtime_error &) } void testTrackTwoIsolatedCubes() diff --git a/Framework/Geometry/test/ComponentInfoTest.h b/Framework/Geometry/test/ComponentInfoTest.h index 0dce9322e192fbbd18540cb43d1216f1ae34640d..afe6ff39a54f0eb01cd49944c98e4ba8cb366ebb 100644 --- a/Framework/Geometry/test/ComponentInfoTest.h +++ b/Framework/Geometry/test/ComponentInfoTest.h @@ -332,6 +332,53 @@ public: TS_ASSERT(boundingBox.isNull()); } + // Test calculation of the bounding box for a milimiter-sized + // capped-cylinder detector pixel + void test_boundingBox_single_component_capped_cylinder() { + + const double radius = 0.00275; + const double height = 0.0042; + const Mantid::Kernel::V3D baseCentre(0., 0., 0.); + const Mantid::Kernel::V3D axis(0, 1, 0); + const std::string id("cy-1"); + + Eigen::Vector3d position{1., 1., 1.}; + auto internalInfo = makeSingleBeamlineComponentInfo(position); + Mantid::Geometry::ObjComponent comp1( + "component1", ComponentCreationHelper::createCappedCylinder( + radius, height, baseCentre, axis, id)); + + auto componentIds = + boost::make_shared<std::vector<Mantid::Geometry::ComponentID>>( + std::vector<Mantid::Geometry::ComponentID>{&comp1}); + + auto shapes = boost::make_shared< + std::vector<boost::shared_ptr<const Geometry::IObject>>>(); + shapes->push_back(ComponentCreationHelper::createCappedCylinder( + radius, height, baseCentre, axis, id)); + + ComponentInfo componentInfo(std::move(internalInfo), componentIds, + makeComponentIDMap(componentIds), shapes); + + BoundingBox boundingBox = componentInfo.boundingBox(0 /*componentIndex*/); + + TS_ASSERT( + (boundingBox.width() - (Kernel::V3D{2 * radius, height, 2 * radius})) + .norm() < 1e-9); + TS_ASSERT( + (boundingBox.minPoint() - + (Kernel::V3D{position[0] - radius, position[1], position[2] - radius})) + .norm() < 1e-9); + TS_ASSERT((boundingBox.maxPoint() - + (Kernel::V3D{position[0] + radius, position[1] + height, + position[2] + radius})) + .norm() < 1e-9); + // Nullify shape and retest BoundingBox + shapes->at(0) = boost::shared_ptr<const Geometry::IObject>(nullptr); + boundingBox = componentInfo.boundingBox(0); + TS_ASSERT(boundingBox.isNull()); + } + void test_boundingBox_complex() { const V3D sourcePos(-1, 0, 0); const V3D samplePos(0, 0, 0); diff --git a/Framework/Geometry/test/InstrumentDefinitionParserTest.h b/Framework/Geometry/test/InstrumentDefinitionParserTest.h index bdf210f2db1b1f1ed319a820542bdcec6b67d351..d1a52c7c45141ee7fdc1b2b8a02915be78c19b08 100644 --- a/Framework/Geometry/test/InstrumentDefinitionParserTest.h +++ b/Framework/Geometry/test/InstrumentDefinitionParserTest.h @@ -815,7 +815,7 @@ public: // generated by the InstrumentDefinitionParser. Poco::Path path( Mantid::Kernel::ConfigService::Instance().getTempDir().c_str()); - path.append(instrumentEnv._instName + ".vtp"); + path.append(parser.getMangledName() + ".vtp"); remove(path.toString().c_str()); } diff --git a/Framework/Geometry/test/MeshObjectTest.h b/Framework/Geometry/test/MeshObjectTest.h index d5b5b782a8b08d360bcb193de4b7eede1d3a0fcd..160da94b438fcfbeffbf77880269ae09f0008d4b 100644 --- a/Framework/Geometry/test/MeshObjectTest.h +++ b/Framework/Geometry/test/MeshObjectTest.h @@ -399,7 +399,7 @@ public: dir.normalize(); Track track(V3D(-10, 0, 0), dir); - TS_ASSERT_THROWS(geom_obj->distance(track), std::runtime_error) + TS_ASSERT_THROWS(geom_obj->distance(track), const std::runtime_error &) } void testTrackTwoIsolatedCubes() diff --git a/Framework/Kernel/inc/MantidKernel/IPropertyManager.h b/Framework/Kernel/inc/MantidKernel/IPropertyManager.h index f34774d9590cde19dcb0fa12d5f2af341d62472b..ef28254cb5b1ed61cc0f9d8b375b4f96451ff71c 100644 --- a/Framework/Kernel/inc/MantidKernel/IPropertyManager.h +++ b/Framework/Kernel/inc/MantidKernel/IPropertyManager.h @@ -67,6 +67,132 @@ public: virtual void declareOrReplaceProperty(std::unique_ptr<Property> p, const std::string &doc = "") = 0; + /** Add a property of the template type to the list of managed properties + * @param name :: The name to assign to the property + * @param value :: The initial value to assign to the property + * @param validator :: Pointer to the (optional) validator. + * @param doc :: The (optional) documentation string + * @param direction :: The (optional) direction of the property, in, out or + * inout + * @throw Exception::ExistsError if a property with the given name already + * exists + * @throw std::invalid_argument if the name argument is empty + */ + template <typename T> + void declareProperty( + const std::string &name, T value, + IValidator_sptr validator = boost::make_shared<NullValidator>(), + const std::string &doc = "", + const unsigned int direction = Direction::Input) { + std::unique_ptr<PropertyWithValue<T>> p = + std::make_unique<PropertyWithValue<T>>(name, value, validator, + direction); + declareProperty(std::move(p), doc); + } + + /** Add a property to the list of managed properties with no validator + * @param name :: The name to assign to the property + * @param value :: The initial value to assign to the property + * @param doc :: The documentation string + * @param direction :: The (optional) direction of the property, in + * (default), out or inout + * @throw Exception::ExistsError if a property with the given name already + * exists + * @throw std::invalid_argument if the name argument is empty + */ + template <typename T> + void declareProperty(const std::string &name, T value, const std::string &doc, + const unsigned int direction = Direction::Input) { + std::unique_ptr<PropertyWithValue<T>> p = + std::make_unique<PropertyWithValue<T>>( + name, value, boost::make_shared<NullValidator>(), direction); + declareProperty(std::move(p), doc); + } + + /** Add a property of the template type to the list of managed properties + * @param name :: The name to assign to the property + * @param value :: The initial value to assign to the property + * @param direction :: The direction of the property, in, out or inout + * @throw Exception::ExistsError if a property with the given name already + * exists + * @throw std::invalid_argument if the name argument is empty + */ + template <typename T> + void declareProperty(const std::string &name, T value, + const unsigned int direction) { + std::unique_ptr<PropertyWithValue<T>> p = + std::make_unique<PropertyWithValue<T>>( + name, value, boost::make_shared<NullValidator>(), direction); + declareProperty(std::move(p)); + } + + /** Specialised version of declareProperty template method to prevent the + * creation of a + * PropertyWithValue of type const char* if an argument in quotes is passed + * (it will be + * converted to a string). The validator, if provided, needs to be a string + * validator. + * @param name :: The name to assign to the property + * @param value :: The initial value to assign to the property + * @param validator :: Pointer to the (optional) validator. Ownership will be + * taken over. + * @param doc :: The (optional) documentation string + * @param direction :: The (optional) direction of the property, in, out or + * inout + * @throw Exception::ExistsError if a property with the given name already + * exists + * @throw std::invalid_argument if the name argument is empty + */ + void declareProperty( + const std::string &name, const char *value, + IValidator_sptr validator = boost::make_shared<NullValidator>(), + const std::string &doc = std::string(), + const unsigned int direction = Direction::Input) { + // Simply call templated method, converting character array to a string + declareProperty(name, std::string(value), std::move(validator), doc, + direction); + } + + /** Specialised version of declareProperty template method to prevent the + * creation of a + * PropertyWithValue of type const char* if an argument in quotes is passed + * (it will be + * converted to a string). The validator, if provided, needs to be a string + * validator. + * @param name :: The name to assign to the property + * @param value :: The initial value to assign to the property + * @param doc :: The (optional) documentation string + * @param validator :: Pointer to the (optional) validator. Ownership will be + * taken over. + * @param direction :: The (optional) direction of the property, in, out or + * inout + * @throw Exception::ExistsError if a property with the given name already + * exists + * @throw std::invalid_argument if the name argument is empty + */ + void declareProperty( + const std::string &name, const char *value, const std::string &doc, + IValidator_sptr validator = boost::make_shared<NullValidator>(), + const unsigned int direction = Direction::Input) { + // Simply call templated method, converting character array to a string + declareProperty(name, std::string(value), std::move(validator), doc, + direction); + } + + /** Add a property of string type to the list of managed properties + * @param name :: The name to assign to the property + * @param value :: The initial value to assign to the property + * @param direction :: The direction of the property, in, out or inout + * @throw Exception::ExistsError if a property with the given name already + * exists + * @throw std::invalid_argument if the name argument is empty + */ + void declareProperty(const std::string &name, const char *value, + const unsigned int direction) { + declareProperty(name, std::string(value), + boost::make_shared<NullValidator>(), "", direction); + } + /// Removes the property from management virtual void removeProperty(const std::string &name, const bool delproperty = true) = 0; @@ -226,132 +352,6 @@ public: virtual void filterByProperty(const TimeSeriesProperty<bool> & /*filte*/) = 0; protected: - /** Add a property of the template type to the list of managed properties - * @param name :: The name to assign to the property - * @param value :: The initial value to assign to the property - * @param validator :: Pointer to the (optional) validator. - * @param doc :: The (optional) documentation string - * @param direction :: The (optional) direction of the property, in, out or - * inout - * @throw Exception::ExistsError if a property with the given name already - * exists - * @throw std::invalid_argument if the name argument is empty - */ - template <typename T> - void declareProperty( - const std::string &name, T value, - IValidator_sptr validator = boost::make_shared<NullValidator>(), - const std::string &doc = "", - const unsigned int direction = Direction::Input) { - std::unique_ptr<PropertyWithValue<T>> p = - std::make_unique<PropertyWithValue<T>>(name, value, validator, - direction); - declareProperty(std::move(p), doc); - } - - /** Add a property to the list of managed properties with no validator - * @param name :: The name to assign to the property - * @param value :: The initial value to assign to the property - * @param doc :: The documentation string - * @param direction :: The (optional) direction of the property, in - * (default), out or inout - * @throw Exception::ExistsError if a property with the given name already - * exists - * @throw std::invalid_argument if the name argument is empty - */ - template <typename T> - void declareProperty(const std::string &name, T value, const std::string &doc, - const unsigned int direction = Direction::Input) { - std::unique_ptr<PropertyWithValue<T>> p = - std::make_unique<PropertyWithValue<T>>( - name, value, boost::make_shared<NullValidator>(), direction); - declareProperty(std::move(p), doc); - } - - /** Add a property of the template type to the list of managed properties - * @param name :: The name to assign to the property - * @param value :: The initial value to assign to the property - * @param direction :: The direction of the property, in, out or inout - * @throw Exception::ExistsError if a property with the given name already - * exists - * @throw std::invalid_argument if the name argument is empty - */ - template <typename T> - void declareProperty(const std::string &name, T value, - const unsigned int direction) { - std::unique_ptr<PropertyWithValue<T>> p = - std::make_unique<PropertyWithValue<T>>( - name, value, boost::make_shared<NullValidator>(), direction); - declareProperty(std::move(p)); - } - - /** Specialised version of declareProperty template method to prevent the - * creation of a - * PropertyWithValue of type const char* if an argument in quotes is passed - * (it will be - * converted to a string). The validator, if provided, needs to be a string - * validator. - * @param name :: The name to assign to the property - * @param value :: The initial value to assign to the property - * @param validator :: Pointer to the (optional) validator. Ownership will be - * taken over. - * @param doc :: The (optional) documentation string - * @param direction :: The (optional) direction of the property, in, out or - * inout - * @throw Exception::ExistsError if a property with the given name already - * exists - * @throw std::invalid_argument if the name argument is empty - */ - void declareProperty( - const std::string &name, const char *value, - IValidator_sptr validator = boost::make_shared<NullValidator>(), - const std::string &doc = std::string(), - const unsigned int direction = Direction::Input) { - // Simply call templated method, converting character array to a string - declareProperty(name, std::string(value), std::move(validator), doc, - direction); - } - - /** Specialised version of declareProperty template method to prevent the - * creation of a - * PropertyWithValue of type const char* if an argument in quotes is passed - * (it will be - * converted to a string). The validator, if provided, needs to be a string - * validator. - * @param name :: The name to assign to the property - * @param value :: The initial value to assign to the property - * @param doc :: The (optional) documentation string - * @param validator :: Pointer to the (optional) validator. Ownership will be - * taken over. - * @param direction :: The (optional) direction of the property, in, out or - * inout - * @throw Exception::ExistsError if a property with the given name already - * exists - * @throw std::invalid_argument if the name argument is empty - */ - void declareProperty( - const std::string &name, const char *value, const std::string &doc, - IValidator_sptr validator = boost::make_shared<NullValidator>(), - const unsigned int direction = Direction::Input) { - // Simply call templated method, converting character array to a string - declareProperty(name, std::string(value), std::move(validator), doc, - direction); - } - - /** Add a property of string type to the list of managed properties - * @param name :: The name to assign to the property - * @param value :: The initial value to assign to the property - * @param direction :: The direction of the property, in, out or inout - * @throw Exception::ExistsError if a property with the given name already - * exists - * @throw std::invalid_argument if the name argument is empty - */ - void declareProperty(const std::string &name, const char *value, - const unsigned int direction) { - declareProperty(name, std::string(value), - boost::make_shared<NullValidator>(), "", direction); - } - /// Get a property by an index virtual Property *getPointerToPropertyOrdinal(const int &index) const = 0; diff --git a/Framework/Kernel/inc/MantidKernel/MultiThreaded.h b/Framework/Kernel/inc/MantidKernel/MultiThreaded.h index 4111a097b632f9701363cf7c9405e3a2f97fb7a7..a05a85eabb9ef81f13243fb160a18dc6db2a6fba 100644 --- a/Framework/Kernel/inc/MantidKernel/MultiThreaded.h +++ b/Framework/Kernel/inc/MantidKernel/MultiThreaded.h @@ -85,8 +85,12 @@ void AtomicOp(std::atomic<T> &f, T d, BinaryOp op) { // GCC #ifdef _MSC_VER #define PRAGMA __pragma +#define PARALLEL_SET_CONFIG_THREADS #else //_MSC_VER #define PRAGMA(x) _Pragma(#x) +#define PARALLEL_SET_CONFIG_THREADS \ + setMaxCoresToConfig(); \ + PARALLEL_SET_DYNAMIC(false); #endif //_MSC_VER /** Begins a block to skip processing is the algorithm has been interupted @@ -136,8 +140,7 @@ void AtomicOp(std::atomic<T> &f, T d, BinaryOp op) { * code to be executed in parallel */ #define PARALLEL_FOR_IF(condition) \ - setMaxCoresToConfig(); \ - PARALLEL_SET_DYNAMIC(false); \ + PARALLEL_SET_CONFIG_THREADS \ PRAGMA(omp parallel for if (condition) ) /** Includes code to add OpenMP commands to run the next for loop in parallel. @@ -145,8 +148,7 @@ void AtomicOp(std::atomic<T> &f, T d, BinaryOp op) { * and therefore should not be used in any loops that access workspaces. */ #define PARALLEL_FOR_NO_WSP_CHECK() \ - setMaxCoresToConfig(); \ - PARALLEL_SET_DYNAMIC(false); \ + PARALLEL_SET_CONFIG_THREADS \ PRAGMA(omp parallel for) /** Includes code to add OpenMP commands to run the next for loop in parallel. @@ -155,13 +157,11 @@ void AtomicOp(std::atomic<T> &f, T d, BinaryOp op) { * and therefore should not be used in any loops that access workspace. */ #define PARALLEL_FOR_NOWS_CHECK_FIRSTPRIVATE(variable) \ - setMaxCoresToConfig(); \ - PARALLEL_SET_DYNAMIC(false); \ + PARALLEL_SET_CONFIG_THREADS \ PRAGMA(omp parallel for firstprivate(variable) ) #define PARALLEL_FOR_NO_WSP_CHECK_FIRSTPRIVATE2(variable1, variable2) \ - setMaxCoresToConfig(); \ - PARALLEL_SET_DYNAMIC(false); \ + PARALLEL_SET_CONFIG_THREADS \ PRAGMA(omp parallel for firstprivate(variable1, variable2) ) /** Ensures that the next execution line or block is only executed if diff --git a/Framework/Kernel/inc/MantidKernel/Unit.h b/Framework/Kernel/inc/MantidKernel/Unit.h index 0ee9e5ea615a2a89322f1cd24e535d57d910341b..a5f10d184ecbf4e7a58aa4fcc5f0537676143fb8 100644 --- a/Framework/Kernel/inc/MantidKernel/Unit.h +++ b/Framework/Kernel/inc/MantidKernel/Unit.h @@ -55,6 +55,10 @@ public: /// @return The unit label virtual const UnitLabel label() const = 0; + /// Returns if the unit can be used in conversions + /// @return true if the unit can be used in unit conversions + virtual bool isConvertible() const { return true; } + // Equality operators based on the value returned by unitID(); bool operator==(const Unit &u) const; bool operator!=(const Unit &u) const; @@ -189,9 +193,6 @@ protected: void addConversion(std::string to, const double &factor, const double &power = 1.0) const; - /// Removes all registered 'quick conversions' - void clearConversions() const; - /// The unit values have been initialized bool initialized; /// l1 :: The source-sample distance (in metres) @@ -244,6 +245,7 @@ public: const std::string caption() const override { return ""; } const UnitLabel label() const override; + bool isConvertible() const override { return false; } double singleToTOF(const double x) const override; double singleFromTOF(const double tof) const override; void init() override; @@ -595,6 +597,7 @@ public: const std::string caption() const override { return "t"; } const UnitLabel label() const override; + bool isConvertible() const override { return false; } double singleToTOF(const double x) const override; double singleFromTOF(const double tof) const override; double conversionTOFMax() const override; diff --git a/Framework/Kernel/inc/MantidKernel/UnitFactory.h b/Framework/Kernel/inc/MantidKernel/UnitFactory.h index f523bb06832e8416e992997ed95da8106a4047ff..553bd44d0c58dff0419853dfa84768fce2b3024f 100644 --- a/Framework/Kernel/inc/MantidKernel/UnitFactory.h +++ b/Framework/Kernel/inc/MantidKernel/UnitFactory.h @@ -18,9 +18,9 @@ #define DECLARE_UNIT(classname) \ namespace { \ Mantid::Kernel::RegistrationHelper \ - register_alg_##classname(((Mantid::Kernel::UnitFactory::Instance() \ - .subscribe<classname>(#classname)), \ - 0)); \ + register_unit_##classname(((Mantid::Kernel::UnitFactory::Instance() \ + .subscribe<classname>(#classname)), \ + 0)); \ } \ const std::string Mantid::Kernel::Units::classname::unitID() const { \ return #classname; \ @@ -32,15 +32,11 @@ #include "MantidKernel/DllConfig.h" #include "MantidKernel/DynamicFactory.h" #include "MantidKernel/SingletonHolder.h" +#include "MantidKernel/Unit.h" namespace Mantid { namespace Kernel { -//---------------------------------------------------------------------- -// Forward declaration -//---------------------------------------------------------------------- -class Unit; - /** Creates instances of concrete units. The factory is a singleton that hands out shared pointers to the base Unit class. @@ -58,6 +54,21 @@ public: UnitFactoryImpl(const UnitFactoryImpl &) = delete; UnitFactoryImpl &operator=(const UnitFactoryImpl &) = delete; + /// Returns the names of the convertible units in the factory + /// @return A string vector of keys + const std::vector<std::string> getConvertibleUnits() const { + std::vector<std::string> convertibleUnits; + + for (const auto &unitKey : getKeys()) { + const auto unit_sptr = create(unitKey); + if (unit_sptr->isConvertible()) { + convertibleUnits.emplace_back(unitKey); + } + } + + return convertibleUnits; + } + private: friend struct CreateUsingNew<UnitFactoryImpl>; diff --git a/Framework/Kernel/src/ConfigService.cpp b/Framework/Kernel/src/ConfigService.cpp index 3e152df83f0885447c13c52838079df5542f60e3..6da6bb138987257a3c225278c92fd2c8ccb24266 100644 --- a/Framework/Kernel/src/ConfigService.cpp +++ b/Framework/Kernel/src/ConfigService.cpp @@ -1622,9 +1622,8 @@ void ConfigServiceImpl::appendDataSearchDir(const std::string &path) { } if (!isInDataSearchList(dirPath.toString())) { std::string newSearchString; - std::vector<std::string>::const_iterator it = m_DataSearchDirs.begin(); - for (; it != m_DataSearchDirs.end(); ++it) { - newSearchString.append(*it); + for (const std::string &it : m_DataSearchDirs) { + newSearchString.append(it); newSearchString.append(";"); } newSearchString.append(path); @@ -1935,11 +1934,9 @@ const std::vector<FacilityInfo *> ConfigServiceImpl::getFacilities() const { */ const std::vector<std::string> ConfigServiceImpl::getFacilityNames() const { auto names = std::vector<std::string>(m_facilities.size()); - auto itFacilities = m_facilities.begin(); - auto itNames = names.begin(); - for (; itFacilities != m_facilities.end(); ++itFacilities, ++itNames) { - *itNames = (**itFacilities).name(); - } + std::transform(m_facilities.cbegin(), m_facilities.cend(), names.begin(), + [](const FacilityInfo *facility) { return facility->name(); }); + return names; } @@ -1982,22 +1979,28 @@ ConfigServiceImpl::getFacility(const std::string &facilityName) const { * @throw NotFoundException if the facility is not found */ void ConfigServiceImpl::setFacility(const std::string &facilityName) { - bool found = false; - // Look through the facilities for a matching one. - std::vector<FacilityInfo *>::const_iterator it = m_facilities.begin(); - for (; it != m_facilities.end(); ++it) { - if ((**it).name() == facilityName) { - // Found the facility - found = true; - // So it's safe to set it as our default - setString("default.facility", facilityName); - } - } - if (!found) { + const FacilityInfo *foundFacility = nullptr; + + try { + // Get facility looks up by string - so re-use that to check if the facility + // is known + foundFacility = &getFacility(facilityName); + } catch (const Exception::NotFoundError &) { g_log.error("Failed to set default facility to be " + facilityName + ". Facility not found"); - throw Exception::NotFoundError("Facilities", facilityName); + throw; } + assert(foundFacility); + setString("default.facility", facilityName); + + const auto &associatedInsts = foundFacility->instruments(); + if (associatedInsts.empty()) { + throw std::invalid_argument( + "The selected facility has no instruments associated with it"); + } + + // Update the default instrument to be one from this facility + setString("default.instrument", associatedInsts[0].name()); } /** Add an observer to a notification diff --git a/Framework/Kernel/src/Unit.cpp b/Framework/Kernel/src/Unit.cpp index fddcf6f76696dc6f3b17aac1fc7aac207a3be034..4aef9211ee611000128b412bc31725c274f0097b 100644 --- a/Framework/Kernel/src/Unit.cpp +++ b/Framework/Kernel/src/Unit.cpp @@ -103,12 +103,6 @@ void Unit::addConversion(std::string to, const double &factor, s_conversionFactors[unitID()][to] = std::make_pair(factor, power); } -//--------------------------------------------------------------------------------------- -/** Removes all registered 'quick conversions' from the unit class on which this - * method is called. - */ -void Unit::clearConversions() const { s_conversionFactors.clear(); } - //--------------------------------------------------------------------------------------- /** Initialize the unit to perform conversion using singleToTof() and *singleFromTof() @@ -1068,7 +1062,7 @@ DECLARE_UNIT(SpinEchoLength) const UnitLabel SpinEchoLength::label() const { return Symbol::Nanometre; } -SpinEchoLength::SpinEchoLength() : Wavelength() { clearConversions(); } +SpinEchoLength::SpinEchoLength() : Wavelength() {} void SpinEchoLength::init() { // Efixed must be set to something @@ -1119,7 +1113,7 @@ DECLARE_UNIT(SpinEchoTime) const UnitLabel SpinEchoTime::label() const { return Symbol::Nanosecond; } -SpinEchoTime::SpinEchoTime() : Wavelength() { clearConversions(); } +SpinEchoTime::SpinEchoTime() : Wavelength() {} void SpinEchoTime::init() { // Efixed must be set to something diff --git a/Framework/Kernel/test/ConfigServiceTest.h b/Framework/Kernel/test/ConfigServiceTest.h index cd1d15b008858095b16cf6171295a6531df96ce7..a4181d9037a1e1ffb80552807328560f2fa1566b 100644 --- a/Framework/Kernel/test/ConfigServiceTest.h +++ b/Framework/Kernel/test/ConfigServiceTest.h @@ -148,6 +148,22 @@ public: TS_ASSERT_EQUALS(fac1.name(), "SNS"); } + void testChangingDefaultFacilityChangesInst() { + // When changing default facility using the setFacility method, we should + // also change the instrument to the default so we don't get weird + // inst/facility combinations + auto &config = ConfigService::Instance(); + + config.setFacility("ISIS"); + const auto isisFirstInst = config.getString("default.instrument"); + TS_ASSERT(!isisFirstInst.empty()); + + config.setFacility("SNS"); + const auto snsFirstInst = config.getString("default.instrument"); + TS_ASSERT(!snsFirstInst.empty()); + TS_ASSERT_DIFFERS(snsFirstInst, isisFirstInst); + } + void testFacilityList() { std::vector<FacilityInfo *> facilities = ConfigService::Instance().getFacilities(); diff --git a/Framework/Kernel/test/UnitFactoryTest.h b/Framework/Kernel/test/UnitFactoryTest.h index 3e2838c9f3ebf8ae71b2391f5b025065d0ef8e97..961c398660cb4e71cf6f9fb8a0acb69e8eef8f27 100644 --- a/Framework/Kernel/test/UnitFactoryTest.h +++ b/Framework/Kernel/test/UnitFactoryTest.h @@ -29,6 +29,22 @@ public: TS_ASSERT_THROWS(UnitFactory::Instance().create("_NOT_A_REAL_UNIT"), const Exception::NotFoundError &); } + + void test_getKeys_Includes_Label_and_TOF() { + auto keys = UnitFactory::Instance().getKeys(); + TSM_ASSERT("Cannot find Label in the keys of the unit factory", + std::find(keys.begin(), keys.end(), "Label") != keys.end()); + TSM_ASSERT("Cannot find TOF in the keys of the unit factory", + std::find(keys.begin(), keys.end(), "TOF") != keys.end()); + } + + void test_getConvertibleUnits_Includes_TOF_but_not_label() { + auto keys = UnitFactory::Instance().getConvertibleUnits(); + TSM_ASSERT("Can find Label in the ConvertibleUnits of the unit factory", + std::find(keys.begin(), keys.end(), "Label") == keys.end()); + TSM_ASSERT("Cannot find TOF in the ConvertibleUnits of the unit factory", + std::find(keys.begin(), keys.end(), "TOF") != keys.end()); + } }; #endif /*UNITFACTORYTEST_H_*/ diff --git a/Framework/MDAlgorithms/CMakeLists.txt b/Framework/MDAlgorithms/CMakeLists.txt index d762e82c830a070ad07ef5bd981398604c5d1a61..074dc265c7bbbec3a472f97eaaba12f07adb1ef4 100644 --- a/Framework/MDAlgorithms/CMakeLists.txt +++ b/Framework/MDAlgorithms/CMakeLists.txt @@ -33,6 +33,7 @@ set(SRC_FILES src/ConvertToMDMinMaxLocal.cpp src/ConvertToMDParent.cpp src/ConvertToReflectometryQ.cpp + src/ConvertHFIRSCDtoMDE.cpp src/CreateMD.cpp src/CreateMDHistoWorkspace.cpp src/CreateMDWorkspace.cpp @@ -149,6 +150,7 @@ set( inc/MantidMDAlgorithms/ConvertToMDMinMaxLocal.h inc/MantidMDAlgorithms/ConvertToMDParent.h inc/MantidMDAlgorithms/ConvertToReflectometryQ.h + inc/MantidMDAlgorithms/ConvertHFIRSCDtoMDE.h inc/MantidMDAlgorithms/CreateMD.h inc/MantidMDAlgorithms/CreateMDHistoWorkspace.h inc/MantidMDAlgorithms/CreateMDWorkspace.h @@ -267,6 +269,7 @@ set(TEST_FILES ConvertToMDTest.h ConvertToQ3DdETest.h ConvertToReflectometryQTest.h + ConvertHFIRSCDtoMDETest.h CreateMDHistoWorkspaceTest.h CreateMDTest.h CreateMDWorkspaceTest.h diff --git a/Framework/MDAlgorithms/inc/MantidMDAlgorithms/ConvertHFIRSCDtoMDE.h b/Framework/MDAlgorithms/inc/MantidMDAlgorithms/ConvertHFIRSCDtoMDE.h new file mode 100644 index 0000000000000000000000000000000000000000..1acd2681c87dd49c69caee6830ba055cbafe5a83 --- /dev/null +++ b/Framework/MDAlgorithms/inc/MantidMDAlgorithms/ConvertHFIRSCDtoMDE.h @@ -0,0 +1,38 @@ +// Mantid Repository : https://github.com/mantidproject/mantid +// +// Copyright © 2018 ISIS Rutherford Appleton Laboratory UKRI, +// NScD Oak Ridge National Laboratory, European Spallation Source +// & Institut Laue - Langevin +// SPDX - License - Identifier: GPL - 3.0 + +#ifndef MANTID_MDALGORITHMS_CONVERTWANDSCDTOMDE_H_ +#define MANTID_MDALGORITHMS_CONVERTWANDSCDTOMDE_H_ + +#include "MantidMDAlgorithms/BoxControllerSettingsAlgorithm.h" +#include "MantidMDAlgorithms/DllConfig.h" + +namespace Mantid { +namespace MDAlgorithms { + +/** ConvertHFIRSCDtoMDE : TODO: DESCRIPTION + */ +class MANTID_MDALGORITHMS_DLL ConvertHFIRSCDtoMDE + : public BoxControllerSettingsAlgorithm { +public: + const std::string name() const override; + int version() const override; + const std::vector<std::string> seeAlso() const override { + return {"ConvertWANDSCDtoQ", "LoadWANDSCD"}; + } + const std::string category() const override; + const std::string summary() const override; + std::map<std::string, std::string> validateInputs() override; + +private: + void init() override; + void exec() override; +}; + +} // namespace MDAlgorithms +} // namespace Mantid + +#endif /* MANTID_MDALGORITHMS_CONVERTWANDSCDTOMDE_H_ */ diff --git a/Framework/MDAlgorithms/src/ConvertHFIRSCDtoMDE.cpp b/Framework/MDAlgorithms/src/ConvertHFIRSCDtoMDE.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ffed865f6bc9f1562017647179e7660285fab410 --- /dev/null +++ b/Framework/MDAlgorithms/src/ConvertHFIRSCDtoMDE.cpp @@ -0,0 +1,305 @@ +// Mantid Repository : https://github.com/mantidproject/mantid +// +// Copyright © 2018 ISIS Rutherford Appleton Laboratory UKRI, +// NScD Oak Ridge National Laboratory, European Spallation Source +// & Institut Laue - Langevin +// SPDX - License - Identifier: GPL - 3.0 + + +#include "MantidMDAlgorithms/ConvertHFIRSCDtoMDE.h" +#include "MantidAPI/IMDEventWorkspace.h" +#include "MantidAPI/IMDHistoWorkspace.h" +#include "MantidAPI/IMDIterator.h" +#include "MantidAPI/Run.h" +#include "MantidDataObjects/MDBoxBase.h" +#include "MantidDataObjects/MDEventFactory.h" +#include "MantidDataObjects/MDEventInserter.h" +#include "MantidGeometry/Instrument/DetectorInfo.h" +#include "MantidGeometry/MDGeometry/QSample.h" +#include "MantidKernel/ArrayProperty.h" +#include "MantidKernel/BoundedValidator.h" +#include "MantidKernel/ConfigService.h" +#include "MantidKernel/PropertyWithValue.h" +#include "MantidKernel/TimeSeriesProperty.h" +#include "MantidKernel/UnitLabelTypes.h" + +#include "Eigen/Dense" +#include "boost/math/constants/constants.hpp" + +namespace Mantid { +namespace MDAlgorithms { + +using Mantid::API::WorkspaceProperty; +using Mantid::Kernel::Direction; +using namespace Mantid::Geometry; +using namespace Mantid::Kernel; +using namespace Mantid::MDAlgorithms; +using namespace Mantid::API; +using namespace Mantid::DataObjects; + +// Register the algorithm into the AlgorithmFactory +DECLARE_ALGORITHM(ConvertHFIRSCDtoMDE) + +//---------------------------------------------------------------------------------------------- + +/// Algorithms name for identification. @see Algorithm::name +const std::string ConvertHFIRSCDtoMDE::name() const { + return "ConvertHFIRSCDtoMDE"; +} + +/// Algorithm's version for identification. @see Algorithm::version +int ConvertHFIRSCDtoMDE::version() const { return 1; } + +/// Algorithm's category for identification. @see Algorithm::category +const std::string ConvertHFIRSCDtoMDE::category() const { + return "MDAlgorithms\\Creation"; +} + +/// Algorithm's summary for use in the GUI and help. @see Algorithm::summary +const std::string ConvertHFIRSCDtoMDE::summary() const { + return "Convert from the detector vs scan index MDHistoWorkspace into a " + "MDEventWorkspace with units in Q_sample."; +} + +std::map<std::string, std::string> ConvertHFIRSCDtoMDE::validateInputs() { + std::map<std::string, std::string> result; + + API::IMDHistoWorkspace_sptr inputWS = this->getProperty("InputWorkspace"); + std::stringstream inputWSmsg; + if (inputWS->getNumDims() != 3) { + inputWSmsg << "Incorrect number of dimensions"; + } else if (inputWS->getDimension(0)->getName() != "y" || + inputWS->getDimension(1)->getName() != "x" || + inputWS->getDimension(2)->getName() != "scanIndex") { + inputWSmsg << "Wrong dimensions"; + } else if (inputWS->getNumExperimentInfo() == 0) { + inputWSmsg << "Missing experiment info"; + } else if (inputWS->getExperimentInfo(0)->getInstrument()->getName() != + "HB3A" && + inputWS->getExperimentInfo(0)->getInstrument()->getName() != + "WAND") { + inputWSmsg << "This only works for DEMAND (HB3A) or WAND (HB2C)"; + } else { + std::string instrument = + inputWS->getExperimentInfo(0)->getInstrument()->getName(); + const auto run = inputWS->getExperimentInfo(0)->run(); + size_t number_of_runs = inputWS->getDimension(2)->getNBins(); + std::vector<std::string> logs; + if (instrument == "HB3A") + logs = {"omega", "chi", "phi", "monitor", "time"}; + else + logs = {"duration", "monitor_count", "s1"}; + for (auto log : logs) { + if (run.hasProperty(log)) { + if (static_cast<size_t>(run.getLogData(log)->size()) != number_of_runs) + inputWSmsg << "Log " << log << " has incorrect length, "; + } else { + inputWSmsg << "Missing required log " << log << ", "; + } + } + } + if (!inputWSmsg.str().empty()) + result["InputWorkspace"] = inputWSmsg.str(); + + std::vector<double> minVals = this->getProperty("MinValues"); + std::vector<double> maxVals = this->getProperty("MaxValues"); + + if (minVals.size() != 3 || maxVals.size() != 3) { + std::stringstream msg; + msg << "Must provide 3 values, 1 for every dimension"; + result["MinValues"] = msg.str(); + result["MaxValues"] = msg.str(); + } else { + std::stringstream msg; + + size_t rank = minVals.size(); + for (size_t i = 0; i < rank; ++i) { + if (minVals[i] >= maxVals[i]) { + if (msg.str().empty()) + msg << "max not bigger than min "; + else + msg << ", "; + msg << "at index=" << (i + 1) << " (" << minVals[i] + << ">=" << maxVals[i] << ")"; + } + } + + if (!msg.str().empty()) { + result["MinValues"] = msg.str(); + result["MaxValues"] = msg.str(); + } + } + + return result; +} + +//---------------------------------------------------------------------------------------------- +/** Initialize the algorithm's properties. + */ +void ConvertHFIRSCDtoMDE::init() { + + declareProperty(std::make_unique<WorkspaceProperty<API::IMDHistoWorkspace>>( + "InputWorkspace", "", Direction::Input), + "An input workspace."); + declareProperty( + std::make_unique<PropertyWithValue<double>>( + "Wavelength", DBL_MAX, + boost::make_shared<BoundedValidator<double>>(0.0, 100.0, true), + Direction::Input), + "Wavelength"); + declareProperty( + std::make_unique<ArrayProperty<double>>("MinValues", "-10,-10,-10"), + "It has to be 3 comma separated values, one for each dimension in " + "q_sample." + "Values smaller then specified here will not be added to " + "workspace."); + declareProperty( + std::make_unique<ArrayProperty<double>>("MaxValues", "10,10,10"), + "A list of the same size and the same units as MinValues " + "list. Values higher or equal to the specified by " + "this list will be ignored"); + // Box controller properties. These are the defaults + this->initBoxControllerProps("5" /*SplitInto*/, 1000 /*SplitThreshold*/, + 20 /*MaxRecursionDepth*/); + declareProperty(std::make_unique<WorkspaceProperty<API::IMDEventWorkspace>>( + "OutputWorkspace", "", Direction::Output), + "An output workspace."); +} + +//---------------------------------------------------------------------------------------------- +/** Execute the algorithm. + */ +void ConvertHFIRSCDtoMDE::exec() { + double wavelength = this->getProperty("Wavelength"); + + API::IMDHistoWorkspace_sptr inputWS = this->getProperty("InputWorkspace"); + auto &expInfo = *(inputWS->getExperimentInfo(static_cast<uint16_t>(0))); + std::string instrument = expInfo.getInstrument()->getName(); + + std::vector<double> twotheta, azimuthal; + std::vector<double> s1, omega, chi, phi; + if (instrument == "HB3A") { + auto omegaLog = dynamic_cast<Kernel::TimeSeriesProperty<double> *>( + expInfo.run().getLogData("omega")); + omega = omegaLog->valuesAsVector(); + auto chiLog = dynamic_cast<Kernel::TimeSeriesProperty<double> *>( + expInfo.run().getLogData("chi")); + chi = chiLog->valuesAsVector(); + auto phiLog = dynamic_cast<Kernel::TimeSeriesProperty<double> *>( + expInfo.run().getLogData("phi")); + phi = phiLog->valuesAsVector(); + const auto &di = expInfo.detectorInfo(); + for (size_t x = 0; x < 512; x++) { + for (size_t y = 0; y < 512 * 3; y++) { + size_t n = x + y * 512; + if (!di.isMonitor(n)) { + twotheta.push_back(di.twoTheta(n)); + azimuthal.push_back(di.azimuthal(n)); + } + } + } + } else { // HB2C + s1 = (*(dynamic_cast<Kernel::PropertyWithValue<std::vector<double>> *>( + expInfo.getLog("s1"))))(); + azimuthal = + (*(dynamic_cast<Kernel::PropertyWithValue<std::vector<double>> *>( + expInfo.getLog("azimuthal"))))(); + twotheta = + (*(dynamic_cast<Kernel::PropertyWithValue<std::vector<double>> *>( + expInfo.getLog("twotheta"))))(); + } + + auto outputWS = DataObjects::MDEventFactory::CreateMDWorkspace(3, "MDEvent"); + Mantid::Geometry::QSample frame; + std::vector<double> minVals = this->getProperty("MinValues"); + std::vector<double> maxVals = this->getProperty("MaxValues"); + outputWS->addDimension(boost::make_shared<Geometry::MDHistoDimension>( + "Q_sample_x", "Q_sample_x", frame, static_cast<coord_t>(minVals[0]), + static_cast<coord_t>(maxVals[0]), 1)); + + outputWS->addDimension(boost::make_shared<Geometry::MDHistoDimension>( + "Q_sample_y", "Q_sample_y", frame, static_cast<coord_t>(minVals[1]), + static_cast<coord_t>(maxVals[1]), 1)); + + outputWS->addDimension(boost::make_shared<Geometry::MDHistoDimension>( + "Q_sample_z", "Q_sample_z", frame, static_cast<coord_t>(minVals[2]), + static_cast<coord_t>(maxVals[2]), 1)); + outputWS->setCoordinateSystem(Mantid::Kernel::QSample); + outputWS->initialize(); + + BoxController_sptr bc = outputWS->getBoxController(); + this->setBoxController(bc); + outputWS->splitBox(); + + auto mdws_mdevt_3 = + boost::dynamic_pointer_cast<MDEventWorkspace<MDEvent<3>, 3>>(outputWS); + MDEventInserter<MDEventWorkspace<MDEvent<3>, 3>::sptr> inserter(mdws_mdevt_3); + + float k = + boost::math::float_constants::two_pi / static_cast<float>(wavelength); + std::vector<Eigen::Vector3f> q_lab_pre; + q_lab_pre.reserve(azimuthal.size()); + for (size_t m = 0; m < azimuthal.size(); ++m) { + auto twotheta_f = static_cast<float>(twotheta[m]); + auto azimuthal_f = static_cast<float>(azimuthal[m]); + q_lab_pre.push_back({-sin(twotheta_f) * cos(azimuthal_f) * k, + -sin(twotheta_f) * sin(azimuthal_f) * k, + (1.f - cos(twotheta_f)) * k}); + } + + for (size_t n = 0; n < inputWS->getDimension(2)->getNBins(); n++) { + Eigen::Matrix3f goniometer; + if (instrument == "HB3A") { + float omega_radian = + static_cast<float>(omega[n]) * boost::math::float_constants::degree; + float chi_radian = + static_cast<float>(chi[n]) * boost::math::float_constants::degree; + float phi_radian = + static_cast<float>(phi[n]) * boost::math::float_constants::degree; + Eigen::Matrix3f r1; + r1 << cos(omega_radian), 0, -sin(omega_radian), 0, 1, 0, + sin(omega_radian), 0, cos(omega_radian); // omega 0,1,0,-1 + Eigen::Matrix3f r2; + r2 << cos(chi_radian), sin(chi_radian), 0, -sin(chi_radian), + cos(chi_radian), 0, 0, 0, 1; // chi 0,0,1,-1 + Eigen::Matrix3f r3; + r3 << cos(phi_radian), 0, -sin(phi_radian), 0, 1, 0, sin(phi_radian), 0, + cos(phi_radian); // phi 0,1,0,-1 + goniometer = r1 * r2 * r3; + } else { // HB2C + float s1_radian = + static_cast<float>(s1[n]) * boost::math::float_constants::degree; + goniometer << cos(s1_radian), 0, sin(s1_radian), 0, 1, 0, -sin(s1_radian), + 0, cos(s1_radian); // s1 0,1,0,1 + } + goniometer = goniometer.inverse().eval(); + for (size_t m = 0; m < azimuthal.size(); m++) { + size_t idx = n * azimuthal.size() + m; + coord_t signal = static_cast<coord_t>(inputWS->getSignalAt(idx)); + if (signal > 0.f) { + Eigen::Vector3f q_sample = goniometer * q_lab_pre[m]; + inserter.insertMDEvent(signal, signal, 0, 0, q_sample.data()); + } + } + } + + auto *ts = new ThreadSchedulerFIFO(); + ThreadPool tp(ts); + outputWS->splitAllIfNeeded(ts); + tp.joinAll(); + + outputWS->refreshCache(); + outputWS->copyExperimentInfos(*inputWS); + + auto user_convention = + Kernel::ConfigService::Instance().getString("Q.convention"); + auto ws_convention = outputWS->getConvention(); + if (user_convention != ws_convention) { + auto convention_alg = createChildAlgorithm("ChangeQConvention"); + convention_alg->setProperty("InputWorkspace", outputWS); + convention_alg->executeAsChildAlg(); + } + setProperty("OutputWorkspace", outputWS); +} + +} // namespace MDAlgorithms +} // namespace Mantid diff --git a/Framework/MDAlgorithms/src/IntegrateFlux.cpp b/Framework/MDAlgorithms/src/IntegrateFlux.cpp index 539559faf12011829874837aeac94a3951c7b9ff..f3c845fef3684513d96b3b7f7651d62362594e86 100644 --- a/Framework/MDAlgorithms/src/IntegrateFlux.cpp +++ b/Framework/MDAlgorithms/src/IntegrateFlux.cpp @@ -51,7 +51,7 @@ const std::string IntegrateFlux::category() const { /// Algorithm's summary for use in the GUI and help. @see Algorithm::summary const std::string IntegrateFlux::summary() const { - return "Interates spectra in a matrix workspace at a set of points."; + return "Integrates spectra in a matrix workspace at a set of points."; } //---------------------------------------------------------------------------------------------- diff --git a/Framework/MDAlgorithms/test/ConvertHFIRSCDtoMDETest.h b/Framework/MDAlgorithms/test/ConvertHFIRSCDtoMDETest.h new file mode 100644 index 0000000000000000000000000000000000000000..c19f9f4f973358933d49a06ba06a77a1a431d444 --- /dev/null +++ b/Framework/MDAlgorithms/test/ConvertHFIRSCDtoMDETest.h @@ -0,0 +1,99 @@ +// Mantid Repository : https://github.com/mantidproject/mantid +// +// Copyright © 2018 ISIS Rutherford Appleton Laboratory UKRI, +// NScD Oak Ridge National Laboratory, European Spallation Source +// & Institut Laue - Langevin +// SPDX - License - Identifier: GPL - 3.0 + +#ifndef MANTID_MDALGORITHMS_CONVERTWANDSCDTOMDETEST_H_ +#define MANTID_MDALGORITHMS_CONVERTWANDSCDTOMDETEST_H_ + +#include <cxxtest/TestSuite.h> + +#include "MantidAPI/AnalysisDataService.h" +#include "MantidAPI/FileFinder.h" +#include "MantidAPI/IMDEventWorkspace.h" +#include "MantidAPI/Workspace_fwd.h" +#include "MantidDataObjects/MDEventWorkspace.h" +#include "MantidMDAlgorithms/ConvertHFIRSCDtoMDE.h" +#include "MantidMDAlgorithms/LoadMD.h" + +using namespace Mantid::API; +using namespace Mantid::MDAlgorithms; + +class ConvertHFIRSCDtoMDETest : public CxxTest::TestSuite { +public: + // This pair of boilerplate methods prevent the suite being created statically + // This means the constructor isn't called when running other tests + static ConvertHFIRSCDtoMDETest *createSuite() { + return new ConvertHFIRSCDtoMDETest(); + } + static void destroySuite(ConvertHFIRSCDtoMDETest *suite) { delete suite; } + + void test_Init() { + ConvertHFIRSCDtoMDE alg; + TS_ASSERT_THROWS_NOTHING(alg.initialize()) + TS_ASSERT(alg.isInitialized()) + } + + void test_exec() { + // Create test input if necessary + LoadMD loader; + loader.initialize(); + loader.setPropertyValue( + "Filename", + Mantid::API::FileFinder::Instance().getFullPath("HB3A_data.nxs")); + loader.setPropertyValue("OutputWorkspace", "ConvertHFIRSCDtoMDETest_data"); + loader.setProperty("FileBackEnd", false); + loader.execute(); + auto inputWS = Mantid::API::AnalysisDataService::Instance() + .retrieveWS<Mantid::API::IMDHistoWorkspace>( + "ConvertHFIRSCDtoMDETest_data"); + + ConvertHFIRSCDtoMDE alg; + // Don't put output in ADS by default + alg.setChild(true); + TS_ASSERT_THROWS_NOTHING(alg.initialize()) + TS_ASSERT(alg.isInitialized()) + TS_ASSERT_THROWS_NOTHING( + alg.setProperty("InputWorkspace", "ConvertHFIRSCDtoMDETest_data")); + TS_ASSERT_THROWS_NOTHING(alg.setProperty("Wavelength", "1.008")); + TS_ASSERT_THROWS_NOTHING( + alg.setPropertyValue("OutputWorkspace", "_unused_for_child")); + TS_ASSERT_THROWS_NOTHING(alg.execute();); + TS_ASSERT(alg.isExecuted()); + + // Retrieve the workspace from the algorithm. The type here will probably + // need to change. It should be the type using in declareProperty for the + // "OutputWorkspace" type. We can't use auto as it's an implicit conversion. + IMDEventWorkspace_sptr outWS = alg.getProperty("OutputWorkspace"); + TS_ASSERT(outWS); + + // check dimensions + TS_ASSERT_EQUALS(3, outWS->getNumDims()); + TS_ASSERT_EQUALS(Mantid::Kernel::QSample, + outWS->getSpecialCoordinateSystem()); + TS_ASSERT_EQUALS("QSample", outWS->getDimension(0)->getMDFrame().name()); + TS_ASSERT_EQUALS(true, outWS->getDimension(0)->getMDUnits().isQUnit()); + TS_ASSERT_EQUALS(-10, outWS->getDimension(0)->getMinimum()); + TS_ASSERT_EQUALS(10, outWS->getDimension(0)->getMaximum()); + TS_ASSERT_EQUALS("QSample", outWS->getDimension(1)->getMDFrame().name()); + TS_ASSERT_EQUALS(true, outWS->getDimension(1)->getMDUnits().isQUnit()); + TS_ASSERT_EQUALS(-10, outWS->getDimension(1)->getMinimum()); + TS_ASSERT_EQUALS(10, outWS->getDimension(1)->getMaximum()); + TS_ASSERT_EQUALS("QSample", outWS->getDimension(2)->getMDFrame().name()); + TS_ASSERT_EQUALS(true, outWS->getDimension(2)->getMDUnits().isQUnit()); + TS_ASSERT_EQUALS(-10, outWS->getDimension(2)->getMinimum()); + TS_ASSERT_EQUALS(10, outWS->getDimension(2)->getMaximum()); + + // check other things + TS_ASSERT_EQUALS(1, outWS->getNumExperimentInfo()); + TS_ASSERT_EQUALS(9038, outWS->getNEvents()); + const Mantid::coord_t coords[3] = { + -0.42f, 1.71f, 2.3f}; // roughly the location of maximum instenity + TS_ASSERT_DELTA( + outWS->getSignalAtCoord(coords, Mantid::API::NoNormalization), 568, + 1e-5); + } +}; + +#endif /* MANTID_MDALGORITHMS_CONVERTWANDSCDTOMDETEST_H_ */ diff --git a/Framework/NexusGeometry/inc/MantidNexusGeometry/TubeHelpers.h b/Framework/NexusGeometry/inc/MantidNexusGeometry/TubeHelpers.h index f91ab9769b9f6e5ea1c82ff8c73748e753bbced4..411b0ac64165215dd4efb5275ba994db0afa10f9 100644 --- a/Framework/NexusGeometry/inc/MantidNexusGeometry/TubeHelpers.h +++ b/Framework/NexusGeometry/inc/MantidNexusGeometry/TubeHelpers.h @@ -29,7 +29,7 @@ findAndSortTubes(const Mantid::Geometry::IObject &detShape, const std::vector<Mantid::detid_t> &detIDs); MANTID_NEXUSGEOMETRY_DLL std::vector<Mantid::detid_t> notInTubes(const std::vector<detail::TubeBuilder> &tubes, - const std::vector<Mantid::detid_t> &detIDs); + std::vector<Mantid::detid_t> detIDs); } // namespace TubeHelpers } // namespace NexusGeometry } // namespace Mantid diff --git a/Framework/NexusGeometry/src/TubeHelpers.cpp b/Framework/NexusGeometry/src/TubeHelpers.cpp index eb21e27c9e2dc8b34db45354424ca4537621f95f..2dbec668110cdb62c829a00364f68f7d0f1a6cdc 100644 --- a/Framework/NexusGeometry/src/TubeHelpers.cpp +++ b/Framework/NexusGeometry/src/TubeHelpers.cpp @@ -67,7 +67,7 @@ findAndSortTubes(const Mantid::Geometry::IObject &detShape, */ std::vector<Mantid::detid_t> notInTubes(const std::vector<detail::TubeBuilder> &tubes, - const std::vector<Mantid::detid_t> &detIDs) { + std::vector<Mantid::detid_t> detIDs) { std::vector<Mantid::detid_t> used; for (const auto &tube : tubes) { for (const auto &id : tube.detIDs()) { @@ -75,6 +75,8 @@ notInTubes(const std::vector<detail::TubeBuilder> &tubes, } } std::vector<Mantid::detid_t> diff; + std::sort(detIDs.begin(), detIDs.end()); + std::sort(used.begin(), used.end()); std::set_difference(detIDs.begin(), detIDs.end(), used.begin(), used.end(), std::inserter(diff, diff.begin())); return diff; diff --git a/Framework/NexusGeometry/test/JSONInstrumentBuilderTest.h b/Framework/NexusGeometry/test/JSONInstrumentBuilderTest.h index 0019c8f74d3d9e710f9fc74f5211399510571694..b1371c7b4b52afb424a42ba660c38e7aee6d9ab6 100644 --- a/Framework/NexusGeometry/test/JSONInstrumentBuilderTest.h +++ b/Framework/NexusGeometry/test/JSONInstrumentBuilderTest.h @@ -32,7 +32,7 @@ public: } void test_constructor_fail_invalid_instrument() { - TS_ASSERT_THROWS(JSONInstrumentBuilder(""), std::invalid_argument); + TS_ASSERT_THROWS(JSONInstrumentBuilder(""), const std::invalid_argument &); } void test_build_geometry() { diff --git a/Framework/NexusGeometry/test/NexusGeometrySaveTest.h b/Framework/NexusGeometry/test/NexusGeometrySaveTest.h index 17d5d26b6c459ac60f957905066b06cc0de932f8..5c8fe184e854faddbb366c64f15b02abe5e3a0f0 100644 --- a/Framework/NexusGeometry/test/NexusGeometrySaveTest.h +++ b/Framework/NexusGeometry/test/NexusGeometrySaveTest.h @@ -109,7 +109,7 @@ used. TS_ASSERT_THROWS(NexusGeometrySave::saveInstrument( instr, destinationFile, DEFAULT_ROOT_ENTRY_NAME, logger, false /*append*/), - std::invalid_argument); + const std::invalid_argument &); TS_ASSERT(testing::Mock::VerifyAndClearExpectations(&logger)); } diff --git a/Framework/PythonInterface/mantid/api/src/Exports/ADSValidator.cpp b/Framework/PythonInterface/mantid/api/src/Exports/ADSValidator.cpp index 21ff2b4bfc6a0b5c6c1538de3e1eeeae7eb707d4..2746f3785167d83ddcb41fbdfb1dd76e7e6b968a 100644 --- a/Framework/PythonInterface/mantid/api/src/Exports/ADSValidator.cpp +++ b/Framework/PythonInterface/mantid/api/src/Exports/ADSValidator.cpp @@ -16,10 +16,11 @@ using namespace boost::python; /// This is the base TypedValidator for most of the WorkspaceValidators void export_ADSValidator() { - TypedValidatorExporter<std::string>::define("StringTypedValidator"); + TypedValidatorExporter<std::vector<std::string>>::define( + "StringTypedValidator"); - class_<ADSValidator, bases<TypedValidator<std::string>>, boost::noncopyable>( - "ADSValidator", init<>("Default constructor")) + class_<ADSValidator, bases<TypedValidator<std::vector<std::string>>>, + boost::noncopyable>("ADSValidator", init<>("Default constructor")) .def(init<const bool, const bool>( "Constructor setting allowMultiple and isOptional.", args("allowMultipleSelection", "isOptional"))) diff --git a/Framework/PythonInterface/mantid/api/src/Exports/IEventList.cpp b/Framework/PythonInterface/mantid/api/src/Exports/IEventList.cpp index 30453b7dbde7cdb4ba692cfddafc98930d5d8c09..825c5ddfaef7b31f55559dae4db6a5292288805c 100644 --- a/Framework/PythonInterface/mantid/api/src/Exports/IEventList.cpp +++ b/Framework/PythonInterface/mantid/api/src/Exports/IEventList.cpp @@ -26,10 +26,15 @@ using namespace boost::python; GET_POINTER_SPECIALIZATION(IEventList) namespace { +void addPulsetimes(IEventList &self, const NDArray &data) { + self.addPulsetimes(Converters::NDArrayToVector<double>(data)()); +} + void maskCondition(IEventList &self, const NDArray &data) { self.maskCondition(Converters::NDArrayToVector<bool>(data)()); } } // namespace + /// return_value_policy for copied numpy array using return_clone_numpy = return_value_policy<Policies::VectorToNumpy>; @@ -71,6 +76,9 @@ void export_IEventList() { .def("addPulsetime", &IEventList::addPulsetime, args("self", "seconds"), "Add an offset to the pulsetime (wall-clock time) of each event in " "the list.") + .def("addPulsetimes", &addPulsetimes, args("self", "seconds"), + "Add offsets to the pulsetime (wall-clock time) of each event in " + "the list.") .def("maskTof", &IEventList::maskTof, args("self", "tofMin", "tofMax"), "Mask out events that have a tof between tofMin and tofMax " "(inclusively)") diff --git a/Framework/PythonInterface/mantid/geometry/src/Exports/MDFrame.cpp b/Framework/PythonInterface/mantid/geometry/src/Exports/MDFrame.cpp index ef696e91ae61dd2b51bbca27dfc153c65bf66a6a..964d9e99b846f62e8b41f096b1c314b626db7995 100644 --- a/Framework/PythonInterface/mantid/geometry/src/Exports/MDFrame.cpp +++ b/Framework/PythonInterface/mantid/geometry/src/Exports/MDFrame.cpp @@ -22,5 +22,6 @@ void export_MDFrame() { class_<MDFrame, boost::noncopyable>("MDFrame", no_init) .def("getUnitLabel", &MDFrame::getUnitLabel, arg("self")) - .def("name", &MDFrame::name, arg("self")); + .def("name", &MDFrame::name, arg("self")) + .def("isQ", &MDFrame::isQ, arg("self")); } diff --git a/Framework/PythonInterface/mantid/kernel/src/Exports/Material.cpp b/Framework/PythonInterface/mantid/kernel/src/Exports/Material.cpp index 9f5ab31bd8eb233892fc6122ee8764fb7ce0fe81..23d3ca860a5f34936724ff32564fa9002605fcd0 100644 --- a/Framework/PythonInterface/mantid/kernel/src/Exports/Material.cpp +++ b/Framework/PythonInterface/mantid/kernel/src/Exports/Material.cpp @@ -87,24 +87,19 @@ void export_Material() { "Returns True if any of the scattering values are non-zero") #endif .def("cohScatterXSection", - (double (Material::*)(double) const)(&Material::cohScatterXSection), - (arg("self"), - arg("lambda") = static_cast<double>(NeutronAtom::ReferenceLambda)), + (double (Material::*)() const)(&Material::cohScatterXSection), + (arg("self")), "Coherent Scattering Cross-Section for the given wavelength in " "barns") - .def( - "incohScatterXSection", - (double (Material::*)(double) const)(&Material::incohScatterXSection), - (arg("self"), - arg("lambda") = static_cast<double>(NeutronAtom::ReferenceLambda)), - "Incoherent Scattering Cross-Section for the given wavelength in " - "barns") - .def( - "totalScatterXSection", - (double (Material::*)(double) const)(&Material::totalScatterXSection), - (arg("self"), - arg("lambda") = static_cast<double>(NeutronAtom::ReferenceLambda)), - "Total Scattering Cross-Section for the given wavelength in barns") + .def("incohScatterXSection", + (double (Material::*)() const)(&Material::incohScatterXSection), + (arg("self")), + "Incoherent Scattering Cross-Section for the given wavelength in " + "barns") + .def("totalScatterXSection", + (double (Material::*)() const)(&Material::totalScatterXSection), + (arg("self")), + "Total Scattering Cross-Section for the given wavelength in barns") .def("absorbXSection", (double (Material::*)(double) const)(&Material::absorbXSection), (arg("self"), @@ -154,6 +149,7 @@ void export_Material() { "Imaginary part of Incoherent Scattering Length for the given " "wavelength " "in fm") + .def( "cohScatterLengthSqrd", (double (Material::*)(double) const)(&Material::cohScatterLengthSqrd), diff --git a/Framework/PythonInterface/mantid/plots/__init__.py b/Framework/PythonInterface/mantid/plots/__init__.py index f6c6188e9327f92ecd2d6e80886e396f274f59a6..86c9847721583276fe74cf9b441691ae020c0bbc 100644 --- a/Framework/PythonInterface/mantid/plots/__init__.py +++ b/Framework/PythonInterface/mantid/plots/__init__.py @@ -150,6 +150,7 @@ class _WorkspaceArtists(object): else: new_artists = self._data_replace_cb(self._artists, workspace) self._set_artists(new_artists) + return len(self._artists) == 0 def _set_artists(self, artists): """Ensure the stored artists is an iterable""" @@ -421,10 +422,7 @@ class MantidAxes(Axes): :param unary_predicate: A predicate taking a single matplotlib artist object :return: True if the artist_info is empty, false if artist_info remain """ - is_empty_list = [] - for workspace_artist in artist_info: - empty = workspace_artist.remove_if(self, unary_predicate) - is_empty_list.append(empty) + is_empty_list = [workspace_artist.remove_if(self, unary_predicate) for workspace_artist in artist_info] for index, empty in reversed(list(enumerate(is_empty_list))): if empty: @@ -444,8 +442,12 @@ class MantidAxes(Axes): except KeyError: return False - for workspace_artist in artist_info: - workspace_artist.replace_data(workspace) + is_empty_list = [workspace_artist.replace_data(workspace) for workspace_artist in artist_info] + + for index, empty in reversed(list(enumerate(is_empty_list))): + if empty: + artist_info.pop(index) + return True def replot_artist(self, artist, errorbars=False, **kwargs): @@ -592,16 +594,37 @@ class MantidAxes(Axes): if helperfunctions.validate_args(*args): logger.debug('using plotfunctions') - autoscale = kwargs.pop("autoscale_on_update", self.get_autoscale_on()) + autoscale_on = kwargs.pop("autoscale_on_update", self.get_autoscale_on()) def _data_update(artists, workspace, new_kwargs=None): # It's only possible to plot 1 line at a time from a workspace + try: + if new_kwargs: + x, y, _, _ = plotfunctions._plot_impl(self, workspace, args, new_kwargs) + else: + x, y, _, _ = plotfunctions._plot_impl(self, workspace, args, kwargs) + artists[0].set_data(x, y) + except RuntimeError as ex: + # if curve couldn't be plotted then remove it - can happen if the workspace doesn't contain the + # spectrum any more following execution of an algorithm + logger.information('Curve not plotted: {0}'.format(ex.args[0])) + + # remove the curve using similar to logic that in _WorkspaceArtists._remove + artists[0].remove() + + # blank out list that will be returned + artists = [] + + # also remove the curve from the legend + if (not self.is_empty(self)) and self.legend_ is not None: + self.legend().draggable() + if new_kwargs: - x, y, _, __ = plotfunctions._plot_impl(self, workspace, args, new_kwargs) + _autoscale_on = new_kwargs.pop("autoscale_on_update", self.get_autoscale_on()) else: - x, y, _, __ = plotfunctions._plot_impl(self, workspace, args, kwargs) - artists[0].set_data(x, y) - if new_kwargs and new_kwargs.pop('autoscale_on_update', self.get_autoscale_on()): + _autoscale_on = self.get_autoscale_on() + + if _autoscale_on: self.relim() self.autoscale() return artists @@ -611,14 +634,7 @@ class MantidAxes(Axes): normalize_by_bin_width, kwargs = get_normalize_by_bin_width(workspace, self, **kwargs) is_normalized = normalize_by_bin_width or workspace.isDistribution() - # If we are making the first plot on an axes object - # i.e. self.lines is empty, axes has default ylim values. - # Therefore we need to autoscale regardless of autoscale_on_update. - if self.lines: - # Otherwise set autoscale to autoscale_on_update. - self.set_autoscaley_on(autoscale_on_update) - - with autoscale_on_update(self, autoscale): + with autoscale_on_update(self, autoscale_on): artist = self.track_workspace_artist(workspace, plotfunctions.plot(self, *args, **kwargs), _data_update, spec_num, is_normalized, @@ -675,13 +691,13 @@ class MantidAxes(Axes): if helperfunctions.validate_args(*args): logger.debug('using plotfunctions') - autoscale = kwargs.pop("autoscale_on_update", self.get_autoscale_on()) + autoscale_on = kwargs.pop("autoscale_on_update", self.get_autoscale_on()) def _data_update(artists, workspace, new_kwargs=None): if new_kwargs: - _autoscale = new_kwargs.pop("autoscale_on_update", self.get_autoscale_on()) + _autoscale_on = new_kwargs.pop("autoscale_on_update", self.get_autoscale_on()) else: - _autoscale = self.get_autoscale_on() + _autoscale_on = self.get_autoscale_on() # errorbar with workspaces can only return a single container container_orig = artists[0] # It is not possible to simply reset the error bars so @@ -694,28 +710,38 @@ class MantidAxes(Axes): self.containers.remove(container_orig) except ValueError: pass - with autoscale_on_update(self, _autoscale): - # this gets pushed back onto the containers list - if new_kwargs: - container_new = plotfunctions.errorbar(self, workspace, **new_kwargs) - else: - container_new = plotfunctions.errorbar(self, workspace, **kwargs) - self.containers.insert(orig_idx, container_new) - self.containers.pop() - - # Update joining line - if container_new[0] and container_orig[0]: - container_new[0].update_from(container_orig[0]) - # Update caps - for orig_caps, new_caps in zip(container_orig[1], container_new[1]): - new_caps.update_from(orig_caps) - # Update bars - for orig_bars, new_bars in zip(container_orig[2], container_new[2]): - new_bars.update_from(orig_bars) - - # Re-plotting in the config dialog will assign this attr - if hasattr(container_orig, 'errorevery'): - setattr(container_new, 'errorevery', container_orig.errorevery) + # this gets pushed back onto the containers list + try: + with autoscale_on_update(self, _autoscale_on): + # this gets pushed back onto the containers list + if new_kwargs: + container_new = plotfunctions.errorbar(self, workspace, **new_kwargs) + else: + container_new = plotfunctions.errorbar(self, workspace, **kwargs) + + self.containers.insert(orig_idx, container_new) + self.containers.pop() + # Update joining line + if container_new[0] and container_orig[0]: + container_new[0].update_from(container_orig[0]) + # Update caps + for orig_caps, new_caps in zip(container_orig[1], container_new[1]): + new_caps.update_from(orig_caps) + # Update bars + for orig_bars, new_bars in zip(container_orig[2], container_new[2]): + new_bars.update_from(orig_bars) + # Re-plotting in the config dialog will assign this attr + if hasattr(container_orig, 'errorevery'): + setattr(container_new, 'errorevery', container_orig.errorevery) + + # ax.relim does not support collections... + self._update_line_limits(container_new[0]) + except RuntimeError as ex: + logger.information('Error bar not plotted: {0}'.format(ex.args[0])) + container_new = [] + # also remove the curve from the legend + if (not self.is_empty(self)) and self.legend_ is not None: + self.legend().draggable() return container_new @@ -723,10 +749,7 @@ class MantidAxes(Axes): spec_num = self.get_spec_number_or_bin(workspace, kwargs) is_normalized, kwargs = get_normalize_by_bin_width(workspace, self, **kwargs) - if self.lines: - self.set_autoscaley_on(autoscale_on_update) - - with autoscale_on_update(self, autoscale): + with autoscale_on_update(self, autoscale_on): artist = self.track_workspace_artist(workspace, plotfunctions.errorbar(self, *args, **kwargs), _data_update, spec_num, is_normalized, @@ -843,7 +866,8 @@ class MantidAxes(Axes): workspace = args[0] normalize_by_bin_width, _ = get_normalize_by_bin_width(workspace, self, **kwargs) - is_normalized = normalize_by_bin_width or workspace.isDistribution() + is_normalized = normalize_by_bin_width or \ + (hasattr(workspace, 'isDistribution') and workspace.isDistribution()) # We return the last mesh so the return type is a single artist like the standard Axes artists = self.track_workspace_artist(workspace, plotfunctions_func(self, *args, **kwargs), diff --git a/Framework/PythonInterface/mantid/plots/helperfunctions.py b/Framework/PythonInterface/mantid/plots/helperfunctions.py index bdd05ac1c4d49bc97d22894e64f18f086c707647..2cfcdb642634c0f68ea60f8059a24b86e3802aea 100644 --- a/Framework/PythonInterface/mantid/plots/helperfunctions.py +++ b/Framework/PythonInterface/mantid/plots/helperfunctions.py @@ -17,7 +17,7 @@ from scipy.interpolate import interp1d import mantid.api import mantid.kernel -from mantid.api import MultipleExperimentInfos +from mantid.api import MultipleExperimentInfos, MatrixWorkspace from mantid.dataobjects import EventWorkspace, MDHistoWorkspace, Workspace2D from mantid.plots.utility import MantidAxType @@ -228,9 +228,15 @@ def _get_wksp_index_and_spec_num(workspace, axis, **kwargs): # convert the spectrum number to a workspace index and vice versa if spectrum_number is not None: - workspace_index = workspace.getIndexFromSpectrumNumber(int(spectrum_number)) + try: + workspace_index = workspace.getIndexFromSpectrumNumber(int(spectrum_number)) + except RuntimeError: + raise RuntimeError('Spectrum Number {0} not found in workspace {1}'.format(spectrum_number,workspace.name())) elif axis == MantidAxType.SPECTRUM: # Only get a spectrum number if we're traversing the spectra - spectrum_number = workspace.getSpectrum(workspace_index).getSpectrumNo() + try: + spectrum_number = workspace.getSpectrum(workspace_index).getSpectrumNo() + except RuntimeError: + raise RuntimeError('Workspace index {0} not found in workspace {1}'.format(workspace_index,workspace.name())) return workspace_index, spectrum_number, kwargs @@ -566,7 +572,7 @@ def get_data_uneven_flag(workspace, **kwargs): def check_resample_to_regular_grid(ws, **kwargs): - if not isinstance(ws, MDHistoWorkspace): + if isinstance(ws, MatrixWorkspace): aligned = kwargs.pop('axisaligned', False) if not ws.isCommonBins() or aligned: return True, kwargs diff --git a/Framework/PythonInterface/mantid/plots/utility.py b/Framework/PythonInterface/mantid/plots/utility.py index 6b42a239e0c162a62a8e8c1ab7f5c83da73ea161..4e02c746b12961434f6a1d815bab46bf307c7321 100644 --- a/Framework/PythonInterface/mantid/plots/utility.py +++ b/Framework/PythonInterface/mantid/plots/utility.py @@ -54,23 +54,28 @@ def autoscale_on_update(ax, state, axis='both'): """ original_state = ax.get_autoscale_on() try: - if axis == 'both': - original_state = ax.get_autoscale_on() - ax.set_autoscale_on(state) - elif axis == 'x': - original_state = ax.get_autoscalex_on() - ax.set_autoscalex_on(state) - elif axis == 'y': - original_state = ax.get_autoscaley_on() - ax.set_autoscaley_on(state) + # If we are making the first plot on an axes object + # i.e. ax.lines is empty, axes has default ylim values. + # Therefore we need to autoscale regardless of state parameter. + if ax.lines: + if axis == 'both': + original_state = ax.get_autoscale_on() + ax.set_autoscale_on(state) + elif axis == 'x': + original_state = ax.get_autoscalex_on() + ax.set_autoscalex_on(state) + elif axis == 'y': + original_state = ax.get_autoscaley_on() + ax.set_autoscaley_on(state) yield finally: - if axis == 'both': - ax.set_autoscale_on(original_state) - elif axis == 'x': - ax.set_autoscalex_on(original_state) - elif axis == 'y': - ax.set_autoscaley_on(original_state) + if ax.lines: + if axis == 'both': + ax.set_autoscale_on(original_state) + elif axis == 'x': + ax.set_autoscalex_on(original_state) + elif axis == 'y': + ax.set_autoscaley_on(original_state) def find_errorbar_container(line, containers): diff --git a/Framework/PythonInterface/plugins/algorithms/ConvertWANDSCDtoQ.py b/Framework/PythonInterface/plugins/algorithms/ConvertWANDSCDtoQ.py index dfd0d3bf02742b9a3fdf1bc4d5b648d12c03d02b..ed5e07670ead025a9bf8ef5cc5e5aa09f4a5ff47 100644 --- a/Framework/PythonInterface/plugins/algorithms/ConvertWANDSCDtoQ.py +++ b/Framework/PythonInterface/plugins/algorithms/ConvertWANDSCDtoQ.py @@ -19,7 +19,7 @@ class ConvertWANDSCDtoQ(PythonAlgorithm): return 'DataHandling\\Nexus' def seeAlso(self): - return [ "LoadWANDSCD" ] + return [ "LoadWANDSCD", "ConvertHFIRSCDtoMDE" ] def name(self): return 'ConvertWANDSCDtoQ' diff --git a/Framework/PythonInterface/plugins/algorithms/WorkflowAlgorithms/ApplyPaalmanPingsCorrection.py b/Framework/PythonInterface/plugins/algorithms/WorkflowAlgorithms/ApplyPaalmanPingsCorrection.py index b474259ad07d0a6a8d0d74abd596cc963eb81a25..753e69b81ec58d40c55f4db9e0ce2a6d73441d3e 100644 --- a/Framework/PythonInterface/plugins/algorithms/WorkflowAlgorithms/ApplyPaalmanPingsCorrection.py +++ b/Framework/PythonInterface/plugins/algorithms/WorkflowAlgorithms/ApplyPaalmanPingsCorrection.py @@ -7,7 +7,7 @@ # pylint: disable=no-init,too-many-instance-attributes from __future__ import (absolute_import, division, print_function) -import mantid.simpleapi as s_api +from mantid.simpleapi import * from mantid.api import PythonAlgorithm, AlgorithmFactory, MatrixWorkspaceProperty, WorkspaceGroupProperty, \ PropertyMode, MatrixWorkspace, Progress, WorkspaceGroup from mantid.kernel import Direction, logger @@ -63,6 +63,7 @@ class ApplyPaalmanPingsCorrection(PythonAlgorithm): # pylint: disable=too-many-branches def PyExec(self): + self._setup() if not self._use_corrections: logger.information('Not using corrections') if not self._use_can: @@ -95,10 +96,10 @@ class ApplyPaalmanPingsCorrection(PythonAlgorithm): correction_type = 'sample_corrections_only' # Add corrections filename to log values prog_corr.report('Correcting sample') - s_api.AddSampleLog(Workspace=output_workspace, - LogName='corrections_filename', - LogType='String', - LogText=self._corrections_ws_name) + AddSampleLog(Workspace=output_workspace, + LogName='corrections_filename', + LogType='String', + LogText=self._corrections_ws_name) else: # Do simple subtraction output_workspace = self._subtract(sample_ws_wavelength, container_ws_wavelength) @@ -107,34 +108,34 @@ class ApplyPaalmanPingsCorrection(PythonAlgorithm): can_base = self.getPropertyValue("CanWorkspace") can_base = can_base[:can_base.index('_')] prog_corr.report('Adding container filename') - s_api.AddSampleLog(Workspace=output_workspace, - LogName='container_filename', - LogType='String', - LogText=can_base) + AddSampleLog(Workspace=output_workspace, + LogName='container_filename', + LogType='String', + LogText=can_base) prog_wrkflow = Progress(self, 0.6, 1.0, nreports=5) # Record the container scale factor if self._use_can and self._scale_can: prog_wrkflow.report('Adding container scaling') - s_api.AddSampleLog(Workspace=output_workspace, - LogName='container_scale', - LogType='Number', - LogText=str(self._can_scale_factor)) + AddSampleLog(Workspace=output_workspace, + LogName='container_scale', + LogType='Number', + LogText=str(self._can_scale_factor)) # Record the container shift amount if self._use_can and self._shift_can: prog_wrkflow.report('Adding container shift') - s_api.AddSampleLog(Workspace=output_workspace, - LogName='container_shift', - LogType='Number', - LogText=str(self._can_shift_factor)) + AddSampleLog(Workspace=output_workspace, + LogName='container_shift', + LogType='Number', + LogText=str(self._can_shift_factor)) # Record the type of corrections applied prog_wrkflow.report('Adding correction type') - s_api.AddSampleLog(Workspace=output_workspace, - LogName='corrections_type', - LogType='String', - LogText=correction_type) + AddSampleLog(Workspace=output_workspace, + LogName='corrections_type', + LogType='String', + LogText=correction_type) # Add original sample as log entry sam_base = self.getPropertyValue("SampleWorkspace") @@ -142,18 +143,21 @@ class ApplyPaalmanPingsCorrection(PythonAlgorithm): if '_' in sam_base: sam_base = sam_base[:sam_base.index('_')] prog_wrkflow.report('Adding sample filename') - s_api.AddSampleLog(Workspace=output_workspace, - LogName='sample_filename', - LogType='String', - LogText=sam_base) + AddSampleLog(Workspace=output_workspace, + LogName='sample_filename', + LogType='String', + LogText=sam_base) # Convert Units back to original emode = str(output_workspace.getEMode()) efixed = 0.0 if emode == "Indirect": efixed = self._get_e_fixed(output_workspace) - output_workspace = self._convert_units(output_workspace, sample_unit, emode, efixed) + if sample_unit != 'Label': + output_workspace = self._convert_units(output_workspace, sample_unit, emode, efixed) + if output_workspace.name(): + RenameWorkspace(InputWorkspace=output_workspace, OutputWorkspace=self.getPropertyValue('OutputWorkspace')) self.setProperty('OutputWorkspace', output_workspace) prog_wrkflow.report('Algorithm Complete') @@ -188,15 +192,13 @@ class ApplyPaalmanPingsCorrection(PythonAlgorithm): if corrections_issues: issues['CorrectionsWorkspace'] = "\n".join(corrections_issues) - sample_ws = self.getProperty("SampleWorkspace").value - if isinstance(sample_ws, MatrixWorkspace): - sample_unit_id = sample_ws.getAxis(0).getUnit().unitID() + if isinstance(self._sample_workspace, MatrixWorkspace): + sample_unit_id = self._sample_workspace.getAxis(0).getUnit().unitID() # Check sample and container X axis units match if self._use_can: - can_ws = self.getProperty("CanWorkspace").value - if isinstance(can_ws, MatrixWorkspace): - can_unit_id = can_ws.getAxis(0).getUnit().unitID() + if isinstance(self._container_workspace, MatrixWorkspace): + can_unit_id = self._container_workspace.getAxis(0).getUnit().unitID() if can_unit_id != sample_unit_id: issues['CanWorkspace'] = 'X axis unit must match SampleWorkspace' else: @@ -243,10 +245,10 @@ class ApplyPaalmanPingsCorrection(PythonAlgorithm): self._corrections_approximation = self._three_factor_corrections_approximation def _shift_workspace(self, workspace, shift_factor): - return s_api.ScaleX(InputWorkspace=workspace, - Factor=shift_factor, - OutputWorkspace="__shifted", - Operation="Add", StoreInADS=False) + return ScaleX(InputWorkspace=workspace, + Factor=shift_factor, + OutputWorkspace="__shifted", + Operation="Add", StoreInADS=False) def _convert_units_wavelength(self, workspace): unit = workspace.getAxis(0).getUnit().unitID() @@ -261,14 +263,17 @@ class ApplyPaalmanPingsCorrection(PythonAlgorithm): emode = 'Indirect' efixed = self._get_e_fixed(workspace) return self._convert_units(workspace, "Wavelength", emode, efixed) + else: + # for fixed window scans the unit might be empty (e.g. temperature) + return workspace else: return workspace def _convert_units(self, workspace, target, emode, efixed): - return s_api.ConvertUnits(InputWorkspace=workspace, - OutputWorkspace="__units_converted", - Target=target, EMode=emode, - EFixed=efixed, StoreInADS=False) + return ConvertUnits(InputWorkspace=workspace, + OutputWorkspace="__units_converted", + Target=target, EMode=emode, + EFixed=efixed, StoreInADS=False) def _get_e_fixed(self, workspace): from IndirectCommon import getEfixed @@ -328,10 +333,10 @@ class ApplyPaalmanPingsCorrection(PythonAlgorithm): if self._rebin_container_ws: logger.information('Rebining container to ensure Minus') - subtrahend_workspace = s_api.RebinToWorkspace(WorkspaceToRebin=subtrahend_workspace, - WorkspaceToMatch=minuend_workspace, - OutputWorkspace="__rebinned", - StoreInADS=False) + subtrahend_workspace = RebinToWorkspace(WorkspaceToRebin=subtrahend_workspace, + WorkspaceToMatch=minuend_workspace, + OutputWorkspace="__rebinned", + StoreInADS=False) return minuend_workspace - subtrahend_workspace def _clone(self, workspace): @@ -340,17 +345,18 @@ class ApplyPaalmanPingsCorrection(PythonAlgorithm): :param workspace: The workspace to clone. :return: A clone of the specified workspace. """ - return s_api.CloneWorkspace(InputWorkspace=workspace, - OutputWorkspace="cloned", - StoreInADS=False) + return CloneWorkspace(InputWorkspace=workspace, + OutputWorkspace="cloned", + StoreInADS=False) def _correct_sample(self, sample_workspace, a_ss_workspace): """ Correct for sample only (when no container is given). """ - logger.information('Correcting sample') - return sample_workspace / self._convert_units_wavelength(a_ss_workspace) + correction_in_lambda = self._convert_units_wavelength(a_ss_workspace) + corrected = Divide(LHSWorkspace=sample_workspace, RHSWorkspace=correction_in_lambda) + return corrected def _correct_sample_can(self, sample_workspace, container_workspace, factor_workspaces): """ @@ -363,10 +369,10 @@ class ApplyPaalmanPingsCorrection(PythonAlgorithm): in factor_workspaces.items()} if self._rebin_container_ws: - container_workspace = s_api.RebinToWorkspace(WorkspaceToRebin=container_workspace, - WorkspaceToMatch=factor_workspaces_wavelength['acc'], - OutputWorkspace="rebinned", - StoreInADS=False) + container_workspace = RebinToWorkspace(WorkspaceToRebin=container_workspace, + WorkspaceToMatch=factor_workspaces_wavelength['acc'], + OutputWorkspace="rebinned", + StoreInADS=False) return self._corrections_approximation(sample_workspace, container_workspace, factor_workspaces_wavelength) def _three_factor_corrections_approximation(self, sample_workspace, container_workspace, factor_workspaces): @@ -380,6 +386,5 @@ class ApplyPaalmanPingsCorrection(PythonAlgorithm): ass = factor_workspaces['ass'] return (sample_workspace / ass) - (container_workspace / acc) - # Register algorithm with Mantid AlgorithmFactory.subscribe(ApplyPaalmanPingsCorrection) diff --git a/Framework/PythonInterface/plugins/algorithms/WorkflowAlgorithms/MatchAndMergeWorkspaces.py b/Framework/PythonInterface/plugins/algorithms/WorkflowAlgorithms/MatchAndMergeWorkspaces.py new file mode 100644 index 0000000000000000000000000000000000000000..0bb57638599a96984abef2a44dbb70c2cbbf1e0a --- /dev/null +++ b/Framework/PythonInterface/plugins/algorithms/WorkflowAlgorithms/MatchAndMergeWorkspaces.py @@ -0,0 +1,152 @@ +# Mantid Repository : https://github.com/mantidproject/mantid +# +# Copyright © 2019 ISIS Rutherford Appleton Laboratory UKRI, +# NScD Oak Ridge National Laboratory, European Spallation Source +# & Institut Laue - Langevin +# SPDX - License - Identifier: GPL - 3.0 + +from __future__ import (absolute_import, division, print_function) + +from mantid.simpleapi import (AnalysisDataService, CloneWorkspace, ConjoinWorkspaces, CropWorkspaceRagged, + DeleteWorkspace, MatchSpectra, Rebin, SumSpectra) +from mantid.api import (AlgorithmFactory, DataProcessorAlgorithm, WorkspaceProperty, WorkspaceGroup, ADSValidator) +from mantid.dataobjects import Workspace2D +from mantid.kernel import (Direction, FloatArrayProperty, StringArrayProperty) +import numpy as np + + +class MatchAndMergeWorkspaces(DataProcessorAlgorithm): + + def name(self): + return 'MatchAndMergeWorkspaces' + + def category(self): + return 'Workflow\\Diffraction' + + def seeAlso(self): + return [] + + def summary(self): + return 'Merges a group workspace using weighting from a set of range limits for each workspace.' + + def checkGroups(self): + return False + + def version(self): + return 1 + + def validateInputs(self): + # given workspaces must exist + # and must be public of ExperimentInfo + issues = dict() + ws_list = self.getProperty('InputWorkspaces').value + spectra_count = 0 + for name_in_list in ws_list: + ws_in_list = AnalysisDataService.retrieve(name_in_list) + if isinstance(ws_in_list, Workspace2D): + spectra_count += ws_in_list.getNumberHistograms() + if isinstance(ws_in_list, WorkspaceGroup): + for ws_in_group in ws_in_list: + spectra_count += ws_in_group.getNumberHistograms() + + x_min = self.getProperty('XMin').value + if not x_min.size == spectra_count: + issues['XMin'] = 'XMin entries does not match size of workspace group' + + x_max = self.getProperty('XMax').value + if not x_max.size == spectra_count: + issues['XMax'] = 'XMax entries does not match size of workspace group' + + return issues + + def PyInit(self): + self.declareProperty(StringArrayProperty('InputWorkspaces', direction=Direction.Input, validator=ADSValidator()), + doc='List of workspaces or group workspace containing workspaces to be merged.') + self.declareProperty(WorkspaceProperty('OutputWorkspace', '', direction=Direction.Output), + doc='The merged workspace.') + self.declareProperty(FloatArrayProperty('XMin', [], direction=Direction.Input), + doc='Array of minimum X values for each workspace.') + self.declareProperty(FloatArrayProperty('XMax', [], direction=Direction.Input), + doc='Array of maximum X values for each workspace.') + self.declareProperty('CalculateScale', True, + doc='Calculate scale factor when matching spectra.') + self.declareProperty('CalculateOffset', True, + doc='Calculate vertical shift when matching spectra.') + + def PyExec(self): + ws_list = self.getProperty('InputWorkspaces').value + x_min = self.getProperty('XMin').value + x_max = self.getProperty('XMax').value + scale_bool = self.getProperty('CalculateScale').value + offset_bool = self.getProperty('CalculateOffset').value + flattened_list = self.unwrap_groups(ws_list) + largest_range_spectrum, rebin_param = self.get_common_bin_range_and_largest_spectra(flattened_list) + CloneWorkspace(InputWorkspace=flattened_list[0], OutputWorkspace='ws_conjoined') + Rebin(InputWorkspace='ws_conjoined', OutputWorkspace='ws_conjoined', Params=rebin_param) + for ws in flattened_list[1:]: + temp = CloneWorkspace(InputWorkspace=ws) + temp = Rebin(InputWorkspace=temp, Params=rebin_param) + ConjoinWorkspaces(InputWorkspace1='ws_conjoined', + InputWorkspace2=temp, + CheckOverlapping=False) + ws_conjoined = AnalysisDataService.retrieve('ws_conjoined') + ref_spec = ws_conjoined.getSpectrum(largest_range_spectrum).getSpectrumNo() + ws_conjoined, offset, scale, chisq = MatchSpectra(InputWorkspace=ws_conjoined, + ReferenceSpectrum=ref_spec, + CalculateScale=scale_bool, + CalculateOffset=offset_bool) + x_min, x_max, bin_width = self.fit_x_lims_to_match_histogram_bins(ws_conjoined, x_min, x_max) + + ws_conjoined = CropWorkspaceRagged(InputWorkspace=ws_conjoined, XMin=x_min, XMax=x_max) + ws_conjoined = Rebin(InputWorkspace=ws_conjoined, Params=[min(x_min), bin_width, max(x_max)]) + merged_ws = SumSpectra(InputWorkspace=ws_conjoined, WeightedSum=True, MultiplyBySpectra=False, StoreInADS=False) + DeleteWorkspace(ws_conjoined) + self.setProperty('OutputWorkspace', merged_ws) + + @staticmethod + def unwrap_groups(inputs): + output = [] + for name_in_list in inputs: + ws_in_list = AnalysisDataService.retrieve(name_in_list) + if isinstance(ws_in_list, Workspace2D): + output.append(ws_in_list) + if isinstance(ws_in_list, WorkspaceGroup): + for ws_in_group in ws_in_list: + output.append(ws_in_group) + return output + + @staticmethod + def get_common_bin_range_and_largest_spectra(ws_list): + x_min = np.inf + x_max = -np.inf + x_num = -np.inf + ws_max_range = 0 + largest_range_spectrum = 0 + for i in range(len(ws_list)): + for j in range(ws_list[i].getNumberHistograms()): + x_data = ws_list[i].dataX(j) + x_min = min(np.min(x_data), x_min) + x_max = max(np.max(x_data), x_max) + x_num = max(x_data.size, x_num) + ws_range = np.max(x_data) - np.min(x_data) + if ws_range > ws_max_range: + largest_range_spectrum = i + ws_max_range = ws_range + if x_min.any() == -np.inf or x_min.any() == np.inf: + raise AttributeError("Workspace x range contains an infinite value.") + return largest_range_spectrum, [x_min, (x_max - x_min) / x_num, x_max] + + @staticmethod + def fit_x_lims_to_match_histogram_bins(ws_conjoined, x_min, x_max): + bin_width = np.inf + for i in range(x_min.size): + pdf_x_array = ws_conjoined.readX(i) + x_min[i] = pdf_x_array[np.amin(np.where(pdf_x_array >= x_min[i]))] + x_max[i] = pdf_x_array[np.amax(np.where(pdf_x_array <= x_max[i]))] + bin_width = min(pdf_x_array[1] - pdf_x_array[0], bin_width) + if x_min.any() == -np.inf or x_min.any() == np.inf: + raise AttributeError("Limits contains an infinite value.") + return x_min, x_max, bin_width + + +# Register algorithm with Mantid +AlgorithmFactory.subscribe(MatchAndMergeWorkspaces) diff --git a/Framework/PythonInterface/plugins/algorithms/WorkflowAlgorithms/SANS/SANSBeamCentreFinder.py b/Framework/PythonInterface/plugins/algorithms/WorkflowAlgorithms/SANS/SANSBeamCentreFinder.py index 86bf54065a0907606827d6db398ee88c1bcdd630..892bf9a7181e97990e2c0be6ca14f6102db6168f 100644 --- a/Framework/PythonInterface/plugins/algorithms/WorkflowAlgorithms/SANS/SANSBeamCentreFinder.py +++ b/Framework/PythonInterface/plugins/algorithms/WorkflowAlgorithms/SANS/SANSBeamCentreFinder.py @@ -76,9 +76,9 @@ class SANSBeamCentreFinder(DataProcessorAlgorithm): doc='The can direct data') # The component, i.e. HAB or LAB - allowed_detectors = StringListValidator([DetectorType.to_string(DetectorType.LAB), - DetectorType.to_string(DetectorType.HAB)]) - self.declareProperty("Component", DetectorType.to_string(DetectorType.LAB), + allowed_detectors = StringListValidator([DetectorType.LAB.value, + DetectorType.HAB.value]) + self.declareProperty("Component", DetectorType.LAB.value, validator=allowed_detectors, direction=Direction.Input, doc="The component of the instrument which is to be reduced.") @@ -94,7 +94,7 @@ class SANSBeamCentreFinder(DataProcessorAlgorithm): self.declareProperty('Tolerance', 0.0001251, direction=Direction.Input, doc="The search tolerance") - self.declareProperty('Direction', FindDirectionEnum.to_string(FindDirectionEnum.All), direction=Direction.Input, + self.declareProperty('Direction', FindDirectionEnum.ALL.value, direction=Direction.Input, doc="The search direction is an enumerable which can be either All, LeftRight or UpDown") self.declareProperty('Verbose', False, direction=Direction.Input, @@ -153,9 +153,9 @@ class SANSBeamCentreFinder(DataProcessorAlgorithm): position_2_step = position_1_step find_direction = self.getProperty("Direction").value - if find_direction == FindDirectionEnum.to_string(FindDirectionEnum.Left_Right): + if find_direction == FindDirectionEnum.LEFT_RIGHT.value: position_2_step = 0.0 - elif find_direction == FindDirectionEnum.to_string(FindDirectionEnum.Up_Down): + elif find_direction == FindDirectionEnum.UP_DOWN.value: position_1_step = 0.0 centre1 = x_start centre2 = y_start @@ -188,10 +188,10 @@ class SANSBeamCentreFinder(DataProcessorAlgorithm): if verbose: self._rename_and_group_workspaces(j, output_workspaces) - residueLR.append(self._calculate_residuals(sample_quartiles[MaskingQuadrant.Left], - sample_quartiles[MaskingQuadrant.Right])) - residueTB.append(self._calculate_residuals(sample_quartiles[MaskingQuadrant.Top], - sample_quartiles[MaskingQuadrant.Bottom])) + residueLR.append(self._calculate_residuals(sample_quartiles[MaskingQuadrant.LEFT], + sample_quartiles[MaskingQuadrant.RIGHT])) + residueTB.append(self._calculate_residuals(sample_quartiles[MaskingQuadrant.TOP], + sample_quartiles[MaskingQuadrant.BOTTOM])) if j == 0: self.logger.notice("Itr {0}: ( {1:.3f}, {2:.3f} ) SX={3:.5f} SY={4:.5f}". format(j, self.scale_1 * centre1, @@ -307,12 +307,12 @@ class SANSBeamCentreFinder(DataProcessorAlgorithm): out_right = strip_end_nans(alg.getProperty("OutputWorkspaceRight").value, self) out_top = strip_end_nans(alg.getProperty("OutputWorkspaceTop").value, self) out_bottom = strip_end_nans(alg.getProperty("OutputWorkspaceBottom").value, self) - return {MaskingQuadrant.Left: out_left, MaskingQuadrant.Right: out_right, MaskingQuadrant.Top: out_top, - MaskingQuadrant.Bottom: out_bottom} + return {MaskingQuadrant.LEFT: out_left, MaskingQuadrant.RIGHT: out_right, MaskingQuadrant.TOP: out_top, + MaskingQuadrant.BOTTOM: out_bottom} def _get_component(self, workspace): component_as_string = self.getProperty("Component").value - component = DetectorType.from_string(component_as_string) + component = DetectorType(component_as_string) return get_component_name(workspace, component) def _get_state(self): diff --git a/Framework/PythonInterface/plugins/algorithms/WorkflowAlgorithms/SANS/SANSBeamCentreFinderCore.py b/Framework/PythonInterface/plugins/algorithms/WorkflowAlgorithms/SANS/SANSBeamCentreFinderCore.py index 65df93b6fce248771d9689ede5217da3ce1e1de5..bf1050837d382b416549377c3d4edcd5b6b5ec2e 100644 --- a/Framework/PythonInterface/plugins/algorithms/WorkflowAlgorithms/SANS/SANSBeamCentreFinderCore.py +++ b/Framework/PythonInterface/plugins/algorithms/WorkflowAlgorithms/SANS/SANSBeamCentreFinderCore.py @@ -9,22 +9,22 @@ """ Finds the beam centre.""" from __future__ import (absolute_import, division, print_function) + from mantid.api import (DataProcessorAlgorithm, MatrixWorkspaceProperty, AlgorithmFactory, PropertyMode, Progress, IEventWorkspace) from mantid.kernel import (Direction, PropertyManagerProperty, StringListValidator) - from sans.algorithm_detail.CreateSANSAdjustmentWorkspaces import CreateSANSAdjustmentWorkspaces from sans.algorithm_detail.convert_to_q import convert_workspace -from sans.algorithm_detail.scale_sans_workspace import scale_workspace from sans.algorithm_detail.crop_helper import get_component_name from sans.algorithm_detail.mask_sans_workspace import mask_workspace from sans.algorithm_detail.move_sans_instrument_component import move_component, MoveTypes +from sans.algorithm_detail.scale_sans_workspace import scale_workspace from sans.algorithm_detail.slice_sans_event import slice_sans_event +from sans.algorithm_detail.xml_shapes import quadrant_xml from sans.common.constants import EMPTY_NAME +from sans.common.enums import (DetectorType, DataType, MaskingQuadrant) from sans.common.general_functions import create_child_algorithm, append_to_sans_file_tag from sans.state.state_base import create_deserialized_sans_state_from_property_manager -from sans.common.enums import (DetectorType, DataType, MaskingQuadrant, RangeStepType, RebinType) -from sans.algorithm_detail.xml_shapes import quadrant_xml class SANSBeamCentreFinderCore(DataProcessorAlgorithm): @@ -65,16 +65,16 @@ class SANSBeamCentreFinderCore(DataProcessorAlgorithm): self.setPropertyGroup("TransmissionWorkspace", 'Data') self.setPropertyGroup("DirectWorkspace", 'Data') - allowed_detectors = StringListValidator([DetectorType.to_string(DetectorType.LAB), - DetectorType.to_string(DetectorType.HAB)]) - self.declareProperty("Component", DetectorType.to_string(DetectorType.LAB), + allowed_detectors = StringListValidator([DetectorType.LAB.value, + DetectorType.HAB.value]) + self.declareProperty("Component", DetectorType.LAB.value, validator=allowed_detectors, direction=Direction.Input, doc="The component of the instrument which is to be reduced.") # The data type - allowed_data = StringListValidator([DataType.to_string(DataType.Sample), - DataType.to_string(DataType.Can)]) - self.declareProperty("DataType", DataType.to_string(DataType.Sample), + allowed_data = StringListValidator([DataType.SAMPLE.value, + DataType.CAN.value]) + self.declareProperty("DataType", DataType.SAMPLE.value, validator=allowed_data, direction=Direction.Input, doc="The component of the instrument which is to be reduced.") @@ -123,9 +123,9 @@ class SANSBeamCentreFinderCore(DataProcessorAlgorithm): # state.compatibility.use_compatibility_mode = self.getProperty('CompatibilityMode').value # Set test centre - state.move.detectors[DetectorType.to_string(DetectorType.LAB)].sample_centre_pos1 = \ + state.move.detectors[DetectorType.LAB.value].sample_centre_pos1 = \ self.getProperty("Centre1").value - state.move.detectors[DetectorType.to_string(DetectorType.LAB)].sample_centre_pos2 = \ + state.move.detectors[DetectorType.LAB.value].sample_centre_pos2 = \ self.getProperty("Centre2").value component_as_string = self.getProperty("Component").value @@ -241,7 +241,7 @@ class SANSBeamCentreFinderCore(DataProcessorAlgorithm): # ------------------------------------------------------------ # 10. Split workspace into 4 quadrant workspaces # ------------------------------------------------------------ - quadrants = [MaskingQuadrant.Left, MaskingQuadrant.Right, MaskingQuadrant.Top, MaskingQuadrant.Bottom] + quadrants = [MaskingQuadrant.LEFT, MaskingQuadrant.RIGHT, MaskingQuadrant.TOP, MaskingQuadrant.BOTTOM] quadrant_scatter_reduced = {} centre = [0, 0, 0] r_min = self.getProperty("RMin").value @@ -263,10 +263,10 @@ class SANSBeamCentreFinderCore(DataProcessorAlgorithm): # # ------------------------------------------------------------ # # Populate the output # # ------------------------------------------------------------ - self.setProperty("OutputWorkspaceLeft", quadrant_scatter_reduced[MaskingQuadrant.Left]) - self.setProperty("OutputWorkspaceRight", quadrant_scatter_reduced[MaskingQuadrant.Right]) - self.setProperty("OutputWorkspaceTop", quadrant_scatter_reduced[MaskingQuadrant.Top]) - self.setProperty("OutputWorkspaceBottom", quadrant_scatter_reduced[MaskingQuadrant.Bottom]) + self.setProperty("OutputWorkspaceLeft", quadrant_scatter_reduced[MaskingQuadrant.LEFT]) + self.setProperty("OutputWorkspaceRight", quadrant_scatter_reduced[MaskingQuadrant.RIGHT]) + self.setProperty("OutputWorkspaceTop", quadrant_scatter_reduced[MaskingQuadrant.TOP]) + self.setProperty("OutputWorkspaceBottom", quadrant_scatter_reduced[MaskingQuadrant.BOTTOM]) def _mask_quadrants(self, workspace, shape): mask_name = "MaskDetectorsInShape" @@ -279,7 +279,7 @@ class SANSBeamCentreFinderCore(DataProcessorAlgorithm): scatter_workspace = self.getProperty("ScatterWorkspace").value alg_name = "CropToComponent" - component_to_crop = DetectorType.from_string(component) + component_to_crop = DetectorType(component) component_to_crop = get_component_name(scatter_workspace, component_to_crop) crop_options = {"InputWorkspace": scatter_workspace, @@ -323,9 +323,8 @@ class SANSBeamCentreFinderCore(DataProcessorAlgorithm): "WavelengthLow": wavelength_state.wavelength_low[0], "WavelengthHigh": wavelength_state.wavelength_high[0], "WavelengthStep": wavelength_state.wavelength_step, - "WavelengthStepType": RangeStepType.to_string( - wavelength_state.wavelength_step_type), - "RebinMode": RebinType.to_string(wavelength_state.rebin_type)} + "WavelengthStepType": wavelength_state.wavelength_step_type.value, + "RebinMode": wavelength_state.rebin_type.value} wavelength_alg = create_child_algorithm(self, wavelength_name, **wavelength_options) wavelength_alg.execute() diff --git a/Framework/PythonInterface/plugins/algorithms/WorkflowAlgorithms/SANS/SANSBeamCentreFinderMassMethod.py b/Framework/PythonInterface/plugins/algorithms/WorkflowAlgorithms/SANS/SANSBeamCentreFinderMassMethod.py index b4493e25dec972d38344e9453918129fc8937353..d07bdb3797e7e44afe149917849a0e9bf5c5d8da 100644 --- a/Framework/PythonInterface/plugins/algorithms/WorkflowAlgorithms/SANS/SANSBeamCentreFinderMassMethod.py +++ b/Framework/PythonInterface/plugins/algorithms/WorkflowAlgorithms/SANS/SANSBeamCentreFinderMassMethod.py @@ -9,18 +9,19 @@ """ Finds the beam centre for SANS""" from __future__ import (absolute_import, division, print_function) + from mantid.api import (DataProcessorAlgorithm, MatrixWorkspaceProperty, AlgorithmFactory, PropertyMode, Progress, IEventWorkspace) from mantid.kernel import (Direction, PropertyManagerProperty, StringListValidator) -from sans.algorithm_detail.scale_sans_workspace import scale_workspace from sans.algorithm_detail.crop_helper import get_component_name from sans.algorithm_detail.mask_sans_workspace import mask_workspace from sans.algorithm_detail.move_sans_instrument_component import move_component, MoveTypes +from sans.algorithm_detail.scale_sans_workspace import scale_workspace from sans.algorithm_detail.slice_sans_event import slice_sans_event from sans.common.constants import EMPTY_NAME +from sans.common.enums import (DetectorType) from sans.common.general_functions import create_child_algorithm, append_to_sans_file_tag from sans.state.state_base import create_deserialized_sans_state_from_property_manager -from sans.common.enums import (DetectorType, RangeStepType, RebinType) class SANSBeamCentreFinderMassMethod(DataProcessorAlgorithm): @@ -55,9 +56,9 @@ class SANSBeamCentreFinderMassMethod(DataProcessorAlgorithm): self.declareProperty('Iterations', 10, direction=Direction.Input) - allowed_detectors = StringListValidator([DetectorType.to_string(DetectorType.LAB), - DetectorType.to_string(DetectorType.HAB)]) - self.declareProperty("Component", DetectorType.to_string(DetectorType.LAB), + allowed_detectors = StringListValidator([DetectorType.LAB.value, + DetectorType.HAB.value]) + self.declareProperty("Component", DetectorType.LAB.value, validator=allowed_detectors, direction=Direction.Input, doc="The component of the instrument which is to be reduced.") @@ -193,7 +194,7 @@ class SANSBeamCentreFinderMassMethod(DataProcessorAlgorithm): scatter_workspace = self.getProperty("SampleScatterWorkspace").value alg_name = "CropToComponent" - component_to_crop = DetectorType.from_string(component) + component_to_crop = DetectorType(component) component_to_crop = get_component_name(scatter_workspace, component_to_crop) crop_options = {"InputWorkspace": scatter_workspace, @@ -238,9 +239,8 @@ class SANSBeamCentreFinderMassMethod(DataProcessorAlgorithm): "WavelengthLow": wavelength_state.wavelength_low[0], "WavelengthHigh": wavelength_state.wavelength_high[0], "WavelengthStep": wavelength_state.wavelength_step, - "WavelengthStepType": RangeStepType.to_string( - wavelength_state.wavelength_step_type), - "RebinMode": RebinType.to_string(wavelength_state.rebin_type)} + "WavelengthStepType": wavelength_state.wavelength_step_type.value, + "RebinMode": wavelength_state.rebin_type.value} wavelength_alg = create_child_algorithm(self, wavelength_name, **wavelength_options) wavelength_alg.execute() diff --git a/Framework/PythonInterface/plugins/algorithms/WorkflowAlgorithms/SANS/SANSConvertToWavelengthAndRebin.py b/Framework/PythonInterface/plugins/algorithms/WorkflowAlgorithms/SANS/SANSConvertToWavelengthAndRebin.py index 964e631119e0644ba98f663b73e7502ce9700638..dd2a52a771cc241e5d080373691448f0e8c4e327 100644 --- a/Framework/PythonInterface/plugins/algorithms/WorkflowAlgorithms/SANS/SANSConvertToWavelengthAndRebin.py +++ b/Framework/PythonInterface/plugins/algorithms/WorkflowAlgorithms/SANS/SANSConvertToWavelengthAndRebin.py @@ -9,13 +9,15 @@ """ SANSConvertToWavelengthAndRebin algorithm converts to wavelength units and performs a rebin.""" from __future__ import (absolute_import, division, print_function) -from mantid.kernel import (Direction, StringListValidator, Property) + +from mantid.api import (DistributedDataProcessorAlgorithm, MatrixWorkspaceProperty, AlgorithmFactory, PropertyMode, + Progress) from mantid.dataobjects import EventWorkspace -from mantid.api import (DistributedDataProcessorAlgorithm, MatrixWorkspaceProperty, AlgorithmFactory, PropertyMode, Progress) +from mantid.kernel import (Direction, StringListValidator, Property) from sans.common.constants import EMPTY_NAME +from sans.common.enums import (RebinType, RangeStepType) from sans.common.general_functions import (create_unmanaged_algorithm, append_to_sans_file_tag, get_input_workspace_as_copy_if_not_same_as_output_workspace) -from sans.common.enums import (RebinType, RangeStepType) class SANSConvertToWavelengthAndRebin(DistributedDataProcessorAlgorithm): @@ -39,16 +41,16 @@ class SANSConvertToWavelengthAndRebin(DistributedDataProcessorAlgorithm): doc='The step size of the wavelength binning.') # Step type - allowed_step_types = StringListValidator([RangeStepType.to_string(RangeStepType.Log), - RangeStepType.to_string(RangeStepType.Lin)]) - self.declareProperty('WavelengthStepType', RangeStepType.to_string(RangeStepType.Lin), + allowed_step_types = StringListValidator([RangeStepType.LOG.value, + RangeStepType.LIN.value]) + self.declareProperty('WavelengthStepType', RangeStepType.LIN.value, validator=allowed_step_types, direction=Direction.Input, doc='The step type for rebinning.') # Rebin type - allowed_rebin_methods = StringListValidator([RebinType.to_string(RebinType.Rebin), - RebinType.to_string(RebinType.InterpolatingRebin)]) - self.declareProperty("RebinMode", RebinType.to_string(RebinType.Rebin), + allowed_rebin_methods = StringListValidator([RebinType.REBIN.value, + RebinType.INTERPOLATING_REBIN.value]) + self.declareProperty("RebinMode", RebinType.REBIN.value, validator=allowed_rebin_methods, direction=Direction.Input, doc="The method which is to be applied to the rebinning.") @@ -66,9 +68,9 @@ class SANSConvertToWavelengthAndRebin(DistributedDataProcessorAlgorithm): workspace = self._convert_units_to_wavelength(workspace) # Get the rebin option - rebin_type = RebinType.from_string(self.getProperty("RebinMode").value) + rebin_type = RebinType(self.getProperty("RebinMode").value) rebin_string = self._get_rebin_string(workspace) - if rebin_type is RebinType.Rebin: + if rebin_type is RebinType.REBIN: rebin_options = {"InputWorkspace": workspace, "PreserveEvents": True, "Params": rebin_string} @@ -105,8 +107,8 @@ class SANSConvertToWavelengthAndRebin(DistributedDataProcessorAlgorithm): # Check the workspace workspace = self.getProperty("InputWorkspace").value - rebin_type = RebinType.from_string(self.getProperty("RebinMode").value) - if rebin_type is RebinType.InterpolatingRebin and isinstance(workspace, EventWorkspace): + rebin_type = RebinType(self.getProperty("RebinMode").value) + if rebin_type is RebinType.INTERPOLATING_REBIN and isinstance(workspace, EventWorkspace): errors.update({"RebinMode": "An interpolating rebin cannot be applied to an EventWorkspace."}) return errors @@ -134,13 +136,13 @@ class SANSConvertToWavelengthAndRebin(DistributedDataProcessorAlgorithm): wavelength_high = max(workspace.readX(0)) wavelength_step = self.getProperty("WavelengthStep").value - step_type = RangeStepType.from_string(self.getProperty("WavelengthStepType").value) - pre_factor = -1 if step_type == RangeStepType.Log else 1 + step_type = RangeStepType(self.getProperty("WavelengthStepType").value) + pre_factor = -1 if step_type == RangeStepType.LOG else 1 wavelength_step *= pre_factor return str(wavelength_low) + "," + str(wavelength_step) + "," + str(wavelength_high) def _perform_rebin(self, rebin_type, rebin_options, workspace): - rebin_name = "Rebin" if rebin_type is RebinType.Rebin else "InterpolatingRebin" + rebin_name = "Rebin" if rebin_type is RebinType.REBIN else "InterpolatingRebin" rebin_alg = create_unmanaged_algorithm(rebin_name, **rebin_options) rebin_alg.setPropertyValue("OutputWorkspace", EMPTY_NAME) rebin_alg.setProperty("OutputWorkspace", workspace) diff --git a/Framework/PythonInterface/plugins/algorithms/WorkflowAlgorithms/SANS/SANSCreateAdjustmentWorkspaces.py b/Framework/PythonInterface/plugins/algorithms/WorkflowAlgorithms/SANS/SANSCreateAdjustmentWorkspaces.py index 4bfd2b90871c26fa54e7b3a8ceae0eef8f7efc6c..c9351a18a22367e54891295af5480a271b363278 100644 --- a/Framework/PythonInterface/plugins/algorithms/WorkflowAlgorithms/SANS/SANSCreateAdjustmentWorkspaces.py +++ b/Framework/PythonInterface/plugins/algorithms/WorkflowAlgorithms/SANS/SANSCreateAdjustmentWorkspaces.py @@ -11,11 +11,12 @@ """ from __future__ import (absolute_import, division, print_function) -from mantid.kernel import (Direction, PropertyManagerProperty, StringListValidator, CompositeValidator) + from mantid.api import (DistributedDataProcessorAlgorithm, MatrixWorkspaceProperty, AlgorithmFactory, PropertyMode, WorkspaceUnitValidator) -from sans.algorithm_detail.normalize_to_sans_monitor import normalize_to_monitor +from mantid.kernel import (Direction, PropertyManagerProperty, StringListValidator, CompositeValidator) from sans.algorithm_detail.calculate_sans_transmission import calculate_transmission +from sans.algorithm_detail.normalize_to_sans_monitor import normalize_to_monitor from sans.common.constants import EMPTY_NAME from sans.common.enums import (DataType, DetectorType) from sans.common.general_functions import create_unmanaged_algorithm @@ -58,16 +59,16 @@ class SANSCreateAdjustmentWorkspaces(DistributedDataProcessorAlgorithm): 'This used to verify the solid angle. The workspace is not modified, just inspected.') # The component - allowed_detector_types = StringListValidator([DetectorType.to_string(DetectorType.HAB), - DetectorType.to_string(DetectorType.LAB)]) - self.declareProperty("Component", DetectorType.to_string(DetectorType.LAB), + allowed_detector_types = StringListValidator([DetectorType.HAB.value, + DetectorType.LAB.value]) + self.declareProperty("Component", DetectorType.LAB.value, validator=allowed_detector_types, direction=Direction.Input, doc="The component of the instrument which is currently being investigated.") # The data type - allowed_data = StringListValidator([DataType.to_string(DataType.Sample), - DataType.to_string(DataType.Can)]) - self.declareProperty("DataType", DataType.to_string(DataType.Sample), + allowed_data = StringListValidator([DataType.SAMPLE.value, + DataType.CAN.value]) + self.declareProperty("DataType", DataType.SAMPLE.value, validator=allowed_data, direction=Direction.Input, doc="The component of the instrument which is to be reduced.") diff --git a/Framework/PythonInterface/plugins/algorithms/WorkflowAlgorithms/SANS/SANSLoad.py b/Framework/PythonInterface/plugins/algorithms/WorkflowAlgorithms/SANS/SANSLoad.py index 36f723d38ebe12cccc5ab6189490b8a23c8cfe70..bdbe213a5b89bafd55d84c54b4f0b7e8306c7937 100644 --- a/Framework/PythonInterface/plugins/algorithms/WorkflowAlgorithms/SANS/SANSLoad.py +++ b/Framework/PythonInterface/plugins/algorithms/WorkflowAlgorithms/SANS/SANSLoad.py @@ -9,15 +9,15 @@ """ SANSLoad algorithm which handles loading SANS files""" from __future__ import (absolute_import, division, print_function) -from mantid.kernel import (Direction, PropertyManagerProperty, FloatArrayProperty) + from mantid.api import (ParallelDataProcessorAlgorithm, MatrixWorkspaceProperty, AlgorithmFactory, PropertyMode, Progress, WorkspaceProperty) +from mantid.kernel import (Direction, PropertyManagerProperty, FloatArrayProperty) +from sans.algorithm_detail.load_data import SANSLoadDataFactory from sans.algorithm_detail.move_sans_instrument_component import move_component, MoveTypes - -from sans.state.state_base import create_deserialized_sans_state_from_property_manager from sans.common.enums import SANSDataType -from sans.algorithm_detail.load_data import SANSLoadDataFactory +from sans.state.state_base import create_deserialized_sans_state_from_property_manager class SANSLoad(ParallelDataProcessorAlgorithm): @@ -263,25 +263,25 @@ class SANSLoad(ParallelDataProcessorAlgorithm): return errors def set_output_for_workspaces(self, workspace_type, workspaces): - if workspace_type is SANSDataType.SampleScatter: + if workspace_type is SANSDataType.SAMPLE_SCATTER: self.set_property_with_number_of_workspaces("SampleScatterWorkspace", workspaces) - elif workspace_type is SANSDataType.SampleTransmission: + elif workspace_type is SANSDataType.SAMPLE_TRANSMISSION: self.set_property_with_number_of_workspaces("SampleTransmissionWorkspace", workspaces) - elif workspace_type is SANSDataType.SampleDirect: + elif workspace_type is SANSDataType.SAMPLE_DIRECT: self.set_property_with_number_of_workspaces("SampleDirectWorkspace", workspaces) - elif workspace_type is SANSDataType.CanScatter: + elif workspace_type is SANSDataType.CAN_SCATTER: self.set_property_with_number_of_workspaces("CanScatterWorkspace", workspaces) - elif workspace_type is SANSDataType.CanTransmission: + elif workspace_type is SANSDataType.CAN_TRANSMISSION: self.set_property_with_number_of_workspaces("CanTransmissionWorkspace", workspaces) - elif workspace_type is SANSDataType.CanDirect: + elif workspace_type is SANSDataType.CAN_DIRECT: self.set_property_with_number_of_workspaces("CanDirectWorkspace", workspaces) else: raise RuntimeError("SANSLoad: Unknown data output workspace format: {0}".format(str(workspace_type))) def set_output_for_monitor_workspaces(self, workspace_type, workspaces): - if workspace_type is SANSDataType.SampleScatter: + if workspace_type is SANSDataType.SAMPLE_SCATTER: self.set_property("SampleScatterMonitorWorkspace", workspaces) - elif workspace_type is SANSDataType.CanScatter: + elif workspace_type is SANSDataType.CAN_SCATTER: self.set_property("CanScatterMonitorWorkspace", workspaces) else: raise RuntimeError("SANSLoad: Unknown data output workspace format: {0}".format(str(workspace_type))) @@ -328,8 +328,8 @@ class SANSLoad(ParallelDataProcessorAlgorithm): # The workspaces are stored in a dict: workspace_names (sample_scatter, etc) : ListOfWorkspaces for key, workspace_list in workspaces.items(): - is_trans = SANSDataType.to_string(key) in \ - ("SampleTransmission", "CanTransmission", "CanDirect", "SampleDirect") + is_trans = key in [SANSDataType.CAN_DIRECT, SANSDataType.CAN_TRANSMISSION, + SANSDataType.SAMPLE_TRANSMISSION, SANSDataType.SAMPLE_DIRECT] for workspace in workspace_list: move_component(component_name="", move_info=state.move, diff --git a/Framework/PythonInterface/plugins/algorithms/WorkflowAlgorithms/SANS/SANSReductionCore.py b/Framework/PythonInterface/plugins/algorithms/WorkflowAlgorithms/SANS/SANSReductionCore.py index d5a8f328c56d25a17dcbc862ed0e81ec58a21aa3..183727bfb5622e438a352fb768688f0898d813d9 100644 --- a/Framework/PythonInterface/plugins/algorithms/WorkflowAlgorithms/SANS/SANSReductionCore.py +++ b/Framework/PythonInterface/plugins/algorithms/WorkflowAlgorithms/SANS/SANSReductionCore.py @@ -9,13 +9,13 @@ """ SANSReductionCore algorithm runs the sequence of reduction steps which are necessary to reduce a data set.""" from __future__ import (absolute_import, division, print_function) -from mantid.api import AlgorithmFactory, Progress +from SANSReductionCoreBase import SANSReductionCoreBase + +from mantid.api import AlgorithmFactory, Progress from sans.algorithm_detail.mask_workspace import mask_bins from sans.common.enums import DetectorType -from SANSReductionCoreBase import SANSReductionCoreBase - class SANSReductionCore(SANSReductionCoreBase): def category(self): @@ -96,7 +96,7 @@ class SANSReductionCore(SANSReductionCoreBase): # Convert and rebin the dummy workspace to get correct bin flags if use_dummy_workspace: dummy_mask_workspace = mask_bins(state.mask, dummy_mask_workspace, - DetectorType.from_string(component_as_string)) + DetectorType(component_as_string)) dummy_mask_workspace = self._convert_to_wavelength(state=state, workspace=dummy_mask_workspace) # -------------------------------------------------------------------------------------------------------------- diff --git a/Framework/PythonInterface/plugins/algorithms/WorkflowAlgorithms/SANS/SANSReductionCoreBase.py b/Framework/PythonInterface/plugins/algorithms/WorkflowAlgorithms/SANS/SANSReductionCoreBase.py index 4b9b5aea05b4c1d9f4c020f39776b64ce2dc34aa..8c7f43c7f3eb926238b1ba8d97bf310e6c7fbd5f 100644 --- a/Framework/PythonInterface/plugins/algorithms/WorkflowAlgorithms/SANS/SANSReductionCoreBase.py +++ b/Framework/PythonInterface/plugins/algorithms/WorkflowAlgorithms/SANS/SANSReductionCoreBase.py @@ -10,20 +10,19 @@ from __future__ import (absolute_import, division, print_function) -from mantid.kernel import (Direction, PropertyManagerProperty, StringListValidator) from mantid.api import (DistributedDataProcessorAlgorithm, MatrixWorkspaceProperty, PropertyMode, IEventWorkspace) - +from mantid.kernel import (Direction, PropertyManagerProperty, StringListValidator) from sans.algorithm_detail.CreateSANSAdjustmentWorkspaces import CreateSANSAdjustmentWorkspaces from sans.algorithm_detail.convert_to_q import convert_workspace -from sans.algorithm_detail.scale_sans_workspace import scale_workspace from sans.algorithm_detail.crop_helper import get_component_name from sans.algorithm_detail.mask_sans_workspace import mask_workspace from sans.algorithm_detail.move_sans_instrument_component import move_component, MoveTypes +from sans.algorithm_detail.scale_sans_workspace import scale_workspace from sans.algorithm_detail.slice_sans_event import slice_sans_event -from sans.state.state_base import create_deserialized_sans_state_from_property_manager from sans.common.constants import EMPTY_NAME +from sans.common.enums import (DetectorType, DataType) from sans.common.general_functions import (create_child_algorithm, append_to_sans_file_tag) -from sans.common.enums import (DetectorType, DataType, RangeStepType, RebinType) +from sans.state.state_base import create_deserialized_sans_state_from_property_manager class SANSReductionCoreBase(DistributedDataProcessorAlgorithm): @@ -59,16 +58,16 @@ class SANSReductionCoreBase(DistributedDataProcessorAlgorithm): self.setPropertyGroup("DirectWorkspace", 'Data') # The component - allowed_detectors = StringListValidator([DetectorType.to_string(DetectorType.LAB), - DetectorType.to_string(DetectorType.HAB)]) - self.declareProperty("Component", DetectorType.to_string(DetectorType.LAB), + allowed_detectors = StringListValidator([DetectorType.LAB.value, + DetectorType.HAB.value]) + self.declareProperty("Component", DetectorType.LAB.value, validator=allowed_detectors, direction=Direction.Input, doc="The component of the instrument which is to be reduced.") # The data type - allowed_data = StringListValidator([DataType.to_string(DataType.Sample), - DataType.to_string(DataType.Can)]) - self.declareProperty("DataType", DataType.to_string(DataType.Sample), + allowed_data = StringListValidator([DataType.SAMPLE.value, + DataType.CAN.value]) + self.declareProperty("DataType", DataType.SAMPLE.value, validator=allowed_data, direction=Direction.Input, doc="The component of the instrument which is to be reduced.") @@ -101,7 +100,7 @@ class SANSReductionCoreBase(DistributedDataProcessorAlgorithm): scatter_workspace = self.getProperty("ScatterWorkspace").value alg_name = "CropToComponent" - component_to_crop = DetectorType.from_string(component) + component_to_crop = DetectorType(component) component_to_crop = get_component_name(scatter_workspace, component_to_crop) crop_options = {"InputWorkspace": scatter_workspace, @@ -148,9 +147,8 @@ class SANSReductionCoreBase(DistributedDataProcessorAlgorithm): "WavelengthLow": wavelength_state.wavelength_low[0], "WavelengthHigh": wavelength_state.wavelength_high[0], "WavelengthStep": wavelength_state.wavelength_step, - "WavelengthStepType": RangeStepType.to_string( - wavelength_state.wavelength_step_type), - "RebinMode": RebinType.to_string(wavelength_state.rebin_type)} + "WavelengthStepType": wavelength_state.wavelength_step_type.value, + "RebinMode": wavelength_state.rebin_type.value} wavelength_alg = create_child_algorithm(self, wavelength_name, **wavelength_options) wavelength_alg.execute() diff --git a/Framework/PythonInterface/plugins/algorithms/WorkflowAlgorithms/SANS/SANSReductionCoreEventSlice.py b/Framework/PythonInterface/plugins/algorithms/WorkflowAlgorithms/SANS/SANSReductionCoreEventSlice.py index c900da9b6050b2d46a2072a93e200ce5bc6f3019..77f101549b7fab0d3cf183f697b9c58b206dc925 100644 --- a/Framework/PythonInterface/plugins/algorithms/WorkflowAlgorithms/SANS/SANSReductionCoreEventSlice.py +++ b/Framework/PythonInterface/plugins/algorithms/WorkflowAlgorithms/SANS/SANSReductionCoreEventSlice.py @@ -11,13 +11,13 @@ for which data must be event sliced. These steps are: slicing, adjustment, conve from __future__ import (absolute_import, division, print_function) +from SANSReductionCoreBase import SANSReductionCoreBase + from mantid.api import (MatrixWorkspaceProperty, AlgorithmFactory, PropertyMode, Progress) from mantid.kernel import (Direction, PropertyManagerProperty, StringListValidator) from sans.common.enums import (DetectorType, DataType) -from SANSReductionCoreBase import SANSReductionCoreBase - class SANSReductionCoreEventSlice(SANSReductionCoreBase): def category(self): @@ -62,16 +62,16 @@ class SANSReductionCoreEventSlice(SANSReductionCoreBase): self.setPropertyGroup("TransmissionWorkspace", 'Data') # The component - allowed_detectors = StringListValidator([DetectorType.to_string(DetectorType.LAB), - DetectorType.to_string(DetectorType.HAB)]) - self.declareProperty("Component", DetectorType.to_string(DetectorType.LAB), + allowed_detectors = StringListValidator([DetectorType.LAB.value, + DetectorType.HAB.value]) + self.declareProperty("Component", DetectorType.LAB.value, validator=allowed_detectors, direction=Direction.Input, doc="The component of the instrument which is to be reduced.") # The data type - allowed_data = StringListValidator([DataType.to_string(DataType.Sample), - DataType.to_string(DataType.Can)]) - self.declareProperty("DataType", DataType.to_string(DataType.Sample), + allowed_data = StringListValidator([DataType.SAMPLE.value, + DataType.CAN.value]) + self.declareProperty("DataType", DataType.SAMPLE.value, validator=allowed_data, direction=Direction.Input, doc="The component of the instrument which is to be reduced.") diff --git a/Framework/PythonInterface/plugins/algorithms/WorkflowAlgorithms/SANS/SANSReductionCorePreprocess.py b/Framework/PythonInterface/plugins/algorithms/WorkflowAlgorithms/SANS/SANSReductionCorePreprocess.py index 2f805bb153f62b77089484dc4ca6ecf7872a9e2d..4bee74ea586acf296e5e872297678ce2f42fe3cc 100644 --- a/Framework/PythonInterface/plugins/algorithms/WorkflowAlgorithms/SANS/SANSReductionCorePreprocess.py +++ b/Framework/PythonInterface/plugins/algorithms/WorkflowAlgorithms/SANS/SANSReductionCorePreprocess.py @@ -11,13 +11,13 @@ which can be performed before event slicing.""" from __future__ import (absolute_import, division, print_function) +from SANSReductionCoreBase import SANSReductionCoreBase + from mantid.api import MatrixWorkspaceProperty, AlgorithmFactory, Progress from mantid.kernel import Direction from sans.algorithm_detail.mask_workspace import mask_bins from sans.common.enums import DetectorType -from SANSReductionCoreBase import SANSReductionCoreBase - class SANSReductionCorePreprocess(SANSReductionCoreBase): def category(self): @@ -98,7 +98,7 @@ class SANSReductionCorePreprocess(SANSReductionCoreBase): # Convert and rebin the dummy workspace to get correct bin flags if use_dummy_workspace: dummy_mask_workspace = mask_bins(state.mask, dummy_mask_workspace, - DetectorType.from_string(component_as_string)) + DetectorType(component_as_string)) dummy_mask_workspace = self._convert_to_wavelength(state=state, workspace=dummy_mask_workspace) # -------------------------------------------------------------------------------------------------------------- diff --git a/Framework/PythonInterface/plugins/algorithms/WorkflowAlgorithms/SANS/SANSSave.py b/Framework/PythonInterface/plugins/algorithms/WorkflowAlgorithms/SANS/SANSSave.py index 74333195a723e9f547c04adc10ab8f85b700f19f..8ad85228dd871f91f383b8b5f4f7a22fd6149b14 100644 --- a/Framework/PythonInterface/plugins/algorithms/WorkflowAlgorithms/SANS/SANSSave.py +++ b/Framework/PythonInterface/plugins/algorithms/WorkflowAlgorithms/SANS/SANSSave.py @@ -11,11 +11,12 @@ """ SANSSave algorithm performs saving of SANS reduction data.""" from __future__ import (absolute_import, division, print_function) -from mantid.kernel import (Direction) + from mantid.api import (DataProcessorAlgorithm, MatrixWorkspaceProperty, AlgorithmFactory, PropertyMode, FileProperty, FileAction, Progress) -from sans.common.enums import (SaveType) +from mantid.kernel import (Direction) from sans.algorithm_detail.save_workspace import (save_to_file, get_zero_error_free_workspace, file_format_with_append) +from sans.common.enums import (SaveType) class SANSSave(DataProcessorAlgorithm): @@ -115,7 +116,7 @@ class SANSSave(DataProcessorAlgorithm): progress = Progress(self, start=0.0, end=1.0, nreports=len(file_formats) + 1) for file_format in file_formats: - progress_message = "Saving to {0}.".format(SaveType.to_string(file_format.file_format)) + progress_message = "Saving to {0}.".format(file_format.file_format.value) progress.report(progress_message) progress.report(progress_message) save_to_file(workspace, file_format, file_name, transmission_workspaces, additional_run_numbers) @@ -151,10 +152,10 @@ class SANSSave(DataProcessorAlgorithm): def _get_file_formats(self): file_types = [] - self._check_file_types(file_types, "Nexus", SaveType.Nexus) - self._check_file_types(file_types, "CanSAS", SaveType.CanSAS) - self._check_file_types(file_types, "NXcanSAS", SaveType.NXcanSAS) - self._check_file_types(file_types, "NistQxy", SaveType.NistQxy) + self._check_file_types(file_types, "Nexus", SaveType.NEXUS) + self._check_file_types(file_types, "CanSAS", SaveType.CAN_SAS) + self._check_file_types(file_types, "NXcanSAS", SaveType.NX_CAN_SAS) + self._check_file_types(file_types, "NistQxy", SaveType.NIST_QXY) self._check_file_types(file_types, "RKH", SaveType.RKH) self._check_file_types(file_types, "CSV", SaveType.CSV) @@ -163,20 +164,20 @@ class SANSSave(DataProcessorAlgorithm): file_formats = [] # SaveNexusProcessed clashes with SaveNXcanSAS, but we only alter NXcanSAS data - self.add_file_format_with_appended_name_requirement(file_formats, SaveType.Nexus, file_types, []) + self.add_file_format_with_appended_name_requirement(file_formats, SaveType.NEXUS, file_types, []) # SaveNXcanSAS clashes with SaveNexusProcessed - self.add_file_format_with_appended_name_requirement(file_formats, SaveType.NXcanSAS, file_types, + self.add_file_format_with_appended_name_requirement(file_formats, SaveType.NX_CAN_SAS, file_types, []) # SaveNISTDAT clashes with SaveRKH, both can save to .dat - self.add_file_format_with_appended_name_requirement(file_formats, SaveType.NistQxy, file_types, [SaveType.RKH]) + self.add_file_format_with_appended_name_requirement(file_formats, SaveType.NIST_QXY, file_types, [SaveType.RKH]) # SaveRKH clashes with SaveNISTDAT, but we alter SaveNISTDAT self.add_file_format_with_appended_name_requirement(file_formats, SaveType.RKH, file_types, []) # SaveCanSAS1D does not clash with anyone - self.add_file_format_with_appended_name_requirement(file_formats, SaveType.CanSAS, file_types, []) + self.add_file_format_with_appended_name_requirement(file_formats, SaveType.CAN_SAS, file_types, []) # SaveCSV does not clash with anyone self.add_file_format_with_appended_name_requirement(file_formats, SaveType.CSV, file_types, []) diff --git a/Framework/PythonInterface/plugins/algorithms/WorkflowAlgorithms/SANS/SANSSingleReduction.py b/Framework/PythonInterface/plugins/algorithms/WorkflowAlgorithms/SANS/SANSSingleReduction.py index 2a8a680f654f791af065023ce27e0c265a2cbef3..f1c94d6c6dc9cae4f463cf4b907920f46ba0acee 100644 --- a/Framework/PythonInterface/plugins/algorithms/WorkflowAlgorithms/SANS/SANSSingleReduction.py +++ b/Framework/PythonInterface/plugins/algorithms/WorkflowAlgorithms/SANS/SANSSingleReduction.py @@ -10,15 +10,15 @@ from __future__ import (absolute_import, division, print_function) +from SANSSingleReductionBase import SANSSingleReductionBase + from mantid.api import (MatrixWorkspaceProperty, AlgorithmFactory, PropertyMode) from mantid.kernel import (Direction, Property) from mantid.simpleapi import CloneWorkspace from sans.algorithm_detail.single_execution import (run_core_reduction, run_optimized_for_can) -from sans.common.enums import (ReductionMode, DataType, ISISReductionMode, FitType) +from sans.common.enums import (DataType, ReductionMode, FitType) from sans.common.general_functions import does_can_workspace_exist_on_ads -from SANSSingleReductionBase import SANSSingleReductionBase - class SANSSingleReduction(SANSSingleReductionBase): def category(self): @@ -142,7 +142,7 @@ class SANSSingleReduction(SANSSingleReductionBase): progress.report("Running a single reduction ...") # We want to make use of optimizations here. If a can workspace has already been reduced with the same can # settings and is stored in the ADS, then we should use it (provided the user has optimizations enabled). - if use_optimizations and reduction_setting_bundle.data_type is DataType.Can: + if use_optimizations and reduction_setting_bundle.data_type is DataType.CAN: output_bundle, output_parts_bundle, \ output_transmission_bundle = run_optimized_for_can(reduction_alg, reduction_setting_bundle) else: @@ -173,11 +173,11 @@ class SANSSingleReduction(SANSSingleReductionBase): # In an MPI reduction output_workspace is produced on the master rank, skip others. if output_workspace is None: continue - if reduction_mode is ReductionMode.Merged: + if reduction_mode is ReductionMode.MERGED: self.setProperty("OutputWorkspaceMerged", output_workspace) - elif reduction_mode is ISISReductionMode.LAB: + elif reduction_mode is ReductionMode.LAB: self.setProperty("OutputWorkspaceLAB", output_workspace) - elif reduction_mode is ISISReductionMode.HAB: + elif reduction_mode is ReductionMode.HAB: self.setProperty("OutputWorkspaceHAB", output_workspace) else: raise RuntimeError("SANSSingleReduction: Cannot set the output workspace. The selected reduction " @@ -195,15 +195,15 @@ class SANSSingleReduction(SANSSingleReductionBase): # Find the LAB Can and HAB Can entries if they exist output_bundles = output_bundles[0] for output_bundle in output_bundles: - if output_bundle.data_type is DataType.Can: + if output_bundle.data_type is DataType.CAN: reduction_mode = output_bundle.reduction_mode output_workspace = output_bundle.output_workspace # Make sure that the output workspace is not None which can be the case if there has never been a # can set for the reduction. if output_workspace is not None and not does_can_workspace_exist_on_ads(output_workspace): - if reduction_mode is ISISReductionMode.LAB: + if reduction_mode is ReductionMode.LAB: self.setProperty("OutputWorkspaceLABCan", output_workspace) - elif reduction_mode is ISISReductionMode.HAB: + elif reduction_mode is ReductionMode.HAB: self.setProperty("OutputWorkspaceHABCan", output_bundle.output_workspace) else: raise RuntimeError("SANSSingleReduction: The reduction mode {0} should not" @@ -223,7 +223,7 @@ class SANSSingleReduction(SANSSingleReductionBase): # Find the partial output bundles fo LAB Can and HAB Can if they exist output_bundles_part = output_bundles_parts[0] for output_bundle_part in output_bundles_part: - if output_bundle_part.data_type is DataType.Can: + if output_bundle_part.data_type is DataType.CAN: reduction_mode = output_bundle_part.reduction_mode output_workspace_count = output_bundle_part.output_workspace_count output_workspace_norm = output_bundle_part.output_workspace_norm @@ -232,10 +232,10 @@ class SANSSingleReduction(SANSSingleReductionBase): if output_workspace_norm is not None and output_workspace_count is not None and \ not does_can_workspace_exist_on_ads(output_workspace_norm) and \ not does_can_workspace_exist_on_ads(output_workspace_count): - if reduction_mode is ISISReductionMode.LAB: + if reduction_mode is ReductionMode.LAB: self.setProperty("OutputWorkspaceLABCanCount", output_workspace_count) self.setProperty("OutputWorkspaceLABCanNorm", output_workspace_norm) - elif reduction_mode is ISISReductionMode.HAB: + elif reduction_mode is ReductionMode.HAB: self.setProperty("OutputWorkspaceHABCanCount", output_workspace_count) self.setProperty("OutputWorkspaceHABCanNorm", output_workspace_norm) else: @@ -255,27 +255,27 @@ class SANSSingleReduction(SANSSingleReductionBase): """ output_bundles = output_bundles[0] for output_bundle in output_bundles: - if output_bundle.data_type is DataType.Can: + if output_bundle.data_type is DataType.CAN: reduction_mode = output_bundle.reduction_mode output_workspace = output_bundle.output_workspace if output_workspace is not None and not does_can_workspace_exist_on_ads(output_workspace): - if reduction_mode is ISISReductionMode.LAB: + if reduction_mode is ReductionMode.LAB: self.setProperty("OutputWorkspaceLABCan", output_workspace) - elif reduction_mode is ISISReductionMode.HAB: + elif reduction_mode is ReductionMode.HAB: self.setProperty("OutputWorkspaceHABCan", output_bundle.output_workspace) else: raise RuntimeError("SANSSingleReduction: The reduction mode {0} should not" " be set with a can.".format(reduction_mode)) - elif output_bundle.data_type is DataType.Sample: + elif output_bundle.data_type is DataType.SAMPLE: reduction_mode = output_bundle.reduction_mode output_workspace = output_bundle.output_workspace if output_workspace is not None: - if reduction_mode is ISISReductionMode.LAB: + if reduction_mode is ReductionMode.LAB: self.setProperty("OutputWorkspaceLABSample", output_workspace) - elif reduction_mode is ISISReductionMode.HAB: + elif reduction_mode is ReductionMode.HAB: self.setProperty("OutputWorkspaceHABSample", output_bundle.output_workspace) else: raise RuntimeError("SANSSingleReduction: The reduction mode {0} should not" @@ -283,10 +283,10 @@ class SANSSingleReduction(SANSSingleReductionBase): def set_transmission_workspaces_on_output(self, transmission_bundles, fit_state): for transmission_bundle in transmission_bundles: - fit_performed = fit_state[DataType.to_string(transmission_bundle.data_type)].fit_type != FitType.NoFit + fit_performed = fit_state[transmission_bundle.data_type].fit_type != FitType.NO_FIT calculated_transmission_workspace = transmission_bundle.calculated_transmission_workspace unfitted_transmission_workspace = transmission_bundle.unfitted_transmission_workspace - if transmission_bundle.data_type is DataType.Can: + if transmission_bundle.data_type is DataType.CAN: if does_can_workspace_exist_on_ads(calculated_transmission_workspace): # The workspace is cloned here because the transmission runs are diagnostic output so even though # the values already exist they need to be labelled seperately for each reduction. @@ -296,7 +296,7 @@ class SANSSingleReduction(SANSSingleReductionBase): if fit_performed: self.setProperty("OutputWorkspaceCalculatedTransmissionCan", calculated_transmission_workspace) self.setProperty("OutputWorkspaceUnfittedTransmissionCan", unfitted_transmission_workspace) - elif transmission_bundle.data_type is DataType.Sample: + elif transmission_bundle.data_type is DataType.SAMPLE: if fit_performed: self.setProperty("OutputWorkspaceCalculatedTransmission", calculated_transmission_workspace) self.setProperty("OutputWorkspaceUnfittedTransmission", unfitted_transmission_workspace) diff --git a/Framework/PythonInterface/plugins/algorithms/WorkflowAlgorithms/SANS/SANSSingleReduction2.py b/Framework/PythonInterface/plugins/algorithms/WorkflowAlgorithms/SANS/SANSSingleReduction2.py index 8b566a12fcd7a869567cb92c682ecc0fea0d8c6f..aa6b8222e74596aaa755259514bee08c555d09ef 100644 --- a/Framework/PythonInterface/plugins/algorithms/WorkflowAlgorithms/SANS/SANSSingleReduction2.py +++ b/Framework/PythonInterface/plugins/algorithms/WorkflowAlgorithms/SANS/SANSSingleReduction2.py @@ -12,20 +12,20 @@ from __future__ import (absolute_import, division, print_function) from copy import deepcopy +from SANSSingleReductionBase import SANSSingleReductionBase + from mantid.api import (AlgorithmFactory, AnalysisDataService, MatrixWorkspaceProperty, PropertyMode, WorkspaceGroup, WorkspaceGroupProperty) -from mantid.simpleapi import CloneWorkspace from mantid.kernel import Direction +from mantid.simpleapi import CloneWorkspace from sans.algorithm_detail.bundles import EventSliceSettingBundle from sans.algorithm_detail.single_execution import (run_initial_event_slice_reduction, run_core_event_slice_reduction, get_reduction_mode_vs_output_bundles, run_optimized_for_can) -from sans.common.enums import (ReductionMode, DataType, ISISReductionMode, FitType) +from sans.common.enums import (DataType, ReductionMode, FitType) from sans.common.general_functions import (create_child_algorithm, does_can_workspace_exist_on_ads, get_transmission_output_name, get_output_name) -from SANSSingleReductionBase import SANSSingleReductionBase - class SANSSingleReduction(SANSSingleReductionBase): def category(self): @@ -168,7 +168,7 @@ class SANSSingleReduction(SANSSingleReductionBase): # If a can workspace has already been reduced with the same can # settings and is stored in the ADS, then we should use it # (provided the user has optimizations enabled). - if use_optimizations and slice_bundle.data_type is DataType.Can: + if use_optimizations and slice_bundle.data_type is DataType.CAN: output_bundle, output_parts_bundle, \ output_transmission_bundle = run_optimized_for_can(reduction_alg, slice_bundle, @@ -271,11 +271,11 @@ class SANSSingleReduction(SANSSingleReductionBase): continue else: AnalysisDataService.addOrReplace(output_name, output_workspace) - if reduction_mode is ReductionMode.Merged: + if reduction_mode is ReductionMode.MERGED: workspace_group_merged.addWorkspace(output_workspace) - elif reduction_mode is ISISReductionMode.LAB: + elif reduction_mode is ReductionMode.LAB: workspace_group_lab.addWorkspace(output_workspace) - elif reduction_mode is ISISReductionMode.HAB: + elif reduction_mode is ReductionMode.HAB: workspace_group_hab.addWorkspace(output_workspace) else: raise RuntimeError("SANSSingleReduction: Cannot set the output workspace. " @@ -301,7 +301,7 @@ class SANSSingleReduction(SANSSingleReductionBase): # Find the LAB Can and HAB Can entries if they exist for component_bundle in output_bundles: for output_bundle in component_bundle: - if output_bundle.data_type is DataType.Can: + if output_bundle.data_type is DataType.CAN: reduction_mode = output_bundle.reduction_mode output_workspace = output_bundle.output_workspace # Make sure that the output workspace is not None which can be the case if there has never been a @@ -310,9 +310,9 @@ class SANSSingleReduction(SANSSingleReductionBase): name = self._get_output_workspace_name(output_bundle.state, output_bundle.reduction_mode, can=True) AnalysisDataService.addOrReplace(name, output_workspace) - if reduction_mode is ISISReductionMode.LAB: + if reduction_mode is ReductionMode.LAB: workspace_group_lab_can.addWorkspace(output_workspace) - elif reduction_mode is ISISReductionMode.HAB: + elif reduction_mode is ReductionMode.HAB: workspace_group_hab_can.addWorkspace(output_workspace) else: raise RuntimeError("SANSSingleReduction: The reduction mode {0} should not" @@ -341,7 +341,7 @@ class SANSSingleReduction(SANSSingleReductionBase): # Find the partial output bundles fo LAB Can and HAB Can if they exist for event_slice_bundles in output_bundles_parts: for output_bundle_part in event_slice_bundles: - if output_bundle_part.data_type is DataType.Can: + if output_bundle_part.data_type is DataType.CAN: reduction_mode = output_bundle_part.reduction_mode output_workspace_count = output_bundle_part.output_workspace_count output_workspace_norm = output_bundle_part.output_workspace_norm @@ -355,10 +355,10 @@ class SANSSingleReduction(SANSSingleReductionBase): norm_name = name + "_hab_can_norm" AnalysisDataService.addOrReplace(count_name, output_workspace_count) AnalysisDataService.addOrReplace(norm_name, output_workspace_norm) - if reduction_mode is ISISReductionMode.LAB: + if reduction_mode is ReductionMode.LAB: workspace_group_lab_can_count.addWorkspace(output_workspace_count) workspace_group_lab_can_norm.addWorkspace(output_workspace_norm) - elif reduction_mode is ISISReductionMode.HAB: + elif reduction_mode is ReductionMode.HAB: workspace_group_hab_can_count.addWorkspace(output_workspace_count) workspace_group_hab_can_norm.addWorkspace(output_workspace_norm) else: @@ -391,7 +391,7 @@ class SANSSingleReduction(SANSSingleReductionBase): for component_bundle in output_bundles: for output_bundle in component_bundle: - if output_bundle.data_type is DataType.Can: + if output_bundle.data_type is DataType.CAN: reduction_mode = output_bundle.reduction_mode output_workspace = output_bundle.output_workspace @@ -399,14 +399,14 @@ class SANSSingleReduction(SANSSingleReductionBase): can_name = self._get_output_workspace_name(output_bundle.state, output_bundle.reduction_mode, can=True) AnalysisDataService.addOrReplace(can_name, output_workspace) - if reduction_mode is ISISReductionMode.LAB: + if reduction_mode is ReductionMode.LAB: workspace_group_lab_can.addWorkspace(output_workspace) - elif reduction_mode is ISISReductionMode.HAB: + elif reduction_mode is ReductionMode.HAB: workspace_group_hab_can.addWorkspace(output_workspace) else: raise RuntimeError("SANSSingleReduction: The reduction mode {0} should not" " be set with a can.".format(reduction_mode)) - elif output_bundle.data_type is DataType.Sample: + elif output_bundle.data_type is DataType.SAMPLE: reduction_mode = output_bundle.reduction_mode output_workspace = output_bundle.output_workspace @@ -414,9 +414,9 @@ class SANSSingleReduction(SANSSingleReductionBase): sample_name = self._get_output_workspace_name(output_bundle.state, output_bundle.reduction_mode, sample=True) AnalysisDataService.addOrReplace(sample_name, output_workspace) - if reduction_mode is ISISReductionMode.LAB: + if reduction_mode is ReductionMode.LAB: workspace_group_lab_sample.addWorkspace(output_workspace) - elif reduction_mode is ISISReductionMode.HAB: + elif reduction_mode is ReductionMode.HAB: workspace_group_hab_sample.addWorkspace(output_workspace) else: raise RuntimeError("SANSSingleReduction: The reduction mode {0} should not" @@ -433,10 +433,10 @@ class SANSSingleReduction(SANSSingleReductionBase): def set_transmission_workspaces_on_output(self, transmission_bundles, fit_state): for transmission_bundle in transmission_bundles: - fit_performed = fit_state[DataType.to_string(transmission_bundle.data_type)].fit_type != FitType.NoFit + fit_performed = fit_state[transmission_bundle.data_type].fit_type != FitType.NO_FIT calculated_transmission_workspace = transmission_bundle.calculated_transmission_workspace unfitted_transmission_workspace = transmission_bundle.unfitted_transmission_workspace - if transmission_bundle.data_type is DataType.Can: + if transmission_bundle.data_type is DataType.CAN: if does_can_workspace_exist_on_ads(calculated_transmission_workspace): # The workspace is cloned here because the transmission runs are diagnostic output so even though # the values already exist they need to be labelled seperately for each reduction. @@ -447,7 +447,7 @@ class SANSSingleReduction(SANSSingleReductionBase): if fit_performed: self.setProperty("OutputWorkspaceCalculatedTransmissionCan", calculated_transmission_workspace) self.setProperty("OutputWorkspaceUnfittedTransmissionCan", unfitted_transmission_workspace) - elif transmission_bundle.data_type is DataType.Sample: + elif transmission_bundle.data_type is DataType.SAMPLE: if fit_performed: self.setProperty("OutputWorkspaceCalculatedTransmission", calculated_transmission_workspace) self.setProperty("OutputWorkspaceUnfittedTransmission", unfitted_transmission_workspace) @@ -473,7 +473,7 @@ class SANSSingleReduction(SANSSingleReductionBase): # Find the sample in the data collection state, reduction_mode = next(((output_bundle.state, output_bundle.reduction_mode) for output_bundle in output_bundles - if output_bundle.data_type == DataType.Sample), None) + if output_bundle.data_type == DataType.SAMPLE), None) # Get the workspace name name = self._get_output_workspace_name(state, reduction_mode=reduction_mode) @@ -488,7 +488,7 @@ class SANSSingleReduction(SANSSingleReductionBase): :return: a workspace name """ state = output_parts_bundle[0].state - return self._get_output_workspace_name(state, reduction_mode=ReductionMode.Merged) + return self._get_output_workspace_name(state, reduction_mode=ReductionMode.MERGED) def _get_output_workspace_name(self, state, reduction_mode=None, data_type=None, can=False, sample=False, transmission=False, fitted=False): @@ -496,7 +496,7 @@ class SANSSingleReduction(SANSSingleReductionBase): Get the output names for the sliced workspaces (within the group workspaces, which are already named). :param state: a SANS state object - :param reduction_mode: an optional ISISReductionMode enum: "HAB", "LAB", "Merged", or "All" + :param reduction_mode: an optional ReductionMode enum: "HAB", "LAB", "Merged", or "All" :param data_type: an optional DataType enum: "Sample" or "Can" :param can: optional bool. If true then creating name for a can workspace :param sample: optional bool. If true then creating name for a sample workspace. Sample and can cannot both be @@ -512,14 +512,14 @@ class SANSSingleReduction(SANSSingleReductionBase): if not transmission: _suffix = "" if can: - if reduction_mode == ISISReductionMode.HAB: + if reduction_mode == ReductionMode.HAB: _suffix = "_hab_can" - elif reduction_mode == ISISReductionMode.LAB: + elif reduction_mode == ReductionMode.LAB: _suffix = "_lab_can" elif sample: - if reduction_mode == ISISReductionMode.HAB: + if reduction_mode == ReductionMode.HAB: _suffix = "_hab_sample" - elif reduction_mode == ISISReductionMode.LAB: + elif reduction_mode == ReductionMode.LAB: _suffix = "_lab_sample" return get_output_name(state, reduction_mode, True, suffix=_suffix, multi_reduction_type=_multi)[0] else: diff --git a/Framework/PythonInterface/plugins/algorithms/WorkflowAlgorithms/SANS/SANSSingleReductionBase.py b/Framework/PythonInterface/plugins/algorithms/WorkflowAlgorithms/SANS/SANSSingleReductionBase.py index 1867b034e48b6d408291440014dc371ece5c6e45..e2bb4ba5c5a8a7289a8faad319264cc3c1bbd1ad 100644 --- a/Framework/PythonInterface/plugins/algorithms/WorkflowAlgorithms/SANS/SANSSingleReductionBase.py +++ b/Framework/PythonInterface/plugins/algorithms/WorkflowAlgorithms/SANS/SANSSingleReductionBase.py @@ -19,7 +19,7 @@ from sans.algorithm_detail.bundles import ReductionSettingBundle from sans.algorithm_detail.single_execution import (get_final_output_workspaces, get_merge_bundle_for_merge_request) from sans.algorithm_detail.strip_end_nans_and_infs import strip_end_nans -from sans.common.enums import (ReductionMode, DataType, ISISReductionMode) +from sans.common.enums import (DataType, ReductionMode) from sans.common.general_functions import create_child_algorithm from sans.state.state_base import create_deserialized_sans_state_from_property_manager @@ -147,22 +147,22 @@ class SANSSingleReductionBase(DistributedDataProcessorAlgorithm): # Merge if required with stitching etc. scale_factors = [] shift_factors = [] - if overall_reduction_mode is ReductionMode.Merged: + if overall_reduction_mode is ReductionMode.MERGED: progress.report("Merging reductions ...") for i, event_slice_part_bundle in enumerate(output_parts_bundles): merge_bundle = get_merge_bundle_for_merge_request(event_slice_part_bundle, self) scale_factors.append(merge_bundle.scale) shift_factors.append(merge_bundle.shift) - reduction_mode_vs_output_workspaces[ReductionMode.Merged].append(merge_bundle.merged_workspace) + reduction_mode_vs_output_workspaces[ReductionMode.MERGED].append(merge_bundle.merged_workspace) merged_name = self._get_merged_workspace_name(event_slice_part_bundle) - reduction_mode_vs_workspace_names[ReductionMode.Merged].append(merged_name) + reduction_mode_vs_workspace_names[ReductionMode.MERGED].append(merged_name) scaled_HAB = strip_end_nans(merge_bundle.scaled_hab_workspace, self) - reduction_mode_vs_output_workspaces[ISISReductionMode.HAB].append(scaled_HAB) + reduction_mode_vs_output_workspaces[ReductionMode.HAB].append(scaled_HAB) # Get HAB workspace name state = event_slice_part_bundle[0].state - hab_name = self._get_output_workspace_name(state, reduction_mode=ISISReductionMode.HAB) - reduction_mode_vs_workspace_names[ISISReductionMode.HAB].append(hab_name) + hab_name = self._get_output_workspace_name(state, reduction_mode=ReductionMode.HAB) + reduction_mode_vs_workspace_names[ReductionMode.HAB].append(hab_name) self.set_shift_and_scale_output(scale_factors, shift_factors) @@ -246,15 +246,15 @@ class SANSSingleReductionBase(DistributedDataProcessorAlgorithm): def _get_reduction_setting_bundles(self, state, reduction_mode): # We need to output the parts if we request a merged reduction mode. This is necessary for stitching later on. - output_parts = reduction_mode is ReductionMode.Merged + output_parts = reduction_mode is ReductionMode.MERGED # If the reduction mode is MERGED, then we need to make sure that all reductions for that selection # are executed, i.e. we need to split it up - if reduction_mode is ReductionMode.Merged: + if reduction_mode is ReductionMode.MERGED: # If we are dealing with a merged reduction we need to know which detectors should be merged. reduction_info = state.reduction reduction_modes = reduction_info.get_merge_strategy() - elif reduction_mode is ReductionMode.All: + elif reduction_mode is ReductionMode.ALL: reduction_info = state.reduction reduction_modes = reduction_info.get_all_reduction_modes() else: @@ -262,7 +262,7 @@ class SANSSingleReductionBase(DistributedDataProcessorAlgorithm): # Create the Scatter information sample_info = self._create_reduction_bundles_for_data_type(state=state, - data_type=DataType.Sample, + data_type=DataType.SAMPLE, reduction_modes=reduction_modes, output_parts=output_parts, scatter_name="SampleScatterWorkspace", @@ -272,7 +272,7 @@ class SANSSingleReductionBase(DistributedDataProcessorAlgorithm): # Create the Can information can_info = self._create_reduction_bundles_for_data_type(state=state, - data_type=DataType.Can, + data_type=DataType.CAN, reduction_modes=reduction_modes, output_parts=output_parts, scatter_name="CanScatterWorkspace", @@ -312,6 +312,6 @@ class SANSSingleReductionBase(DistributedDataProcessorAlgorithm): return reduction_setting_bundles def _get_progress(self, number_of_reductions, overall_reduction_mode): - number_from_merge = 1 if overall_reduction_mode is ReductionMode.Merged else 0 + number_from_merge = 1 if overall_reduction_mode is ReductionMode.MERGED else 0 number_of_progress_reports = number_of_reductions + number_from_merge + 1 return Progress(self, start=0.0, end=1.0, nreports=number_of_progress_reports) diff --git a/Framework/PythonInterface/plugins/algorithms/WorkflowAlgorithms/TotScatCalculateSelfScattering.py b/Framework/PythonInterface/plugins/algorithms/WorkflowAlgorithms/TotScatCalculateSelfScattering.py new file mode 100644 index 0000000000000000000000000000000000000000..95b7f2e907ce2161f47c5b0d476d3f17ff09b0b6 --- /dev/null +++ b/Framework/PythonInterface/plugins/algorithms/WorkflowAlgorithms/TotScatCalculateSelfScattering.py @@ -0,0 +1,111 @@ +# Mantid Repository : https://github.com/mantidproject/mantid +# +# Copyright © 2019 ISIS Rutherford Appleton Laboratory UKRI, +# NScD Oak Ridge National Laboratory, European Spallation Source +# & Institut Laue - Langevin +# SPDX - License - Identifier: GPL - 3.0 + +from __future__ import (absolute_import, division, print_function) + +from mantid.simpleapi import (CalculatePlaczekSelfScattering, ConvertToDistribution, ConvertUnits, CreateWorkspace, + DeleteWorkspace, DiffractionFocussing, Divide, ExtractSpectra, FitIncidentSpectrum, + LoadCalFile, SetSample) +from mantid.api import (AlgorithmFactory, DataProcessorAlgorithm, FileAction, FileProperty, WorkspaceProperty) +from mantid.kernel import Direction +import numpy as np + + +class TotScatCalculateSelfScattering(DataProcessorAlgorithm): + + def name(self): + return 'TotScatCalculateSelfScattering CalculateSelfScatteringCorrection' + + def category(self): + return "Workflow\\Diffraction" + + def seeAlso(self): + return [] + + def summary(self): + return "Calculates the self scattering correction factor for total scattering data." + + def checkGroups(self): + return False + + def version(self): + return 1 + + def PyInit(self): + self.declareProperty(WorkspaceProperty('InputWorkspace', '', direction=Direction.Input), + doc='Raw workspace.') + self.declareProperty(WorkspaceProperty('OutputWorkspace', '', direction=Direction.Output), + doc='Focused corrected workspace.') + self.declareProperty(FileProperty("CalFileName", "", + direction=Direction.Input, + action=FileAction.Load), + doc='File path for the instrument calibration file.') + self.declareProperty(name='SampleGeometry', defaultValue={}, + doc='Geometry of the sample material.') + self.declareProperty(name='SampleMaterial', defaultValue={}, + doc='Chemical formula for the sample material.') + + def PyExec(self): + raw_ws = self.getProperty('InputWorkspace').value + sample_geometry = self.getPropertyValue('SampleGeometry') + sample_material = self.getPropertyValue('SampleMaterial') + cal_file_name = self.getPropertyValue('CalFileName') + SetSample(InputWorkspace=raw_ws, + Geometry=sample_geometry, + Material=sample_material) + # find the closest monitor to the sample for incident spectrum + raw_spec_info = raw_ws.spectrumInfo() + incident_index = None + for i in range(raw_spec_info.size()): + if raw_spec_info.isMonitor(i): + l2 = raw_spec_info.position(i)[2] + if not incident_index: + incident_index = i + else: + if raw_spec_info.position(incident_index)[2] < l2 < 0: + incident_index = i + monitor = ExtractSpectra(InputWorkspace=raw_ws, WorkspaceIndexList=[incident_index]) + monitor = ConvertUnits(InputWorkspace=monitor, Target="Wavelength") + x_data = monitor.dataX(0) + min_x = np.min(x_data) + max_x = np.max(x_data) + width_x = (max_x - min_x) / x_data.size + fit_spectra = FitIncidentSpectrum(InputWorkspace=monitor, + BinningForCalc=[min_x, 1 * width_x, max_x], + BinningForFit=[min_x, 10 * width_x, max_x], + FitSpectrumWith="CubicSpline") + self_scattering_correction = CalculatePlaczekSelfScattering(InputWorkspace=raw_ws, + IncidentSpecta=fit_spectra) + cal_workspace = LoadCalFile(InputWorkspace=self_scattering_correction, + CalFileName=cal_file_name, + Workspacename='cal_workspace', + MakeOffsetsWorkspace=False, + MakeMaskWorkspace=False) + self_scattering_correction = DiffractionFocussing(InputWorkspace=self_scattering_correction, + GroupingFilename=cal_file_name) + + n_pixel = np.zeros(self_scattering_correction.getNumberHistograms()) + + for i in range(cal_workspace.getNumberHistograms()): + grouping = cal_workspace.dataY(i) + if grouping[0] > 0: + n_pixel[int(grouping[0] - 1)] += 1 + correction_ws = CreateWorkspace(DataY=n_pixel, DataX=[0, 1], + NSpec=self_scattering_correction.getNumberHistograms()) + self_scattering_correction = Divide(LHSWorkspace=self_scattering_correction, RHSWorkspace=correction_ws) + ConvertToDistribution(Workspace=self_scattering_correction) + self_scattering_correction = ConvertUnits(InputWorkspace=self_scattering_correction, + Target="MomentumTransfer", EMode='Elastic') + DeleteWorkspace('cal_workspace_group') + DeleteWorkspace(correction_ws) + DeleteWorkspace(fit_spectra) + DeleteWorkspace(monitor) + DeleteWorkspace(raw_ws) + self.setProperty('OutputWorkspace', self_scattering_correction) + + +# Register algorithm with Mantid +AlgorithmFactory.subscribe(TotScatCalculateSelfScattering) diff --git a/Framework/PythonInterface/test/python/mantid/plots/plots__init__Test.py b/Framework/PythonInterface/test/python/mantid/plots/plots__init__Test.py index 5a8b3ceac4d6dc1f0d346c6db1a1b4fb33e7896e..b4b86f23126e5985f6e0545b2688a14fd6b1cf72 100644 --- a/Framework/PythonInterface/test/python/mantid/plots/plots__init__Test.py +++ b/Framework/PythonInterface/test/python/mantid/plots/plots__init__Test.py @@ -174,6 +174,26 @@ class Plots__init__Test(unittest.TestCase): # try deleting self.ax.remove_workspace_artists(plot_data) + def test_replace_workspace_data_plot_with_fewer_spectra(self): + plot_data = CreateWorkspace(DataX=[10, 20, 30, 10, 20, 30, 10, 20, 30], + DataY=[3, 4, 5, 3, 4, 5], + DataE=[1, 2, 3, 4, 1, 1], + NSpec=3) + line_ws2d_histo_spec_1 = self.ax.plot(plot_data, specNum=1, color='r')[0] + line_ws2d_histo_spec_2 = self.ax.plot(plot_data, specNum=2, color='r')[0] + line_ws2d_histo_spec_3 = self.ax.plot(plot_data, specNum=3, color='r')[0] + + plot_data = CreateWorkspace(DataX=[20, 30, 40, 20, 30, 40], + DataY=[3, 4, 3, 4], + DataE=[1, 2, 1, 2], + NSpec=2) + self.ax.replace_workspace_artists(plot_data) + self.assertAlmostEqual(25, line_ws2d_histo_spec_2.get_xdata()[0]) + self.assertAlmostEqual(35, line_ws2d_histo_spec_2.get_xdata()[-1]) + self.assertEqual('r', line_ws2d_histo_spec_2.get_color()) + # try deleting + self.ax.remove_workspace_artists(plot_data) + def test_replace_workspace_data_errorbar(self): eb_data = CreateWorkspace(DataX=[10, 20, 30, 10, 20, 30, 10, 20, 30], DataY=[3, 4, 5, 3, 4, 5], @@ -312,7 +332,7 @@ class Plots__init__Test(unittest.TestCase): self.assertTrue(ws_artists[1].is_normalized) self.assertTrue(ws_artists[2].is_normalized) - def test_artists_normalization_state_labeled_correctly_for_2d_plots_of_non_dist_workspace(self): + def test_artists_normalization_labeled_correctly_for_2d_plots_of_non_dist_workspace_and_dist_argument_false(self): plot_funcs = ['imshow', 'pcolor', 'pcolormesh', 'pcolorfast', 'tripcolor', 'contour', 'contourf', 'tricontour', 'tricontourf'] non_dist_2d_ws = CreateWorkspace(DataX=[10, 20, 10, 20], @@ -327,10 +347,32 @@ class Plots__init__Test(unittest.TestCase): self.assertTrue(self.ax.tracked_workspaces[non_dist_2d_ws.name()][0].is_normalized) del self.ax.tracked_workspaces[non_dist_2d_ws.name()] + def test_artists_normalization_labeled_correctly_for_2d_plots_of_non_dist_workspace_and_dist_argument_true(self): + plot_funcs = ['imshow', 'pcolor', 'pcolormesh', 'pcolorfast', 'tripcolor', + 'contour', 'contourf', 'tricontour', 'tricontourf'] + non_dist_2d_ws = CreateWorkspace(DataX=[10, 20, 10, 20], + DataY=[2, 3, 2, 3], + DataE=[1, 2, 1, 2], + NSpec=2, + Distribution=False, + OutputWorkspace='non_dist_workpace') + for plot_func in plot_funcs: + func = getattr(self.ax, plot_func) func(non_dist_2d_ws, distribution=True) self.assertFalse(self.ax.tracked_workspaces[non_dist_2d_ws.name()][0].is_normalized) del self.ax.tracked_workspaces[non_dist_2d_ws.name()] + def test_artists_normalization_labeled_correctly_for_2d_plots_of_non_dist_workspace_and_global_setting_on(self): + plot_funcs = ['imshow', 'pcolor', 'pcolormesh', 'pcolorfast', 'tripcolor', + 'contour', 'contourf', 'tricontour', 'tricontourf'] + non_dist_2d_ws = CreateWorkspace(DataX=[10, 20, 10, 20], + DataY=[2, 3, 2, 3], + DataE=[1, 2, 1, 2], + NSpec=2, + Distribution=False, + OutputWorkspace='non_dist_workpace') + for plot_func in plot_funcs: + func = getattr(self.ax, plot_func) func(non_dist_2d_ws) auto_dist = (config['graph1d.autodistribution'] == 'On') self.assertEqual(auto_dist, self.ax.tracked_workspaces[non_dist_2d_ws.name()][0].is_normalized) diff --git a/Framework/PythonInterface/test/python/plugins/algorithms/AbinsBasicTest.py b/Framework/PythonInterface/test/python/plugins/algorithms/AbinsBasicTest.py index 3aae7bdf0211bd9509578e55c348a3f6032c1713..7ee4639f12047ff38ddf12f6571197cdcac9e7e2 100644 --- a/Framework/PythonInterface/test/python/plugins/algorithms/AbinsBasicTest.py +++ b/Framework/PythonInterface/test/python/plugins/algorithms/AbinsBasicTest.py @@ -35,7 +35,7 @@ class AbinsBasicTest(unittest.TestCase): def tearDown(self): AbinsTestHelpers.remove_output_files(list_of_names=["explicit", "default", "total", "squaricn_sum_Abins", "squaricn_scale", "benzene_exp", "benzene_Abins", - "experimental"]) + "experimental", "numbered"]) mtd.clear() def test_wrong_input(self): diff --git a/Framework/PythonInterface/test/python/plugins/algorithms/ExportExperimentLogTest.py b/Framework/PythonInterface/test/python/plugins/algorithms/ExportExperimentLogTest.py index 926026df7bda6f68ec81d21ad74568739ac3e6ea..8b30a75b4a1050e1e4ee1149354b20b086400085 100644 --- a/Framework/PythonInterface/test/python/plugins/algorithms/ExportExperimentLogTest.py +++ b/Framework/PythonInterface/test/python/plugins/algorithms/ExportExperimentLogTest.py @@ -75,6 +75,7 @@ class ExportExperimentLogTest(unittest.TestCase): self.assertAlmostEqual(avgpcharge, v4) # Remove generated files + os.remove(outfilename) AnalysisDataService.remove("TestMatrixWS") @@ -696,7 +697,8 @@ class ExportExperimentLogTest(unittest.TestCase): self.assertAlmostEqual(avgpcharge, v5) # Remove generated files - # os.remove(outfilename) + os.remove(outfilename) + AnalysisDataService.remove("TestMatrixWS") return diff --git a/Framework/PythonInterface/test/python/plugins/algorithms/LoadNMoldyn3AsciiTest.py b/Framework/PythonInterface/test/python/plugins/algorithms/LoadNMoldyn3AsciiTest.py index e5756067c010616e269651bd473a167cb9ac7f1f..994e5f7d8c6a22ef006d9c3932ff4b154ba04744 100644 --- a/Framework/PythonInterface/test/python/plugins/algorithms/LoadNMoldyn3AsciiTest.py +++ b/Framework/PythonInterface/test/python/plugins/algorithms/LoadNMoldyn3AsciiTest.py @@ -21,6 +21,7 @@ class LoadNMoldyn3AsciiTest(unittest.TestCase): """ Load an F(Q, t) function from an nMOLDYN 3 .cdl file """ + moldyn_group = LoadNMoldyn3Ascii(Filename=self._cdl_filename, Functions=['Fqt-total'], OutputWorkspace='__LoadNMoldyn3Ascii_test') @@ -82,6 +83,12 @@ class LoadNMoldyn3AsciiTest(unittest.TestCase): self.assertTrue(isinstance(moldyn_ws, MatrixWorkspace)) self.assertTrue(moldyn_ws.getNumberHistograms(), 12) + workdir = config['defaultsave.directory'] + filename = 'MolDyn_angles.txt' + path = os.path.join(workdir, filename) + if os.path.exists(path): + os.remove(path) + def test_function_validation_cdl(self): """ diff --git a/Framework/PythonInterface/test/python/plugins/algorithms/MergeCalFilesTest.py b/Framework/PythonInterface/test/python/plugins/algorithms/MergeCalFilesTest.py index 717a5d4dd742d36fb77bc5b60f2c035c425724e8..1a568a8ce5c049bc126430d137fd0627e29ff00a 100644 --- a/Framework/PythonInterface/test/python/plugins/algorithms/MergeCalFilesTest.py +++ b/Framework/PythonInterface/test/python/plugins/algorithms/MergeCalFilesTest.py @@ -9,6 +9,7 @@ from __future__ import (absolute_import, division, print_function) import unittest import os import tempfile +import shutil from mantid.kernel import * from mantid.api import * from mantid.simpleapi import * @@ -53,10 +54,11 @@ A helper resource managing wrapper over a new calfile object. Creates cal file a class DisposableCalFileObject: _fullpath = None + _dirpath = None def __init__(self, name): - dirpath = tempfile.mkdtemp() - self._fullpath = os.path.join(dirpath, name) + self._dirpath = tempfile.mkdtemp() + self._fullpath = os.path.join(self._dirpath, name) file = open(self._fullpath, 'w') file.close() @@ -67,6 +69,8 @@ class DisposableCalFileObject: def __del__(self): os.remove(self._fullpath) + if os.path.exists(self._dirpath): + shutil.rmtree(self._dirpath) def getPath(self): return self._fullpath @@ -76,19 +80,20 @@ A helper resource managing wrapper over an existing cal file for reading. Dispos class ReadableCalFileObject: _fullpath = None + _dirpath = None - def __init__(self, fullpath): + def __init__(self, dirpath, filename): + fullpath = os.path.join(dirpath, filename) if not os.path.exists(fullpath): raise RuntimeError("No readable cal file at location: " + fullpath) else: self._fullpath = fullpath + self._dirpath = dirpath def __del__(self): pass os.remove(self._fullpath) - - def getPath(self): - return _fullpath + shutil.rmtree(self._dirpath) def readline(self): result = None @@ -127,8 +132,9 @@ class MergeCalFilesTest(unittest.TestCase): OutputFile=outputfilestring, MergeOffsets=mergeOffsets, MergeSelections=mergeSelect, MergeGroups=mergeGroups) # Read the results file and return the first line as a CalFileEntry - outputfile = ReadableCalFileObject(outputfilestring) + outputfile = ReadableCalFileObject(dirpath,"product.cal") firstLineOutput = outputfile.readline() + return firstLineOutput def test_replace_nothing(self): diff --git a/Framework/PythonInterface/test/python/plugins/algorithms/SaveNexusPDTest.py b/Framework/PythonInterface/test/python/plugins/algorithms/SaveNexusPDTest.py index 0bab9a2a6acfaabd11b9159ea6096274de7b1938..3a2fccde45d167820914bc888be717dfd900083e 100644 --- a/Framework/PythonInterface/test/python/plugins/algorithms/SaveNexusPDTest.py +++ b/Framework/PythonInterface/test/python/plugins/algorithms/SaveNexusPDTest.py @@ -26,7 +26,7 @@ class SaveNexusPDTest(unittest.TestCase): dataDir = mantid.config.getString('defaultsave.directory') return os.path.join(dataDir, wkspname+'.h5') - def cleanup(self, wkspname, filename): + def cleanup(self, filename, wkspname): if os.path.exists(filename): os.remove(filename) if mantid.mtd.doesExist(wkspname): diff --git a/Framework/PythonInterface/test/python/plugins/algorithms/WorkflowAlgorithms/ApplyPaalmanPingsCorrectionTest.py b/Framework/PythonInterface/test/python/plugins/algorithms/WorkflowAlgorithms/ApplyPaalmanPingsCorrectionTest.py index 35f1fbd6c60d70e60c109153225a742caee321b6..ccd23502c03cb7919597ff9eebe953c4805c9fc5 100644 --- a/Framework/PythonInterface/test/python/plugins/algorithms/WorkflowAlgorithms/ApplyPaalmanPingsCorrectionTest.py +++ b/Framework/PythonInterface/test/python/plugins/algorithms/WorkflowAlgorithms/ApplyPaalmanPingsCorrectionTest.py @@ -53,7 +53,6 @@ class ApplyPaalmanPingsCorrectionTest(unittest.TestCase): self._corrections_ws = corrections - def tearDown(self): """ Remove workspaces from ADS. @@ -63,7 +62,6 @@ class ApplyPaalmanPingsCorrectionTest(unittest.TestCase): DeleteWorkspace(mtd['can_ws']) DeleteWorkspace(self._corrections_ws) - def _verify_workspace(self, ws, correction_type): """ Do validation on a correction workspace. @@ -85,11 +83,10 @@ class ApplyPaalmanPingsCorrectionTest(unittest.TestCase): log_correction_type = logs['corrections_type'].value self.assertEqual(log_correction_type, correction_type) - def _create_group_of_factors(self, corrections, factors): def is_factor(workspace, factor): if factor == "ass": - return factor in workspace.name() and not 'assc' in workspace.name() + return factor in workspace.name() and 'assc' not in workspace.name() else: return factor in workspace.name() @@ -103,14 +100,12 @@ class ApplyPaalmanPingsCorrectionTest(unittest.TestCase): correction_names = [correction.name() for correction in cloned_corr] return GroupWorkspaces(InputWorkspaces=correction_names, OutputWorkspace="factor_group") - def test_can_subtraction(self): corr = ApplyPaalmanPingsCorrection(SampleWorkspace=self._sample_ws, CanWorkspace=self._can_ws) self._verify_workspace(corr, 'can_subtraction') - def test_can_subtraction_with_can_scale(self): corr = ApplyPaalmanPingsCorrection(SampleWorkspace=self._sample_ws, CanWorkspace=self._can_ws, @@ -125,14 +120,12 @@ class ApplyPaalmanPingsCorrectionTest(unittest.TestCase): self._verify_workspace(corr, 'can_subtraction') - def test_sample_corrections_only(self): corr = ApplyPaalmanPingsCorrection(SampleWorkspace=self._sample_ws, CorrectionsWorkspace=self._corrections_ws) self._verify_workspace(corr, 'sample_corrections_only') - def test_sample_and_can_corrections(self): corr = ApplyPaalmanPingsCorrection(SampleWorkspace=self._sample_ws, CorrectionsWorkspace=self._corrections_ws, @@ -140,7 +133,6 @@ class ApplyPaalmanPingsCorrectionTest(unittest.TestCase): self._verify_workspace(corr, 'sample_and_can_corrections') - def test_sample_and_can_corrections_with_can_scale(self): corr = ApplyPaalmanPingsCorrection(SampleWorkspace=self._sample_ws, CorrectionsWorkspace=self._corrections_ws, @@ -225,5 +217,29 @@ class ApplyPaalmanPingsCorrectionTest(unittest.TestCase): DeleteWorkspace(sample_1) DeleteWorkspace(container_1) + def test_fixed_window_scan(self): + Load(Filename='ILL/IN16B/mc-abs-corr-q.nxs', OutputWorkspace='fws_corrections_ass') + GroupWorkspaces(InputWorkspaces=['fws_corrections_ass'], OutputWorkspace='fws_corrections') + in_ws = Load(Filename='ILL/IN16B/fapi-fws-q.nxs', OutputWorkspace='fapi') + out_ws = ApplyPaalmanPingsCorrection(SampleWorkspace=in_ws[0], CorrectionsWorkspace='fws_corrections', + OutputWorkspace='wsfapicorr') + self.assertTrue(out_ws) + self.assertTrue(isinstance(out_ws, MatrixWorkspace)) + self.assertEquals(out_ws.blocksize(), in_ws[0].blocksize()) + self.assertEquals(out_ws.getNumberHistograms(), in_ws[0].getNumberHistograms()) + self.assertEquals(out_ws.getAxis(0).getUnit().unitID(), in_ws[0].getAxis(0).getUnit().unitID()) + + def test_group_raise(self): + Load(Filename='ILL/IN16B/mc-abs-corr-q.nxs', OutputWorkspace='fws_corrections_ass') + GroupWorkspaces(InputWorkspaces=['fws_corrections_ass'], OutputWorkspace='fws_corrections') + in_ws = Load(Filename='ILL/IN16B/fapi-fws-q.nxs', OutputWorkspace='fapi') + kwargs = { + 'SampleWorkspace': in_ws, + 'CorrectionsWorkspace': 'fws_corrections', + 'OutputWorkspaced': 'wsfapicorr' + } + self.assertRaises(RuntimeError, ApplyPaalmanPingsCorrection, **kwargs) + + if __name__=="__main__": unittest.main() diff --git a/Framework/PythonInterface/test/python/plugins/algorithms/WorkflowAlgorithms/MatchAndMergeWorkspacesTest.py b/Framework/PythonInterface/test/python/plugins/algorithms/WorkflowAlgorithms/MatchAndMergeWorkspacesTest.py new file mode 100644 index 0000000000000000000000000000000000000000..74ea740da32c09332e0b66a9cf968e6ffdbc9521 --- /dev/null +++ b/Framework/PythonInterface/test/python/plugins/algorithms/WorkflowAlgorithms/MatchAndMergeWorkspacesTest.py @@ -0,0 +1,94 @@ +# Mantid Repository : https://github.com/mantidproject/mantid +# +# Copyright © 2019 ISIS Rutherford Appleton Laboratory UKRI, +# NScD Oak Ridge National Laboratory, European Spallation Source +# & Institut Laue - Langevin +# SPDX - License - Identifier: GPL - 3.0 + +from __future__ import (absolute_import, division, print_function) + +import unittest +import numpy as np +from mantid.api import MatrixWorkspace +from mantid.simpleapi import (AnalysisDataService, ConjoinWorkspaces, CreateWorkspace, MatchAndMergeWorkspaces, + GroupWorkspaces, DeleteWorkspace) + + +class MatchAndMergeWorkspacesTest(unittest.TestCase): + + def setUp(self): + ws_list = [] + for i in range(5): + ws_name = 'ws_' + str(i+1) + data_x = np.arange(i, (i+1)*10+0.1, 0.1) + data_y = np.arange(i, (i+1)*10, 0.1) + data_e = np.arange(i, (i+1)*10, 0.1) + CreateWorkspace(OutputWorkspace=ws_name, DataX=data_x, DataY=data_y, DataE=data_e) + ws_list.append(ws_name) + GroupWorkspaces(InputWorkspaces=ws_list, OutputWorkspace='ws_group') + + def test_MatchAndMergeWorkspaces_executes(self): + x_min = np.array([0, 5, 10, 15, 20]) + x_max = np.array([10, 20, 30, 40, 50]) + ws_merged = MatchAndMergeWorkspaces(InputWorkspaces='ws_group', XMin=x_min, XMax=x_max) + self.assertIsInstance(ws_merged, MatrixWorkspace) + self.assertEqual(ws_merged.getNumberHistograms(), 1) + self.assertAlmostEqual(min(ws_merged.dataX(0)), 0, places=0) + self.assertAlmostEqual(max(ws_merged.dataX(0)), 50, places=0) + + def test_MatchAndMergeWorkspaces_produces_correct_range(self): + x_min = np.array([2, 5, 10, 15, 20]) + x_max = np.array([10, 20, 30, 40, 45]) + ws_merged = MatchAndMergeWorkspaces(InputWorkspaces='ws_group', XMin=x_min, XMax=x_max) + self.assertIsInstance(ws_merged, MatrixWorkspace) + self.assertEqual(ws_merged.getNumberHistograms(), 1) + self.assertAlmostEqual(min(ws_merged.dataX(0)), 2, places=0) + self.assertAlmostEqual(max(ws_merged.dataX(0)), 45, places=0) + + def test_MatchAndMergeWorkspaces_accepts_a_list_of_workspaces(self): + x_min = np.array([2, 5, 10]) + x_max = np.array([10, 20, 30]) + ws_group = AnalysisDataService.retrieve('ws_group') + ws_list = [ws_group[0], ws_group[1], ws_group[2]] + ws_merged = MatchAndMergeWorkspaces(InputWorkspaces=ws_list, XMin=x_min, XMax=x_max) + self.assertIsInstance(ws_merged, MatrixWorkspace) + self.assertEqual(ws_merged.getNumberHistograms(), 1) + self.assertAlmostEqual(min(ws_merged.dataX(0)), 2, places=0) + self.assertAlmostEqual(max(ws_merged.dataX(0)), 30, places=0) + + def test_MatchAndMergeWorkspaces_accepts_a_mixture_of_ws_size(self): + x_min = np.array([2, 5, 10, 15, 20]) + x_max = np.array([10, 20, 30, 40, 45]) + ws_group = AnalysisDataService.retrieve('ws_group') + ConjoinWorkspaces(InputWorkspace1=ws_group[3], + InputWorkspace2=ws_group[4], + CheckOverlapping=False) + ws_list = [ws_group[0], ws_group[1], ws_group[2], ws_group[3]] + ws_merged = MatchAndMergeWorkspaces(InputWorkspaces=ws_list, XMin=x_min, XMax=x_max) + self.assertIsInstance(ws_merged, MatrixWorkspace) + self.assertEqual(ws_merged.getNumberHistograms(), 1) + self.assertAlmostEqual(min(ws_merged.dataX(0)), 2, places=0) + self.assertAlmostEqual(max(ws_merged.dataX(0)), 45, places=0) + + def test_MatchAndMergeWorkspaces_fails_with_wrong_number_min_limits(self): + x_min = np.array([0]) + x_max = np.array([10, 20, 30, 40, 50]) + self.assertRaises(RuntimeError, MatchAndMergeWorkspaces, InputWorkspaces='ws_group', XMin=x_min, XMax=x_max) + + def test_MatchAndMergeWorkspaces_fails_with_wrong_number_max_limits(self): + x_min = np.array([0, 5, 10, 15, 20]) + x_max = np.array([10]) + self.assertRaises(RuntimeError, MatchAndMergeWorkspaces, InputWorkspaces='ws_group', XMin=x_min, XMax=x_max) + + def test_MatchAndMergeWorkspaces_fails_with_wrong_ws_input(self): + x_min = np.array([0, 5, 10, 15, 20]) + x_max = np.array([10, 20, 30, 40, 50]) + self.assertRaises(ValueError, MatchAndMergeWorkspaces, InputWorkspaces='fake_group', XMin=x_min, XMax=x_max) + + def test_MatchAndMergeWorkspaces_fails_with_min_larger_than_max(self): + x_min = np.array([10, 20, 30, 40, 50]) + x_max = np.array([0, 5, 10, 15, 20]) + self.assertRaises(ValueError, MatchAndMergeWorkspaces, InputWorkspaces='fake_group', XMin=x_min, XMax=x_max) + + +if __name__ == "__main__": + unittest.main() diff --git a/Framework/PythonInterface/test/python/plugins/algorithms/WorkflowAlgorithms/SANSILLIntegrationTest.py b/Framework/PythonInterface/test/python/plugins/algorithms/WorkflowAlgorithms/SANSILLIntegrationTest.py index 6a63f164ecc8a37701a5591598e936ee25d96327..241640eb6115b0c11ec8b3d502e1c6ad814bdcd3 100644 --- a/Framework/PythonInterface/test/python/plugins/algorithms/WorkflowAlgorithms/SANSILLIntegrationTest.py +++ b/Framework/PythonInterface/test/python/plugins/algorithms/WorkflowAlgorithms/SANSILLIntegrationTest.py @@ -7,6 +7,7 @@ from __future__ import (absolute_import, division, print_function) import unittest + from mantid.api import MatrixWorkspace, WorkspaceGroup from mantid.simpleapi import SANSILLIntegration, SANSILLReduction, config, mtd @@ -17,13 +18,15 @@ class SANSILLIntegrationTest(unittest.TestCase): def setUp(self): self._facility = config['default.facility'] + self._data_search_dirs = config.getDataSearchDirs() config.appendDataSearchSubDir('ILL/D11/') config.appendDataSearchSubDir('ILL/D33/') - config['default.facility'] = 'ILL' + config.setFacility("ILL") SANSILLReduction(Run='010569', ProcessAs='Sample', OutputWorkspace='sample') def tearDown(self): - config['default.facility'] = self._facility + config.setFacility(self._facility) + config.setDataSearchDirs(self._data_search_dirs) mtd.clear() def test_monochromatic(self): diff --git a/Framework/PythonInterface/test/python/plugins/algorithms/WorkflowAlgorithms/SaveVulcanGSSTest.py b/Framework/PythonInterface/test/python/plugins/algorithms/WorkflowAlgorithms/SaveVulcanGSSTest.py index dd3bc732e90897e711c56e3d79b2d57d2ae08a77..8f36146d5574c1d65ca610599144ea960a97436b 100644 --- a/Framework/PythonInterface/test/python/plugins/algorithms/WorkflowAlgorithms/SaveVulcanGSSTest.py +++ b/Framework/PythonInterface/test/python/plugins/algorithms/WorkflowAlgorithms/SaveVulcanGSSTest.py @@ -34,11 +34,14 @@ class SaveVulcanGSSTest(unittest.TestCase): bin_table = self._create_simple_binning_table(bin_ws_name) # Execute + import tempfile + tempDir = tempfile.gettempdir() + filename=os.path.join(tempDir, "tempout.gda") alg_test = run_algorithm("SaveVulcanGSS", InputWorkspace=data_ws_name, BinningTable=bin_ws_name, OutputWorkspace=data_ws_name + "_rebinned", - GSSFilename="/tmp/tempout.gda", + GSSFilename=filename, IPTS=12345, GSSParmFileName='test.prm') @@ -56,6 +59,8 @@ class SaveVulcanGSSTest(unittest.TestCase): AnalysisDataService.remove(bin_ws_name) AnalysisDataService.remove(data_ws_name+"_rebinned") + os.remove(filename) + return def test_save_gss_vdrive(self): diff --git a/Framework/PythonInterface/test/python/plugins/algorithms/WorkflowAlgorithms/TotScatCalculateSelfScatteringTest.py b/Framework/PythonInterface/test/python/plugins/algorithms/WorkflowAlgorithms/TotScatCalculateSelfScatteringTest.py new file mode 100644 index 0000000000000000000000000000000000000000..1b909da5fa0a27f4857bed6445efd6bf4d2f8514 --- /dev/null +++ b/Framework/PythonInterface/test/python/plugins/algorithms/WorkflowAlgorithms/TotScatCalculateSelfScatteringTest.py @@ -0,0 +1,46 @@ +# Mantid Repository : https://github.com/mantidproject/mantid +# +# Copyright © 2019 ISIS Rutherford Appleton Laboratory UKRI, +# NScD Oak Ridge National Laboratory, European Spallation Source +# & Institut Laue - Langevin +# SPDX - License - Identifier: GPL - 3.0 + +from __future__ import (absolute_import, division, print_function) + +import unittest +from mantid.simpleapi import TotScatCalculateSelfScattering, Load +from isis_powder import SampleDetails + + +class TotScatCalculateSelfScatteringTest(unittest.TestCase): + + def setUp(self): + sample_details = SampleDetails(height=4.0, radius=0.2985, center=[0, 0, 0], shape='cylinder') + sample_details.set_material(chemical_formula='Si') + self.geometry = {'Shape': 'Cylinder', + 'Height': sample_details.height(), + 'Radius': sample_details.radius(), + 'Center': sample_details.center()} + + material = sample_details.material_object + material_json = {'ChemicalFormula': material.chemical_formula} + if material.number_density: + material_json["SampleNumberDensity"] = material.number_density + if material.absorption_cross_section: + material_json["AttenuationXSection"] = material.absorption_cross_section + if material.scattering_cross_section: + material_json["ScatteringXSection"] = material.scattering_cross_section + self.material = material_json + + self.cal_file_path = "polaris_grouping_file.cal" + + def test_TotScatCalculateSelfScattering_executes(self): + raw_ws = Load(Filename='POLARIS98533.nxs') + correction_ws = TotScatCalculateSelfScattering(InputWorkspace=raw_ws, + CalFileName=self.cal_file_path, + SampleGeometry=self.geometry, + SampleMaterial=self.material) + self.assertEqual(correction_ws.getNumberHistograms(), 5) + + +if __name__ == "__main__": + unittest.main() diff --git a/Framework/PythonInterface/test/python/plugins/algorithms/WorkflowAlgorithms/sans/SANSConvertToWavelengthAndRebinTest.py b/Framework/PythonInterface/test/python/plugins/algorithms/WorkflowAlgorithms/sans/SANSConvertToWavelengthAndRebinTest.py index 6e6e72d9e48b53d58016e02b503e711f4228d218..a7b7a0e33430c9ec4f1359e1c37c06f7ecf68839 100644 --- a/Framework/PythonInterface/test/python/plugins/algorithms/WorkflowAlgorithms/sans/SANSConvertToWavelengthAndRebinTest.py +++ b/Framework/PythonInterface/test/python/plugins/algorithms/WorkflowAlgorithms/sans/SANSConvertToWavelengthAndRebinTest.py @@ -43,7 +43,7 @@ class SANSSConvertToWavelengthImplementationTest(unittest.TestCase): "WavelengthLow": 1.0, "WavelengthHigh": 3.0, "WavelengthStep": 1.5, - "WavelengthStepType": RangeStepType.to_string(RangeStepType.Lin)} + "WavelengthStepType": RangeStepType.LIN.value} convert_alg = create_unmanaged_algorithm("SANSConvertToWavelengthAndRebin", **convert_options) had_run_time_error = False try: @@ -60,7 +60,7 @@ class SANSSConvertToWavelengthImplementationTest(unittest.TestCase): "WavelengthLow": -1.0, "WavelengthHigh": 3.0, "WavelengthStep": 1.5, - "WavelengthStepType": RangeStepType.to_string(RangeStepType.Log)} + "WavelengthStepType": RangeStepType.LOG.value} convert_alg = create_unmanaged_algorithm("SANSConvertToWavelengthAndRebin", **convert_options) had_run_time_error = False try: @@ -77,7 +77,7 @@ class SANSSConvertToWavelengthImplementationTest(unittest.TestCase): "WavelengthLow": 4.0, "WavelengthHigh": 3.0, "WavelengthStep": 1.5, - "WavelengthStepType": RangeStepType.to_string(RangeStepType.Log)} + "WavelengthStepType": RangeStepType.LOG.value} convert_alg = create_unmanaged_algorithm("SANSConvertToWavelengthAndRebin", **convert_options) had_run_time_error = False try: @@ -94,7 +94,7 @@ class SANSSConvertToWavelengthImplementationTest(unittest.TestCase): "WavelengthLow": 1.0, "WavelengthHigh": 10.0, "WavelengthStep": 1.0, - "WavelengthStepType": RangeStepType.to_string(RangeStepType.Lin)} + "WavelengthStepType": RangeStepType.LIN.value} convert_alg = create_unmanaged_algorithm("SANSConvertToWavelengthAndRebin", **convert_options) convert_alg.execute() self.assertTrue(convert_alg.isExecuted()) @@ -117,7 +117,7 @@ class SANSSConvertToWavelengthImplementationTest(unittest.TestCase): "RebinMode": "Rebin", "WavelengthLow": 1.0, "WavelengthStep": 1.0, - "WavelengthStepType": RangeStepType.to_string(RangeStepType.Lin)} + "WavelengthStepType": RangeStepType.LIN.value} convert_alg = create_unmanaged_algorithm("SANSConvertToWavelengthAndRebin", **convert_options) convert_alg.execute() self.assertTrue(convert_alg.isExecuted()) diff --git a/MantidPlot/CMakeLists.txt b/MantidPlot/CMakeLists.txt index 583c50db2fa1c6f5d8e58cd405ad6db96cb2777e..394f3d0978aa7a4a17082783172f362e8021efc7 100644 --- a/MantidPlot/CMakeLists.txt +++ b/MantidPlot/CMakeLists.txt @@ -813,6 +813,10 @@ target_include_directories(MantidPlot ../qt/widgets/factory/inc ../Framework/PythonInterface/inc) +# GCC 8 onwards needs to disable functional casting at the Python interface +target_compile_options( MantidPlot PRIVATE + $<$<AND:$<CXX_COMPILER_ID:GNU>,$<VERSION_GREATER_EQUAL:$<CXX_COMPILER_VERSION>,8.0>>:-Wno-cast-function-type> ) + # Library dependencies target_link_libraries(MantidPlot LINK_PRIVATE diff --git a/Testing/Data/UnitTest/HB3A_data.nxs.md5 b/Testing/Data/UnitTest/HB3A_data.nxs.md5 new file mode 100644 index 0000000000000000000000000000000000000000..d3279946092339cdf048f3068f8ee771fe869166 --- /dev/null +++ b/Testing/Data/UnitTest/HB3A_data.nxs.md5 @@ -0,0 +1 @@ +f524fad52590c9f68172293d05d01d52 diff --git a/Testing/Data/UnitTest/ILL/IN16B/fapi-fws-q.nxs.md5 b/Testing/Data/UnitTest/ILL/IN16B/fapi-fws-q.nxs.md5 new file mode 100644 index 0000000000000000000000000000000000000000..ab4624df37c9b5fc371614535d0a79d0577ae5bc --- /dev/null +++ b/Testing/Data/UnitTest/ILL/IN16B/fapi-fws-q.nxs.md5 @@ -0,0 +1 @@ +4fcc17adc2dc7a327f632e3c7e18e290 diff --git a/Testing/Data/UnitTest/ILL/IN16B/mc-abs-corr-q.nxs.md5 b/Testing/Data/UnitTest/ILL/IN16B/mc-abs-corr-q.nxs.md5 new file mode 100644 index 0000000000000000000000000000000000000000..6ac51276a385496057e4c1b5e9761a5183cbb446 --- /dev/null +++ b/Testing/Data/UnitTest/ILL/IN16B/mc-abs-corr-q.nxs.md5 @@ -0,0 +1 @@ +b43c45ad1e28124b30c6fd9cc4246ddf diff --git a/Testing/Data/UnitTest/polaris_grouping_file.cal.md5 b/Testing/Data/UnitTest/polaris_grouping_file.cal.md5 new file mode 100644 index 0000000000000000000000000000000000000000..a53859b91f89949d8cee8fc9fa54cef1293c6305 --- /dev/null +++ b/Testing/Data/UnitTest/polaris_grouping_file.cal.md5 @@ -0,0 +1 @@ +075b0a53eead49322b8e6e73f78df112 diff --git a/Testing/SystemTests/tests/analysis/ConvertHFIRSCDtoMDETest.py b/Testing/SystemTests/tests/analysis/ConvertHFIRSCDtoMDETest.py new file mode 100644 index 0000000000000000000000000000000000000000..9a3385c1a090652e637186336d0031a2c670f82b --- /dev/null +++ b/Testing/SystemTests/tests/analysis/ConvertHFIRSCDtoMDETest.py @@ -0,0 +1,89 @@ +# Mantid Repository : https://github.com/mantidproject/mantid +# +# Copyright © 2018 ISIS Rutherford Appleton Laboratory UKRI, +# NScD Oak Ridge National Laboratory, European Spallation Source +# & Institut Laue - Langevin +# SPDX - License - Identifier: GPL - 3.0 + +import systemtesting +import numpy as np +from mantid.simpleapi import * + + +class ConvertHFIRSCDtoMDETest(systemtesting.MantidSystemTest): + def requiredMemoryMB(self): + return 4000 + + def runTest(self): + LoadMD('HB2C_WANDSCD_data.nxs', OutputWorkspace='ConvertHFIRSCDtoMDETest_data') + + ConvertHFIRSCDtoMDETest_Q=ConvertHFIRSCDtoMDE(InputWorkspace='ConvertHFIRSCDtoMDETest_data', Wavelength=1.488) + + self.assertEqual(ConvertHFIRSCDtoMDETest_Q.getNEvents(), 18022177) + + ConvertHFIRSCDtoMDETest_peaks=FindPeaksMD(InputWorkspace=ConvertHFIRSCDtoMDETest_Q, PeakDistanceThreshold=2.2, + CalculateGoniometerForCW=True, Wavelength=1.488) + + self.assertEqual(ConvertHFIRSCDtoMDETest_peaks.getNumberPeaks(), 14) + + peak = ConvertHFIRSCDtoMDETest_peaks.getPeak(0) + np.testing.assert_allclose(peak.getQSampleFrame(), [0.09446778,0.001306865,2.180508], rtol=1e-05) + self.assertDelta(peak.getWavelength(), 1.488, 1e-5) + + peak = ConvertHFIRSCDtoMDETest_peaks.getPeak(13) + np.testing.assert_allclose(peak.getQSampleFrame(), [6.754011,0.001306865,1.918834], rtol=1e-05) + self.assertDelta(peak.getWavelength(), 1.488, 1e-5) + + +class ConvertHFIRSCDtoMDE_HB3A_Test(systemtesting.MantidSystemTest): + def requiredMemoryMB(self): + return 1000 + + def runTest(self): + LoadMD('HB3A_data.nxs', OutputWorkspace='ConvertHFIRSCDtoMDE_HB3ATest_data') + + SetGoniometer('ConvertHFIRSCDtoMDE_HB3ATest_data', + Axis0='omega,0,1,0,-1', + Axis1='chi,0,0,1,-1', + Axis2='phi,0,1,0,-1') + + ConvertHFIRSCDtoMDETest_Q=ConvertHFIRSCDtoMDE(InputWorkspace='ConvertHFIRSCDtoMDE_HB3ATest_data', Wavelength=1.008) + + self.assertEqual(ConvertHFIRSCDtoMDETest_Q.getNEvents(), 9038) + + ConvertHFIRSCDtoMDETest_peaks=FindPeaksMD(InputWorkspace='ConvertHFIRSCDtoMDETest_Q', + PeakDistanceThreshold=0.25, + DensityThresholdFactor=20000, + CalculateGoniometerForCW=True, + Wavelength=1.008, + FlipX=True, + InnerGoniometer=False) + + IndexPeaks(ConvertHFIRSCDtoMDETest_peaks) + + self.assertEqual(ConvertHFIRSCDtoMDETest_peaks.getNumberPeaks(), 1) + + peak = ConvertHFIRSCDtoMDETest_peaks.getPeak(0) + self.assertDelta(peak.getWavelength(), 1.008, 1e-7) + np.testing.assert_allclose(peak.getQSampleFrame(), [-0.417683, 1.792265, 2.238072], rtol=1e-5) + np.testing.assert_array_equal(peak.getHKL(), [0, 0, 6]) + + def validate(self): + results = 'ConvertHFIRSCDtoMDETest_Q' + reference = 'ConvertHFIRSCDtoMDE_HB3A_Test.nxs' + + Load(Filename=reference,OutputWorkspace=reference) + + checker = AlgorithmManager.create("CompareMDWorkspaces") + checker.setLogging(True) + checker.setPropertyValue("Workspace1",results) + checker.setPropertyValue("Workspace2",reference) + checker.setPropertyValue("Tolerance", "1e-5") + + checker.execute() + if checker.getPropertyValue("Equals") != "1": + print(" Workspaces do not match, result: ",checker.getPropertyValue("Result")) + print(self.__class__.__name__) + SaveMD(InputWorkspace=results,Filename=self.__class__.__name__+'-mismatch.nxs') + return False + + return True diff --git a/Testing/SystemTests/tests/analysis/ILLIndirectEnergyTransferBATS.py b/Testing/SystemTests/tests/analysis/ILLIndirectEnergyTransferBATS.py index e72360b6450bb82c730292702bf21075a15dc1dc..0b024f4d95aa3d47adf1cea2e1773a67498faf30 100644 --- a/Testing/SystemTests/tests/analysis/ILLIndirectEnergyTransferBATS.py +++ b/Testing/SystemTests/tests/analysis/ILLIndirectEnergyTransferBATS.py @@ -32,7 +32,7 @@ class ILLIndirectEnergyTransferBATSTest(systemtesting.MantidSystemTest): # parameters file evolves quite often, so this is not checked self.disableChecking = ['Instrument'] - def tearDown(self): + def cleanup(self): config['default.facility'] = self.facility config['default.instrument'] = self.instrument config['datasearch.directories'] = self.datadirs diff --git a/Testing/SystemTests/tests/analysis/ILLPowderD2BEfficiencyTest.py b/Testing/SystemTests/tests/analysis/ILLPowderD2BEfficiencyTest.py index 33b45cde221a0f7c362e44f0b9232b19eaf53da1..de9c0e5baebbc989452fc42a8694d50835077c96 100644 --- a/Testing/SystemTests/tests/analysis/ILLPowderD2BEfficiencyTest.py +++ b/Testing/SystemTests/tests/analysis/ILLPowderD2BEfficiencyTest.py @@ -26,7 +26,7 @@ class ILLPowderD2BEfficiencyTest(systemtesting.MantidSystemTest): def requiredFiles(self): return ['532008.nxs', '532009.nxs'] - def tearDown(self): + def cleanup(self): mtd.clear() def testAutoMasking(self): diff --git a/Testing/SystemTests/tests/analysis/ILLPowderDetectorScanTest.py b/Testing/SystemTests/tests/analysis/ILLPowderDetectorScanTest.py index d09d1080853983fe44f05fb9e637502e6b75031d..d80a358640ce1fadd81c94cd210dba3c7ecee50e 100644 --- a/Testing/SystemTests/tests/analysis/ILLPowderDetectorScanTest.py +++ b/Testing/SystemTests/tests/analysis/ILLPowderDetectorScanTest.py @@ -32,7 +32,7 @@ class _DiffReductionTest(systemtesting.MantidSystemTest): config.appendDataSearchSubDir('ILL/D2B/') config.appendDataSearchSubDir('ILL/D20/') - def tearDown(self): + def cleanup(self): config['default.facility'] = self._facility config['default.instrument'] = self._instrument config['datasearch.directories'] = self._directories diff --git a/Testing/SystemTests/tests/analysis/ILLPowderEfficiencyClosureTest.py b/Testing/SystemTests/tests/analysis/ILLPowderEfficiencyClosureTest.py index 01ecb9df79a0c92ea39dfa33750bd2e8cd46e693..b7fb45a975d8b8093fa9b596aac5d2c4e27d45a6 100644 --- a/Testing/SystemTests/tests/analysis/ILLPowderEfficiencyClosureTest.py +++ b/Testing/SystemTests/tests/analysis/ILLPowderEfficiencyClosureTest.py @@ -30,7 +30,7 @@ class ILLPowderEfficiencyClosureTest(systemtesting.MantidSystemTest): def requiredFiles(self): return ['967076.nxs'] - def tearDown(self): + def cleanup(self): mtd.clear() remove(self._m_tmp_file) diff --git a/Testing/SystemTests/tests/analysis/ILLPowderEfficiencyTest.py b/Testing/SystemTests/tests/analysis/ILLPowderEfficiencyTest.py index 040178ec2c10524eed43ddd953cfad7f3c995949..70b7a7636382f86aa15878de7184c8e0c983f64a 100644 --- a/Testing/SystemTests/tests/analysis/ILLPowderEfficiencyTest.py +++ b/Testing/SystemTests/tests/analysis/ILLPowderEfficiencyTest.py @@ -25,7 +25,7 @@ class ILLPowderEfficiencyTest(systemtesting.MantidSystemTest): def requiredFiles(self): return ['967076.nxs'] - def tearDown(self): + def cleanup(self): mtd.clear() def runTest(self): diff --git a/Testing/SystemTests/tests/analysis/ILLPowderParameterScanTest.py b/Testing/SystemTests/tests/analysis/ILLPowderParameterScanTest.py index 492b752f03c23e8b16504a96a35c82a010c776fc..b2fa7870b4d20e643fd1b0595e0927fa6af99e70 100644 --- a/Testing/SystemTests/tests/analysis/ILLPowderParameterScanTest.py +++ b/Testing/SystemTests/tests/analysis/ILLPowderParameterScanTest.py @@ -25,7 +25,7 @@ class ILLPowderParameterScanTest(systemtesting.MantidSystemTest): def requiredFiles(self): return ['967087.nxs', '967088.nxs'] - def tearDown(self): + def cleanup(self): mtd.clear() def runTest(self): diff --git a/Testing/SystemTests/tests/analysis/ISIS_PowderPolarisTest.py b/Testing/SystemTests/tests/analysis/ISIS_PowderPolarisTest.py index 441acda6accba04ad6c22301641c7933a6d35fd9..75ca0e0c026b352aeb40c8d8a17cb4580f2260d9 100644 --- a/Testing/SystemTests/tests/analysis/ISIS_PowderPolarisTest.py +++ b/Testing/SystemTests/tests/analysis/ISIS_PowderPolarisTest.py @@ -190,7 +190,7 @@ class TotalScatteringMergedTest(systemtesting.MantidSystemTest): # Whilst total scattering is in development, the validation will avoid using reference files as they will have # to be updated very frequently. In the meantime, the expected peak in the PDF at ~3.9 Angstrom will be checked. # After rebin this is at X index 37 - self.assertAlmostEqual(self.pdf_output.dataY(0)[37], 0.8055205, places=3) + self.assertAlmostEqual(self.pdf_output.dataY(0)[37], 0.7376667, places=3) def run_total_scattering(run_number, merge_banks, q_lims=None): diff --git a/Testing/SystemTests/tests/analysis/SANSBatchReductionTest.py b/Testing/SystemTests/tests/analysis/SANSBatchReductionTest.py index dbe37b713a365906889657d93a214d6f2cf20e18..a8ae41c2e84a34f0da13b45432f791e5ac117424 100644 --- a/Testing/SystemTests/tests/analysis/SANSBatchReductionTest.py +++ b/Testing/SystemTests/tests/analysis/SANSBatchReductionTest.py @@ -13,7 +13,7 @@ from mantid.api import AnalysisDataService from sans.sans_batch import SANSBatchReduction from sans.user_file.state_director import StateDirectorISIS from sans.state.data import get_data_builder -from sans.common.enums import (SANSFacility, ISISReductionMode, OutputMode) +from sans.common.enums import (SANSFacility, ReductionMode, OutputMode) from sans.common.constants import EMPTY_NAME from sans.common.general_functions import create_unmanaged_algorithm from sans.common.file_information import SANSFileInformationFactory @@ -26,7 +26,7 @@ class SANSBatchReductionTest(unittest.TestCase): def _run_batch_reduction(self, states, use_optimizations=False): batch_reduction_alg = SANSBatchReduction() - batch_reduction_alg(states, use_optimizations, OutputMode.PublishToADS) + batch_reduction_alg(states, use_optimizations, OutputMode.PUBLISH_TO_ADS) def _compare_workspace(self, workspace, reference_file_name): # Load the reference file @@ -81,7 +81,7 @@ class SANSBatchReductionTest(unittest.TestCase): user_file_director = StateDirectorISIS(data_info, file_information) user_file_director.set_user_file("USER_SANS2D_154E_2p4_4m_M3_Xpress_8mm_SampleChanger.txt") # Set the reduction mode to LAB - user_file_director.set_reduction_builder_reduction_mode(ISISReductionMode.LAB) + user_file_director.set_reduction_builder_reduction_mode(ReductionMode.LAB) # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! # COMPATIBILITY BEGIN -- Remove when appropriate # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! @@ -120,7 +120,7 @@ class SANSBatchReductionTest(unittest.TestCase): user_file_director = StateDirectorISIS(data_info, file_information) user_file_director.set_user_file("MASKSANS2Doptions.091A") # Set the reduction mode to LAB - user_file_director.set_reduction_builder_reduction_mode(ISISReductionMode.LAB) + user_file_director.set_reduction_builder_reduction_mode(ReductionMode.LAB) state = user_file_director.construct() # Act @@ -160,7 +160,7 @@ class SANSBatchReductionTest(unittest.TestCase): user_file_director = StateDirectorISIS(data_info, file_information) user_file_director.set_user_file("USER_SANS2D_154E_2p4_4m_M3_Xpress_8mm_SampleChanger.txt") # Set the reduction mode to LAB - user_file_director.set_reduction_builder_reduction_mode(ISISReductionMode.LAB) + user_file_director.set_reduction_builder_reduction_mode(ReductionMode.LAB) # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! # COMPATIBILITY BEGIN -- Remove when appropriate # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! @@ -213,7 +213,7 @@ class SANSBatchReductionTest(unittest.TestCase): user_file_director = StateDirectorISIS(data_info, file_information) user_file_director.set_user_file("USER_SANS2D_154E_2p4_4m_M3_Xpress_8mm_SampleChanger.txt") # Set the reduction mode to LAB - user_file_director.set_reduction_builder_reduction_mode(ISISReductionMode.LAB) + user_file_director.set_reduction_builder_reduction_mode(ReductionMode.LAB) # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! # COMPATIBILITY BEGIN -- Remove when appropriate # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! @@ -272,7 +272,7 @@ class SANSBatchReductionTest(unittest.TestCase): user_file_director = StateDirectorISIS(data_info, file_information) user_file_director.set_user_file("MASKSANS2Doptions.091A") # Set the reduction mode to LAB - user_file_director.set_reduction_builder_reduction_mode(ISISReductionMode.LAB) + user_file_director.set_reduction_builder_reduction_mode(ReductionMode.LAB) user_file_director.set_slice_event_builder_start_time([1.0, 3.0]) user_file_director.set_slice_event_builder_end_time([3.0, 5.0]) diff --git a/Testing/SystemTests/tests/analysis/SANSBeamCentreFinderCoreTest.py b/Testing/SystemTests/tests/analysis/SANSBeamCentreFinderCoreTest.py index 4ce9c33cd66c0b2dc1c870c842c4fa6f4a757023..99dfdb2e7a537cf766ad88bf0bb92a4c37e48196 100644 --- a/Testing/SystemTests/tests/analysis/SANSBeamCentreFinderCoreTest.py +++ b/Testing/SystemTests/tests/analysis/SANSBeamCentreFinderCoreTest.py @@ -58,8 +58,8 @@ class SANSBeamCentreFinderCoreTest(unittest.TestCase): return sample_scatter, sample_scatter_monitor_workspace, transmission_workspace, direct_workspace def _run_beam_centre_core(self, state, workspace, monitor, transmission=None, direct=None, - detector_type=DetectorType.LAB, component=DataType.Sample, centre_1 = 0.1, centre_2 = -0.1 - ,r_min = 0.06, r_max = 0.26): + detector_type=DetectorType.LAB, component=DataType.SAMPLE, centre_1 = 0.1, centre_2 = -0.1 + , r_min = 0.06, r_max = 0.26): beam_centre_core_alg = AlgorithmManager.createUnmanaged("SANSBeamCentreFinderCore") beam_centre_core_alg.setChild(True) beam_centre_core_alg.initialize() @@ -75,8 +75,8 @@ class SANSBeamCentreFinderCoreTest(unittest.TestCase): if direct: beam_centre_core_alg.setProperty("DirectWorkspace", direct) - beam_centre_core_alg.setProperty("Component", DetectorType.to_string(detector_type)) - beam_centre_core_alg.setProperty("DataType", DataType.to_string(component)) + beam_centre_core_alg.setProperty("Component", detector_type.value) + beam_centre_core_alg.setProperty("DataType", component.value) beam_centre_core_alg.setProperty("Centre1", centre_1) beam_centre_core_alg.setProperty("Centre2", centre_2) beam_centre_core_alg.setProperty("RMax", r_max) diff --git a/Testing/SystemTests/tests/analysis/SANSILLAutoProcessTest.py b/Testing/SystemTests/tests/analysis/SANSILLAutoProcessTest.py index 712b336efd46c9bcf90bcbb510ee79ecb6ed9a37..51c969b2bbbac0cce5430e9b8f177146c9d3bb9d 100644 --- a/Testing/SystemTests/tests/analysis/SANSILLAutoProcessTest.py +++ b/Testing/SystemTests/tests/analysis/SANSILLAutoProcessTest.py @@ -22,7 +22,7 @@ class D11_AutoProcess_Test(systemtesting.MantidSystemTest): config['logging.loggers.root.level'] = 'Warning' config.appendDataSearchSubDir('ILL/D11/') - def tearDown(self): + def cleanup(self): mtd.clear() def validate(self): diff --git a/Testing/SystemTests/tests/analysis/SANSILLReductionTest.py b/Testing/SystemTests/tests/analysis/SANSILLReductionTest.py index 22b6b4f64fb5473f741c57ba451a2318d9d58faa..d2b7a6b6bb5a69f133ad1c00f9e7218a1db7d537 100644 --- a/Testing/SystemTests/tests/analysis/SANSILLReductionTest.py +++ b/Testing/SystemTests/tests/analysis/SANSILLReductionTest.py @@ -21,7 +21,7 @@ class ILL_D11_Test(systemtesting.MantidSystemTest): config['default.instrument'] = 'D11' config.appendDataSearchSubDir('ILL/D11/') - def tearDown(self): + def cleanup(self): mtd.clear() def validate(self): @@ -81,7 +81,7 @@ class ILL_D22_Test(systemtesting.MantidSystemTest): config['default.instrument'] = 'D22' config.appendDataSearchSubDir('ILL/D22/') - def tearDown(self): + def cleanup(self): mtd.clear() def validate(self): @@ -142,7 +142,7 @@ class ILL_D33_VTOF_Test(systemtesting.MantidSystemTest): config['default.instrument'] = 'D33' config.appendDataSearchSubDir('ILL/D33/') - def tearDown(self): + def cleanup(self): mtd.clear() def validate(self): @@ -186,7 +186,7 @@ class ILL_D33_LTOF_Test(systemtesting.MantidSystemTest): config['default.instrument'] = 'D33' config.appendDataSearchSubDir('ILL/D33/') - def tearDown(self): + def cleanup(self): mtd.clear() def validate(self): @@ -231,7 +231,7 @@ class ILL_D33_Test(systemtesting.MantidSystemTest): config['default.instrument'] = 'D33' config.appendDataSearchSubDir('ILL/D33/') - def tearDown(self): + def cleanup(self): mtd.clear() def validate(self): diff --git a/Testing/SystemTests/tests/analysis/SANSReductionCoreTest.py b/Testing/SystemTests/tests/analysis/SANSReductionCoreTest.py index 878ac1768d357e56b11364f9c19106964a7b9978..e6c8fe65ad38b2468e43875a259f346ea6ffc848 100644 --- a/Testing/SystemTests/tests/analysis/SANSReductionCoreTest.py +++ b/Testing/SystemTests/tests/analysis/SANSReductionCoreTest.py @@ -58,7 +58,7 @@ class SANSReductionCoreTest(unittest.TestCase): return sample_scatter, sample_scatter_monitor_workspace, transmission_workspace, direct_workspace def _run_reduction_core(self, state, workspace, monitor, transmission=None, direct=None, - detector_type=DetectorType.LAB, component=DataType.Sample): + detector_type=DetectorType.LAB, component=DataType.SAMPLE): reduction_core_alg = AlgorithmManager.createUnmanaged("SANSReductionCore") reduction_core_alg.setChild(True) reduction_core_alg.initialize() @@ -74,8 +74,8 @@ class SANSReductionCoreTest(unittest.TestCase): if direct: reduction_core_alg.setProperty("DirectWorkspace", direct) - reduction_core_alg.setProperty("Component", DetectorType.to_string(detector_type)) - reduction_core_alg.setProperty("DataType", DataType.to_string(component)) + reduction_core_alg.setProperty("Component", detector_type.value) + reduction_core_alg.setProperty("DataType", component.value) reduction_core_alg.setProperty("OutputWorkspace", EMPTY_NAME) diff --git a/Testing/SystemTests/tests/analysis/SANSSingleReductionTest.py b/Testing/SystemTests/tests/analysis/SANSSingleReductionTest.py index 778cd65c78ea659454ffe1a69410d0c0c59975b6..bac2bbb2603bd4f56aaaab5e27ee6fa44f7900eb 100644 --- a/Testing/SystemTests/tests/analysis/SANSSingleReductionTest.py +++ b/Testing/SystemTests/tests/analysis/SANSSingleReductionTest.py @@ -15,7 +15,7 @@ import mantid # noqa from mantid.api import AlgorithmManager from sans.user_file.state_director import StateDirectorISIS from sans.state.data import get_data_builder -from sans.common.enums import (SANSFacility, ISISReductionMode, ReductionDimensionality, FitModeForMerge) +from sans.common.enums import (SANSFacility, ReductionMode, ReductionDimensionality, FitModeForMerge) from sans.common.constants import EMPTY_NAME from sans.common.general_functions import create_unmanaged_algorithm from sans.common.file_information import SANSFileInformationFactory @@ -162,7 +162,7 @@ class SANSSingleReductionTest(SingleReductionTest): user_file_director = StateDirectorISIS(data_info, file_information) user_file_director.set_user_file("USER_SANS2D_154E_2p4_4m_M3_Xpress_8mm_SampleChanger.txt") # Set the reduction mode to LAB - user_file_director.set_reduction_builder_reduction_mode(ISISReductionMode.LAB) + user_file_director.set_reduction_builder_reduction_mode(ReductionMode.LAB) # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! # COMPATIBILITY BEGIN -- Remove when appropriate @@ -231,7 +231,7 @@ class SANSSingleReductionTest(SingleReductionTest): user_file_director = StateDirectorISIS(data_info, file_information) user_file_director.set_user_file("USER_SANS2D_154E_2p4_4m_M3_Xpress_8mm_SampleChanger.txt") # Set the reduction mode to LAB - user_file_director.set_reduction_builder_reduction_mode(ISISReductionMode.HAB) + user_file_director.set_reduction_builder_reduction_mode(ReductionMode.HAB) # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! # COMPATIBILITY BEGIN -- Remove when appropriate @@ -285,8 +285,8 @@ class SANSSingleReductionTest(SingleReductionTest): user_file_director = StateDirectorISIS(data_info, file_information) user_file_director.set_user_file("USER_SANS2D_154E_2p4_4m_M3_Xpress_8mm_SampleChanger.txt") # Set the reduction mode to LAB - user_file_director.set_reduction_builder_reduction_mode(ISISReductionMode.Merged) - user_file_director.set_reduction_builder_merge_fit_mode(FitModeForMerge.Both) + user_file_director.set_reduction_builder_reduction_mode(ReductionMode.MERGED) + user_file_director.set_reduction_builder_merge_fit_mode(FitModeForMerge.BOTH) user_file_director.set_reduction_builder_merge_scale(1.0) user_file_director.set_reduction_builder_merge_shift(0.0) # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! @@ -350,9 +350,9 @@ class SANSSingleReductionTest(SingleReductionTest): user_file_director = StateDirectorISIS(data_info, file_information) user_file_director.set_user_file("USER_SANS2D_154E_2p4_4m_M3_Xpress_8mm_SampleChanger.txt") # Set the reduction mode to LAB - user_file_director.set_reduction_builder_reduction_mode(ISISReductionMode.LAB) - user_file_director.set_reduction_builder_reduction_dimensionality(ReductionDimensionality.TwoDim) - user_file_director.set_convert_to_q_builder_reduction_dimensionality(ReductionDimensionality.TwoDim) + user_file_director.set_reduction_builder_reduction_mode(ReductionMode.LAB) + user_file_director.set_reduction_builder_reduction_dimensionality(ReductionDimensionality.TWO_DIM) + user_file_director.set_convert_to_q_builder_reduction_dimensionality(ReductionDimensionality.TWO_DIM) # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! # COMPATIBILITY BEGIN -- Remove when appropriate # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! @@ -417,7 +417,7 @@ class SANSSingleReduction2Test(SingleReductionTest): user_file_director = StateDirectorISIS(data_info, file_information) user_file_director.set_user_file("USER_SANS2D_154E_2p4_4m_M3_Xpress_8mm_SampleChanger.txt") # Set the reduction mode to HAB - user_file_director.set_reduction_builder_reduction_mode(ISISReductionMode.HAB) + user_file_director.set_reduction_builder_reduction_mode(ReductionMode.HAB) user_file_director.set_compatibility_builder_use_compatibility_mode(False) # Add some event slices @@ -542,7 +542,7 @@ class SANSSingleReduction2Test(SingleReductionTest): user_file_director = StateDirectorISIS(data_info, file_information) user_file_director.set_user_file("USER_SANS2D_154E_2p4_4m_M3_Xpress_8mm_SampleChanger.txt") # Set the reduction mode to LAB - user_file_director.set_reduction_builder_reduction_mode(ISISReductionMode.LAB) + user_file_director.set_reduction_builder_reduction_mode(ReductionMode.LAB) user_file_director.set_compatibility_builder_use_compatibility_mode(False) # Add some event slices diff --git a/Testing/SystemTests/tests/analysis/reference/ConvertHFIRSCDtoMDE_HB3A_Test.nxs.md5 b/Testing/SystemTests/tests/analysis/reference/ConvertHFIRSCDtoMDE_HB3A_Test.nxs.md5 new file mode 100644 index 0000000000000000000000000000000000000000..345043a2d302beb00456c55174c73ae7032678d6 --- /dev/null +++ b/Testing/SystemTests/tests/analysis/reference/ConvertHFIRSCDtoMDE_HB3A_Test.nxs.md5 @@ -0,0 +1 @@ +46276778996c4d44a2078d8bd3a6467f diff --git a/Testing/SystemTests/tests/analysis/reference/D2B_LowCounts.nxs.md5 b/Testing/SystemTests/tests/analysis/reference/D2B_LowCounts.nxs.md5 index eae7af78ee9af1163c6da5169abac167da85ec56..4b705b14ec59cf87621eec30b15156b305b4d43b 100644 --- a/Testing/SystemTests/tests/analysis/reference/D2B_LowCounts.nxs.md5 +++ b/Testing/SystemTests/tests/analysis/reference/D2B_LowCounts.nxs.md5 @@ -1 +1 @@ -36aa0c84472a689ab52ae40e80d62c93 +f6dc6ff9985bcb92af86b5cba0c2460e diff --git a/Testing/SystemTests/tests/analysis/reference/ILL_IN4_SofQW.nxs.md5 b/Testing/SystemTests/tests/analysis/reference/ILL_IN4_SofQW.nxs.md5 index 990b6ec5b6b4b4a905051ef4ed3d52ca4f9e24a1..8d279153ff2d013672c7bbb5c451999d82b4ee75 100644 --- a/Testing/SystemTests/tests/analysis/reference/ILL_IN4_SofQW.nxs.md5 +++ b/Testing/SystemTests/tests/analysis/reference/ILL_IN4_SofQW.nxs.md5 @@ -1 +1 @@ -635875c6e0271e5c3bf50e8adeabd65a +f074b58b9a4f1afbd81d1ec830e13d53 diff --git a/buildconfig/CMake/Bootstrap.cmake b/buildconfig/CMake/Bootstrap.cmake index a2a97dc6da51531eefe07c87c8cc121df5eba6ab..12f00143779c1c881e44b43c3a986c2c9ad36976 100644 --- a/buildconfig/CMake/Bootstrap.cmake +++ b/buildconfig/CMake/Bootstrap.cmake @@ -10,7 +10,7 @@ if( MSVC ) include ( ExternalProject ) set( EXTERNAL_ROOT ${PROJECT_SOURCE_DIR}/external CACHE PATH "Location to clone third party dependencies to" ) set( THIRD_PARTY_GIT_URL "https://github.com/mantidproject/thirdparty-msvc2015.git" ) - set ( THIRD_PARTY_GIT_SHA1 df6c47608066dc639d9313df5b15e7fd493895ac ) + set ( THIRD_PARTY_GIT_SHA1 622c6d0aa7d4480cd4d338e153f1f09e52b8fb09 ) set ( THIRD_PARTY_DIR ${EXTERNAL_ROOT}/src/ThirdParty ) # Generates a script to do the clone/update in tmp set ( _project_name ThirdParty ) diff --git a/buildconfig/CMake/CommonSetup.cmake b/buildconfig/CMake/CommonSetup.cmake index 22857fed68e568c00d5067d8229b5f2b32d1476a..77cefcf267855375405401abad037b903646d809 100644 --- a/buildconfig/CMake/CommonSetup.cmake +++ b/buildconfig/CMake/CommonSetup.cmake @@ -89,12 +89,13 @@ find_package(Doxygen) # optional if(CMAKE_HOST_WIN32) find_package(ZLIB REQUIRED CONFIGS zlib-config.cmake) - set(HDF5_DIR "${THIRD_PARTY_DIR}/cmake") + set(HDF5_DIR "${THIRD_PARTY_DIR}/cmake/hdf5") find_package( HDF5 COMPONENTS CXX HL REQUIRED CONFIGS hdf5-config.cmake ) + set (HDF5_LIBRARIES hdf5::hdf5_cpp-shared hdf5::hdf5_hl-shared) else() find_package(ZLIB REQUIRED) find_package( diff --git a/buildconfig/CMake/Eigen.cmake b/buildconfig/CMake/Eigen.cmake index d5babb128d0382c813b4d9993b52e0d8160741ab..8f611ac18027e1f70dd21211107f4fed2d1df741 100644 --- a/buildconfig/CMake/Eigen.cmake +++ b/buildconfig/CMake/Eigen.cmake @@ -15,10 +15,6 @@ else() # Download and unpack Eigen at configure time configure_file(${CMAKE_SOURCE_DIR}/buildconfig/CMake/Eigen.in ${CMAKE_BINARY_DIR}/extern-eigen/CMakeLists.txt) - # The OLD behavior for this policy is to ignore the visibility properties - # for static libraries, object libraries, and executables without exports. - cmake_policy(SET CMP0063 "OLD") - execute_process(COMMAND ${CMAKE_COMMAND} -G "${CMAKE_GENERATOR}" -DCMAKE_SYSTEM_VERSION=${CMAKE_SYSTEM_VERSION} . WORKING_DIRECTORY ${CMAKE_BINARY_DIR}/extern-eigen ) execute_process(COMMAND ${CMAKE_COMMAND} --build . WORKING_DIRECTORY ${CMAKE_BINARY_DIR}/extern-eigen ) diff --git a/buildconfig/CMake/Eigen.in b/buildconfig/CMake/Eigen.in index 8af9e7cce8afe9b1c5d97cd11a62818a8b5b8e0d..cd0523c778ad7c7a36686cd77aafd117b345edd5 100644 --- a/buildconfig/CMake/Eigen.in +++ b/buildconfig/CMake/Eigen.in @@ -1,4 +1,7 @@ cmake_minimum_required ( VERSION 3.5 ) + +project(eigen-download NONE) + include( ExternalProject ) ExternalProject_Add(eigen diff --git a/buildconfig/CMake/SipQtTargetFunctions.cmake b/buildconfig/CMake/SipQtTargetFunctions.cmake index fb6bc5562b26d46e48be413eede143955f8f4431..610c776bc56f7b820ca1c57341ec42b83cae5803 100644 --- a/buildconfig/CMake/SipQtTargetFunctions.cmake +++ b/buildconfig/CMake/SipQtTargetFunctions.cmake @@ -81,6 +81,11 @@ function ( mtd_add_sip_module ) ) add_library ( ${PARSED_TARGET_NAME} MODULE ${_sip_generated_cpp} ${_sip_include_deps} ) + # Suppress Warnings about sip bindings have PyObject -> PyFunc casts which + # is a valid pattern GCC8 onwards detects + # GCC 8 onwards needs to disable functional casting at the Python interface + target_compile_options( ${PARSED_TARGET_NAME} PRIVATE + $<$<AND:$<CXX_COMPILER_ID:GNU>,$<VERSION_GREATER_EQUAL:$<CXX_COMPILER_VERSION>,8.0>>:-Wno-cast-function-type> ) target_include_directories ( ${PARSED_TARGET_NAME} SYSTEM PRIVATE ${SIP_INCLUDE_DIR} ) target_include_directories ( ${PARSED_TARGET_NAME} PRIVATE ${PARSED_INCLUDE_DIRS} ) target_include_directories ( ${PARSED_TARGET_NAME} SYSTEM PRIVATE ${PARSED_SYSTEM_INCLUDE_DIRS} ) diff --git a/buildconfig/CMake/Span.cmake b/buildconfig/CMake/Span.cmake index 5d17bfd9d5c641151206a8d7c3cfade3924426a1..120f837fcf1bf8eeca6611b21d1b8b739052b4c4 100644 --- a/buildconfig/CMake/Span.cmake +++ b/buildconfig/CMake/Span.cmake @@ -5,10 +5,6 @@ message(STATUS "Using tcbrindle/span in ExternalProject") # Download and unpack Eigen at configure time configure_file(${CMAKE_SOURCE_DIR}/buildconfig/CMake/Span.in ${CMAKE_BINARY_DIR}/extern-span/CMakeLists.txt) -# The OLD behavior for this policy is to ignore the visibility properties -# for static libraries, object libraries, and executables without exports. -cmake_policy(SET CMP0063 "OLD") - execute_process(COMMAND ${CMAKE_COMMAND} -G "${CMAKE_GENERATOR}" -DCMAKE_SYSTEM_VERSION=${CMAKE_SYSTEM_VERSION} . WORKING_DIRECTORY ${CMAKE_BINARY_DIR}/extern-span ) execute_process(COMMAND ${CMAKE_COMMAND} --build . WORKING_DIRECTORY ${CMAKE_BINARY_DIR}/extern-span ) diff --git a/buildconfig/CMake/Span.in b/buildconfig/CMake/Span.in index 4185571ec22db24a35755d1b67ce4434e4026066..60f79cc7919325210497817e8ed77b198bea2e79 100644 --- a/buildconfig/CMake/Span.in +++ b/buildconfig/CMake/Span.in @@ -1,4 +1,7 @@ cmake_minimum_required ( VERSION 3.5 ) + +project(span-download NONE) + include( ExternalProject ) ExternalProject_Add(span diff --git a/dev-docs/source/GettingStarted.rst b/dev-docs/source/GettingStarted.rst index 6a02cd77d9db523db73154e8eb0195bf8b5c95ce..5b5977d510535af36c9c52cda99d36c4286620f8 100644 --- a/dev-docs/source/GettingStarted.rst +++ b/dev-docs/source/GettingStarted.rst @@ -84,7 +84,7 @@ Ubuntu .. code-block:: sh apt install gdebi-core - apt install ~/Downloads/mantid-developer.X.Y.Z.deb + gdebi ~/Downloads/mantid-developer.X.Y.Z.deb where ``X.Y.Z`` should be replaced with the version that was downloaded. diff --git a/dev-docs/source/ISISReflectometryInterface.rst b/dev-docs/source/ISISReflectometryInterface.rst new file mode 100644 index 0000000000000000000000000000000000000000..0df4316763407393120e04095e055e073a0d6acf --- /dev/null +++ b/dev-docs/source/ISISReflectometryInterface.rst @@ -0,0 +1,325 @@ +.. _ISISReflectometryInterface: + +============================ +ISIS Reflectometry Interface +============================ + +This document gives a brief overview of the `ISIS Reflectometry Interface <https://docs.mantidproject.org/nightly/interfaces/ISIS%20Reflectometry.html>`_ design and things that you should be aware of when working on this interface. If you need to work on this interface, please make sure you are familiar with the `Development guidelines`_ below as a minimum. + +Overview +-------- + +The `ISIS Reflectometry Interface <https://docs.mantidproject.org/nightly/interfaces/ISIS%20Reflectometry.html>`_ provides a way for users to easily run a reduction on a *batch* of runs. A batch of runs is entered into a table, which is actually a tree structure with two levels - this allows sets of runs to be grouped so that their outputs are post-processed (stitched) together. Various default settings can be specified on the tabs. A tab is also provided to make exporting the results for a set of workspaces easy. + +The reduction for each row is done via :ref:`algm-ReflectometryISISLoadAndProcess` (which includes any pre-processing). Post-processing for a group is done via :ref:`algm-Stitch1DMany`. + +The GUI provides a lot of other functionality as well. Because it is quite complex, it is important to keep to the established guidelines, in particular sticking to the MVP design pattern, to avoid the code becoming difficult to work with. + +Structure +--------- + +:code:`GUI` +########### + +This directory contains all of the GUI code for the interface. Each separate component e.g. the Experiment tab, has its own subdirectory. Each of these components has its own view and presenter. There is also a :code:`Common` subdirectory for GUI components/interfaces common to more than one widget (e.g. the :code:`IMessageHandler` interface). + +Briefly the structure is as follows: + +- The top level is the :code:`MainWindow`. +- This can have one or more vertical :code:`Batch` tabs. +- Each :code:`Batch` has several horizontal tabs: + + - The :code:`Runs` tab is where the user specifies which runs to process. The actual runs list is specified in the embedded :code:`RunsTable` component, which comprises of the generic `JobTreeView <BatchWidget/API/JobTreeView.html>`_ table along with a reflectometry-specific toolbar. The :code:`Runs` tab also contains various other operations to do with finding and processing runs, such as searching and autoprocessing by investigation ID and a live data monitor. Note that a *table* here actually refers to a two-level tree, due to the way sets of runs can be grouped together for post-processing. + - The :code:`Experiment` tab allows the user to enter default settings related to a particular experiment. + - The :code:`Instrument` tab allows the user to enter default settings related to the current instrument. + - The :code:`Save` tab allows easy saving of a batch of outputs in ASCII format. It essentially just works on the ADS so this might not be necessary longer term if similar batch-saving functionality could be provided from the workspaces list. + +.. figure:: images/ISISReflectometryInterface_structure.png + :figwidth: 70% + :align: center + +:code:`Reduction` +################# + +This directory contains models which centre around the *reduction configuration*. This is a representation of all of the runs and settings that have been entered in a particular Batch tab in the GUI. The top level :code:`Batch` model therefore provides everything needed to perform a reduction for a particular set of runs. + +Additionally, these models also contain state information, e.g. the :code:`Row` and :code:`Group` contain information about whether processing has been performed and what the output workspaces are. + +:code:`Common` +############## + +This directory contains non-GUI-specific utility files useful in more than one component of the reflectometry interface but that are still specific to this interface, e.g. :code:`Parse.h` contains parsing utility functions that are specific for parsing reflectometry input strings such as lists of run numbers. More generic utilities should be put elsewhere, e.g. generic string handling functions might go in :code:`Framework/Kernel`. + +:code:`TestHelpers` +################### + +This directory contains components specific for unit testing. The actual tests are in :code:`../test/ISISReflectometry/`. + +Reduction back-end +------------------ + +The back-end is primarily a set of algorithms, with the entry points from the GUI being :ref:`algm-ReflectometryISISLoadAndProcess` (for reducing a row) and :ref:`algm-Stitch1DMany` (for post-processing a group). Any additional processing should be added to these algorithms, or a new wrapper algorithm could be added if appropriate (this might be necessary in future if post-processing will involve more than just stitching). + +The :code:`BatchPresenter` is the main coordinator for executing a reduction. It uses the :code:`BatchJobRunner`, which converts the reduction configuration to a format appropriate for the algorithms. The conversion functions are in files called :code:`RowProcessingAlgorithm` and :code:`GroupProcessingAlgorithm`, and any algorithm-specific code should be kept to these files. + +Unfortunately the whole batch cannot be farmed off to a single algorithm because we need to update the GUI after each row completes, and we must be able to interrupt processing so that we can cancel a large batch operation. We also need to know whether rows completed successfully before we can set up the group post-processing algorithms. Some queue management is therefore done by the :code:`BatchPresenter`, with the help of the :code:`BatchAlgorithmRunner`. + +Development guidelines +---------------------- + +The following design principles should be adhered to when developing the GUI. If the current design does not seem appropriate for additional feature requests, do consult with a senior developer to work out the best way forward rather than proceeding in a non-optimal way. + +Adhere to MVP +############# + +To ensure the GUI can be easily tested we follow the MVP design pattern. There is general guidance on this `here <https://developer.mantidproject.org/GUIDesignGuidelines.html>`_. + +The view cannot easily be tested, so the aim of MVP is to keep the view as simple as possible so that testing it is not necessary. Typically any user action on the view results in a notification to the presenter and is handled from there (even if that is just an update back to the view). Even simple things like which buttons are enabled on startup are controlled via the presenter rather than setting defaults in the view itself. + +It can be tempting to add one line to toggle or update something in the view without wiring up the presenter. But these quick fixes can quickly introduce bugs as they accumulate. The first question to ask yourself before making any change is: how will I unit test it? In fact, we recommend you follow `test driven development <https://www.mantidproject.org/TDD>`_ and write the unit tests first. + +Note that the views should not have a direct pointer to their presenters, so the notification is done via a subscriber interface (see `Subscriber pattern`_ for an example). The only exception is the :code:`QtMainWindowView` (see `Dependency inversion`_), but notifications should still be done via the subscriber interface. This helps to avoid accidentally introducing logic into the view about what should happen on an event and instead just notify that an event happened. It could also be easily extended to support multiple notifyees of different types, such as different subscribed presenters. + +Dependency inversion +#################### + +Dependency inversion has been introduced in an effort to simplify some aspects of the design and to make the code more modular. Objects that a class depends on are "injected", rather than being created directly within the class that requires them. This makes testing easier, because the real objects can easily be replaced with mocks. Most injection is currently performed using constructors and takes place at the 'entry-point' for the Reflectometry GUI, in :code:`QtMainWindowView`. See the `Dependency injection`_ example below. + +It is not normal in MVP for a view to have ownership of its presenter. However since the whole of mantid does not use Dependency Injection, and due to the way interfaces get instantiated this is currently necessary for :code:`QtMainWindowView`. This pointer should only be used for ownership and all other usage should be avoided, so ensure you use the :code:`MainWindowSubscriber` interface to send notifications to the presenter - i.e. use :code:`m_notifyee` instead of :code:`m_presenter`. + +Coordinate via presenters +######################### + +Although the components are largely self-contained, there are occasions where communication between them is required. For example, when processing is running, we do not want the user to be able to edit any settings, because this would change the model that the reduction is running on. We therefore disable all inputs that would affect the reduction when processing is running, and re-enable them when it stops. + +Although enabling/disabling inputs in this example affects the views, coordination between components is done via the presenters. This is to ensure that all of these interactions can be unit tested. Each presenter owns its child presenters, and also has a pointer to its parent presenter (which is set by its parent calling :code:`acceptMainPresenter` on the child and passing a pointer to itself). + +In the example mentioned, processing is initiated from e.g. the button on the :code:`RunsView`. This sends a notification to the :code:`RunsPresenter` via the subscriber interface. However, the :code:`RunsPresenter` cannot initiate processing because information is needed from the other tabs, and the other tabs need to be updated after it starts. Processing therefore needs to be coordinated at the Batch level. The :code:`RunsPresenter` therefore simply notifies its parent :code:`BatchPresenter` that the user requested to start processing. The :code:`BatchPresenter` then does the work to initiate processing. Once it has started (assuming it started successfully) it then notifies all of its child presenters (including the :code:`RunsPresenter`) that processing is in progress. + +Communication between different Batch components is also occasionally required. For example, for usability reasons, only one autoprocessing operation is allowed to be running at any one time. This means that when autoprocessing is running, we need to disable the :code:`AutoProcess` button on all of the other Batch tabs as well. This must be coordinated via the MainWindow component, which is the only component that has access to all of the Batch tabs. The user initiates autoprocessing using the :code:`AutoProcess` button on the :code:`BatchView`, which notifies the :code:`BatchPresenter` via the subscriber interface. Since the :code:`BatchPresenter` knows everything it needs to start autoprocessing for that batch, it does the work to initiate autoprocessing itself. It then simply notifies its parent :code:`MainWindowPresneter` that autoprocessing is in progress (again, assuming that it started successfully). The :code:`MainWindowPresenter` then notifies all of its child presenters that autoprocessing is in progress so that they can enable/disable any buttons/widgets as required. + +Avoid use of Qt types outside of Qt classes +########################################### + +Qt-specific types such as :code:`QString`, :code:`QColor` and subclasses of :code:`QWidget` should be kept out of the presenters and models. This avoids confusion over which types should be used and a potentially messy situation where we are always having to convert back and forth between Qt types and :code:`std` types. It also avoids an over-reliance on Qt, so that the view could be swapped out in future to one using a different framework, with little or no changes to the presenters and models. + +To help make it clear where Qt is used, all classes that use Qt (namely the views, along with a few supporting classes which wrap or subclass :code:`QObject`) are named with a :code:`Qt` prefix in their file and class names. Conversion from types like :code:`QString` to :code:`std::string` is performed within the views, and no Qt types are present in their interfaces. + +Keep the reduction configuration up to date +########################################### + +Any change on the GUI's views results in a notification to the relevant presenter, which typically then updates a relevant model in the :code:`Reduction` directory. The model should always be an up-to-date representation of the view. + +Model state (i.e. processed state for rows/groups and output workspace names) should also be kept up to date. For example, if a row's output workspace has been deleted, then its state is reset. If settings have changed that would affect the reduction output, then the state is also reset. + +Perform all processing in algorithms +#################################### + +When adding new functionality, where possible this should be done by extending the algorithms rather than by adding logic to the GUI. The aim is that there is a single algorithm that will be run for each entry in the table (albeit a different algorithm for Rows and Groups). + +Consider adding new wrapper algorithms if appropriate. :ref:`algm-ReflectometryISISLoadAndProcess` is an algorithm that has been added specifically for this purpose and can usually be extended or modified quite easily because it is designed for use with this GUI. The post-processing algorithm, :ref:`algm-Stitch1DMany`, is more generic so it is likely in future that we would want to add a wrapper for this algorithm rather than changing it directly. + +Design pattern examples +----------------------- + +Subscriber pattern +################## + +Let's take the :code:`Event` component as an example. + +- The view is constructed first and is passed to the presenter. The presenter then immediately subscribes to the view. + + .. code-block:: c++ + + EventPresenter::EventPresenter(IEventView *view) + : m_view(view) { + m_view->subscribe(this); + } + +- This sets the notifyee in the view, using a subscriber interface. + + .. code-block:: c++ + + void QtEventView::subscribe(EventViewSubscriber *notifyee) { + m_notifyee = notifyee; + } + +- The subscriber interface defines the set of notifications that the view needs to send. + + .. code-block:: c++ + + class MANTIDQT_ISISREFLECTOMETRY_DLL EventViewSubscriber { + public: + virtual void notifySliceTypeChanged(SliceType newSliceType) = 0; + virtual void notifyUniformSliceCountChanged(int sliceCount) = 0; + ... + }; + + Note that :code:`MANTIDQT_ISISREFLECTOMETRY_DLL` is used to expose classes/functions so they can be used in different modules. In this case, it is needed in order for this class to be used in the tests, because the tests are not part of the ISISReflectometry library. If you get linker errors, this is one thing to check. + +- The presenter implements the subscriber interface. + + .. code-block:: c++ + + class MANTIDQT_ISISREFLECTOMETRY_DLL EventPresenter + : public IEventPresenter, + public EventViewSubscriber + +- It overrides the notification functions to perform the relevant actions. + + .. code-block:: c++ + + void EventPresenter::notifyUniformSliceCountChanged(int) { + setUniformSlicingByNumberOfSlicesFromView(); + m_mainPresenter->notifySettingsChanged(); + } + +- When a user interacts with the view, all the view needs to do is send the appropriate notification. By using an interface, the view does not know anything about the concrete type that it is notifying. This helps to avoid accidentally introducing logic into the view about what should happen on an event and instead just notify that an event happened. It could also be easily extended to support multiple notifyees of different types, such as different subscribed presenters. + + .. code-block:: c++ + + void QtEventView::onUniformEvenChanged(int numberOfSlices) { + m_notifyee->notifyUniformSliceCountChanged(numberOfSlices); + } + + +Dependency injection +#################### + +A simple example of `Dependency inversion`_ is in the use of an :code:`IMessageHandler` interface to provide a service to display messages to the user. These messages must be displayed by a Qt view. Rather than each view having to implement this, we use one object (in this case the :code:`QtMainWindowView`) to implement this functionality and inject it as an :code:`IMessageHandler` to all of the presenters that need it. + +- The :code:`IMessageHandler` interface defines the functions for displaying messages: + + .. code-block:: c++ + + class IMessageHandler { + public: + virtual void giveUserCritical(const std::string &prompt, + const std::string &title) = 0; + ... + }; + +- The :code:`QtMainWindowView` implements these: + + .. code-block:: c++ + + void QtMainWindowView::giveUserCritical(const std::string &prompt, + const std::string &title) { + QMessageBox::critical(this, QString::fromStdString(title), + QString::fromStdString(prompt), QMessageBox::Ok, + QMessageBox::Ok); + } + +- The :code:`QtMainWindowView` creates a concrete instance of the interface (actually just a pointer to itself) and passes it in the construction of anything that needs it, e.g. the :code:`RunsPresenter` (in this case using a factory to perform the construction - more about the `Factory pattern`_ below): + + .. code-block:: c++ + + auto messageHandler = this; + auto makeRunsPresenter = RunsPresenterFactory(..., messageHandler); + +- The :code:`RunsPresenter` then has a simple service it can use to display messages without needing to know anything about the :code:`QtMainWindowView`: + + .. code-block:: c++ + + m_messageHandler->giveUserCritical("Catalog login failed", "Error"); + +- Our unit tests can then ensure that a notification is sent to Qt in a known critical situation, e.g. in :code:`RunsPresenterTest`: + + .. code-block:: c++ + + void testSearchCatalogLoginFails() { + ... + EXPECT_CALL(m_messageHandler, + giveUserCritical("Catalog login failed", "Error")) + .Times(1); + ... + } + +Factory pattern +############### + +The :code:`MainWindowPresenter` constructs the child Batch presenters on demand. This prevents us injecting them in its constructor. In order to follow `Dependency inversion`_, we therefore need to use factories to create the child presenters. Let's use the :code:`MainWindow` -> :code:`Batch` -> :code:`Event` components as an example. + +- As mentioned, the :code:`QtMainWindowView` is our entry point. This creates (and owns) the :code:`MainWindowPresenter`. It: + + - creates an :code:`EventPresenterFactory`; + - passes this to the :code:`BatchPresenterFactory` constructor so it can create the child :code:`EventPresenter` when needed; + - passes this to the :code:`MainWindowPresenter` constructor, which accepts a :code:`BatchPresenterFactory`, ready for making a Batch when needed. + +- When required, we then create a Batch: + + - The :code:`QtMainWindowView` notifies the presenter that a new batch was requested. + - The presenter instructs the view to create a child :code:`QtBatchView` (which will also construct its child :code:`QtEventView`). + - The :code:`QtBatchView` is passed to the :code:`BatchPresenterFactory` to create the :code:`BatchPresenter`: + + - the :code:`BatchPresenterFactory` extracts the :code:`QtEventView` from the :code:`QtBatchView`; + - this is passed to the :code:`EventPresenterFactory` to create the child :code:`EventPresenter`; it receives an :code:`IEventPresenter` back; + - the batch factory injects the :code:`IEventPresenter` into the :code:`BatchPresenter` constructor; + - it returns the result as an :code:`IBatchPresenter`. + + - The :code:`IBatchPresenter` is then added to the :code:`MainWindowPresenter`'s list of child presenters. + +The :code:`MainWindowPresenter` therefore creates, and owns, the :code:`BatchPresenter`, but does not need to know its concrete type. In turn, the :code:`BatchPresenterFactory` creates the child :code:`EventPresenter` and injects this into the :code:`BatchPresenter`, also without knowing the child's concrete type. As mentioned in the `Dependency inversion`_ section, this helps testability by allowing us to replace the real dependencies with mock objects. + +Testing +------- + +Let's look at the presenter-view interactions in the :code:`Event` component as an example. + +- The :code:`EventPresenterTest` class sets up a mock view to use for testing: + + .. code-block:: c++ + + NiceMock<MockEventView> m_view; + EventPresenter makePresenter() { + auto presenter = EventPresenter(&m_view); + ... + return presenter; + } + +- The mock view mocks any methods we're interested in testing, e.g. it mocks the subscribe method so that we can check that the presenter subscribes to the view: + + .. code-block:: c++ + + class MockEventView : public IEventView { + public: + MOCK_METHOD1(subscribe, void(EventViewSubscriber *)); + +- The presenter then uses :code:`EXPECT_CALL` to check that the method was called. Note that for :code:`subscribe` it is difficult to check that the correct presenter pointer is passed because of the two-way dependency in the construction, so we just check that it is called with any argument; for other methods we typically want to check the exact arguments. + + .. code-block:: c++ + + void testPresenterSubscribesToView() { + EXPECT_CALL(m_view, subscribe(_)).Times(1); + auto presenter = makePresenter(); + verifyAndClear(); + } + +- We know that the only notifications we can get from the view are the set of methods in the subscriber interface. Our presenter tests should test each of these. We may also have functions in the presenter that are initiated from different callers than the view, e.g. the parent presenter, so we must test these too. Generally, we want to test all functions in the public interface to the presenter class - and exercise all possible code paths that can result. + +- Note that it's likely we need multiple tests for each notification, for example :code:`notifyUniformSliceCountChanged` has a test to check that the model is updated as well as one to check that the main presenter is notified: + + .. code-block:: c++ + + void testChangingSliceCountUpdatesModel() { + ... + presenter.notifyUniformSliceCountChanged(expectedSliceCount); + auto const &sliceValues = + boost::get<UniformSlicingByNumberOfSlices>(presenter.slicing()); + TS_ASSERT(sliceValues == + UniformSlicingByNumberOfSlices(expectedSliceCount)); + verifyAndClear(); + } + + .. code-block:: c++ + + void testChangingSliceCountNotifiesMainPresenter() { + auto presenter = makePresenter(); + EXPECT_CALL(m_mainPresenter, notifySettingsChanged()).Times(AtLeast(1)); + presenter.notifyUniformSliceCountChanged(1); + verifyAndClear(); + } + +- Testing outcomes separately like this speeds up future development because it makes it easier to see where and why failures happen. It also makes it easier to maintain the tests as the code develops - e.g. if a functional change deliberately changes the expected action on the main presenter then we only need to update that test. The test that checks the model should not be affected (and if it is, we know we've broken something!). + +- Note that although the :code:`EventPresenter` tests currently check the model directly, the model could (and should) be mocked out and tested separately if it was more complex. diff --git a/dev-docs/source/ISISSANSReductionBackend.rst b/dev-docs/source/ISISSANSReductionBackend.rst index f0f5d0ee112cbc015ce372d45578b8a0ba00c3c5..c8c6ea13ce5a307f9a94daffc3a7d20b56c5a14f 100644 --- a/dev-docs/source/ISISSANSReductionBackend.rst +++ b/dev-docs/source/ISISSANSReductionBackend.rst @@ -177,21 +177,12 @@ build custom types. The current list of types are: - *PositiveFloatParameter* - *PositiveIntegerParameter* - *DictParameter* -- *ClassTypeParameter* - *FloatWithNoneParameter* - *StringWithNoneParameter* - *PositiveFloatWithNoneParameter* - *FloatListParameter* - *StringListParameter* - *PositiveIntegerListParameter* -- *ClassTypeListParameter* - -Most of the typed parameters are self-descriptive. The *ClassTypeParameter* -refers to the enum-like class definitions in *enum.py*. Note that if a parameter -is not set by the state builder, then it will return *None* when it is queried. -If it is a mandatory parameter on a state object, then this needs to be enforced -in the *validate* method of the state. - Individual states ^^^^^^^^^^^^^^^^^ @@ -256,7 +247,7 @@ can_direct_period The period to use for the can direct * calibration The path to the calibration file *StringParameter* Y N sample_scatter_run_number Run number of the sample scatter file *PositiveIntegerParameter* - Y sample_scatter_is_multi_period If the sample scatter is multi-period *BoolParameter* - Y -instrument Enum for the SANS instrument *ClassTypeParameter(SANSInstrument)* - Y +instrument Enum for the SANS instrument *Enum (SANSInstrument)* - Y idf_file_path Path to the IDF file *StringParameter* - Y ipf_file_path Path to the IPF file *StringParameter* - Y =============================== ============================================== ===================================== ========= =============== @@ -309,9 +300,9 @@ reduction. It contains the following parameters: =============================== ===================================================== ============================================== ========= =============== =========================================== Name Comment Type Optional? Auto-generated? Default value =============================== ===================================================== ============================================== ========= =============== =========================================== -reduction_mode The type of reduction, i.e. LAB, HAB, merged or both *ClassTypeParameter(ReductionMode)* N N *ISISReductionMode.LAB* enum value -reduction_dimensionality If 1D or 2D reduction *ClassTypeParameter(ReductionDimensionality)* N N *ReductionDimensionality.OneDim* enum value -merge_fit_mode The fit mode for merging *ClassTypeParameter(FitModeForMerge)* Y N *FitModeForMerge.NoFit* enum value +reduction_mode The type of reduction, i.e. LAB, HAB, merged or both *Enum(ReductionMode)* N N *ReductionMode.LAB* enum value +reduction_dimensionality If 1D or 2D reduction *Enum(ReductionDimensionality)* N N *ReductionDimensionality.OneDim* enum value +merge_fit_mode The fit mode for merging *Enum(FitModeForMerge)* Y N *FitModeForMerge.NoFit* enum value merge_shift The shift value for merging *FloatParameter* Y N 0.0 merge_scale The scale value for merging *FloatParameter* Y N 1.0 merge_range_min The min q value for merging *FloatWithNoneParameter* Y N *None* @@ -405,11 +396,11 @@ from time-of-flight to wavelength units. The parameters are: ===================== ==================================== =================================== ========= =============== Name Comment Type Optional? Auto-generated? ===================== ==================================== =================================== ========= =============== -rebin_type The type of rebinning *ClassTypeParameter(RebinType)* N N +rebin_type The type of rebinning *Enum(RebinType)* N N wavelength_low The lower wavelength boundary *PositiveFloatParameter* N N wavelength_high The upper wavelength boundary *PositiveFloatParameter* N N wavelength_step The wavelength step *PositiveFloatParameter* N N -wavelength_step_type This is either linear or logarithmic *ClassTypeParameter(RangeStepType)* N N +wavelength_step_type This is either linear or logarithmic *Enum(RangeStepType)* N N ===================== ==================================== =================================== ========= =============== The validation ensures that all entries are specified and that the lower wavelength boundary is smaller than the upper wavelength boundary. @@ -424,7 +415,7 @@ the required information about saving the reduced data. The relevant parameters Name Comment Type Optional? Auto-generated? Default ================================== ================================================== =================================== ========= =============== ======= zero_free_correction If zero error correction (inflation) should happen *BoolParameter* Y N True -file_format A list of file formats to save into *ClassTypeListParameter(SaveType)* Y N - +file_format A list of file formats to save into *EnumList(SaveType)* Y N - user_specified_output_name A custom user-specified name for the saved file *StringWithNoneParameter* Y N - user_specified_output_name_suffix A custom user-specified suffix for the saved file *StringParameter* Y N - use_reduction_mode_as_suffix If the reduction mode should be used as a suffix *BoolParameter* Y N - @@ -441,12 +432,12 @@ and the volume information. The parameters are: ===================== ======================================== ================================== ========= =============== Name Comment Type Optional? Auto-generated? ===================== ======================================== ================================== ========= =============== -shape The user-specified shape of the sample *ClassTypeParameter(SampleShape)* N Y +shape The user-specified shape of the sample *Enum(SampleShape)* N Y thickness The user-specified sample thickness *PositiveFloatParameter* N Y width The user-specified sample width *PositiveFloatParameter* N Y height The user-specified sample height *PositiveFloatParameter* N Y scale The user-specified absolute scale *PositiveFloatParameter* N Y -shape_from_file The file-extracted shape of the sample *ClassTypeParameter(SampleShape)* N Y +shape_from_file The file-extracted shape of the sample *Enum(SampleShape)* N Y thickness_from_file The file-extracted sample thickness *PositiveFloatParameter* N Y width_from_file The file-extracted sample width *PositiveFloatParameter* N Y height_from_file The file-extracted sample height *PositiveFloatParameter* N Y @@ -501,11 +492,11 @@ incident_monitor The incident monitor prompt_peak_correction_min The start time of a prompt peak correction *PositiveFloatParameter* Y N - prompt_peak_correction_max The stop time of a prompt peak correction *PositiveFloatParameter* Y N - prompt_peak_correction_enabled If the prompt peak correction should occur *BoolParameter* Y N True -rebin_type The type of wavelength rebinning, i.e. standard or interpolating *ClassTypeParameter(RebinType)* Y N - +rebin_type The type of wavelength rebinning, i.e. standard or interpolating *Enum(RebinType)* Y N - wavelength_low The lower wavelength boundary *PositiveFloatParameter* Y N - wavelength_high The upper wavelength boundary *PositiveFloatParameter* Y N - wavelength_step The wavelength step *PositiveFloatParameter* Y N - -wavelength_step_type The wavelength step type, i.e. lin or log *ClassTypeParameter(RebinType)* Y N - +wavelength_step_type The wavelength step type, i.e. lin or log *Enum(RebinType)* Y N - use_full_wavelength_range If the full wavelength range of the instrument should be used *BoolParameter* Y N - wavelength_full_range_low The lower wavelength boundary of the full wavelength range *PositiveFloatParameter* Y N - wavelength_full_range_high The upper wavelength boundary of the full wavelength range *PositiveFloatParameter* Y N - @@ -532,7 +523,7 @@ fit information. The set of parameters describing this fit are: ================= ================================================================= ================================ ========= =============== ======================== Name Comment Type Optional? Auto-generated? Default ================= ================================================================= ================================ ========= =============== ======================== -fit_type The type of fitting, i.e. lin, log or poly *ClassTypeParameter(FitType)* Y N *FitType.Log* enum value +fit_type The type of fitting, i.e. lin, log or poly *Enum(FitType)* Y N *FitType.Log* enum value polynomial_order Polynomial order when poly fit type has been selected *PositiveIntegerParameter* Y N 0 wavelength_low Lower wavelength bound for fitting (*None* means no lower bound) *PositiveFloatWithNoneParameter* Y N - wavelength_high Upper wavelength bound for fitting (*None* means no upper bound) *PositiveFloatWithNoneParameter* Y N - @@ -557,11 +548,11 @@ incident_monitor The incident monitor prompt_peak_correction_min The start time of a prompt peak correction *PositiveFloatParameter* Y N - prompt_peak_correction_max The stop time of a prompt peak correction *PositiveFloatParameter* Y N - prompt_peak_correction_enabled If the prompt peak correction should occur *BoolParameter* Y N False -rebin_type The type of wavelength rebinning, i.e. standard or interpolating *ClassTypeParameter(RebinType)* Y N *RebinType.Rebin* enum value +rebin_type The type of wavelength rebinning, i.e. standard or interpolating *Enum(RebinType)* Y N *RebinType.Rebin* enum value wavelength_low The lower wavelength boundary *PositiveFloatParameter* Y N - wavelength_high The upper wavelength boundary *PositiveFloatParameter* Y N - wavelength_step The wavelength step *PositiveFloatParameter* Y N - -wavelength_step_type The wavelength step type, i.e. lin or log *ClassTypeParameter(RangeStepType)* Y N - +wavelength_step_type The wavelength step type, i.e. lin or log *Enum(RangeStepType)* Y N - background_TOF_general_start General lower boundary for background correction *FloatParameter* Y N - background_TOF_general_stop General upper boundary for background correction *FloatParameter* Y N - background_TOF_monitor_start Monitor specific lower boundary for background correction (monitor vs. start value) *DictParameter* Y N - @@ -582,7 +573,7 @@ Name Comment wavelength_low The lower bound of the for the wavelength range *PositiveFloatParameter* N N wavelength_high The upper bound of the for the wavelength range *PositiveFloatParameter* N N wavelength_step The wavelength step *PositiveFloatParameter* N N -wavelength_step_type The wavelength step type, i.e. lin or log *ClassTypeParameter(RangeStepType)* N N +wavelength_step_type The wavelength step type, i.e. lin or log *Enum(RangeStepType)* N N adjustment_files Dict to adjustment files; detector type vs *StateAdjustmentFiles* object *DictParamter* N Y idf_path Path to the IDF file *StringParameter* N Y ====================== ========================================================================== =================================== ========= =============== @@ -611,7 +602,7 @@ The parameters are: ================================ ============================================= ============================================= =============================== =============== =========================================== Name Comment Type Optional? Auto-generated? Default ================================ ============================================= ============================================= =============================== =============== =========================================== -reduction_dimensionality 1D or 2D *ClassTypeParameter(ReductionDimensionality)* N N *ReductionDimensionality.OneDim* enum value +reduction_dimensionality 1D or 2D *Enum(ReductionDimensionality)* N N *ReductionDimensionality.OneDim* enum value use_gravity If gravity correction should be applied *BoolParameter* Y N False gravity_extra_length Extra length for gravity correction *PositiveFloatParameter* Y N 0 radius_cuto-off Radius above which pixels are not considered *PositiveFloatParameter* Y N 0 @@ -621,7 +612,7 @@ q_max Max momentum transfer value for 1D reduction * q_1d_rebin_string Rebin string for Q1D *StringParameter* N, if 1D N - q_xy_max Max momentum transfer value for 2D reduction *PositiveFloatParameter* N, if 2D N - q_xy_step Momentum transfer step for 2D reduction *PositiveFloatParameter* N, if 2D N - -q_xy_step_type The step type, i.e. lin or log *ClassTypeParameter(RangeStepType)* N, if 2D N - +q_xy_step_type The step type, i.e. lin or log *Enum(RangeStepType)* N, if 2D N - use_q_resolution If should perform a q resolution calculation *BoolParameter* Y N False q_resolution_collimation_length Collimation length *PositiveFloatParameter* N, if performing q resolution N - q_resolution_delta_r Virtual ring width on the detector *PositiveFloatParameter* N, if performing q resolution N - diff --git a/dev-docs/source/UserSupport.rst b/dev-docs/source/UserSupport.rst new file mode 100644 index 0000000000000000000000000000000000000000..95e3342648bbdb7a87564e3648b3374a143aa302 --- /dev/null +++ b/dev-docs/source/UserSupport.rst @@ -0,0 +1,38 @@ +.. _UserSupport: + +============ +User Support +============ + +.. contents:: + :local: + +Introduction +############ + +As Mantid continues to facilitate cutting-edge scientific research, for an +increasing number of users, the support side of Mantid is growing more +and more important. This can be in many circumstances and through +different avenues; therefore, below is detailed our support procedures. + +The main purpose of user support for the Mantid project, is to aide contact between the users and developers. + +.. figure:: images/errorReporter.png + :class: screenshot + :width: 700px + :align: right + :alt: Error reporter + + *Error reporter sends details directly to Mantid support* + +Bugs and Error Reports +---------------------- + +1. Users can report bugs via the `Mantid Help Forum <https://forum.mantidproject.org/>`_ or the `Mantid Help Email <mantid-help@mantidproject.org>`_, or from collected **Error Reports**. Currently this is a quick first contact with the team, but doesn't give much detail about the usage or unexpected error. +2. The bug is verified and reproduced by the support team. +3. The impact and importance is assessed by the support team member by contacting the users, instrument scientists, developers or project manager as appropriate. +4. A GitHub issue to resolve the problem is created if appropriate and/or workaround tested if possible. +5. The user is contacted to give a link to the created issue and/or workaround solution, by the support team. +6. When any issue is completed naming a user, that user is contacted to let them know it will be available in the nightly build and next release. The gatekeeper that merges the fix should message the appropriate developer, to remind them to contact the original reporter. This could simply be through adding a comment while merging that points this out. + + diff --git a/dev-docs/source/conf.py b/dev-docs/source/conf.py index 13d67fcecaa3062d0e883867f2e220d7c37e89f8..8f10e560e127bd0aff5d1485099c70eec5884ac9 100644 --- a/dev-docs/source/conf.py +++ b/dev-docs/source/conf.py @@ -168,7 +168,7 @@ epub_uid = "Mantid Reference: " + version # -- Link to other projects ---------------------------------------------------- intersphinx_mapping = { - 'h5py': ('http://docs.h5py.org/en/latest/', None), + 'h5py': ('https://h5py.readthedocs.io/en/stable/', None), 'matplotlib': ('http://matplotlib.org', None), 'numpy': ('https://docs.scipy.org/doc/numpy/', None), 'python': ('https://docs.python.org/3/', None), diff --git a/dev-docs/source/images/ISISReflectometryInterface_structure.png b/dev-docs/source/images/ISISReflectometryInterface_structure.png new file mode 100644 index 0000000000000000000000000000000000000000..beb6ca428cb49538166ee45e57e42f557a2c7e8a Binary files /dev/null and b/dev-docs/source/images/ISISReflectometryInterface_structure.png differ diff --git a/dev-docs/source/images/errorReporter.png b/dev-docs/source/images/errorReporter.png new file mode 100644 index 0000000000000000000000000000000000000000..db44632ba3d945510058c4fbf84bec27f280acb6 Binary files /dev/null and b/dev-docs/source/images/errorReporter.png differ diff --git a/dev-docs/source/index.rst b/dev-docs/source/index.rst index 5a40bbd3a9ba58f3b2280824730eeaa72c923833..8c43f6fb93a35abb7ab04da79d05039f1a43131c 100644 --- a/dev-docs/source/index.rst +++ b/dev-docs/source/index.rst @@ -62,6 +62,7 @@ Development Process DevelopmentAndReleaseCycle Communication IssueTracking + UserSupport GitWorkflow AutomatedBuildProcess JenkinsConfiguration @@ -79,6 +80,9 @@ Development Process :doc:`IssueTracking` Describes how issues are tracked over the project. +:doc:`UserSupport` + Procedures for User Problems to be tested and passed to the Development Team. + :doc:`GitWorkflow` Details the workflow used development with git and GitHub. @@ -187,6 +191,7 @@ GUI Development MVPTutorial/index QtDesignerForPython MantidUsedIconsTable + ISISReflectometryInterface :doc:`GUIDesignGuidelines` Gives some guidelines to consider when developing a new graphical user interface. @@ -200,6 +205,9 @@ GUI Development :doc:`MantidUsedIconsTable` The currently used Icons in Mantid and what they are used for. +:doc:`ISISReflectometryInterface` + An example of a complex C++ interface that uses MVP. + ========= Workbench ========= @@ -229,6 +237,7 @@ Component Overviews HandlingXML IndexProperty InstrumentViewer + ISISReflectometryInterface ISISSANSReductionBackend LoadAlgorithmHook Logging diff --git a/docs/source/algorithms/ConvertHFIRSCDtoMDE-v1.rst b/docs/source/algorithms/ConvertHFIRSCDtoMDE-v1.rst new file mode 100644 index 0000000000000000000000000000000000000000..891c8c50ee1e306ca6a328f11c94389ed0f3c64f --- /dev/null +++ b/docs/source/algorithms/ConvertHFIRSCDtoMDE-v1.rst @@ -0,0 +1,44 @@ + +.. algorithm:: + +.. summary:: + +.. relatedalgorithms:: + +.. properties:: + +Description +----------- + +This algorithm will convert the output of :ref:`algm-LoadWANDSCD` or +the autoreduced data from DEMAND (HB3A) into a :ref:`MDEventWorkspace +<MDWorkspace>` in Q-sample, where every pixel at every scan point is +converted to a MDEvent. This is similar to +:ref:`algm-ConvertWANDSCDtoQ` except that it doesn't histogram the +data or do normalization. :ref:`algm-FindPeaksMD` can be run on the +output Q sample space, then the UB can be found and used to then +convert to HKL using :ref:`algm-ConvertWANDSCDtoQ` +. :ref:`algm-IntegratePeaksMD` will also work on the output of this +algorithm. + +Usage +----- + +**Example - ConvertHFIRSCDtoMDE** + +.. code-block:: python + + LoadWANDSCD(IPTS=7776, RunNumbers='26640-27944', OutputWorkspace='data',Grouping='4x4') + ConvertHFIRSCDtoMDE(InputWorkspace='data', + Wavelength=1.488, + OutputWorkspace='Q') + + +Output: + +.. figure:: /images/ConvertHFIRSCDtoMDE.png + +.. categories:: + +.. sourcelink:: + diff --git a/docs/source/algorithms/EstimateResolutionDiffraction-v1.rst b/docs/source/algorithms/EstimateResolutionDiffraction-v1.rst index 4072fc364b89a7e4e982bf53fac9147b2bace63b..f89aa6ada3c17344bf050db975c25043bf193b80 100644 --- a/docs/source/algorithms/EstimateResolutionDiffraction-v1.rst +++ b/docs/source/algorithms/EstimateResolutionDiffraction-v1.rst @@ -96,9 +96,9 @@ Output: .. testoutput:: ExHistSimple Size of workspace 'PG3_Resolution' = 1000 - Estimated resolution of detector of spectrum 0 = 0.00323913250277 - Estimated resolution of detector of spectrum 100 = 0.00323608373204 - Estimated resolution of detector of spectrum 999 = 0.00354849279137 + Estimated resolution of detector of spectrum 0 = 0.00323913137315 + Estimated resolution of detector of spectrum 100 = 0.00323608260137 + Estimated resolution of detector of spectrum 999 = 0.00354849176520 .. seealso :: Algorithms :ref:`algm-EstimateDivergence`, :ref:`algm-CalibrateRectangularDetectors` and :ref:`algm-GetDetOffsetsMultiPeaks` diff --git a/docs/source/algorithms/HRPDSlabCanAbsorption-v1.rst b/docs/source/algorithms/HRPDSlabCanAbsorption-v1.rst index 845656a7c5c6da6f60d712574e2662714374c5e1..bb96eea9d53a7abbe715d5e573a06a5a0c773a25 100644 --- a/docs/source/algorithms/HRPDSlabCanAbsorption-v1.rst +++ b/docs/source/algorithms/HRPDSlabCanAbsorption-v1.rst @@ -55,7 +55,7 @@ Usage ws = Rebin(ws,Params=[1]) SetSampleMaterial(ws,ChemicalFormula="V") - wsOut = HRPDSlabCanAbsorption (ws,Thickness='0.2',ElementSize=3) + wsOut = HRPDSlabCanAbsorption (ws,Thickness=0.2,ElementSize=2) print('The created workspace has one entry for each spectra: {:d}'.format(wsOut.getNumberHistograms())) diff --git a/docs/source/algorithms/MatchAndMergeWorkspaces-v1.rst b/docs/source/algorithms/MatchAndMergeWorkspaces-v1.rst new file mode 100644 index 0000000000000000000000000000000000000000..b6e64b2a8409e535b3f5d000a1f4991fc001787e --- /dev/null +++ b/docs/source/algorithms/MatchAndMergeWorkspaces-v1.rst @@ -0,0 +1,67 @@ +.. algorithm:: + +.. summary:: + +.. relatedalgorithms:: + +.. properties:: + +Description +----------- + +This is a workflow algorithm that merges down a workspace, workspace +group, or list of workspaces, `MatchSpectra <algm-MatchSpectra>` +and sums the spectra using weighted mean from ranges for each +spectra. For each workspace an XMin and XMax is supplied, the +workspace is then cropped to that range and all workspaces are +Rebinned to have common bin edges. The output is the mean value +for each bin for all workspaces with a value in that bin. +This is done by executing several sub-algorithms as +listed below. + +#. :ref:`algm-Rebin` To rebin all spectra to have common bins. +#. :ref:`algm-ConjoinWorkspaces` repeated for every workspaces in the workspace group. +#. :ref:`algm-MatchSpectra` Matched against the spectra with the largest original x range. +#. :ref:`algm-CropWorkspaceRagged` to cut spectra to match the X limits given. +#. :ref:`algm-Rebin` To rebin all spectra to have common bins. +#. :ref:`algm-SumSpectra` using `WeightedSum=True` and `MultiplyBySpectra=False`. + +Usage +----- + +.. code-block:: python + + from mantid.simpleapi import * + + import numpy as np + + from isis_powder import polaris, SampleDetails + config_file_path = r"C:/Users/wey38795/Documents/polaris-calculate-pdf/polaris_config_example.yaml" + polaris = polaris.Polaris(config_file=config_file_path, user_name="test", mode="PDF") + + sample_details = SampleDetails(height=4.0, radius=0.2985, center=[0, 0, 0], shape='cylinder') + sample_details.set_material(chemical_formula='Si') + polaris.set_sample_details(sample=sample_details) + + polaris.create_vanadium(first_cycle_run_no="98532", + multiple_scattering=False) + polaris.focus(run_number="98533", input_mode='Summed') + + polaris.create_total_scattering_pdf(run_number="98533", merge_banks=False) + x_min = np.array([0.5, 3, 4, 6, 7]) + x_max = np.array([3.5, 5, 7, 11, 20]) + merged_ws = MatchAndMergeWorkspaces(WorkspaceGroup='focused_ws', XMin=x_min, XMax=x_max, CalculateScale=False) + + fig, ax = plt.subplots(subplot_kw={'projection':'mantid'}) + ax.plot(merged_ws) + ax.legend() + fig.show() + +This will produce a plot that looks like this: + +.. figure:: ../images/MatchAndMergeWorkspaces.png + +Workflow +######## + +.. diagram:: MatchAndMergeWorkspaces-v1_wkflw.dot diff --git a/docs/source/algorithms/SaveAscii-v1.rst b/docs/source/algorithms/SaveAscii-v1.rst index 2aa848631a306eaa4a2fcbb544e883bb742d22f7..59a39b9d55f917b8cc4effb3cae050892ed31979 100644 --- a/docs/source/algorithms/SaveAscii-v1.rst +++ b/docs/source/algorithms/SaveAscii-v1.rst @@ -14,6 +14,9 @@ contains the X-values, followed by pairs of Y and E values. Columns are separated by commas. The resulting file can normally be loaded into a workspace by the :ref:`algm-LoadAscii` algorithm. +As far as we are aware, this algorithm is only used by the ISIS Indirect group (including within the Indirect Diffraction GUI). +Please see the new version: :ref:`algm-SaveAscii`. + Limitations ########### diff --git a/docs/source/algorithms/TotScatCalculateSelfScattering-v1.rst b/docs/source/algorithms/TotScatCalculateSelfScattering-v1.rst new file mode 100644 index 0000000000000000000000000000000000000000..507ee64a7fe5225c6997c19648dfb3d676689807 --- /dev/null +++ b/docs/source/algorithms/TotScatCalculateSelfScattering-v1.rst @@ -0,0 +1,31 @@ +.. algorithm:: + +.. summary:: + +.. relatedalgorithms:: + +.. properties:: + +Description +----------- + +This is a workflow algorithm that calculates the placzek self scattering +factor focused into detector banks. This is done by executing several +sub-algorithms as listed below. + +#. :ref:`algm-SetSample` Sets sample data for the run that is to be corrected to the raw workspace. +#. :ref:`algm-ExtractSpectra` Extracts the monitor spectrum closest to the sample (incident spectrum). +#. :ref:`algm-ConvertUnits` Converts incident spectrum to wavelength. +#. :ref:`algm-FitIncidentSpectrum` Fit a curve to the incident spectrum. +#. :ref:`algm-CalculatePlaczekSelfScattering` Calculate the Placzek self scattering factor for each pixel. +#. :ref:`algm-LoadCalFile` Loads the detector calibration. +#. :ref:`algm-DiffractionFocussing` Focus the Placzek self scattering factor into detector banks. +#. :ref:`algm-CreateWorkspace` Create a workspace containing the number of pixels in each detector bank. +#. :ref:`algm-Divide` Normalize the Placzek correction by pixel number in bank +#. :ref:`algm-ConvertToDistribution` Change the workspace into a format that can be subtracted. +#. :ref:`algm-ConvertUnits` Converts correction into MomentumTransfer. + +Workflow +######## + +.. diagram:: TotScatCalculateSelfScattering-v1_wkflw.dot diff --git a/docs/source/algorithms/VesuvioCalculateGammaBackground-v1.rst b/docs/source/algorithms/VesuvioCalculateGammaBackground-v1.rst index 635e2a8cf6d88507625be070eb0a9068816e628b..639677d263a8e09679ac134538f9a51ed3b220f6 100644 --- a/docs/source/algorithms/VesuvioCalculateGammaBackground-v1.rst +++ b/docs/source/algorithms/VesuvioCalculateGammaBackground-v1.rst @@ -66,8 +66,8 @@ Output: .. testoutput:: First 5 values of input: [ 1. 1. 1. 1.] - First 5 values of background: [ 1.00053361 1.00054704 1.00056072 1.00057468] - First 5 values of corrected: [-0.00053361 -0.00054704 -0.00056072 -0.00057468] + First 5 values of background: [ 1.00053363 1.00054706 1.00056074 1.0005747 ] + First 5 values of corrected: [-0.00053363 -0.00054706 -0.00056074 -0.0005747 ] **Example: Correcting all spectra** diff --git a/docs/source/concepts/PeaksWorkspace.rst b/docs/source/concepts/PeaksWorkspace.rst index 0be01d1fd8395bafff928769fd5a4c2cf554e968..4aa71c0c6343dc8b8ede71129763bc5b15db1bc4 100644 --- a/docs/source/concepts/PeaksWorkspace.rst +++ b/docs/source/concepts/PeaksWorkspace.rst @@ -82,7 +82,7 @@ PeaksWorkspace Python Interface .. code-block:: python pws = mtd['name_of_peaks_workspace'] - pws.getNumberOfPeaks() + pws.getNumberPeaks() p = pws.getPeak(12) pws.removePeak(34) diff --git a/docs/source/conf.py b/docs/source/conf.py index bf8bd1eaf32269c9c06c0a16cf6bf1d15fa27c31..4e630d95a24457d8bdf6f5569800ef32aad82fbc 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -208,7 +208,7 @@ exec(compile(open(html_theme_cfg).read(), html_theme_cfg, 'exec')) # -- Link to other projects ---------------------------------------------------- intersphinx_mapping = { - 'h5py': ('https://docs.h5py.org/en/latest/', None), + 'h5py': ('https://h5py.readthedocs.io/en/stable/', None), 'matplotlib': ('https://matplotlib.org', None), 'numpy': ('https://docs.scipy.org/doc/numpy/', None), 'python': ('https://docs.python.org/3/', None), diff --git a/docs/source/diagrams/MatchAndMergeWorkspaces-v1_wkflw.dot b/docs/source/diagrams/MatchAndMergeWorkspaces-v1_wkflw.dot new file mode 100644 index 0000000000000000000000000000000000000000..99d01bd1ed07cecc8aae80d6e4b786cd99a8d7f1 --- /dev/null +++ b/docs/source/diagrams/MatchAndMergeWorkspaces-v1_wkflw.dot @@ -0,0 +1,52 @@ +digraph MatchAndMergeWorkspaces { +label = "MatchAndMergeWorkspaces Workflow Diagram" +$global_style + +subgraph params { + $param_style + inputWorkspace [label="InputWorkspaces"] + outputWorkspace [label="OutputWorkspace"] + xMin [label="XMin"] + xMax [label="XMax"] + calculateScale [label="CalculateScale"] + calculateOffset [label="CalculateOffset"] + matchSpectraIndex [label="Index of spectra with largest X range"] +} + +subgraph decisions { + $decision_style + isGroupWsNotConjoined [label="Are there\n un-conjoined workspaces\n in the workspace list"] +} + +subgraph algorithms { + $algorithm_style + Rebin1 [label="Rebin v1"] + ConjoinWorkspaces [label="ConjoinWorkspaces v1"] + MatchSpectra [label="MatchSpectra v1"] + CropWorkspaceRagged [label="CropWorkspaceRagged v1"] + Rebin2 [label="Rebin v1"] + SumSpectra [label="SumSpectra v1"] +} + +subgraph process { + $process_style + findLargestRange [label="Find detector data with the largest X range"] +} + +inputWorkspace -> findLargestRange +findLargestRange -> Rebin1 +findLargestRange -> matchSpectraIndex +Rebin1 -> isGroupWsNotConjoined +isGroupWsNotConjoined -> ConjoinWorkspaces [label="Yes"] +isGroupWsNotConjoined -> MatchSpectra [label="No"] +ConjoinWorkspaces -> isGroupWsNotConjoined +matchSpectraIndex -> MatchSpectra +calculateScale -> MatchSpectra +calculateOffset -> MatchSpectra +MatchSpectra -> CropWorkspaceRagged +xMin -> CropWorkspaceRagged +xMax -> CropWorkspaceRagged +CropWorkspaceRagged -> Rebin2 +Rebin2 -> SumSpectra +SumSpectra -> outputWorkspace +} \ No newline at end of file diff --git a/docs/source/diagrams/TotScatCalculateSelfScattering-v1_wkflw.dot b/docs/source/diagrams/TotScatCalculateSelfScattering-v1_wkflw.dot new file mode 100644 index 0000000000000000000000000000000000000000..ab7309a24588ea3a8e79fc6eb49aa205f8f1970f --- /dev/null +++ b/docs/source/diagrams/TotScatCalculateSelfScattering-v1_wkflw.dot @@ -0,0 +1,54 @@ +digraph TotScatCalculateSelfScattering { +label = "TotScatCalculateSelfScattering Workflow Diagram" +$global_style + +subgraph params { + $param_style + inputWorkspace [label="InputWorkspace"] + outputWorkspace [label="OutputWorkspace"] + calFileName [label="CalFileName"] + sampleGeometry [label="SampleGeometry"] + sampleMaterial [label="SampleMaterial"] +} + +subgraph algorithms { + $algorithm_style + SetSample [label="SetSample v1"] + ExtractSpectra [label="ExtractSpectra v1"] + ConvertUnits1 [label="ConvertUnits v1"] + FitIncidentSpectrum [label="FitIncidentSpectrum v1"] + CalculatePlaczekSelfScattering [label="CalculatePlaczekSelfScattering v1"] + LoadCalFile [label="LoadCalFile v1"] + DiffractionFocussing [label="DiffractionFocussing v2"] + CreateWorkspace [label="CreateWorkspace v1"] + Divide [label="Divide v1"] + ConvertToDistribution [label="ConvertToDistribution v1"] + ConvertUnits2 [label="ConvertUnits v1"] +} + +subgraph process { + $process_style + FindMonitorSpectra [label="Find the Monitor\n spectra closest to the sample"] + GetPixelNumberInDetector [label="Count the number\n of pixels in each detector"] +} + +inputWorkspace -> SetSample +sampleGeometry -> SetSample +sampleMaterial -> SetSample +SetSample -> FindMonitorSpectra +SetSample -> ExtractSpectra +FindMonitorSpectra -> ExtractSpectra +ExtractSpectra -> ConvertUnits1 +ConvertUnits1 -> FitIncidentSpectrum +FitIncidentSpectrum -> CalculatePlaczekSelfScattering +calFileName -> LoadCalFile +LoadCalFile -> DiffractionFocussing +CalculatePlaczekSelfScattering -> DiffractionFocussing +LoadCalFile -> GetPixelNumberInDetector +GetPixelNumberInDetector -> CreateWorkspace +CreateWorkspace -> Divide +DiffractionFocussing -> Divide +Divide -> ConvertToDistribution +ConvertToDistribution -> ConvertUnits2 +ConvertUnits2 -> outputWorkspace +} \ No newline at end of file diff --git a/docs/source/images/ColorMapCustomPlot.PNG b/docs/source/images/ColorMapCustomPlot.PNG new file mode 100644 index 0000000000000000000000000000000000000000..5849c625cacddd5a6bf361e03d4b5483f63c4f17 Binary files /dev/null and b/docs/source/images/ColorMapCustomPlot.PNG differ diff --git a/docs/source/images/ConvertHFIRSCDtoMDE.png b/docs/source/images/ConvertHFIRSCDtoMDE.png new file mode 100644 index 0000000000000000000000000000000000000000..45ef0bd164d921fda9a7fae9be41a5d7ba5b5f8e Binary files /dev/null and b/docs/source/images/ConvertHFIRSCDtoMDE.png differ diff --git a/docs/source/images/MatchAndMergeWorkspaces.png b/docs/source/images/MatchAndMergeWorkspaces.png new file mode 100644 index 0000000000000000000000000000000000000000..ebd6668fe52e985dac54d5534355fe3b2920f28c Binary files /dev/null and b/docs/source/images/MatchAndMergeWorkspaces.png differ diff --git a/docs/source/interfaces/Engineering Diffraction 2.rst b/docs/source/interfaces/Engineering Diffraction 2.rst index d31def2177d22f0e47a3ef69012d948327595110..c2c7bf1d40ea2ee7fe43f9d43d47c121220a92c9 100644 --- a/docs/source/interfaces/Engineering Diffraction 2.rst +++ b/docs/source/interfaces/Engineering Diffraction 2.rst @@ -72,4 +72,33 @@ Calibration Sample Number experiment runs. Path - The path to the GSAS parameter file to be loaded. \ No newline at end of file + The path to the GSAS parameter file to be loaded. + +Focus +----- + +This tab allows for the focusing of data files, by providing a run number or selecting the files +manually using the browse button, by making use of the :ref:`EnggFocus<algm-EnggFocus>` algorithm. + +In order to use the tab, a new or existing calibration must be created or loaded. + +Currently, the focusing tab only supports one focusing mode: + +- **Normal Focusing:** + The user is able to select which banks they want to focus, and all the spectra from those banks will be considered. + The output workspaces will have a suffix denoting which bank they are for. + +The focused data files are saved in NeXus format to: + +`Engineering_Mantid/Focus/` + +If an RB number has been specified the files will also be saved to a user directory +in the base directory: + +`Engineering_Mantid/User/RBNumber/Focus/` + +Parameters +^^^^^^^^^^ + +Sample Run Number + The run number or file path to the data file to be focused. diff --git a/docs/source/plotting/index.rst b/docs/source/plotting/index.rst index b310eea03403d5762b806a321cb057e687b33750..d9b6f339e831bdaf5e281f7be6844ad2f71651b2 100644 --- a/docs/source/plotting/index.rst +++ b/docs/source/plotting/index.rst @@ -1,11 +1,14 @@ .. _plotting: +==================== +Matplotlib in Mantid +==================== + .. contents:: Table of contents :local: -==================================== -Introduction to Matplotlib in Mantid -==================================== +Introduction +------------ Mantid can now use `Matplotlib <https://matplotlib.org/>`_ to produce figures. There are several advantages of using this software package: @@ -173,9 +176,8 @@ Here are some of the highlights: One can have multiple Axes objects in one Figure * **Axis** is the container for the ticks and labels for the x and y axis of the plot -====================== Showing/saving figures -====================== +---------------------- There are two main ways that one can visualize images produced by matplotlib. The first one is to pop up a window with the required graph. For that, we use the `show()` function of the figure. @@ -221,9 +223,9 @@ Sometimes one wants to save a multi-page pdf document. Here is how to do this: pdf.savefig(fig) -============ + Simple plots -============ +------------ For matrix workspaces, if we use the `mantid` projection, one can plot the data in a similar fashion as the plotting of arrays in matplotlib. Moreover, one can combine the two in the same figure @@ -402,10 +404,299 @@ One can do twin axes as well: fig.tight_layout() #fig.show() +Custom Colors +------------- + +Custom Color Cycle (Line / 1D plots) +#################################### + +The Default Color Cycle doesn't have to be used. Here is an example where a Custom Color Cycle is chosen. Make sure to fill the list `custom_colors` with either the HTML hex codes (eg. #b3457f) or recognised names for the desired colours. +Both can be found `online <https://www.rapidtables.com/web/color/html-color-codes.html>`_. + +.. plot:: + :include-source: + + from __future__ import (absolute_import, division, print_function, unicode_literals) + import matplotlib.pyplot as plt + from mantid import plots + from mantid.simpleapi import * + + ws=Load('GEM40979.raw') + Number = 12 # How many Spectra to Plot + + prop_cycle = plt.rcParams['axes.prop_cycle'] + colors = prop_cycle.by_key()['color'] # 10 colors in default cycle + + '''Change the following two parameters as you wish''' + custom_colors = ['#0000ffff', 'salmon','#00ff00ff'] # I've chosen Blue, Salmon, Green + + fig = plt.figure(figsize = (10,10)) + ax1 = plt.subplot(211,projection='mantid') + for i in range(Number): + ax1.plot(ws, specNum = i+1, color=colors[i%len(colors)]) + ax1.set_title('Default') + ax1.legend() + + ax2 = plt.subplot(212,projection='mantid') + for i in range(Number): + ax2.plot(ws, specNum= i+1, color=custom_colors[i%len(custom_colors)]) + ax2.set_title('Custom') + ax2.legend() + + fig.suptitle('Line Plots: Color Cycle', fontsize='x-large') + #fig.show() + +Custom Colormap (MantidPlot) +############################ + +In MantidPlot, a Custom Colormap (256 entries of Red, Green and Blue values [0-255 for each]) can be created and saved with: + +.. code-block:: python + + from __future__ import (absolute_import, division, print_function, unicode_literals) + from mantid.simpleapi import * + import matplotlib.pyplot as plt + import numpy as np + + r = np.zeros(256) + g = np.zeros(256) + b = np.zeros(256) + for i in range(256): + '''Control how the RGB values change throughout the Colormap''' + r[i] = i #linear increase in Red + g[i] = 255 - i #linear decrease in Green + + f = open("C:\MantidInstall\colormaps\GreenRed.map","w+") #Change the .map filename as you wish! + for i in range(256): + f.write(str(int(r[i]))) + f.write(' ') + f.write(str(int(g[i]))) + f.write(' ') + f.write(str(int(b[i]))) + f.write('\n') + f.close() + +Then open up any dataset (such as EMU00020884.nxs from the `TrainingCourseData <https://sourceforge.net/projects/mantid/files/Sample%20Data/TrainingCourseData.zip/download>`_) and produce a Colorfill plot. Change the Colormap by following `these instructions <https://docs.mantidproject.org/nightly/tutorials/mantid_basic_course/loading_and_displaying_data/04_displaying_2D_data.html#changing-the-colour-map>`_ and selecting the newly created `Greenred.map`. + +.. figure:: ../images/ColorMapCustomPlot.PNG + :class: screenshot + :width: 500px + :align: center + +This New Colormap is saved within the MantidInstall folder so it can be used without re-running this script! + + +Custom Colormap (MantidWorkbench) +################################# + +You can view the premade Colormaps `here <https://matplotlib.org/2.2.3/gallery/color/colormap_reference.html?highlight=colormap>`_. +These Colormaps can be registered and remain for the current session, but need to be rerun if Mantid has been reopened. Choose the location to Save your Colormap file wisely, outside of your MantidInstall folder! + +The following methods show how to Load, Convert from MantidPlot format, Create from Scratch and Visualise a Custom Colormap. + +- If you already have a Colormap file in an (N by 4) format, with all values between 0 and 1, then use: + +*1a. Load Colormap and Register* + +.. code-block:: python + + import matplotlib.pyplot as plt + import numpy as np + from matplotlib.colors import ListedColormap, LinearSegmentedColormap + + Cmap_Name = 'Beach' # Colormap name + Loaded_Cmap = np.loadtxt("C:\Path\to\File\Filename.txt") + # Register the Loaded Colormap + Listed_CustomCmap = ListedColormap(Loaded_Cmap, name=Cmap_Name) + plt.register_cmap(name=Cmap_Name, cmap= Listed_CustomCmap) + + # Create and register the reverse colormap + Res = len(Loaded_Cmap) + Reverse = np.zeros((Res,4)) + for i in range(Res): + for j in range(4): + Reverse[i][j] = Loaded_Cmap[Res-(i+1)][j] + + Listed_CustomCmap_r = ListedColormap(Reverse, name=(Cmap_Name + '_r') ) + plt.register_cmap(name=(Cmap_Name + '_r'), cmap= Listed_CustomCmap_r) + +- If you have a Colormap file in a MantidPlot format (N by 3) with all values between 0 and 255, firstly **rename the file extension from .map to .txt**, then use: + +*1b. Convert MantidPlot Colormap and Register* + +.. code-block:: python + + import matplotlib.pyplot as plt + import numpy as np + from matplotlib.colors import ListedColormap, LinearSegmentedColormap + + Cmap_Name = 'Beach' + Loaded_Cmap = np.loadtxt("/Path/to/file/Beach.txt") + + Res = len(Loaded_Cmap) + Cmap = np.zeros((Res,4)) + for i in range(Res): + '''Normalise RGB values, Add 4th column alpha set to 1''' + for j in range(3): + Cmap[i][j] = float(Loaded_Cmap[i][j]) / 255 + Cmap[i][3] = 1 + '''Checks all values b/w 0 and 1''' + for j in range(4): + if Cmap[i][j] > 1: + print Cmap[i] + raise ValueError('Values must be between 0 and 1, one of the above is > 1') + if Cmap[i][j] < 0: + print Cmap[i] + raise ValueError('Values must be between 0 and 1, one of the above is negative') + else: + pass + + #np.savetxt("C:\Path\to\File\Filename.txt",Cmap) #uncomment to save to file + + # Register the Loaded Colormap + Listed_CustomCmap = ListedColormap(Cmap, name=Cmap_Name) + plt.register_cmap(name=Cmap_Name, cmap= Listed_CustomCmap) + + # Create and register the reverse colormap + Reverse = np.zeros((Res,4)) + for i in range(Res): + for j in range(4): + Reverse[i][j] = Cmap[Res-(i+1)][j] + + Listed_CustomCmap_r = ListedColormap(Reverse, name=(Cmap_Name + '_r') ) + plt.register_cmap(name=(Cmap_Name + '_r'), cmap= Listed_CustomCmap_r) + +- To Create a Colormap from scratch, use: + +*1c. Create and Register* + +.. code-block:: python + + import matplotlib.pyplot as plt + from matplotlib.colors import ListedColormap, LinearSegmentedColormap + import numpy as np + + Cmap_Name = 'Beach' # Colormap name + Res = 500 # Resolution of your Colormap (number of steps in colormap) + + Re = Res-1 + Cmap = np.zeros((Res,4)) + for i in range(Res): + '''Input functions inside float(), Divide by Res to normalise''' + Cmap[i][0] = float(Res) / Res #Red #just 1 + Cmap[i][1] = float(i) / Re #Green #+ve i divisible by Res-1 = Re + Cmap[i][2] = float(Res-i)**2 / Res**2 #Blue #Make sure Norm_factor correct + Cmap[i][3] = 1 + '''Checks all values b/w 0 and 1''' + for j in range(4): + if Cmap[i][j] > 1: + print Cmap[i] + raise ValueError('Values must be between 0 and 1, one of the above is > 1') + if Cmap[i][j] < 0: + print Cmap[i] + raise ValueError('Values must be between 0 and 1, one of the above is Negative') + else: + pass + + #np.savetxt("C:\Path\to\File\Filename.txt",Cmap) #uncomment to save to file + + Listed_CustomCmap = ListedColormap(Cmap, name = Cmap_Name) + plt.register_cmap(name = Cmap_Name, cmap = Listed_CustomCmap) + + # Create and register the reverse colormap + Reverse = np.zeros((Res,4)) + for i in range(Res): + for j in range(4): + Reverse[i][j] = Cmap[Res-(i+1)][j] + + Listed_CustomCmap_r = ListedColormap(Reverse, name=(Cmap_Name + '_r') ) + plt.register_cmap(name=(Cmap_Name + '_r'), cmap= Listed_CustomCmap_r) + +Now the Custom Colormap has been registered, right-click on a workspace and produce a colorfill plot. In Figure Options (Gear Icon in Plot Figure), under the Images Tab, you can use the drop down-menu to select the new Colormap, and use the check-box to select its Reverse! + +- Otherwise, use a script like this (from above in Section "Simple Plots") to plot with your new Colormap: + +*2. Plot New Colormap* (change the "cmap" name in line 12 accordingly) + +.. code-block:: python + + from mantid.simpleapi import Load, ConvertToMD, BinMD, ConvertUnits, Rebin + from mantid import plots + import matplotlib.pyplot as plt + from matplotlib.colors import LogNorm + data = Load('CNCS_7860') + data = ConvertUnits(InputWorkspace=data,Target='DeltaE', EMode='Direct', EFixed=3) + data = Rebin(InputWorkspace=data, Params='-3,0.025,3', PreserveEvents=False) + md = ConvertToMD(InputWorkspace=data,QDimensions='|Q|',dEAnalysisMode='Direct') + sqw = BinMD(InputWorkspace=md,AlignedDim0='|Q|,0,3,100',AlignedDim1='DeltaE,-3,3,100') + + fig, ax = plt.subplots(subplot_kw={'projection':'mantid'}) + c = ax.pcolormesh(sqw, cmap='Beach', norm=LogNorm()) + cbar=fig.colorbar(c) + cbar.set_label('Intensity (arb. units)') #add text to colorbar + #fig.show() + +.. plot:: + + import matplotlib.pyplot as plt + from matplotlib.colors import ListedColormap, LinearSegmentedColormap + import numpy as np + + Cmap_Name = 'Beach' # Colormap name + Res = 500 # Resolution of your Colormap (number of steps in colormap) + + Re = Res-1 + Cmap = np.zeros((Res,4)) + for i in range(Res): + '''Input functions inside float(), Divide by Res to normalise''' + Cmap[i][0] = float(Res) / Res #Red #just 1 + Cmap[i][1] = float(i) / Re #Green #+ve i divisible by Res-1 = Re + Cmap[i][2] = float(Res-i)**2 / Res**2 #Blue #Make sure Norm_factor correct + Cmap[i][3] = 1 + '''Checks all values b/w 0 and 1''' + for j in range(4): + if Cmap[i][j] > 1: + print Cmap[i] + raise ValueError('Values must be between 0 and 1, one of the above is > 1') + if Cmap[i][j] < 0: + print Cmap[i] + raise ValueError('Values must be between 0 and 1, one of the above is Negative') + else: + pass + + #np.savetxt("C:\Path\to\File\Filename.txt",Cmap) #uncomment to save to file + + Listed_CustomCmap = ListedColormap(Cmap, name = Cmap_Name) + plt.register_cmap(name = Cmap_Name, cmap = Listed_CustomCmap) + + # Create and register the reverse colormap + Reverse = np.zeros((Res,4)) + for i in range(Res): + for j in range(4): + Reverse[i][j] = Cmap[Res-(i+1)][j] + + Listed_CustomCmap_r = ListedColormap(Reverse, name=(Cmap_Name + '_r') ) + plt.register_cmap(name=(Cmap_Name + '_r'), cmap= Listed_CustomCmap_r) + + from mantid.simpleapi import Load, ConvertToMD, BinMD, ConvertUnits, Rebin + from mantid import plots + from matplotlib.colors import LogNorm + data = Load('CNCS_7860') + data = ConvertUnits(InputWorkspace=data,Target='DeltaE', EMode='Direct', EFixed=3) + data = Rebin(InputWorkspace=data, Params='-3,0.025,3', PreserveEvents=False) + md = ConvertToMD(InputWorkspace=data,QDimensions='|Q|',dEAnalysisMode='Direct') + sqw = BinMD(InputWorkspace=md,AlignedDim0='|Q|,0,3,100',AlignedDim1='DeltaE,-3,3,100') + + fig, ax = plt.subplots(subplot_kw={'projection':'mantid'}) + c = ax.pcolormesh(sqw, cmap='Beach', norm=LogNorm()) + cbar=fig.colorbar(c) + cbar.set_label('Intensity (arb. units)') #add text to colorbar + #fig.show() + +Colormaps can also be created with the `colormap package <https://colormap.readthedocs.io/en/latest/>`_ or by `concatenating existing colormaps <https://matplotlib.org/3.1.0/tutorials/colors/colormap-manipulation.html>`_. -==================== Plotting Sample Logs -==================== +-------------------- The :func:`mantid.plots.MantidAxes.plot<mantid.plots.MantidAxes.plot>` function can show sample logs. By default, the time axis represents the time since the first proton charge pulse (the @@ -446,10 +737,8 @@ So one needs to use :func:`mantid.plots.plotfunctions.plot<mantid.plots.plotfunc plots.plotfunctions.plot(axt,w,LogName='ChopperStatus5', FullTime=True) #fig.show() - -============= Complex plots -============= +------------- One common type of a slightly more complex figure involves drawing an inset. @@ -599,9 +888,8 @@ Plotting dispersion curves on multiple panels can also be done using matplotlib .. _mplDefaults: -========================== Change Matplotlib Defaults -========================== +-------------------------- It is possible to alter the default appearance of Matplotlib plots, e.g. linewidths, label sizes, colour cycles etc. This is most readily achieved by setting the ``rcParams`` at the start of a diff --git a/docs/source/release/v4.3.0/diffraction.rst b/docs/source/release/v4.3.0/diffraction.rst index b703b345cc620c351baa7cd8c8a9e4c8bb63b4df..04f25c089a719ede44dcebb04b9673ca7e98df45 100644 --- a/docs/source/release/v4.3.0/diffraction.rst +++ b/docs/source/release/v4.3.0/diffraction.rst @@ -18,6 +18,7 @@ Powder Diffraction - The create_total_scattering_pdf merging banks now matches spectra to the spectrum with the largest x range. - The create_total_scattering_pdf merging banks no longer matches spectra with scale, it now only matches with offset. +- :ref:`HRPDSlabCanAbsorption <algm-HRPDSlabCanAbsorption-v1>` now accepts any thickness parameter and not those in a specified list. Engineering Diffraction ----------------------- @@ -25,6 +26,9 @@ Engineering Diffraction Single Crystal Diffraction -------------------------- +- :ref:`PredictFractionalPeaks <algm-PredictFractionalPeaks-v1>` now accepts the same set of modulation vector properties as :ref:`IndexPeaks <algm-IndexPeaks-v1>`. +- New algorithm :ref:`ConvertHFIRSCDtoMDE <algm-ConvertHFIRSCDtoMDE-v1>` for converting HFIR single crystal data (from WAND and DEMAND) into MDEventWorkspace in units Q_sample. + Imaging ------- @@ -39,6 +43,8 @@ Powder Diffraction Engineering Diffraction ----------------------- +- Fixed a bug where `SaveGSS <algm-SaveGSS-v1>` could crash when attempting to pass a group workspace into it. + Single Crystal Diffraction -------------------------- diff --git a/docs/source/release/v4.3.0/direct_geometry.rst b/docs/source/release/v4.3.0/direct_geometry.rst index 2a4a9b09cfe3d55ddd8decd3430d8215539703de..2250c429fca2c8f8c403172999cd002a925b9a11 100644 --- a/docs/source/release/v4.3.0/direct_geometry.rst +++ b/docs/source/release/v4.3.0/direct_geometry.rst @@ -9,4 +9,6 @@ Direct Geometry Changes putting new features at the top of the section, followed by improvements, followed by bug fixes. -:ref:`Release 4.3.0 <v4.3.0>` \ No newline at end of file +* New ``NOW4`` instrument definition for SNS + +:ref:`Release 4.3.0 <v4.3.0>` diff --git a/docs/source/release/v4.3.0/framework.rst b/docs/source/release/v4.3.0/framework.rst index ac8d1bd62ac5af649b67c3b123ad97cc81c20446..9df25b4238994512a2a77704cc902667db2f9565 100644 --- a/docs/source/release/v4.3.0/framework.rst +++ b/docs/source/release/v4.3.0/framework.rst @@ -9,6 +9,9 @@ Framework Changes putting new features at the top of the section, followed by improvements, followed by bug fixes. +New Features +############ + Concepts -------- @@ -16,13 +19,29 @@ Improvements ############ - Fixed a bug in :ref:`LoadNGEM <algm-LoadNGEM>` where precision was lost due to integer arithmetic. +- Prevent units that are not suitable for :ref:`ConvertUnits <algm-ConvertUnits>` being entered as the target unit. Algorithms ---------- +- :ref:`TotScatCalculateSelfScattering <algm-TotScatCalculateSelfScattering>` will calculate a normalized self scattering correction for foccues total scattering data. +- :ref:`MatchAndMergeWorkspaces <algm-MatchAndMergeWorkspaces>` will merge workspaces in a workspace group withing weighting from a set of limits for each workspace and using `MatchSpectra <algm-MatchSpectra>`. + Data Objects ------------ + + +Geometry +-------- + +Improvements +############ + +- Increased numerical accuracy when calculating the bounding box of mili-meter sized cylindrical detector pixels. + + + Python ------ diff --git a/docs/source/release/v4.3.0/indirect_geometry.rst b/docs/source/release/v4.3.0/indirect_geometry.rst index 1e7455c3e72a4378db5959788c5dec1d3f50cd0d..a452189d03aa202c8f3841b411727de841efafa4 100644 --- a/docs/source/release/v4.3.0/indirect_geometry.rst +++ b/docs/source/release/v4.3.0/indirect_geometry.rst @@ -14,5 +14,6 @@ BugFixes ######## - The Abins file parser no longer fails to read data from non-periodic vibration calculations performed with CRYSTAL17. +- :ref:`ApplyPaalmanPingsCorrection <algm-ApplyPaalmanPingsCorrection>` will now run also for fixed window scan reduced data and will not crash on workspace groups. :ref:`Release 4.3.0 <v4.3.0>` diff --git a/docs/source/release/v4.3.0/mantidworkbench.rst b/docs/source/release/v4.3.0/mantidworkbench.rst index 842aaa12c2dff54bd22271a577c0f6f21a722e76..e786e07127bcdf470bc6eea1986328a365dd0936 100644 --- a/docs/source/release/v4.3.0/mantidworkbench.rst +++ b/docs/source/release/v4.3.0/mantidworkbench.rst @@ -20,13 +20,16 @@ Improvements :width: 500px :align: left -- Normalization option have been added to 2d plots. +- Normalization options have been added to 2d plots and sliceviewer. +- The images tab in figure options no longer forces the max value to be greater than the min value. Bugfixes ######## +- Colorbar scale no longer vanish on colorfill plots with a logarithmic scale - Figure options no longer causes a crash when using 2d plots created from a script. +- Running an algorithm that reduces the number of spectra on an active plot (eg SumSpectra) no longer causes an error - Fix crash when loading a script with syntax errors -:ref:`Release 4.3.0 <v4.3.0>` \ No newline at end of file +:ref:`Release 4.3.0 <v4.3.0>` diff --git a/instrument/Facilities.xml b/instrument/Facilities.xml index bb12995ac2ebe0ec0d622f328d6b902355ef5ba9..02d0256f645e89714bb49f2e2695b28f93ce015c 100644 --- a/instrument/Facilities.xml +++ b/instrument/Facilities.xml @@ -567,6 +567,11 @@ </livedata> </instrument> + <instrument name="NOW4" shortname="NOW4" beamline="14Q"> + <technique>Neutron Spectroscopy</technique> + <technique>TOF Direct Geometry Spectroscopy</technique> + </instrument> + <instrument name="VISION" shortname="VIS" beamline="16B"> <technique>Neutron Spectroscopy</technique> <technique>TOF Indirect Geometry Spectroscopy</technique> @@ -902,7 +907,7 @@ <!-- HZB --> <facility name="HZB" FileExtensions=".nxs"> <timezone>Europe/Berlin</timezone> - + <instrument name="TEST" shortname="TEST"> <zeropadding size="8" /> <technique>ESS Test Beamline</technique> @@ -946,14 +951,14 @@ <!-- Test Facility to allow example usage of Live listeners against "Fake" instrument sources --> <facility name="TEST_LIVE" FileExtensions=".nxs,.raw"> <timezone>UTC</timezone> - + <instrument name="LOKI"> <technique>SANS Test</technique> <livedata> <connection name="event" address="hinata:9092" listener="KafkaEventListener" /> </livedata> </instrument> - + <instrument name="ISIS_Histogram"> <technique>Test Listener</technique> <livedata> diff --git a/instrument/NOW4_Definition.xml b/instrument/NOW4_Definition.xml new file mode 100644 index 0000000000000000000000000000000000000000..f5606937349a1241cf47f30ac5fcc97ca6085f50 --- /dev/null +++ b/instrument/NOW4_Definition.xml @@ -0,0 +1,351 @@ +<?xml version='1.0' encoding='ASCII'?> +<instrument xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xmlns="http://www.mantidproject.org/IDF/1.0" last-modified="2019-06-25 14:11:19.918208" name="NOW4" valid-from="2019-06-02 00:00:00" valid-to="2100-01-31 23:59:59" xsi:schemaLocation="http://www.mantidproject.org/IDF/1.0 http://schema.mantidproject.org/IDF/1.0/IDFSchema.xsd"> + <!--Created by Andrei Savici--> + <defaults> + <length unit="metre"/> + <angle unit="degree"/> + <reference-frame> + <along-beam axis="z"/> + <pointing-up axis="y"/> + <handedness val="right"/> + <theta-sign axis="x"/> + </reference-frame> + </defaults> + <!--SOURCE AND SAMPLE POSITION--> + <component type="moderator"> + <location z="-36.262"/> + </component> + <type is="Source" name="moderator"/> + <component type="sample-position"> + <location x="0.0" y="0.0" z="0.0"/> + </component> + <type is="SamplePos" name="sample-position"/> + <component idlist="detectors" type="detectors"> + <location/> + </component> + <type name="detectors"> + <component type="bank1"> + <location/> + </component> + </type> + <type name="bank1"> + <component type="eightpack"> + <location x="2.64623477185" y="-0.00904297322071" z="-2.3031829704"> + <rot axis-x="0" axis-y="1" axis-z="0" val="311.035082035"/> + </location> + </component> + </type> + <!--STANDARD 8-PACK--> + <type name="eightpack"> + <properties/> + <component type="tube"> + <location name="tube1" x="-0.096012"/> + <location name="tube2" x="-0.06858"/> + <location name="tube3" x="-0.041148"/> + <location name="tube4" x="-0.013716"/> + <location name="tube5" x="0.013716"/> + <location name="tube6" x="0.041148"/> + <location name="tube7" x="0.06858"/> + <location name="tube8" x="0.096012"/> + </component> + </type> + <!--STANDARD 2m 128 PIXEL TUBE--> + <type name="tube" outline="yes"> + <properties/> + <component type="pixel"> + <location name="pixel1" y="-0.2"/> + <location name="pixel2" y="-0.1984375"/> + <location name="pixel3" y="-0.196875"/> + <location name="pixel4" y="-0.1953125"/> + <location name="pixel5" y="-0.19375"/> + <location name="pixel6" y="-0.1921875"/> + <location name="pixel7" y="-0.190625"/> + <location name="pixel8" y="-0.1890625"/> + <location name="pixel9" y="-0.1875"/> + <location name="pixel10" y="-0.1859375"/> + <location name="pixel11" y="-0.184375"/> + <location name="pixel12" y="-0.1828125"/> + <location name="pixel13" y="-0.18125"/> + <location name="pixel14" y="-0.1796875"/> + <location name="pixel15" y="-0.178125"/> + <location name="pixel16" y="-0.1765625"/> + <location name="pixel17" y="-0.175"/> + <location name="pixel18" y="-0.1734375"/> + <location name="pixel19" y="-0.171875"/> + <location name="pixel20" y="-0.1703125"/> + <location name="pixel21" y="-0.16875"/> + <location name="pixel22" y="-0.1671875"/> + <location name="pixel23" y="-0.165625"/> + <location name="pixel24" y="-0.1640625"/> + <location name="pixel25" y="-0.1625"/> + <location name="pixel26" y="-0.1609375"/> + <location name="pixel27" y="-0.159375"/> + <location name="pixel28" y="-0.1578125"/> + <location name="pixel29" y="-0.15625"/> + <location name="pixel30" y="-0.1546875"/> + <location name="pixel31" y="-0.153125"/> + <location name="pixel32" y="-0.1515625"/> + <location name="pixel33" y="-0.15"/> + <location name="pixel34" y="-0.1484375"/> + <location name="pixel35" y="-0.146875"/> + <location name="pixel36" y="-0.1453125"/> + <location name="pixel37" y="-0.14375"/> + <location name="pixel38" y="-0.1421875"/> + <location name="pixel39" y="-0.140625"/> + <location name="pixel40" y="-0.1390625"/> + <location name="pixel41" y="-0.1375"/> + <location name="pixel42" y="-0.1359375"/> + <location name="pixel43" y="-0.134375"/> + <location name="pixel44" y="-0.1328125"/> + <location name="pixel45" y="-0.13125"/> + <location name="pixel46" y="-0.1296875"/> + <location name="pixel47" y="-0.128125"/> + <location name="pixel48" y="-0.1265625"/> + <location name="pixel49" y="-0.125"/> + <location name="pixel50" y="-0.1234375"/> + <location name="pixel51" y="-0.121875"/> + <location name="pixel52" y="-0.1203125"/> + <location name="pixel53" y="-0.11875"/> + <location name="pixel54" y="-0.1171875"/> + <location name="pixel55" y="-0.115625"/> + <location name="pixel56" y="-0.1140625"/> + <location name="pixel57" y="-0.1125"/> + <location name="pixel58" y="-0.1109375"/> + <location name="pixel59" y="-0.109375"/> + <location name="pixel60" y="-0.1078125"/> + <location name="pixel61" y="-0.10625"/> + <location name="pixel62" y="-0.1046875"/> + <location name="pixel63" y="-0.103125"/> + <location name="pixel64" y="-0.1015625"/> + <location name="pixel65" y="-0.1"/> + <location name="pixel66" y="-0.0984375"/> + <location name="pixel67" y="-0.096875"/> + <location name="pixel68" y="-0.0953125"/> + <location name="pixel69" y="-0.09375"/> + <location name="pixel70" y="-0.0921875"/> + <location name="pixel71" y="-0.090625"/> + <location name="pixel72" y="-0.0890625"/> + <location name="pixel73" y="-0.0875"/> + <location name="pixel74" y="-0.0859375"/> + <location name="pixel75" y="-0.084375"/> + <location name="pixel76" y="-0.0828125"/> + <location name="pixel77" y="-0.08125"/> + <location name="pixel78" y="-0.0796875"/> + <location name="pixel79" y="-0.078125"/> + <location name="pixel80" y="-0.0765625"/> + <location name="pixel81" y="-0.075"/> + <location name="pixel82" y="-0.0734375"/> + <location name="pixel83" y="-0.071875"/> + <location name="pixel84" y="-0.0703125"/> + <location name="pixel85" y="-0.06875"/> + <location name="pixel86" y="-0.0671875"/> + <location name="pixel87" y="-0.065625"/> + <location name="pixel88" y="-0.0640625"/> + <location name="pixel89" y="-0.0625"/> + <location name="pixel90" y="-0.0609375"/> + <location name="pixel91" y="-0.059375"/> + <location name="pixel92" y="-0.0578125"/> + <location name="pixel93" y="-0.05625"/> + <location name="pixel94" y="-0.0546875"/> + <location name="pixel95" y="-0.053125"/> + <location name="pixel96" y="-0.0515625"/> + <location name="pixel97" y="-0.05"/> + <location name="pixel98" y="-0.0484375"/> + <location name="pixel99" y="-0.046875"/> + <location name="pixel100" y="-0.0453125"/> + <location name="pixel101" y="-0.04375"/> + <location name="pixel102" y="-0.0421875"/> + <location name="pixel103" y="-0.040625"/> + <location name="pixel104" y="-0.0390625"/> + <location name="pixel105" y="-0.0375"/> + <location name="pixel106" y="-0.0359375"/> + <location name="pixel107" y="-0.034375"/> + <location name="pixel108" y="-0.0328125"/> + <location name="pixel109" y="-0.03125"/> + <location name="pixel110" y="-0.0296875"/> + <location name="pixel111" y="-0.028125"/> + <location name="pixel112" y="-0.0265625"/> + <location name="pixel113" y="-0.025"/> + <location name="pixel114" y="-0.0234375"/> + <location name="pixel115" y="-0.021875"/> + <location name="pixel116" y="-0.0203125"/> + <location name="pixel117" y="-0.01875"/> + <location name="pixel118" y="-0.0171875"/> + <location name="pixel119" y="-0.015625"/> + <location name="pixel120" y="-0.0140625"/> + <location name="pixel121" y="-0.0125"/> + <location name="pixel122" y="-0.0109375"/> + <location name="pixel123" y="-0.009375"/> + <location name="pixel124" y="-0.0078125"/> + <location name="pixel125" y="-0.00625"/> + <location name="pixel126" y="-0.0046875"/> + <location name="pixel127" y="-0.003125"/> + <location name="pixel128" y="-0.0015625"/> + <location name="pixel129" y="0.0"/> + <location name="pixel130" y="0.0015625"/> + <location name="pixel131" y="0.003125"/> + <location name="pixel132" y="0.0046875"/> + <location name="pixel133" y="0.00625"/> + <location name="pixel134" y="0.0078125"/> + <location name="pixel135" y="0.009375"/> + <location name="pixel136" y="0.0109375"/> + <location name="pixel137" y="0.0125"/> + <location name="pixel138" y="0.0140625"/> + <location name="pixel139" y="0.015625"/> + <location name="pixel140" y="0.0171875"/> + <location name="pixel141" y="0.01875"/> + <location name="pixel142" y="0.0203125"/> + <location name="pixel143" y="0.021875"/> + <location name="pixel144" y="0.0234375"/> + <location name="pixel145" y="0.025"/> + <location name="pixel146" y="0.0265625"/> + <location name="pixel147" y="0.028125"/> + <location name="pixel148" y="0.0296875"/> + <location name="pixel149" y="0.03125"/> + <location name="pixel150" y="0.0328125"/> + <location name="pixel151" y="0.034375"/> + <location name="pixel152" y="0.0359375"/> + <location name="pixel153" y="0.0375"/> + <location name="pixel154" y="0.0390625"/> + <location name="pixel155" y="0.040625"/> + <location name="pixel156" y="0.0421875"/> + <location name="pixel157" y="0.04375"/> + <location name="pixel158" y="0.0453125"/> + <location name="pixel159" y="0.046875"/> + <location name="pixel160" y="0.0484375"/> + <location name="pixel161" y="0.05"/> + <location name="pixel162" y="0.0515625"/> + <location name="pixel163" y="0.053125"/> + <location name="pixel164" y="0.0546875"/> + <location name="pixel165" y="0.05625"/> + <location name="pixel166" y="0.0578125"/> + <location name="pixel167" y="0.059375"/> + <location name="pixel168" y="0.0609375"/> + <location name="pixel169" y="0.0625"/> + <location name="pixel170" y="0.0640625"/> + <location name="pixel171" y="0.065625"/> + <location name="pixel172" y="0.0671875"/> + <location name="pixel173" y="0.06875"/> + <location name="pixel174" y="0.0703125"/> + <location name="pixel175" y="0.071875"/> + <location name="pixel176" y="0.0734375"/> + <location name="pixel177" y="0.075"/> + <location name="pixel178" y="0.0765625"/> + <location name="pixel179" y="0.078125"/> + <location name="pixel180" y="0.0796875"/> + <location name="pixel181" y="0.08125"/> + <location name="pixel182" y="0.0828125"/> + <location name="pixel183" y="0.084375"/> + <location name="pixel184" y="0.0859375"/> + <location name="pixel185" y="0.0875"/> + <location name="pixel186" y="0.0890625"/> + <location name="pixel187" y="0.090625"/> + <location name="pixel188" y="0.0921875"/> + <location name="pixel189" y="0.09375"/> + <location name="pixel190" y="0.0953125"/> + <location name="pixel191" y="0.096875"/> + <location name="pixel192" y="0.0984375"/> + <location name="pixel193" y="0.1"/> + <location name="pixel194" y="0.1015625"/> + <location name="pixel195" y="0.103125"/> + <location name="pixel196" y="0.1046875"/> + <location name="pixel197" y="0.10625"/> + <location name="pixel198" y="0.1078125"/> + <location name="pixel199" y="0.109375"/> + <location name="pixel200" y="0.1109375"/> + <location name="pixel201" y="0.1125"/> + <location name="pixel202" y="0.1140625"/> + <location name="pixel203" y="0.115625"/> + <location name="pixel204" y="0.1171875"/> + <location name="pixel205" y="0.11875"/> + <location name="pixel206" y="0.1203125"/> + <location name="pixel207" y="0.121875"/> + <location name="pixel208" y="0.1234375"/> + <location name="pixel209" y="0.125"/> + <location name="pixel210" y="0.1265625"/> + <location name="pixel211" y="0.128125"/> + <location name="pixel212" y="0.1296875"/> + <location name="pixel213" y="0.13125"/> + <location name="pixel214" y="0.1328125"/> + <location name="pixel215" y="0.134375"/> + <location name="pixel216" y="0.1359375"/> + <location name="pixel217" y="0.1375"/> + <location name="pixel218" y="0.1390625"/> + <location name="pixel219" y="0.140625"/> + <location name="pixel220" y="0.1421875"/> + <location name="pixel221" y="0.14375"/> + <location name="pixel222" y="0.1453125"/> + <location name="pixel223" y="0.146875"/> + <location name="pixel224" y="0.1484375"/> + <location name="pixel225" y="0.15"/> + <location name="pixel226" y="0.1515625"/> + <location name="pixel227" y="0.153125"/> + <location name="pixel228" y="0.1546875"/> + <location name="pixel229" y="0.15625"/> + <location name="pixel230" y="0.1578125"/> + <location name="pixel231" y="0.159375"/> + <location name="pixel232" y="0.1609375"/> + <location name="pixel233" y="0.1625"/> + <location name="pixel234" y="0.1640625"/> + <location name="pixel235" y="0.165625"/> + <location name="pixel236" y="0.1671875"/> + <location name="pixel237" y="0.16875"/> + <location name="pixel238" y="0.1703125"/> + <location name="pixel239" y="0.171875"/> + <location name="pixel240" y="0.1734375"/> + <location name="pixel241" y="0.175"/> + <location name="pixel242" y="0.1765625"/> + <location name="pixel243" y="0.178125"/> + <location name="pixel244" y="0.1796875"/> + <location name="pixel245" y="0.18125"/> + <location name="pixel246" y="0.1828125"/> + <location name="pixel247" y="0.184375"/> + <location name="pixel248" y="0.1859375"/> + <location name="pixel249" y="0.1875"/> + <location name="pixel250" y="0.1890625"/> + <location name="pixel251" y="0.190625"/> + <location name="pixel252" y="0.1921875"/> + <location name="pixel253" y="0.19375"/> + <location name="pixel254" y="0.1953125"/> + <location name="pixel255" y="0.196875"/> + <location name="pixel256" y="0.1984375"/> + </component> + </type> + <!--PIXEL FOR STANDARD 2m 128 PIXEL TUBE--> + <type is="detector" name="pixel"> + <cylinder id="cyl-approx"> + <centre-of-bottom-base p="0.0" r="0.0" t="0.0"/> + <axis x="0.0" y="1.0" z="0.0"/> + <radius val="0.006"/> + <height val="0.0015625"/> + </cylinder> + <algebra val="cyl-approx"/> + </type> + <!--MONITOR SHAPE--> + <!--FIXME: Do something real here.--> + <type is="monitor" name="monitor"> + <cylinder id="cyl-approx"> + <centre-of-bottom-base p="0.0" r="0.0" t="0.0"/> + <axis x="0.0" y="0.0" z="1.0"/> + <radius val="0.01"/> + <height val="0.03"/> + </cylinder> + <algebra val="cyl-approx"/> + </type> + <!--DETECTOR IDs--> + <idlist idname="detectors"> + <id end="2047" start="0"/> + </idlist> + <!--DETECTOR PARAMETERS--> + <component-link name="detectors"> + <parameter name="tube_pressure"> + <value units="atm" val="6.0"/> + </parameter> + <parameter name="tube_thickness"> + <value units="metre" val="0.0008"/> + </parameter> + <parameter name="tube_temperature"> + <value units="K" val="290.0"/> + </parameter> + </component-link> +</instrument> diff --git a/instrument/NOW4_Parameters.xml b/instrument/NOW4_Parameters.xml new file mode 100644 index 0000000000000000000000000000000000000000..77935f2bfe0b9173a017b941e8ea3a3003de2c27 --- /dev/null +++ b/instrument/NOW4_Parameters.xml @@ -0,0 +1,198 @@ +<?xml version="1.0" encoding="UTF-8"?> +<parameter-file instrument="NOW4" valid-from="2011-10-10T00:00:00"> + + <component-link name="NOW4"> + + <parameter name="deltaE-mode" type="string"> + <value val="direct" /> + </parameter> + + <parameter name="ei-mon1-spec"> + <value val="2" /> + </parameter> + + <parameter name="ei-mon2-spec"> + <value val="3" /> + </parameter> + + <!-- TODO: Update with real vanadium mass --> + <parameter name="vanadium-mass"> + <value val="-1" /> + </parameter> + + <parameter name="bkgd-range-min"> + <value val="30000" /> + </parameter> + + <parameter name="bkgd-range-max"> + <value val="31500" /> + </parameter> + + <parameter name="scale-factor"> + <value val="1.0" /> + </parameter> + + <parameter name="monovan-integr-min"> + <value val="-1" /> + </parameter> + + <parameter name="monovan-integr-max"> + <value val="1" /> + </parameter> + + <!-- Diagnostic test defaults --> + + <!-- Absolute lo threshold for vanadium diag (tiny) --> + <parameter name="diag_tiny"> + <value val="1e-10"/> + </parameter> + + <!-- Absolute hi threshold for vanadium diag (huge) --> + <parameter name="diag_huge"> + <value val="1e10"/> + </parameter> + + <!-- Remove zeroes in background (s_zero)--> + <parameter name="diag_samp_zero"> + <value val="0.0"/> + </parameter> + + <!-- Fraction of median to consider counting low for the white beam diag (sv_lo)--> + <parameter name="diag_samp_lo"> + <value val="0.0"/> + </parameter> + + <!-- Fraction of median to consider counting high for the white beam diag (sv_hi)--> + <parameter name="diag_samp_hi"> + <value val="5.0"/> + </parameter> + + <!-- Error criterion as a multiple of error bar for background (sv_sig) --> + <parameter name="diag_samp_sig"> + <value val="3.3"/> + </parameter> + + <!-- Lower bound defining outliers as fraction of median value (v_out_lo)--> + <parameter name="diag_van_out_lo"> + <value val="0.01"/> + </parameter> + + <!-- Upper bound defining outliers as fraction of median value (v_out_hi)--> + <parameter name="diag_van_out_hi"> + <value val="100."/> + </parameter> + + <!-- Fraction of median to consider counting low for the white beam diag (vv_lo)--> + <parameter name="diag_van_lo"> + <value val="0.1"/> + </parameter> + + <!-- Fraction of median to consider counting high for the white beam diag (vv_hi)--> + <parameter name="diag_van_hi"> + <value val="1.5"/> + </parameter> + + <!-- Error criterion as a multiple of error bar for background (vv_sig) --> + <parameter name="diag_van_sig"> + <value val="3.3"/> + </parameter> + + <!-- True if background is to be checked --> + <parameter name="check_background"> + <value val="0.0"/> + </parameter> + + <!-- True if the bleed tests should be run --> + <parameter name="diag_bleed_test"> + <value val="0.0"/> + </parameter> + + <!-- Variation for ratio test with second white beam --> + <parameter name="diag_variation"> + <value val="1.1"/> + </parameter> + + <!-- Absolute units conversion average --> + + <parameter name="monovan_lo_bound"> + <value val="0.01" /> + </parameter> + + <parameter name="monovan_hi_bound"> + <value val="100" /> + </parameter> + + <parameter name="monovan_lo_frac"> + <value val="0.8" /> + </parameter> + + <parameter name="monovan_hi_frac"> + <value val="1.2" /> + </parameter> + + <!-- All the following parameters are taken directly from the MARI definition + and are WRONG! They are only here for now to get things working --> + + <parameter name="wb-scale-factor"> + <value val="1.0" /> + </parameter> + + <parameter name="wb-integr-min"> + <value val="0.5" /> + </parameter> + + <parameter name="wb-integr-max"> + <value val="80" /> + </parameter> + + <parameter name="norm-mon1-spec"> + <value val="-3" /> + </parameter> + + <parameter name="norm-mon1-min"> + <value val="1000" /> + </parameter> + + <parameter name="norm-mon1-max"> + <value val="2000" /> + </parameter> + + <parameter name="abs-average-min"> + <value val="1e-10" /> + </parameter> + + <parameter name="abs-average-max"> + <value val="1e10" /> + </parameter> + + <parameter name="abs-median-lbound"> + <value val="0.01" /> + </parameter> + + <parameter name="abs-median-ubound"> + <value val="100" /> + </parameter> + + <parameter name="abs-median-lo-frac"> + <value val="0.8" /> + </parameter> + + <parameter name="abs-median-hi-frac"> + <value val="1.2" /> + </parameter> + + <parameter name="abs-median-signif"> + <value val="3.3" /> + </parameter> + + <!-- formula for t0 calculation. See http://muparser.sourceforge.net/mup_features.html#idDef2 for available operators--> + <parameter name="t0_formula" type="string"> + <value val="4.0 + (107.0 / (1.0 + (incidentEnergy / 31.0)^3))" /> + </parameter> + + <parameter name="treat-background-as-events" type="string"> + <value val="yes" /> + </parameter> + + </component-link> +</parameter-file> diff --git a/qt/applications/workbench/workbench/plotting/figuremanager.py b/qt/applications/workbench/workbench/plotting/figuremanager.py index 09cca35cbaaade1d07af4fa142d3351bc9f2da4f..886eb14306c53220d96adeaecc24b711a6522838 100644 --- a/qt/applications/workbench/workbench/plotting/figuremanager.py +++ b/qt/applications/workbench/workbench/plotting/figuremanager.py @@ -122,7 +122,7 @@ class FigureManagerADSObserver(AnalysisDataServiceObserver): continue redraw = redraw | redraw_this if redraw: - self.canvas.draw_idle() + self.canvas.draw() @_catch_exceptions def renameHandle(self, oldName, newName): diff --git a/qt/applications/workbench/workbench/plotting/propertiesdialog.py b/qt/applications/workbench/workbench/plotting/propertiesdialog.py index f06f001022207591a275dfac5b7ddacc5cc86e81..4d2b03e151a2830c9c7acf54c6947a69f7169f93 100644 --- a/qt/applications/workbench/workbench/plotting/propertiesdialog.py +++ b/qt/applications/workbench/workbench/plotting/propertiesdialog.py @@ -12,9 +12,13 @@ from __future__ import (absolute_import, unicode_literals) # std imports # 3rdparty imports +from mantid.kernel import logger from mantidqt.plotting.figuretype import FigureType, figure_type from mantidqt.utils.qt import load_ui from matplotlib.collections import QuadMesh +from matplotlib.colors import LogNorm +from matplotlib.ticker import LogLocator +from numpy import arange from qtpy.QtGui import QDoubleValidator, QIcon from qtpy.QtWidgets import QDialog, QWidget @@ -191,6 +195,14 @@ class ColorbarAxisEditor(AxisEditor): super(ColorbarAxisEditor, self).changes_accepted() cb = self.images[0] cb.set_clim(self.limit_min, self.limit_max) + if isinstance(self.images[0].norm, LogNorm): + locator = LogLocator(subs=arange(1, 10)) + if locator.tick_values(vmin=self.limit_min, vmax=self.limit_max).size == 0: + locator = LogLocator() + logger.warning("Minor ticks on colorbar scale cannot be shown " + "as the range between min value and max value is too large") + cb.colorbar.locator = locator + cb.colorbar.update_ticks() def create_model(self): memento = AxisEditorModel() diff --git a/qt/applications/workbench/workbench/projectrecovery/test/test_projectrecovery.py b/qt/applications/workbench/workbench/projectrecovery/test/test_projectrecovery.py index fc1935a8f4d4b86d87672b3ddeb001e73dcd30d3..d6292367fe3da27a0154119e505d46c83cb45ee7 100644 --- a/qt/applications/workbench/workbench/projectrecovery/test/test_projectrecovery.py +++ b/qt/applications/workbench/workbench/projectrecovery/test/test_projectrecovery.py @@ -16,6 +16,7 @@ import sys import tempfile import time import unittest +import datetime from mantid.api import AnalysisDataService as ADS from mantid.kernel import ConfigService @@ -36,6 +37,21 @@ class ProjectRecoveryTest(unittest.TestCase): self.multifileinterpreter = mock.MagicMock() self.pr = ProjectRecovery(self.multifileinterpreter) self.working_directory = tempfile.mkdtemp() + # Make sure there is actually a different modified time on the files + self.firstPath = tempfile.mkdtemp() + self.secondPath = tempfile.mkdtemp() + self.thirdPath = tempfile.mkdtemp() + + # offset the date modified stamps in the past in case future modified dates + # cause any problems + finalFileDateTime = datetime.datetime.fromtimestamp(os.path.getmtime(self.thirdPath)) + + dateoffset = finalFileDateTime - datetime.timedelta(hours=2) + modTime = time.mktime(dateoffset.timetuple()) + os.utime(self.firstPath, (modTime, modTime)) + dateoffset = finalFileDateTime - datetime.timedelta(hours=1) + modTime = time.mktime(dateoffset.timetuple()) + os.utime(self.secondPath, (modTime, modTime)) def tearDown(self): ADS.clear() @@ -45,6 +61,15 @@ class ProjectRecoveryTest(unittest.TestCase): if os.path.exists(self.working_directory): shutil.rmtree(self.working_directory) + if os.path.exists(self.firstPath): + shutil.rmtree(self.firstPath) + + if os.path.exists(self.secondPath): + shutil.rmtree(self.secondPath) + + if os.path.exists(self.thirdPath): + shutil.rmtree(self.thirdPath) + def test_constructor_settings_are_set(self): # Test the paths set in the constructor that are generated. self.assertEqual(self.pr.recovery_directory, @@ -101,16 +126,10 @@ class ProjectRecoveryTest(unittest.TestCase): @unittest.skipIf(is_macOS(), "Can be unreliable on macOS and is a test of logic not OS capability") def test_sort_paths_by_last_modified(self): - # Make sure there is actually a different modified time on the files by using sleeps - first = tempfile.mkdtemp() - time.sleep(0.5) - second = tempfile.mkdtemp() - time.sleep(0.5) - third = tempfile.mkdtemp() - paths = [second, third, first] + paths = [self.secondPath, self.thirdPath, self.firstPath] paths = self.pr.sort_by_last_modified(paths) - self.assertListEqual(paths, [first, second, third]) + self.assertListEqual(paths, [self.firstPath, self.secondPath, self.thirdPath]) def test_get_pid_folder_to_be_used_to_load_a_checkpoint_from(self): self.pr._make_process_from_pid = mock.MagicMock() diff --git a/qt/python/mantidqt/project/test/test_project.py b/qt/python/mantidqt/project/test/test_project.py index 1c029e00fd9761247474f685d73a336e08daf3c5..61a6cc60e7b8ec412b4678b49870b85f126f57d8 100644 --- a/qt/python/mantidqt/project/test/test_project.py +++ b/qt/python/mantidqt/project/test/test_project.py @@ -11,6 +11,8 @@ from __future__ import (absolute_import, division, print_function, unicode_liter import os import tempfile import unittest +import shutil +import warnings from qtpy.QtWidgets import QMessageBox @@ -37,6 +39,8 @@ def _raise(exception): @start_qapplication class ProjectTest(unittest.TestCase): + _folders_to_remove = set() + def setUp(self): self.fgfm = FakeGlobalFigureManager() self.fgfm.figs = [] @@ -44,6 +48,13 @@ class ProjectTest(unittest.TestCase): def tearDown(self): ADS.clear() + for folder in self._folders_to_remove: + try: + shutil.rmtree(folder) + except OSError as exc: + warnings.warn("Could not remove folder at \"{}\"\n" + "Error message:\n{}".format(folder, exc)) + self._folders_to_remove.clear() def test_save_calls_save_as_when_last_location_is_not_none(self): self.project.save_as = mock.MagicMock() @@ -56,7 +67,9 @@ class ProjectTest(unittest.TestCase): self.assertEqual(self.project.save_as.call_count, 0) def test_save_saves_project_successfully(self): - working_file = os.path.join(tempfile.mkdtemp(), "temp" + ".mtdproj") + temp_file_path = tempfile.mkdtemp() + self._folders_to_remove.add(temp_file_path) + working_file = os.path.join(temp_file_path, "temp" + ".mtdproj") self.project.last_project_location = working_file CreateSampleWorkspace(OutputWorkspace="ws1") self.project._offer_overwriting_gui = mock.MagicMock(return_value=QMessageBox.Yes) @@ -70,7 +83,9 @@ class ProjectTest(unittest.TestCase): self.assertEqual(self.project._offer_overwriting_gui.call_count, 1) def test_save_as_saves_project_successfully(self): - working_file = os.path.join(tempfile.mkdtemp(), "temp" + ".mtdproj") + temp_file_path = tempfile.mkdtemp() + self._folders_to_remove.add(temp_file_path) + working_file = os.path.join(temp_file_path, "temp" + ".mtdproj") working_directory = os.path.dirname(working_file) self.project._save_file_dialog = mock.MagicMock(return_value=working_file) CreateSampleWorkspace(OutputWorkspace="ws1") @@ -86,6 +101,7 @@ class ProjectTest(unittest.TestCase): def test_load_calls_loads_successfully(self): working_directory = tempfile.mkdtemp() + self._folders_to_remove.add(working_directory) return_value_for_load = os.path.join(working_directory, os.path.basename(working_directory) + ".mtdproj") self.project._save_file_dialog = mock.MagicMock(return_value=return_value_for_load) CreateSampleWorkspace(OutputWorkspace="ws1") diff --git a/qt/python/mantidqt/widgets/plotconfigdialog/imagestabwidget/images_tab.ui b/qt/python/mantidqt/widgets/plotconfigdialog/imagestabwidget/images_tab.ui index abd1fe71e7e740ee4b27cfd4870e9855618f3944..e88024bbb186179b20ad062ddb4fe3bedd96b328 100644 --- a/qt/python/mantidqt/widgets/plotconfigdialog/imagestabwidget/images_tab.ui +++ b/qt/python/mantidqt/widgets/plotconfigdialog/imagestabwidget/images_tab.ui @@ -353,7 +353,7 @@ </item> </layout> </item> - <item row="2" column="0"> + <item row="3" column="0"> <spacer name="verticalSpacer_3"> <property name="orientation"> <enum>Qt::Vertical</enum> @@ -366,6 +366,28 @@ </property> </spacer> </item> + <item row="2" column="0"> + <widget class="QLabel" name="max_min_value_warning"> + <property name="enabled"> + <bool>true</bool> + </property> + <property name="font"> + <font> + <weight>75</weight> + <bold>true</bold> + </font> + </property> + <property name="text"> + <string><html><head/><body><p><span style=" color:#ff0000;">Text</span></p></body></html></string> + </property> + <property name="scaledContents"> + <bool>true</bool> + </property> + <property name="alignment"> + <set>Qt::AlignCenter</set> + </property> + </widget> + </item> </layout> </widget> <tabstops> diff --git a/qt/python/mantidqt/widgets/plotconfigdialog/imagestabwidget/presenter.py b/qt/python/mantidqt/widgets/plotconfigdialog/imagestabwidget/presenter.py index 5c34c27356d122082c040f46c9b9396b4097af1b..549462d5529d809558daf0a17384e7251a337d36 100644 --- a/qt/python/mantidqt/widgets/plotconfigdialog/imagestabwidget/presenter.py +++ b/qt/python/mantidqt/widgets/plotconfigdialog/imagestabwidget/presenter.py @@ -7,7 +7,12 @@ # This file is part of the mantid workbench. from __future__ import (absolute_import, unicode_literals) +from numpy import arange +from matplotlib.colors import LogNorm +from matplotlib.ticker import LogLocator + +from mantid.kernel import logger from mantidqt.utils.qt import block_signals from mantidqt.widgets.plotconfigdialog import generate_ax_name, get_images_from_fig from mantidqt.widgets.plotconfigdialog.imagestabwidget import ImageProperties @@ -46,7 +51,28 @@ class ImagesTabWidgetPresenter: if current_axis_images.colorbar: current_axis_images.colorbar.remove() - self.fig.colorbar(image) + locator = None + if SCALES[props.scale] == LogNorm: + locator = LogLocator(subs=arange(1, 10)) + if locator.tick_values(vmin=props.vmin, vmax=props.vmax).size == 0: + locator = LogLocator() + logger.warning("Minor ticks on colorbar scale cannot be shown " + "as the range between min value and max value is too large") + + self.fig.colorbar(image, ticks=locator) + + if props.vmin > props.vmax: + self.view.max_min_value_warning.setVisible(True) + self.view.max_min_value_warning.setText("<html> <head/> <body> <p> <span style=\"color:#ff0000;\">Max " + "value is less than min value so they have been " + "swapped.</span></p></body></html>") + elif props.vmin == props.vmax: + self.view.max_min_value_warning.setVisible(True) + self.view.max_min_value_warning.setText("<html><head/><body><p><span style=\"color:#ff0000;\">Min and max " + "value are the same so they have been " + "adjusted.</span></p></body></html>") + else: + self.view.max_min_value_warning.setVisible(False) def get_selected_image(self): return self.image_names_dict[self.view.get_selected_image_name()] diff --git a/qt/python/mantidqt/widgets/plotconfigdialog/imagestabwidget/view.py b/qt/python/mantidqt/widgets/plotconfigdialog/imagestabwidget/view.py index 5ef11db78d94a408abaff49f94e0e2d7c4ff4de7..d92f22869cb0dc52942d5ff0844c603feea62483 100644 --- a/qt/python/mantidqt/widgets/plotconfigdialog/imagestabwidget/view.py +++ b/qt/python/mantidqt/widgets/plotconfigdialog/imagestabwidget/view.py @@ -47,19 +47,7 @@ class ImagesTabWidgetView(QWidget): spin_box = getattr(self, '%s_value_spin_box' % bound) spin_box.setRange(0, np.finfo(np.float32).max) - # Make sure min scale value always less than max - self.min_value_spin_box.valueChanged.connect(self._check_max_min_consistency_min_changed) - self.max_value_spin_box.valueChanged.connect(self._check_max_min_consistency_max_changed) - - def _check_max_min_consistency_min_changed(self): - """Check min value smaller than max value after min_value changed""" - if self.get_min_value() >= self.get_max_value(): - self.set_max_value(self.get_min_value() + 0.01) - - def _check_max_min_consistency_max_changed(self): - """Check min value smaller than max value after max value changed""" - if self.get_min_value() >= self.get_max_value(): - self.set_min_value(self.get_max_value() - 0.01) + self.max_min_value_warning.setVisible(False) def _populate_colormap_combo_box(self): for cmap_name in get_colormap_names(): diff --git a/qt/python/mantidqt/widgets/sliceviewer/model.py b/qt/python/mantidqt/widgets/sliceviewer/model.py index 7b3587a61e8308074be5d1524059500c81abe558..effc709f9aa629d9a2bd4309dfee1285bc879677 100644 --- a/qt/python/mantidqt/widgets/sliceviewer/model.py +++ b/qt/python/mantidqt/widgets/sliceviewer/model.py @@ -109,3 +109,8 @@ class SliceViewerModel(object): return WS_TYPE.MDE else: raise ValueError("Unsupported workspace type") + + def can_normalize_workspace(self): + if self.get_ws_type() == WS_TYPE.MATRIX and not self._get_ws().isDistribution(): + return True + return False diff --git a/qt/python/mantidqt/widgets/sliceviewer/presenter.py b/qt/python/mantidqt/widgets/sliceviewer/presenter.py index 65306e21c55291a545c9d6018e329a4efb682ada..8ab3887b1457c58a29a76b26d449e36e2f690944 100644 --- a/qt/python/mantidqt/widgets/sliceviewer/presenter.py +++ b/qt/python/mantidqt/widgets/sliceviewer/presenter.py @@ -10,6 +10,7 @@ from __future__ import (absolute_import, division, print_function) from .model import SliceViewerModel, WS_TYPE from .view import SliceViewerView +import mantid.api class SliceViewer(object): @@ -27,7 +28,13 @@ class SliceViewer(object): self.new_plot = self.new_plot_matrix self.update_plot_data = self.update_plot_data_matrix - self.view = view if view else SliceViewerView(self, self.model.get_dimensions_info(), parent) + self.normalization = mantid.api.MDNormalization.NoNormalization + + self.view = view if view else SliceViewerView(self, self.model.get_dimensions_info(), + self.model.can_normalize_workspace(), parent) + if self.model.can_normalize_workspace(): + self.view.norm_opts.currentTextChanged.connect(self.normalization_changed) + self.view.set_normalization(ws) self.new_plot() @@ -39,7 +46,7 @@ class SliceViewer(object): bin_params=self.view.dimensions.get_bin_params())) def new_plot_matrix(self): - self.view.plot_matrix(self.model.get_ws()) + self.view.plot_matrix(self.model.get_ws(), normalize=self.normalization) def update_plot_data_MDH(self): self.view.update_plot_data(self.model.get_data(self.view.dimensions.get_slicepoint(), self.view.dimensions.transpose)) @@ -56,3 +63,10 @@ class SliceViewer(object): def line_plots(self): self.view.create_axes() self.new_plot() + + def normalization_changed(self, norm_type): + if norm_type == "By bin width": + self.normalization = mantid.api.MDNormalization.VolumeNormalization + else: + self.normalization = mantid.api.MDNormalization.NoNormalization + self.new_plot() diff --git a/qt/python/mantidqt/widgets/sliceviewer/samplingimage.py b/qt/python/mantidqt/widgets/sliceviewer/samplingimage.py index ae4168c6897716e88e222c19a745ed872dc2fd4f..64ef0658a422e31a4dec3cdf4357998da49aa8ee 100644 --- a/qt/python/mantidqt/widgets/sliceviewer/samplingimage.py +++ b/qt/python/mantidqt/widgets/sliceviewer/samplingimage.py @@ -16,6 +16,7 @@ class SamplingImage(mimage.AxesImage): filternorm=1, filterrad=4.0, resample=False, + normalize=mantid.api.MDNormalization.NoNormalization, **kwargs): super(SamplingImage, self).__init__( ax, @@ -30,6 +31,7 @@ class SamplingImage(mimage.AxesImage): **kwargs) self.ws = workspace self.transpose = transpose + self.normalization = normalize def _xlim_changed(self, ax): if self._update_extent(): @@ -56,7 +58,7 @@ class SamplingImage(mimage.AxesImage): xy = np.vstack((Y.ravel(),X.ravel())).T else: xy = np.vstack((X.ravel(),Y.ravel())).T - data = self.ws.getSignalAtCoord(xy, mantid.api.MDNormalization.NoNormalization).reshape(X.shape) + data = self.ws.getSignalAtCoord(xy, self.normalization).reshape(X.shape) self.set_data(data) def _update_extent(self): diff --git a/qt/python/mantidqt/widgets/sliceviewer/test/test_sliceviewer_model.py b/qt/python/mantidqt/widgets/sliceviewer/test/test_sliceviewer_model.py index 9d3d1d7fd4cc4919784fe6311075ede56aac92c8..c0b253084ee51f590c252be4db759d919dbb2145 100644 --- a/qt/python/mantidqt/widgets/sliceviewer/test/test_sliceviewer_model.py +++ b/qt/python/mantidqt/widgets/sliceviewer/test/test_sliceviewer_model.py @@ -160,6 +160,28 @@ class SliceViewerModelTest(unittest.TestCase): self.assertEqual(dim_info['units'], 'meV') self.assertEqual(dim_info['type'], 'MATRIX') + def test_matrix_workspace_can_be_normalized_if_not_a_distribution(self): + ws2d = CreateWorkspace(DataX=[10, 20, 30, 10, 20, 30], + DataY=[2, 3, 4, 5], + DataE=[1, 2, 3, 4], + NSpec=2, + Distribution=False, + OutputWorkspace='ws2d') + model = SliceViewerModel(ws2d) + self.assertTrue(model.can_normalize_workspace()) + + def test_matrix_workspace_cannot_be_normalized_if_a_distribution(self): + model = SliceViewerModel(self.ws2d_histo) + self.assertFalse(model.can_normalize_workspace()) + + def test_MD_workspaces_cannot_be_normalized(self): + model = SliceViewerModel(self.ws_MD_3D) + self.assertFalse(model.can_normalize_workspace()) + + def test_MDE_workspaces_cannot_be_normalized(self): + model = SliceViewerModel(self.ws_MDE_3D) + self.assertFalse(model.can_normalize_workspace()) + if __name__ == '__main__': unittest.main() diff --git a/qt/python/mantidqt/widgets/sliceviewer/test/test_sliceviewer_presenter.py b/qt/python/mantidqt/widgets/sliceviewer/test/test_sliceviewer_presenter.py index baa182d0e72fdad7eae6c1cc45623771d075d97c..b13c98b0f6d533469e4b80bcba535b270591287d 100644 --- a/qt/python/mantidqt/widgets/sliceviewer/test/test_sliceviewer_presenter.py +++ b/qt/python/mantidqt/widgets/sliceviewer/test/test_sliceviewer_presenter.py @@ -13,6 +13,7 @@ import matplotlib matplotlib.use('Agg') # noqa: E402 import unittest +import mantid.api from mantid.py3compat import mock from mantidqt.widgets.sliceviewer.model import SliceViewerModel, WS_TYPE from mantidqt.widgets.sliceviewer.presenter import SliceViewer @@ -24,6 +25,7 @@ class SliceViewerTest(unittest.TestCase): def setUp(self): self.view = mock.Mock(spec=SliceViewerView) self.view.dimensions = mock.Mock() + self.view.norm_opts = mock.Mock() self.model = mock.Mock(spec=SliceViewerModel) self.model.get_ws = mock.Mock() @@ -108,6 +110,15 @@ class SliceViewerTest(unittest.TestCase): self.assertEqual(self.view.dimensions.get_slicepoint.call_count, 0) self.assertEqual(self.view.plot_matrix.call_count, 1) + def test_normalization_change_set_correct_normalization(self): + self.model.get_ws_type = mock.Mock(return_value=WS_TYPE.MATRIX) + self.view.plot_matrix = mock.Mock() + + presenter = SliceViewer(None, model=self.model, view=self.view) + presenter.normalization_changed("By bin width") + self.view.plot_matrix.assert_called_with(self.model.get_ws(), + normalize=mantid.api.MDNormalization.VolumeNormalization) + if __name__ == '__main__': unittest.main() diff --git a/qt/python/mantidqt/widgets/sliceviewer/view.py b/qt/python/mantidqt/widgets/sliceviewer/view.py index 5cc55c19ace1df7640cadd29d221c7c8cf852123..384b790f6c7c9b40d2c395a332d2d177c81083ed 100644 --- a/qt/python/mantidqt/widgets/sliceviewer/view.py +++ b/qt/python/mantidqt/widgets/sliceviewer/view.py @@ -8,21 +8,26 @@ # # from __future__ import (absolute_import, division, print_function) -from qtpy.QtWidgets import QWidget, QVBoxLayout, QHBoxLayout -from qtpy.QtCore import Qt -from mantidqt.MPLwidgets import FigureCanvas -from .toolbar import SliceViewerNavigationToolbar -from matplotlib.figure import Figure from matplotlib import gridspec -from .dimensionwidget import DimensionWidget -from mantidqt.widgets.colorbar.colorbar import ColorbarWidget +from matplotlib.figure import Figure from matplotlib.transforms import Bbox, BboxTransform -import numpy as np + +import mantid.api +from mantid.plots.helperfunctions import get_normalize_by_bin_width +from mantidqt.MPLwidgets import FigureCanvas +from mantidqt.widgets.colorbar.colorbar import ColorbarWidget +from .dimensionwidget import DimensionWidget from .samplingimage import imshow_sampling +from .toolbar import SliceViewerNavigationToolbar + +import numpy as np + +from qtpy.QtCore import Qt +from qtpy.QtWidgets import QComboBox, QLabel, QHBoxLayout, QVBoxLayout, QWidget class SliceViewerView(QWidget): - def __init__(self, presenter, dims_info, parent=None): + def __init__(self, presenter, dims_info, can_normalise, parent=None): super(SliceViewerView, self).__init__(parent) self.presenter = presenter @@ -32,11 +37,27 @@ class SliceViewerView(QWidget): self.setAttribute(Qt.WA_DeleteOnClose, True) self.line_plots = False + self.can_normalise = can_normalise # Dimension widget + self.dimensions_layout = QHBoxLayout() self.dimensions = DimensionWidget(dims_info, parent=self) self.dimensions.dimensionsChanged.connect(self.presenter.new_plot) self.dimensions.valueChanged.connect(self.presenter.update_plot_data) + self.dimensions_layout.addWidget(self.dimensions) + + self.colorbar_layout = QVBoxLayout() + + # normalization options + if can_normalise: + self.norm_layout = QHBoxLayout() + self.norm_label = QLabel("Normalization =") + self.norm_layout.addWidget(self.norm_label) + self.norm_opts = QComboBox() + self.norm_opts.addItems(["None", "By bin width"]) + self.norm_opts.setToolTip("Normalization options") + self.norm_layout.addWidget(self.norm_opts) + self.colorbar_layout.addLayout(self.norm_layout) # MPL figure + colorbar self.mpl_layout = QHBoxLayout() @@ -48,9 +69,10 @@ class SliceViewerView(QWidget): self.create_axes() self.mpl_layout.addWidget(self.canvas) self.colorbar = ColorbarWidget(self) + self.colorbar_layout.addWidget(self.colorbar) self.colorbar.colorbarChanged.connect(self.update_data_clim) self.colorbar.colorbarChanged.connect(self.update_line_plot_limits) - self.mpl_layout.addWidget(self.colorbar) + self.mpl_layout.addLayout(self.colorbar_layout) # MPL toolbar self.mpl_toolbar = SliceViewerNavigationToolbar(self.canvas, self) @@ -60,7 +82,7 @@ class SliceViewerView(QWidget): # layout self.layout = QVBoxLayout(self) - self.layout.addWidget(self.dimensions) + self.layout.addLayout(self.dimensions_layout) self.layout.addWidget(self.mpl_toolbar) self.layout.addLayout(self.mpl_layout, stretch=1) @@ -186,6 +208,16 @@ class SliceViewerView(QWidget): if 0 <= j < arr.shape[1]: self.plot_y_line(np.linspace(ymin, ymax, arr.shape[0]), arr[:,j]) + def set_normalization(self, ws, **kwargs): + normalize_by_bin_width, _ = get_normalize_by_bin_width(ws, self.ax, **kwargs) + is_normalized = normalize_by_bin_width or ws.isDistribution() + if is_normalized: + self.presenter.normalization = mantid.api.MDNormalization.VolumeNormalization + self.norm_opts.setCurrentIndex(1) + else: + self.presenter.normalization = mantid.api.MDNormalization.NoNormalization + self.norm_opts.setCurrentIndex(0) + def closeEvent(self, event): self.deleteLater() super(SliceViewerView, self).closeEvent(event) diff --git a/qt/scientific_interfaces/ISISReflectometry/GUI/Common/CMakeLists.txt b/qt/scientific_interfaces/ISISReflectometry/GUI/Common/CMakeLists.txt index 3a55518e6b50857f73d8212c06989f07939ee8f0..161b95953d28f72ffc3e37fdb159405ccdc9b064 100644 --- a/qt/scientific_interfaces/ISISReflectometry/GUI/Common/CMakeLists.txt +++ b/qt/scientific_interfaces/ISISReflectometry/GUI/Common/CMakeLists.txt @@ -7,6 +7,7 @@ set(COMMON_SRC_FILES # Include files aren't required, but this makes them appear in Visual Studio set( COMMON_INC_FILES + IFileHandler.h IMessageHandler.h IPythonRunner.h IPlotter.h diff --git a/qt/scientific_interfaces/ISISReflectometry/GUI/Common/Decoder.cpp b/qt/scientific_interfaces/ISISReflectometry/GUI/Common/Decoder.cpp index 35ec7b3bd5b14b49014e8e0597682bcf962d5503..95da527dc5fe1b726394bcc43635e7955139df3c 100644 --- a/qt/scientific_interfaces/ISISReflectometry/GUI/Common/Decoder.cpp +++ b/qt/scientific_interfaces/ISISReflectometry/GUI/Common/Decoder.cpp @@ -29,8 +29,10 @@ namespace MantidQt { namespace CustomInterfaces { namespace ISISReflectometry { + BatchPresenter *Decoder::findBatchPresenter(const QtBatchView *gui, - const QtMainWindowView *mwv) { + const IMainWindowView *view) { + auto mwv = dynamic_cast<const QtMainWindowView *>(view); for (auto &ipresenter : mwv->m_presenter->m_batchPresenters) { auto presenter = dynamic_cast<BatchPresenter *>(ipresenter.get()); if (presenter->m_view == gui) { @@ -53,9 +55,8 @@ QWidget *Decoder::decode(const QMap<QString, QVariant> &map, ++batchIndex) { mwv->newBatch(); } - for (auto ii = 0; ii < batches.size(); ++ii) { - decodeBatch(dynamic_cast<QtBatchView *>(mwv->m_batchViews[ii]), mwv, - batches[ii].toMap()); + for (auto batchIndex = 0; batchIndex < batches.size(); ++batchIndex) { + decodeBatch(mwv, batchIndex, batches[batchIndex].toMap()); } return mwv; } @@ -64,13 +65,10 @@ QList<QString> Decoder::tags() { return QList<QString>({QString("ISIS Reflectometry")}); } -void Decoder::decodeBatch(const QtBatchView *gui, const QtMainWindowView *mwv, - const QMap<QString, QVariant> &map, - const BatchPresenter *presenter) { - auto batchPresenter = presenter; - if (!batchPresenter) { - batchPresenter = findBatchPresenter(gui, mwv); - } +void Decoder::decodeBatch(const IMainWindowView *mwv, int batchIndex, + const QMap<QString, QVariant> &map) { + auto gui = dynamic_cast<const QtBatchView *>(mwv->batches()[batchIndex]); + auto batchPresenter = findBatchPresenter(gui, mwv); if (!batchPresenter) { throw std::runtime_error( "BatchPresenter could not be found during decode."); @@ -90,14 +88,6 @@ void Decoder::decodeBatch(const QtBatchView *gui, const QtMainWindowView *mwv, map[QString("runsView")].toMap()); } -void Decoder::decodeBatch(const IBatchPresenter *presenter, - const IMainWindowView *mwv, - const QMap<QString, QVariant> &map) { - auto batchPresenter = dynamic_cast<const BatchPresenter *>(presenter); - decodeBatch(dynamic_cast<QtBatchView *>(batchPresenter->m_view), - dynamic_cast<const QtMainWindowView *>(mwv), map, batchPresenter); -} - void Decoder::decodeExperiment(const QtExperimentView *gui, const QMap<QString, QVariant> &map) { gui->m_ui.analysisModeComboBox->setCurrentIndex( diff --git a/qt/scientific_interfaces/ISISReflectometry/GUI/Common/Decoder.h b/qt/scientific_interfaces/ISISReflectometry/GUI/Common/Decoder.h index 4805cb4294ba2f0eed7bb62ad58cd8f2f21cbca0..dbfb0bb8fa56c62db19ed71a637af1eb798b95d4 100644 --- a/qt/scientific_interfaces/ISISReflectometry/GUI/Common/Decoder.h +++ b/qt/scientific_interfaces/ISISReflectometry/GUI/Common/Decoder.h @@ -9,7 +9,7 @@ #include "../../Common/DllConfig.h" #include "../../Reduction/ReductionOptionsMap.h" -#include "../MainWindow/QtMainWindowView.h" +#include "IDecoder.h" #include "MantidQtWidgets/Common/BaseDecoder.h" #include <QMap> @@ -27,6 +27,7 @@ class ReductionJobs; class ReductionWorkspaces; class Row; class BatchPresenter; +class IMainWindowView; class QtBatchView; class QtExperimentView; class QtInstrumentView; @@ -42,20 +43,18 @@ class RangeInQ; class TransmissionRunPair; class MANTIDQT_ISISREFLECTOMETRY_DLL Decoder - : public MantidQt::API::BaseDecoder { + : public MantidQt::API::BaseDecoder, + public IDecoder { public: QWidget *decode(const QMap<QString, QVariant> &map, const std::string &directory) override; QList<QString> tags() override; - void decodeBatch(const QtBatchView *gui, const QtMainWindowView *mwv, - const QMap<QString, QVariant> &map, - const BatchPresenter *presenter = nullptr); - void decodeBatch(const IBatchPresenter *presenter, const IMainWindowView *mwv, - const QMap<QString, QVariant> &map); + void decodeBatch(const IMainWindowView *mwv, int batchIndex, + const QMap<QString, QVariant> &map) override; private: BatchPresenter *findBatchPresenter(const QtBatchView *gui, - const QtMainWindowView *mww); + const IMainWindowView *mww); void decodeExperiment(const QtExperimentView *gui, const QMap<QString, QVariant> &map); void decodePerAngleDefaults(QTableWidget *tab, @@ -97,4 +96,4 @@ private: } // namespace CustomInterfaces } // namespace MantidQt -#endif /* MANTID_ISISREFLECTOMETRY_DECODER_H */ \ No newline at end of file +#endif /* MANTID_ISISREFLECTOMETRY_DECODER_H */ diff --git a/qt/scientific_interfaces/ISISReflectometry/GUI/Common/Encoder.cpp b/qt/scientific_interfaces/ISISReflectometry/GUI/Common/Encoder.cpp index bc049f16d2ddc0481759e60c591b1d2f633038ec..a03f49e72b3ce1b7aa739629fd99905db7f316ea 100644 --- a/qt/scientific_interfaces/ISISReflectometry/GUI/Common/Encoder.cpp +++ b/qt/scientific_interfaces/ISISReflectometry/GUI/Common/Encoder.cpp @@ -24,8 +24,10 @@ namespace MantidQt { namespace CustomInterfaces { namespace ISISReflectometry { + BatchPresenter *Encoder::findBatchPresenter(const QtBatchView *gui, - const QtMainWindowView *mwv) { + const IMainWindowView *view) { + auto mwv = dynamic_cast<const QtMainWindowView *>(view); for (auto &ipresenter : mwv->m_presenter->m_batchPresenters) { auto presenter = dynamic_cast<BatchPresenter *>(ipresenter.get()); if (presenter->m_view == gui) { @@ -42,9 +44,10 @@ QMap<QString, QVariant> Encoder::encode(const QWidget *gui, QMap<QString, QVariant> map; map.insert(QString("tag"), QVariant(QString("ISIS Reflectometry"))); QList<QVariant> batches; - for (const auto &batchView : mwv->m_batchViews) { - batches.append(QVariant( - encodeBatch(dynamic_cast<const QtBatchView *>(batchView), mwv, true))); + for (size_t batchIndex = 0; batchIndex < mwv->batches().size(); + ++batchIndex) { + batches.append( + QVariant(encodeBatch(mwv, static_cast<int>(batchIndex), true))); } map.insert(QString("batches"), QVariant(batches)); return map; @@ -54,14 +57,10 @@ QList<QString> Encoder::tags() { return QList<QString>({QString("ISIS Reflectometry")}); } -QMap<QString, QVariant> Encoder::encodeBatch(const QtBatchView *gui, - const QtMainWindowView *mwv, - bool projectSave, - const BatchPresenter *presenter) { - auto batchPresenter = presenter; - if (!batchPresenter) { - batchPresenter = findBatchPresenter(gui, mwv); - } +QMap<QString, QVariant> Encoder::encodeBatch(const IMainWindowView *mwv, + int batchIndex, bool projectSave) { + auto gui = dynamic_cast<const QtBatchView *>(mwv->batches()[batchIndex]); + auto batchPresenter = findBatchPresenter(gui, mwv); if (!batchPresenter) { throw std::runtime_error( "BatchPresenter could not be found during encode."); @@ -86,15 +85,6 @@ QMap<QString, QVariant> Encoder::encodeBatch(const QtBatchView *gui, return map; } -QMap<QString, QVariant> Encoder::encodeBatch(const IBatchPresenter *presenter, - const IMainWindowView *mwv, - bool projectSave) { - auto batchPresenter = dynamic_cast<const BatchPresenter *>(presenter); - return encodeBatch(dynamic_cast<QtBatchView *>(batchPresenter->m_view), - dynamic_cast<const QtMainWindowView *>(mwv), projectSave, - batchPresenter); -} - QMap<QString, QVariant> Encoder::encodeRuns(const QtRunsView *gui, bool projectSave, const ReductionJobs *redJobs) { @@ -386,4 +376,4 @@ QMap<QString, QVariant> Encoder::encodeSave(const QtSaveView *gui) { } } // namespace ISISReflectometry } // namespace CustomInterfaces -} // namespace MantidQt \ No newline at end of file +} // namespace MantidQt diff --git a/qt/scientific_interfaces/ISISReflectometry/GUI/Common/Encoder.h b/qt/scientific_interfaces/ISISReflectometry/GUI/Common/Encoder.h index 0fd9832600e4fe8018e502f9c979cce44a15da44..1be278c10de9bf5109e06f974a8bfc30a6f26406 100644 --- a/qt/scientific_interfaces/ISISReflectometry/GUI/Common/Encoder.h +++ b/qt/scientific_interfaces/ISISReflectometry/GUI/Common/Encoder.h @@ -9,7 +9,7 @@ #include "../../Common/DllConfig.h" #include "../../Reduction/ReductionOptionsMap.h" -#include "../MainWindow/QtMainWindowView.h" +#include "IEncoder.h" #include "MantidQtWidgets/Common/BaseEncoder.h" #include <QMap> @@ -26,6 +26,7 @@ class ReductionJobs; class ReductionWorkspaces; class Row; class BatchPresenter; +class IMainWindowView; class QtBatchView; class QtExperimentView; class QtInstrumentView; @@ -41,22 +42,19 @@ class TransmissionRunPair; class QtEventView; class MANTIDQT_ISISREFLECTOMETRY_DLL Encoder - : public MantidQt::API::BaseEncoder { + : public MantidQt::API::BaseEncoder, + public IEncoder { public: QMap<QString, QVariant> encode(const QWidget *window, const std::string &directory) override; QList<QString> tags() override; - QMap<QString, QVariant> - encodeBatch(const QtBatchView *gui, const QtMainWindowView *mwv, - bool projectSave = false, - const BatchPresenter *presenter = nullptr); - QMap<QString, QVariant> encodeBatch(const IBatchPresenter *presenter, - const IMainWindowView *mwv, - bool projectSave = false); + QMap<QString, QVariant> encodeBatch(const IMainWindowView *mwv, + int batchIndex, + bool projectSave = false) override; private: BatchPresenter *findBatchPresenter(const QtBatchView *gui, - const QtMainWindowView *mwv); + const IMainWindowView *mwv); QMap<QString, QVariant> encodeExperiment(const QtExperimentView *gui); QMap<QString, QVariant> encodePerAngleDefaults(const QTableWidget *tab); QList<QVariant> encodePerAngleDefaultsRow(const QTableWidget *tab, @@ -89,4 +87,4 @@ private: } // namespace ISISReflectometry } // namespace CustomInterfaces } // namespace MantidQt -#endif /* MANTID_ISISREFLECTOMETRY_ENCODER_H */ \ No newline at end of file +#endif /* MANTID_ISISREFLECTOMETRY_ENCODER_H */ diff --git a/qt/scientific_interfaces/ISISReflectometry/GUI/Common/IDecoder.h b/qt/scientific_interfaces/ISISReflectometry/GUI/Common/IDecoder.h new file mode 100644 index 0000000000000000000000000000000000000000..10e99b47de11b1367214ba979ce5f0441f0bd45c --- /dev/null +++ b/qt/scientific_interfaces/ISISReflectometry/GUI/Common/IDecoder.h @@ -0,0 +1,34 @@ +// Mantid Repository : https://github.com/mantidproject/mantid +// +// Copyright © 2019 ISIS Rutherford Appleton Laboratory UKRI, +// NScD Oak Ridge National Laboratory, European Spallation Source +// & Institut Laue - Langevin +// SPDX - License - Identifier: GPL - 3.0 + +#ifndef MANTID_ISISREFLECTOMETRY_IDECODER_H +#define MANTID_ISISREFLECTOMETRY_IDECODER_H + +#include <QMap> +#include <QString> +#include <QVariant> + +namespace MantidQt { +namespace CustomInterfaces { +namespace ISISReflectometry { + +class IMainWindowView; + +/** @class IDecoder + +IDecoder is an interface for decoding the contents of the reflectometry +interface from a map +*/ +class IDecoder { +public: + virtual ~IDecoder(){}; + virtual void decodeBatch(const IMainWindowView *mwv, int batchIndex, + const QMap<QString, QVariant> &map) = 0; +}; +} // namespace ISISReflectometry +} // namespace CustomInterfaces +} // namespace MantidQt +#endif diff --git a/qt/scientific_interfaces/ISISReflectometry/GUI/Common/IEncoder.h b/qt/scientific_interfaces/ISISReflectometry/GUI/Common/IEncoder.h new file mode 100644 index 0000000000000000000000000000000000000000..936e02ad308c05c4725cb4267bdd432f825aed78 --- /dev/null +++ b/qt/scientific_interfaces/ISISReflectometry/GUI/Common/IEncoder.h @@ -0,0 +1,35 @@ +// Mantid Repository : https://github.com/mantidproject/mantid +// +// Copyright © 2019 ISIS Rutherford Appleton Laboratory UKRI, +// NScD Oak Ridge National Laboratory, European Spallation Source +// & Institut Laue - Langevin +// SPDX - License - Identifier: GPL - 3.0 + +#ifndef MANTID_ISISREFLECTOMETRY_IENCODER_H +#define MANTID_ISISREFLECTOMETRY_IENCODER_H + +#include <QMap> +#include <QString> +#include <QVariant> + +namespace MantidQt { +namespace CustomInterfaces { +namespace ISISReflectometry { + +class IMainWindowView; + +/** @class IEncoder + +IEncoder is an interface for encoding the contents of the reflectometry +interface into a map +*/ +class IEncoder { +public: + virtual ~IEncoder(){}; + virtual QMap<QString, QVariant> encodeBatch(const IMainWindowView *mwv, + int batchIndex, + bool projectSave = false) = 0; +}; +} // namespace ISISReflectometry +} // namespace CustomInterfaces +} // namespace MantidQt +#endif diff --git a/qt/scientific_interfaces/ISISReflectometry/GUI/Common/IFileHandler.h b/qt/scientific_interfaces/ISISReflectometry/GUI/Common/IFileHandler.h new file mode 100644 index 0000000000000000000000000000000000000000..d4ac5f8b24408ab0644e0886a59c5be7eaa6e40a --- /dev/null +++ b/qt/scientific_interfaces/ISISReflectometry/GUI/Common/IFileHandler.h @@ -0,0 +1,33 @@ +// Mantid Repository : https://github.com/mantidproject/mantid +// +// Copyright © 2019 ISIS Rutherford Appleton Laboratory UKRI, +// NScD Oak Ridge National Laboratory, European Spallation Source +// & Institut Laue - Langevin +// SPDX - License - Identifier: GPL - 3.0 + +#ifndef MANTID_ISISREFLECTOMETRY_IFILEHANDLER_H +#define MANTID_ISISREFLECTOMETRY_IFILEHANDLER_H + +#include <QMap> +#include <QString> +#include <QVariant> +#include <string> + +namespace MantidQt { +namespace CustomInterfaces { +namespace ISISReflectometry { +/** @class IFileHandler + +IFileHandler is an interface for saving/loading files +*/ +class IFileHandler { +public: + virtual ~IFileHandler(){}; + virtual void saveJSONToFile(std::string const &filename, + QMap<QString, QVariant> const &map) = 0; + virtual QMap<QString, QVariant> + loadJSONFromFile(std::string const &filename) = 0; +}; +} // namespace ISISReflectometry +} // namespace CustomInterfaces +} // namespace MantidQt +#endif diff --git a/qt/scientific_interfaces/ISISReflectometry/GUI/Common/IMessageHandler.h b/qt/scientific_interfaces/ISISReflectometry/GUI/Common/IMessageHandler.h index 16bb9666da57b8ae15ad03acc06fd14925ff50f8..7a6d2133db0206e01231d2c2e45803a9aae78e55 100644 --- a/qt/scientific_interfaces/ISISReflectometry/GUI/Common/IMessageHandler.h +++ b/qt/scientific_interfaces/ISISReflectometry/GUI/Common/IMessageHandler.h @@ -24,6 +24,8 @@ public: const std::string &title) = 0; virtual bool askUserYesNo(const std::string &prompt, const std::string &title) = 0; + virtual std::string askUserForSaveFileName(std::string const &filter) = 0; + virtual std::string askUserForLoadFileName(std::string const &filter) = 0; }; } // namespace ISISReflectometry } // namespace CustomInterfaces diff --git a/qt/scientific_interfaces/ISISReflectometry/GUI/MainWindow/MainWindowPresenter.cpp b/qt/scientific_interfaces/ISISReflectometry/GUI/MainWindow/MainWindowPresenter.cpp index 725dec695837ea9d31fd2f5b78dbb87be33a923d..74fdef9292ee5e348feae713ddf250b5392479df 100644 --- a/qt/scientific_interfaces/ISISReflectometry/GUI/MainWindow/MainWindowPresenter.cpp +++ b/qt/scientific_interfaces/ISISReflectometry/GUI/MainWindow/MainWindowPresenter.cpp @@ -6,8 +6,9 @@ // SPDX - License - Identifier: GPL - 3.0 + #include "MainWindowPresenter.h" #include "GUI/Batch/IBatchPresenterFactory.h" -#include "GUI/Common/Decoder.h" -#include "GUI/Common/Encoder.h" +#include "GUI/Common/IDecoder.h" +#include "GUI/Common/IEncoder.h" +#include "GUI/Common/IFileHandler.h" #include "GUI/Common/IMessageHandler.h" #include "GUI/Runs/IRunsPresenter.h" #include "IMainWindowView.h" @@ -16,11 +17,8 @@ #include "MantidKernel/ConfigService.h" #include "MantidQtWidgets/Common/HelpWindow.h" #include "MantidQtWidgets/Common/ISlitCalculator.h" -#include "MantidQtWidgets/Common/QtJSONUtils.h" #include "Reduction/Batch.h" -#include <QFileDialog> - namespace MantidQt { namespace CustomInterfaces { namespace ISISReflectometry { @@ -38,15 +36,22 @@ Mantid::Kernel::Logger g_log("Reflectometry GUI"); * @param view :: [input] The view we are managing * @param messageHandler :: Interface to a class that displays messages to * the user + * @param fileHandler :: Interface to a class that loads/saves files + * @param encoder :: Interface for encoding a batch for saving to file + * @param decoder :: Interface for decoding a batch loaded from file * @param slitCalculator :: Interface to the Slit Calculator dialog * @param batchPresenterFactory :: [input] A factory to create the batches * we will manage */ MainWindowPresenter::MainWindowPresenter( IMainWindowView *view, IMessageHandler *messageHandler, + IFileHandler *fileHandler, std::unique_ptr<IEncoder> encoder, + std::unique_ptr<IDecoder> decoder, std::unique_ptr<ISlitCalculator> slitCalculator, std::unique_ptr<IBatchPresenterFactory> batchPresenterFactory) - : m_view(view), m_messageHandler(messageHandler), m_instrument(), + : m_view(view), m_messageHandler(messageHandler), + m_fileHandler(fileHandler), m_instrument(), m_encoder(std::move(encoder)), + m_decoder(std::move(decoder)), m_slitCalculator(std::move(slitCalculator)), m_batchPresenterFactory(std::move(batchPresenterFactory)) { view->subscribe(this); @@ -182,37 +187,28 @@ void MainWindowPresenter::showHelp() { } void MainWindowPresenter::notifySaveBatchRequested(int tabIndex) { - const QString jsonFilter = QString("JSON (*.json)"); - auto filename = - QFileDialog::getSaveFileName(nullptr, QString(), QString(), jsonFilter, - nullptr, QFileDialog::DontResolveSymlinks); + auto filename = m_messageHandler->askUserForSaveFileName("JSON (*.json)"); if (filename == "") return; - Encoder encoder; - IBatchPresenter *batchPresenter = m_batchPresenters[tabIndex].get(); - auto map = encoder.encodeBatch(batchPresenter, m_view, false); - MantidQt::API::saveJSONToFile(filename, map); + auto map = m_encoder->encodeBatch(m_view, tabIndex, false); + m_fileHandler->saveJSONToFile(filename, map); } void MainWindowPresenter::notifyLoadBatchRequested(int tabIndex) { - const QString jsonFilter = QString("JSON (*.json)"); - auto filename = - QFileDialog::getOpenFileName(nullptr, QString(), QString(), jsonFilter, - nullptr, QFileDialog::DontResolveSymlinks); + auto filename = m_messageHandler->askUserForLoadFileName("JSON (*.json)"); if (filename == "") return; QMap<QString, QVariant> map; try { - map = MantidQt::API::loadJSONFromFile(filename); + map = m_fileHandler->loadJSONFromFile(filename); } catch (const std::runtime_error) { m_messageHandler->giveUserCritical( "Unable to load requested file. Please load a file of " "appropriate format saved from the GUI.", "Error:"); + return; } - IBatchPresenter *batchPresenter = m_batchPresenters[tabIndex].get(); - Decoder decoder; - decoder.decodeBatch(batchPresenter, m_view, map); + m_decoder->decodeBatch(m_view, tabIndex, map); } void MainWindowPresenter::disableSaveAndLoadBatch() { diff --git a/qt/scientific_interfaces/ISISReflectometry/GUI/MainWindow/MainWindowPresenter.h b/qt/scientific_interfaces/ISISReflectometry/GUI/MainWindow/MainWindowPresenter.h index c1070a554bab67ee74248e476c961568e0879d03..814bb0c4e72983d80d98097f82e09ccc842b25ec 100644 --- a/qt/scientific_interfaces/ISISReflectometry/GUI/MainWindow/MainWindowPresenter.h +++ b/qt/scientific_interfaces/ISISReflectometry/GUI/MainWindow/MainWindowPresenter.h @@ -23,7 +23,10 @@ namespace ISISReflectometry { class IBatchPresenterFactory; class IMainWindowView; +class IFileHandler; class IMessageHandler; +class IEncoder; +class IDecoder; /** @class MainWindowPresenter @@ -37,6 +40,8 @@ public: /// Constructor MainWindowPresenter( IMainWindowView *view, IMessageHandler *messageHandler, + IFileHandler *fileHandler, std::unique_ptr<IEncoder> encoder, + std::unique_ptr<IDecoder> decoder, std::unique_ptr<MantidWidgets::ISlitCalculator> slitCalculator, std::unique_ptr<IBatchPresenterFactory> batchPresenterFactory); ~MainWindowPresenter(); @@ -70,10 +75,13 @@ public: protected: IMainWindowView *m_view; IMessageHandler *m_messageHandler; + IFileHandler *m_fileHandler; std::vector<std::unique_ptr<IBatchPresenter>> m_batchPresenters; Mantid::Geometry::Instrument_const_sptr m_instrument; private: + std::unique_ptr<IEncoder> m_encoder; + std::unique_ptr<IDecoder> m_decoder; std::unique_ptr<MantidWidgets::ISlitCalculator> m_slitCalculator; std::unique_ptr<IBatchPresenterFactory> m_batchPresenterFactory; diff --git a/qt/scientific_interfaces/ISISReflectometry/GUI/MainWindow/QtMainWindowView.cpp b/qt/scientific_interfaces/ISISReflectometry/GUI/MainWindow/QtMainWindowView.cpp index e68d8dca0e186db16a430aaab510c656e977e93d..d4efba5b78bc3a03bd4a07a95bf594e5987b1f87 100644 --- a/qt/scientific_interfaces/ISISReflectometry/GUI/MainWindow/QtMainWindowView.cpp +++ b/qt/scientific_interfaces/ISISReflectometry/GUI/MainWindow/QtMainWindowView.cpp @@ -12,7 +12,9 @@ #include "GUI/Common/Encoder.h" #include "GUI/Common/Plotter.h" #include "MantidKernel/UsageService.h" +#include "MantidQtWidgets/Common/QtJSONUtils.h" #include "MantidQtWidgets/Common/SlitCalculator.h" +#include <QFileDialog> #include <QMessageBox> #include <QToolButton> @@ -83,6 +85,7 @@ void QtMainWindowView::initLayout() { instruments, thetaTolerance, std::move(plotter)); auto messageHandler = this; + auto fileHandler = this; auto makeRunsPresenter = RunsPresenterFactory(std::move(makeRunsTablePresenter), thetaTolerance, instruments, messageHandler); @@ -100,7 +103,8 @@ void QtMainWindowView::initLayout() { // Create the presenter auto slitCalculator = std::make_unique<SlitCalculator>(this); m_presenter = std::make_unique<MainWindowPresenter>( - this, messageHandler, std::move(slitCalculator), + this, messageHandler, fileHandler, std::make_unique<Encoder>(), + std::make_unique<Decoder>(), std::move(slitCalculator), std::move(makeBatchPresenter)); m_notifyee->notifyNewBatchRequested(); @@ -211,6 +215,24 @@ bool QtMainWindowView::askUserYesNo(const std::string &prompt, return false; } +std::string +QtMainWindowView::askUserForLoadFileName(std::string const &filter) { + auto filterQString = QString::fromStdString(filter); + auto filename = + QFileDialog::getOpenFileName(nullptr, QString(), QString(), filterQString, + nullptr, QFileDialog::DontResolveSymlinks); + return filename.toStdString(); +} + +std::string +QtMainWindowView::askUserForSaveFileName(std::string const &filter) { + auto filterQString = QString::fromStdString(filter); + auto filename = + QFileDialog::getSaveFileName(nullptr, QString(), QString(), filterQString, + nullptr, QFileDialog::DontResolveSymlinks); + return filename.toStdString(); +} + void QtMainWindowView::disableSaveAndLoadBatch() { m_ui.saveBatch->setEnabled(false); m_ui.loadBatch->setEnabled(false); @@ -220,6 +242,17 @@ void QtMainWindowView::enableSaveAndLoadBatch() { m_ui.saveBatch->setEnabled(true); m_ui.loadBatch->setEnabled(true); } + +void QtMainWindowView::saveJSONToFile(std::string const &filename, + QMap<QString, QVariant> const &map) { + auto filenameQString = QString::fromStdString(filename); + MantidQt::API::saveJSONToFile(filenameQString, map); +} + +QMap<QString, QVariant> +QtMainWindowView::loadJSONFromFile(std::string const &filename) { + return MantidQt::API::loadJSONFromFile(QString::fromStdString(filename)); +} } // namespace ISISReflectometry } // namespace CustomInterfaces } // namespace MantidQt diff --git a/qt/scientific_interfaces/ISISReflectometry/GUI/MainWindow/QtMainWindowView.h b/qt/scientific_interfaces/ISISReflectometry/GUI/MainWindow/QtMainWindowView.h index 550abdcf3587b57735516e51b94461a68ff7da03..8bf66174c18cb5763197530044417a941b563396 100644 --- a/qt/scientific_interfaces/ISISReflectometry/GUI/MainWindow/QtMainWindowView.h +++ b/qt/scientific_interfaces/ISISReflectometry/GUI/MainWindow/QtMainWindowView.h @@ -7,6 +7,7 @@ #ifndef MANTID_ISISREFLECTOMETRY_QTMAINWINDOWVIEW_H #define MANTID_ISISREFLECTOMETRY_QTMAINWINDOWVIEW_H +#include "GUI/Common/IFileHandler.h" #include "GUI/Common/IMessageHandler.h" #include "GUI/Common/IPythonRunner.h" #include "IMainWindowPresenter.h" @@ -31,6 +32,7 @@ class MANTIDQT_ISISREFLECTOMETRY_DLL QtMainWindowView : public MantidQt::API::UserSubWindow, public IMainWindowView, public IMessageHandler, + public IFileHandler, public IPythonRunner { Q_OBJECT public: @@ -38,6 +40,7 @@ public: void subscribe(MainWindowSubscriber *notifyee) override; + // cppcheck-suppress returnTempReference static std::string name() { return "ISIS Reflectometry"; } static QString categoryInfo() { return "Reflectometry"; } std::string runPythonAlgorithm(const std::string &pythonCode) override; @@ -55,10 +58,19 @@ public: const std::string &title) override; bool askUserYesNo(const std::string &prompt, const std::string &title) override; + std::string askUserForLoadFileName(std::string const &filter) override; + std::string askUserForSaveFileName(std::string const &filter) override; void disableSaveAndLoadBatch() override; void enableSaveAndLoadBatch() override; + // TODO Remove Qt types from this interface - conversion should be done in + // QtJSONUtils if possible + void saveJSONToFile(std::string const &filename, + QMap<QString, QVariant> const &map) override; + QMap<QString, QVariant> + loadJSONFromFile(std::string const &filename) override; + public slots: void helpPressed(); void onTabCloseRequested(int tabIndex); diff --git a/qt/scientific_interfaces/Indirect/IndirectFitPropertyBrowser.cpp b/qt/scientific_interfaces/Indirect/IndirectFitPropertyBrowser.cpp index 2fb58b0dd271c08d0646c6d04b4c07c9835c9ceb..2ca286848bbd56b114192aacab6b57483bf1a9dd 100644 --- a/qt/scientific_interfaces/Indirect/IndirectFitPropertyBrowser.cpp +++ b/qt/scientific_interfaces/Indirect/IndirectFitPropertyBrowser.cpp @@ -209,7 +209,7 @@ IndirectFitPropertyBrowser::getFittingFunction() const { } return multiDomainFunction; } - } catch (std::invalid_argument) { + } catch (const std::invalid_argument &) { return boost::make_shared<MultiDomainFunction>(); } } diff --git a/qt/scientific_interfaces/test/ISISReflectometry/Common/ClipboardTest.h b/qt/scientific_interfaces/test/ISISReflectometry/Common/ClipboardTest.h index 9201df531892565786265565dc1e1f3bf46f1e8f..c83b8f7cb07b0f24fa1d1fca8e67ee0c07655ca1 100644 --- a/qt/scientific_interfaces/test/ISISReflectometry/Common/ClipboardTest.h +++ b/qt/scientific_interfaces/test/ISISReflectometry/Common/ClipboardTest.h @@ -36,33 +36,35 @@ public: void testCheckingClipboardTypeThrowsForEmptyClipboard() { auto clipboard = Clipboard(); - TS_ASSERT_THROWS(clipboard.isGroupLocation(0), std::runtime_error); + TS_ASSERT_THROWS(clipboard.isGroupLocation(0), const std::runtime_error &); } void testCheckingGroupNameThrowsForEmptyClipboard() { auto clipboard = Clipboard(); - TS_ASSERT_THROWS(clipboard.groupName(0), std::runtime_error); + TS_ASSERT_THROWS(clipboard.groupName(0), const std::runtime_error &); } void testSettingGroupNameThrowsForEmptyClipboard() { auto clipboard = Clipboard(); TS_ASSERT_THROWS(clipboard.setGroupName(0, "test group"), - std::runtime_error); + const std::runtime_error &); } void testCreateGroupForRootThrowsForEmptyClipboard() { auto clipboard = Clipboard(); - TS_ASSERT_THROWS(clipboard.createGroupForRoot(0), std::runtime_error); + TS_ASSERT_THROWS(clipboard.createGroupForRoot(0), + const std::runtime_error &); } void testCreateRowsForAllRootsThrowsForEmptyClipboard() { auto clipboard = Clipboard(); - TS_ASSERT_THROWS(clipboard.createRowsForAllRoots(), std::runtime_error); + TS_ASSERT_THROWS(clipboard.createRowsForAllRoots(), + const std::runtime_error &); } void testContainsGroupsThrowsForEmptyClipboard() { auto clipboard = Clipboard(); - TS_ASSERT_THROWS(containsGroups(clipboard), std::runtime_error); + TS_ASSERT_THROWS(containsGroups(clipboard), const std::runtime_error &); } void testClipboardIsInitializedWithRow() { @@ -77,18 +79,19 @@ public: void testGettingGroupNameThrowsForRow() { auto clipboard = clipboardWithARow(); - TS_ASSERT_THROWS(clipboard.groupName(0), std::runtime_error); + TS_ASSERT_THROWS(clipboard.groupName(0), const std::runtime_error &); } void testSettingGroupNameThrowsForRow() { auto clipboard = clipboardWithARow(); TS_ASSERT_THROWS(clipboard.setGroupName(0, "test group"), - std::runtime_error); + const std::runtime_error &); } void testCreateGroupForRootThrowsForRow() { auto clipboard = clipboardWithARow(); - TS_ASSERT_THROWS(clipboard.createGroupForRoot(0), std::runtime_error); + TS_ASSERT_THROWS(clipboard.createGroupForRoot(0), + const std::runtime_error &); } void testCreateRowsForAllRootsSucceeds() { @@ -133,7 +136,8 @@ public: void testCreateRowsForAllRootsThrowsForGroup() { auto clipboard = clipboardWithAGroup(); - TS_ASSERT_THROWS(clipboard.createRowsForAllRoots(), std::runtime_error); + TS_ASSERT_THROWS(clipboard.createRowsForAllRoots(), + const std::runtime_error &); } void testContainsGroupsReturnsTrueIfGroupsExist() { @@ -173,7 +177,8 @@ public: void testCreateRowsForAllRootsThrowsForMultiGroupClipboard() { auto clipboard = clipboardWithTwoMultiRowGroups(); - TS_ASSERT_THROWS(clipboard.createRowsForAllRoots(), std::runtime_error); + TS_ASSERT_THROWS(clipboard.createRowsForAllRoots(), + const std::runtime_error &); } void testContainsGroupsReturnsTrueIfMultipleGroupsExist() { diff --git a/qt/scientific_interfaces/test/ISISReflectometry/Common/DecoderTest.h b/qt/scientific_interfaces/test/ISISReflectometry/Common/DecoderTest.h index 5ab5c5b21d2e70fef4fde90de71f623eef3e4549..a35b7d10c6a00a637f65e5a2df4ba90878defe7c 100644 --- a/qt/scientific_interfaces/test/ISISReflectometry/Common/DecoderTest.h +++ b/qt/scientific_interfaces/test/ISISReflectometry/Common/DecoderTest.h @@ -186,7 +186,7 @@ public: mwv.initLayout(); auto gui = dynamic_cast<QtBatchView *>(mwv.batches()[0]); Decoder decoder; - decoder.decodeBatch(gui, &mwv, map); + decoder.decodeBatch(&mwv, 0, map); tester.testBatch(gui, &mwv, map); } @@ -195,4 +195,4 @@ public: } // namespace CustomInterfaces } // namespace MantidQt -#endif /* ISISREFLECTOMETRY_TEST_DECODER_TEST_H_ */ \ No newline at end of file +#endif /* ISISREFLECTOMETRY_TEST_DECODER_TEST_H_ */ diff --git a/qt/scientific_interfaces/test/ISISReflectometry/Common/EncoderTest.h b/qt/scientific_interfaces/test/ISISReflectometry/Common/EncoderTest.h index d68183454da116826e0aea1f6a6e7d1554dffab5..7dde60c8ebdd19172024bf3d82d605faa4b60ad6 100644 --- a/qt/scientific_interfaces/test/ISISReflectometry/Common/EncoderTest.h +++ b/qt/scientific_interfaces/test/ISISReflectometry/Common/EncoderTest.h @@ -47,7 +47,7 @@ public: mwv.initLayout(); auto gui = dynamic_cast<QtBatchView *>(mwv.batches()[0]); Encoder encoder; - auto map = encoder.encodeBatch(gui, &mwv); + auto map = encoder.encodeBatch(&mwv, 0); tester.testBatch(gui, &mwv, map); } @@ -56,4 +56,4 @@ public: } // namespace CustomInterfaces } // namespace MantidQt -#endif /* ISISREFLECTOMETRY_TEST_ENCODER_TEST_H_ */ \ No newline at end of file +#endif /* ISISREFLECTOMETRY_TEST_ENCODER_TEST_H_ */ diff --git a/qt/scientific_interfaces/test/ISISReflectometry/MainWindow/MainWindowPresenterTest.h b/qt/scientific_interfaces/test/ISISReflectometry/MainWindow/MainWindowPresenterTest.h index 24098ee5febd9771b9560f28174a0c29e5568a81..aa60c44e13d1213189f82c00e9639d93511892ae 100644 --- a/qt/scientific_interfaces/test/ISISReflectometry/MainWindow/MainWindowPresenterTest.h +++ b/qt/scientific_interfaces/test/ISISReflectometry/MainWindow/MainWindowPresenterTest.h @@ -315,9 +315,44 @@ public: verifyAndClear(); } + void testSaveBatch() { + auto presenter = makePresenter(); + auto const filename = std::string("test.json"); + auto const map = QMap<QString, QVariant>(); + auto const batchIndex = 1; + EXPECT_CALL(m_messageHandler, askUserForSaveFileName("JSON (*.json)")) + .Times(1) + .WillOnce(Return(filename)); + EXPECT_CALL(*m_encoder, encodeBatch(&m_view, batchIndex, false)) + .Times(1) + .WillOnce(Return(map)); + EXPECT_CALL(m_fileHandler, saveJSONToFile(filename, map)).Times(1); + presenter.notifySaveBatchRequested(batchIndex); + verifyAndClear(); + } + + void testLoadBatch() { + auto presenter = makePresenter(); + auto const filename = std::string("test.json"); + auto const map = QMap<QString, QVariant>(); + auto const batchIndex = 1; + EXPECT_CALL(m_messageHandler, askUserForLoadFileName("JSON (*.json)")) + .Times(1) + .WillOnce(Return(filename)); + EXPECT_CALL(m_fileHandler, loadJSONFromFile(filename)) + .Times(1) + .WillOnce(Return(map)); + EXPECT_CALL(*m_decoder, decodeBatch(&m_view, batchIndex, map)).Times(1); + presenter.notifyLoadBatchRequested(batchIndex); + verifyAndClear(); + } + private: NiceMock<MockMainWindowView> m_view; NiceMock<MockMessageHandler> m_messageHandler; + NiceMock<MockFileHandler> m_fileHandler; + NiceMock<MockEncoder> *m_encoder; + NiceMock<MockDecoder> *m_decoder; std::vector<IBatchView *> m_batchViews; std::vector<NiceMock<MockBatchPresenter> *> m_batchPresenters; NiceMock<MockBatchPresenterFactory> *m_makeBatchPresenter; @@ -329,13 +364,21 @@ private: public: MainWindowPresenterFriend( IMainWindowView *view, IMessageHandler *messageHandler, + IFileHandler *fileHandler, std::unique_ptr<IEncoder> encoder, + std::unique_ptr<IDecoder> decoder, std::unique_ptr<ISlitCalculator> slitCalculator, std::unique_ptr<IBatchPresenterFactory> makeBatchPresenter) - : MainWindowPresenter(view, messageHandler, std::move(slitCalculator), + : MainWindowPresenter(view, messageHandler, fileHandler, + std::move(encoder), std::move(decoder), + std::move(slitCalculator), std::move(makeBatchPresenter)) {} }; MainWindowPresenterFriend makePresenter() { + auto encoder = std::make_unique<NiceMock<MockEncoder>>(); + m_encoder = encoder.get(); + auto decoder = std::make_unique<NiceMock<MockDecoder>>(); + m_decoder = decoder.get(); auto slitCalculator = std::make_unique<NiceMock<MockSlitCalculator>>(); m_slitCalculator = slitCalculator.get(); auto makeBatchPresenter = @@ -350,15 +393,19 @@ private: .WillByDefault(Return(batchPresenter)); } // Make the presenter - auto presenter = MainWindowPresenterFriend(&m_view, &m_messageHandler, - std::move(slitCalculator), - std::move(makeBatchPresenter)); + auto presenter = MainWindowPresenterFriend( + &m_view, &m_messageHandler, &m_fileHandler, std::move(encoder), + std::move(decoder), std::move(slitCalculator), + std::move(makeBatchPresenter)); return presenter; } void verifyAndClear() { TS_ASSERT(Mock::VerifyAndClearExpectations(&m_view)); TS_ASSERT(Mock::VerifyAndClearExpectations(&m_messageHandler)); + TS_ASSERT(Mock::VerifyAndClearExpectations(&m_fileHandler)); + TS_ASSERT(Mock::VerifyAndClearExpectations(&m_encoder)); + TS_ASSERT(Mock::VerifyAndClearExpectations(&m_decoder)); for (auto batchPresenter : m_batchPresenters) TS_ASSERT(Mock::VerifyAndClearExpectations(batchPresenter)); m_batchPresenters.clear(); diff --git a/qt/scientific_interfaces/test/ISISReflectometry/ReflMockObjects.h b/qt/scientific_interfaces/test/ISISReflectometry/ReflMockObjects.h index 13f8321b09f3a92ad6159fd634ed1653e1852068..04d8248ceb90abc389b31bdff9fc43a968e3fe86 100644 --- a/qt/scientific_interfaces/test/ISISReflectometry/ReflMockObjects.h +++ b/qt/scientific_interfaces/test/ISISReflectometry/ReflMockObjects.h @@ -11,6 +11,9 @@ #include "GUI/Batch/IBatchJobRunner.h" #include "GUI/Batch/IBatchPresenter.h" #include "GUI/Batch/IBatchPresenterFactory.h" +#include "GUI/Common/IDecoder.h" +#include "GUI/Common/IEncoder.h" +#include "GUI/Common/IFileHandler.h" #include "GUI/Common/IMessageHandler.h" #include "GUI/Common/IPlotter.h" #include "GUI/Common/IPythonRunner.h" @@ -34,6 +37,9 @@ #include "MantidKernel/WarningSuppressions.h" #include "MantidQtWidgets/Common/BatchAlgorithmRunner.h" #include "MantidQtWidgets/Common/Hint.h" +#include <QMap> +#include <QString> +#include <QVariant> #include <boost/shared_ptr.hpp> #include <gmock/gmock.h> @@ -234,6 +240,27 @@ public: void(const std::string &, const std::string &)); MOCK_METHOD2(giveUserInfo, void(const std::string &, const std::string &)); MOCK_METHOD2(askUserYesNo, bool(const std::string &, const std::string &)); + MOCK_METHOD1(askUserForLoadFileName, std::string(const std::string &)); + MOCK_METHOD1(askUserForSaveFileName, std::string(const std::string &)); +}; + +class MockFileHandler : public IFileHandler { +public: + MOCK_METHOD2(saveJSONToFile, + void(std::string const &, QMap<QString, QVariant> const &)); + MOCK_METHOD1(loadJSONFromFile, QMap<QString, QVariant>(const std::string &)); +}; + +class MockEncoder : public IEncoder { +public: + MOCK_METHOD3(encodeBatch, + QMap<QString, QVariant>(const IMainWindowView *, int, bool)); +}; + +class MockDecoder : public IDecoder { +public: + MOCK_METHOD3(decodeBatch, void(const IMainWindowView *, int, + const QMap<QString, QVariant> &)); }; class MockPythonRunner : public IPythonRunner { diff --git a/qt/widgets/common/test/RepoModelTest.h b/qt/widgets/common/test/RepoModelTest.h index 0b70a8cd7825bc3c3f16fb33c8ccd70df32940ed..ec006fd20597b49e015a5ea2a3e43c12bf0620d7 100644 --- a/qt/widgets/common/test/RepoModelTest.h +++ b/qt/widgets/common/test/RepoModelTest.h @@ -29,6 +29,11 @@ public: static void destroySuite(RepoModelTest *suite) { delete suite; } void setUp() override { + if (Mantid::API::ScriptRepositoryFactory::Instance().exists( + "ScriptRepositoryImpl")) { + Mantid::API::ScriptRepositoryFactory::Instance().unsubscribe( + "ScriptRepositoryImpl"); + } Mantid::API::ScriptRepositoryFactory::Instance() .subscribe<testing::NiceMock<MockScriptRepositoryImpl>>( "ScriptRepositoryImpl"); diff --git a/scripts/Diffraction/isis_powder/hrpd_routines/hrpd_algs.py b/scripts/Diffraction/isis_powder/hrpd_routines/hrpd_algs.py index 94599dc3cc8d600ca1fe3cbb5ac96185093eae73..0128df01f64c13608d71962d7c84935d25ed132e 100644 --- a/scripts/Diffraction/isis_powder/hrpd_routines/hrpd_algs.py +++ b/scripts/Diffraction/isis_powder/hrpd_routines/hrpd_algs.py @@ -61,9 +61,19 @@ def calculate_slab_absorb_corrections(ws_to_correct, sample_details_obj): if previous_units != ws_units.wavelength: ws_to_correct = mantid.ConvertUnits(InputWorkspace=ws_to_correct, OutputWorkspace=ws_to_correct, Target=ws_units.wavelength) - + # set element size based on thickness + sample_thickness = sample_details_obj.thickness() + # half and convert cm to mm 5=(0.5*10) + element_size = 5.*sample_thickness + # limit number of wavelength points as for small samples the number of elements can be required to be quite large + if sample_thickness < 0.1: # 1mm + nlambda = 100 + else: + nlambda = None # use all points absorb_factors = mantid.HRPDSlabCanAbsorption(InputWorkspace=ws_to_correct, - Thickness=str(sample_details_obj.thickness())) + Thickness=sample_thickness, + ElementSize=element_size, + NumberOfWavelengthPoints=nlambda) ws_to_correct = mantid.Divide(LHSWorkspace=ws_to_correct, RHSWorkspace=absorb_factors, OutputWorkspace=ws_to_correct) mantid.DeleteWorkspace(Workspace=absorb_factors) diff --git a/scripts/Diffraction/isis_powder/polaris_routines/polaris_algs.py b/scripts/Diffraction/isis_powder/polaris_routines/polaris_algs.py index e10b37b959ee91e5ed706d0680114c1a93dde10d..8cf5287de594446ed236ea15dec0dd359d5139cb 100644 --- a/scripts/Diffraction/isis_powder/polaris_routines/polaris_algs.py +++ b/scripts/Diffraction/isis_powder/polaris_routines/polaris_algs.py @@ -83,7 +83,14 @@ def generate_ts_pdf(run_number, focus_file_path, merge_banks=False, q_lims=None, sample_details=None): focused_ws = _obtain_focused_run(run_number, focus_file_path) focused_ws = mantid.ConvertUnits(InputWorkspace=focused_ws, Target="MomentumTransfer", EMode='Elastic') - self_scattering_correction = _calculate_self_scattering_correction(run_number, cal_file_name, sample_details) + + raw_ws = mantid.Load(Filename='POLARIS'+str(run_number)+'.nxs') + sample_geometry = common.generate_sample_geometry(sample_details) + sample_material = common.generate_sample_material(sample_details) + self_scattering_correction = mantid.TotScatCalculateSelfScattering(InputWorkspace=raw_ws, + CalFileName=cal_file_name, + SampleGeometry=sample_geometry, + SampleMaterial=sample_material) ws_group_list = [] for i in range(self_scattering_correction.getNumberHistograms()): @@ -97,7 +104,9 @@ def generate_ts_pdf(run_number, focus_file_path, merge_banks=False, q_lims=None, focused_ws = mantid.Subtract(LHSWorkspace=focused_ws, RHSWorkspace=self_scattering_correction) if merge_banks: - merged_ws = _merge_workspace_with_limits(focused_ws, q_lims) + q_min, q_max = _load_qlims(q_lims) + merged_ws = mantid.MatchAndMergeWorkspaces(InputWorkspaces=focused_ws, XMin=q_min, XMax=q_max, + CalculateScale=False) pdf_output = mantid.PDFFourierTransform(Inputworkspace=merged_ws, InputSofQType="S(Q)-1", PDFType="G(r)", Filter=True) else: @@ -135,53 +144,6 @@ def _obtain_focused_run(run_number, focus_file_path): return focused_ws -def _merge_workspace_with_limits(focused_ws, q_lims): - min_x = np.inf - max_x = -np.inf - num_x = -np.inf - ws_max_range = 0 - largest_range_spectrum = 0 - for i in range(focused_ws.size()): - x_data = focused_ws[i].dataX(0) - min_x = min(np.min(x_data), min_x) - max_x = max(np.max(x_data), max_x) - num_x = max(x_data.size, num_x) - ws_range = np.max(x_data)-np.min(x_data) - if ws_range > ws_max_range: - largest_range_spectrum = i + 1 - ws_max_range = ws_range - if min_x == -np.inf or max_x == np.inf: - raise AttributeError("Workspace x range contains an infinite value.") - focused_ws = mantid.Rebin(InputWorkspace=focused_ws, Params=[min_x, (max_x-min_x)/num_x, max_x]) - while focused_ws.size() > 1: - mantid.ConjoinWorkspaces(InputWorkspace1=focused_ws[0], - InputWorkspace2=focused_ws[1]) - focused_ws_conjoined = focused_ws[0] - mantid.MatchSpectra(InputWorkspace=focused_ws_conjoined, OutputWorkspace=focused_ws_conjoined, - ReferenceSpectrum=largest_range_spectrum) - - q_min, q_max = _load_qlims(q_lims) - bin_width = np.inf - for i in range(q_min.size): - pdf_x_array = focused_ws_conjoined.readX(i) - tmp1 = np.where(pdf_x_array >= q_min[i]) - tmp2 = np.amin(tmp1) - q_min[i] = pdf_x_array[tmp2] - q_max[i] = pdf_x_array[np.amax(np.where(pdf_x_array <= q_max[i]))] - bin_width = min(pdf_x_array[1] - pdf_x_array[0], bin_width) - - if min_x == -np.inf or max_x == np.inf: - raise AttributeError("Q lims contains an infinite value.") - focused_data_combined = mantid.CropWorkspaceRagged(InputWorkspace=focused_ws_conjoined, XMin=q_min, XMax=q_max) - focused_data_combined = mantid.Rebin(InputWorkspace=focused_data_combined, - Params=[min(q_min), bin_width, max(q_max)]) - focused_data_combined = mantid.SumSpectra(InputWorkspace=focused_data_combined, - WeightedSum=True, - MultiplyBySpectra=False) - common.remove_intermediate_workspace(focused_ws_conjoined) - return focused_data_combined - - def _load_qlims(q_lims): if type(q_lims) == str or type(q_lims) == unicode: q_min = [] @@ -205,60 +167,6 @@ def _load_qlims(q_lims): return q_min, q_max -def _calculate_self_scattering_correction(run_number, cal_file_name, sample_details): - raw_ws = mantid.Load(Filename='POLARIS'+str(run_number)+'.nxs') - mantid.SetSample(InputWorkspace=raw_ws, - Geometry=common.generate_sample_geometry(sample_details), - Material=common.generate_sample_material(sample_details)) - # find the closest monitor to the sample for incident spectrum - raw_spec_info = raw_ws.spectrumInfo() - incident_index = None - for i in range(raw_spec_info.size()): - if raw_spec_info.isMonitor(i): - l2 = raw_spec_info.position(i)[2] - if not incident_index: - incident_index = i - else: - if raw_spec_info.position(incident_index)[2] < l2 < 0: - incident_index = i - monitor = mantid.ExtractSpectra(InputWorkspace=raw_ws, WorkspaceIndexList=[incident_index]) - monitor = mantid.ConvertUnits(InputWorkspace=monitor, Target="Wavelength") - x_data = monitor.dataX(0) - min_x = np.min(x_data) - max_x = np.max(x_data) - width_x = (max_x - min_x) / x_data.size - fit_spectra = mantid.FitIncidentSpectrum(InputWorkspace=monitor, - BinningForCalc=[min_x, 1 * width_x, max_x], - BinningForFit=[min_x, 10 * width_x, max_x], - FitSpectrumWith="CubicSpline") - self_scattering_correction = mantid.CalculatePlaczekSelfScattering(InputWorkspace=raw_ws, - IncidentSpecta=fit_spectra) - cal_workspace = mantid.LoadCalFile(InputWorkspace=self_scattering_correction, - CalFileName=cal_file_name, - Workspacename='cal_workspace', - MakeOffsetsWorkspace=False, - MakeMaskWorkspace=False) - self_scattering_correction = mantid.DiffractionFocussing(InputWorkspace=self_scattering_correction, - GroupingFilename=cal_file_name) - n_pixel = np.zeros(self_scattering_correction.getNumberHistograms()) - for i in range(cal_workspace.getNumberHistograms()): - grouping = cal_workspace.dataY(i) - if grouping[0] > 0: - n_pixel[int(grouping[0] - 1)] += 1 - correction_ws = mantid.CreateWorkspace(DataY=n_pixel, DataX=[0, 1], - NSpec=self_scattering_correction.getNumberHistograms()) - self_scattering_correction = mantid.Divide(LHSWorkspace=self_scattering_correction, RHSWorkspace=correction_ws) - mantid.ConvertToDistribution(Workspace=self_scattering_correction) - self_scattering_correction = mantid.ConvertUnits(InputWorkspace=self_scattering_correction, - Target="MomentumTransfer", EMode='Elastic') - common.remove_intermediate_workspace('cal_workspace_group') - common.remove_intermediate_workspace(correction_ws) - common.remove_intermediate_workspace(fit_spectra) - common.remove_intermediate_workspace(monitor) - common.remove_intermediate_workspace(raw_ws) - return self_scattering_correction - - def _determine_chopper_mode(ws): if ws.getRun().hasProperty('Frequency'): frequency = ws.getRun()['Frequency'].lastValue() diff --git a/scripts/Engineering/gui/engineering_diffraction/engineering_diffraction.py b/scripts/Engineering/gui/engineering_diffraction/engineering_diffraction.py index bfb03af5406423359bacc1aed41705c377a401fe..3a8dbe99906095ffb77409dd2f222c700fb35310 100644 --- a/scripts/Engineering/gui/engineering_diffraction/engineering_diffraction.py +++ b/scripts/Engineering/gui/engineering_diffraction/engineering_diffraction.py @@ -49,7 +49,7 @@ class EngineeringDiffractionGui(QtWidgets.QMainWindow, Ui_main_window): cal_view = CalibrationView(parent=self.tabs) self.calibration_presenter = CalibrationPresenter(cal_model, cal_view) self.set_on_instrument_changed(self.calibration_presenter.set_instrument_override) - self.set_on_rb_num_changed(self.calibration_presenter.set_rb_number) + self.set_on_rb_num_changed(self.calibration_presenter.set_rb_num) self.tabs.addTab(cal_view, "Calibration") def setup_focus(self): @@ -57,11 +57,12 @@ class EngineeringDiffractionGui(QtWidgets.QMainWindow, Ui_main_window): focus_view = FocusView() self.focus_presenter = FocusPresenter(focus_model, focus_view) self.set_on_instrument_changed(self.focus_presenter.set_instrument_override) - self.set_on_rb_num_changed(self.focus_presenter.set_rb_number) + self.set_on_rb_num_changed(self.focus_presenter.set_rb_num) self.tabs.addTab(focus_view, "Focus") def setup_calibration_notifier(self): - self.calibration_presenter.calibration_notifier.add_subscriber(self.focus_presenter.calibration_observer) + self.calibration_presenter.calibration_notifier.add_subscriber( + self.focus_presenter.calibration_observer) def set_on_help_clicked(self, slot): self.pushButton_help.clicked.connect(slot) diff --git a/scripts/Engineering/gui/engineering_diffraction/tabs/calibration/calibration_tab.ui b/scripts/Engineering/gui/engineering_diffraction/tabs/calibration/calibration_tab.ui index cddb60ae4bbd61e17a82630c9b04c10e7448b61d..6fc3ef07bb6b5f98925cbd3d12d8e0837c69fbba 100644 --- a/scripts/Engineering/gui/engineering_diffraction/tabs/calibration/calibration_tab.ui +++ b/scripts/Engineering/gui/engineering_diffraction/tabs/calibration/calibration_tab.ui @@ -154,7 +154,7 @@ QGroupBox:title { </widget> </item> <item row="2" column="0" colspan="3"> - <widget class="FileFinder" name="finder_calib" native="true"> + <widget class="FileFinder" name="finder_sample" native="true"> <property name="sizePolicy"> <sizepolicy hsizetype="Preferred" vsizetype="Preferred"> <horstretch>0</horstretch> diff --git a/scripts/Engineering/gui/engineering_diffraction/tabs/calibration/model.py b/scripts/Engineering/gui/engineering_diffraction/tabs/calibration/model.py index a01131293b00cd8372cb0572872085add0967601..6616e2367296727778d173acfd411e730fc61424 100644 --- a/scripts/Engineering/gui/engineering_diffraction/tabs/calibration/model.py +++ b/scripts/Engineering/gui/engineering_diffraction/tabs/calibration/model.py @@ -8,12 +8,13 @@ from __future__ import (absolute_import, division, print_function) from os import path, makedirs +from matplotlib import gridspec +import matplotlib.pyplot as plt from mantid.api import AnalysisDataService as Ads from mantid.kernel import logger from mantid.simpleapi import Load, EnggCalibrate, DeleteWorkspace, CloneWorkspace, \ CreateWorkspace, AppendSpectra, CreateEmptyTableWorkspace -from mantidqt.plotting.functions import plot from Engineering.EnggUtils import write_ENGINX_GSAS_iparam_file from Engineering.gui.engineering_diffraction.tabs.common import vanadium_corrections from Engineering.gui.engineering_diffraction.tabs.common import path_handling @@ -32,28 +33,29 @@ SOUTH_BANK_TEMPLATE_FILE = "template_ENGINX_241391_236516_South_bank.prm" class CalibrationModel(object): def create_new_calibration(self, vanadium_path, - ceria_path, + sample_path, plot_output, instrument, rb_num=None): """ - Create a new calibration from a vanadium run and ceria run + Create a new calibration from a vanadium run and sample run :param vanadium_path: Path to vanadium data file. - :param ceria_path: Path to ceria data file + :param sample_path: Path to sample (CeO2) data file :param plot_output: Whether the output should be plotted. :param instrument: The instrument the data relates to. :param rb_num: The RB number for file creation. """ van_integration, van_curves = vanadium_corrections.fetch_correction_workspaces( - vanadium_path, instrument) - ceria_workspace = self.load_ceria(ceria_path) - output = self.run_calibration(ceria_workspace, van_integration, van_curves) + vanadium_path, instrument, rb_num=rb_num) + sample_workspace = self.load_sample(sample_path) + output = self.run_calibration(sample_workspace, van_integration, van_curves) if plot_output: self._plot_vanadium_curves() for i in range(2): difc = [output[i].DIFC] tzero = [output[i].TZERO] - self._plot_difc_zero(difc, tzero) + self._generate_difc_tzero_workspace(difc, tzero, i + 1) + self._plot_difc_tzero() difc = [output[0].DIFC, output[1].DIFC] tzero = [output[0].TZERO, output[1].TZERO] @@ -62,12 +64,12 @@ class CalibrationModel(object): params_table.append([i, difc[i], 0.0, tzero[i]]) self.update_calibration_params_table(params_table) - self.create_output_files(CALIBRATION_DIR, difc, tzero, ceria_path, vanadium_path, + self.create_output_files(CALIBRATION_DIR, difc, tzero, sample_path, vanadium_path, instrument) if rb_num: user_calib_dir = path.join(path_handling.OUT_FILES_ROOT_DIR, "User", rb_num, "Calibration", "") - self.create_output_files(user_calib_dir, difc, tzero, ceria_path, vanadium_path, + self.create_output_files(user_calib_dir, difc, tzero, sample_path, vanadium_path, instrument) def load_existing_gsas_parameters(self, file_path): @@ -75,13 +77,13 @@ class CalibrationModel(object): logger.warning("Could not open GSAS calibration file: ", file_path) return try: - instrument, van_no, ceria_no, params_table = self.get_info_from_file(file_path) + instrument, van_no, sample_no, params_table = self.get_info_from_file(file_path) self.update_calibration_params_table(params_table) except RuntimeError: logger.error("Invalid file selected: ", file_path) return vanadium_corrections.fetch_correction_workspaces(van_no, instrument) - return instrument, van_no, ceria_no + return instrument, van_no, sample_no @staticmethod def update_calibration_params_table(params_table): @@ -110,74 +112,89 @@ class CalibrationModel(object): DeleteWorkspace(van_curve_twin_ws) CloneWorkspace(InputWorkspace="engggui_vanadium_curves", OutputWorkspace=van_curve_twin_ws) van_curves_ws = Ads.retrieve(van_curve_twin_ws) - for i in range(1, 3): - if i == 1: - curve_plot_bank_1 = plot([van_curves_ws], [0, 1, 2]) - curve_plot_bank_1.gca().set_title("Engg GUI Vanadium Curves Bank 1") - curve_plot_bank_1.gca().legend(["Data", "Calc", "Diff"]) - if i == 2: - curve_plot_bank_2 = plot([van_curves_ws], [3, 4, 5]) - curve_plot_bank_2.gca().set_title("Engg GUI Vanadium Curves Bank 2") - curve_plot_bank_2.gca().legend(["Data", "Calc", "Diff"]) + + fig = plt.figure() + gs = gridspec.GridSpec(1, 2) + curve_plot_bank_1 = fig.add_subplot(gs[0], projection="mantid") + curve_plot_bank_2 = fig.add_subplot(gs[1], projection="mantid") + + curve_plot_bank_1.plot(van_curves_ws, wkspIndex=0) + curve_plot_bank_1.plot(van_curves_ws, wkspIndex=1) + curve_plot_bank_1.plot(van_curves_ws, wkspIndex=2) + curve_plot_bank_1.set_title("Engg GUI Vanadium Curves Bank 1") + curve_plot_bank_1.legend(["Data", "Calc", "Diff"]) + + curve_plot_bank_2.plot(van_curves_ws, wkspIndex=3) + curve_plot_bank_2.plot(van_curves_ws, wkspIndex=4) + curve_plot_bank_2.plot(van_curves_ws, wkspIndex=5) + curve_plot_bank_2.set_title("Engg GUI Vanadium Curves Bank 2") + curve_plot_bank_2.legend(["Data", "Calc", "Diff"]) + + fig.show() + + @staticmethod + def _generate_difc_tzero_workspace(difc, tzero, bank): + bank_ws = Ads.retrieve(CalibrationModel._generate_table_workspace_name(bank - 1)) + + x_val = [] + y_val = [] + y2_val = [] + + difc_to_plot = difc[0] + tzero_to_plot = tzero[0] + + for irow in range(0, bank_ws.rowCount()): + x_val.append(bank_ws.cell(irow, 0)) + y_val.append(bank_ws.cell(irow, 5)) + y2_val.append(x_val[irow] * difc_to_plot + tzero_to_plot) + + ws1 = CreateWorkspace(DataX=x_val, + DataY=y_val, + UnitX="Expected Peaks Centre (dSpacing A)", + YUnitLabel="Fitted Peaks Centre(TOF, us)") + ws2 = CreateWorkspace(DataX=x_val, DataY=y2_val) + + output_ws = "engggui_difc_zero_peaks_bank_" + str(bank) + if Ads.doesExist(output_ws): + DeleteWorkspace(output_ws) + + AppendSpectra(ws1, ws2, OutputWorkspace=output_ws) + DeleteWorkspace(ws1) + DeleteWorkspace(ws2) @staticmethod - def _plot_difc_zero(difc, tzero): - for i in range(1, 3): - bank_ws = Ads.retrieve(CalibrationModel._generate_table_workspace_name(i - 1)) - - x_val = [] - y_val = [] - y2_val = [] - - difc_to_plot = difc[0] - tzero_to_plot = tzero[0] - - for irow in range(0, bank_ws.rowCount()): - x_val.append(bank_ws.cell(irow, 0)) - y_val.append(bank_ws.cell(irow, 5)) - y2_val.append(x_val[irow] * difc_to_plot + tzero_to_plot) - - ws1 = CreateWorkspace(DataX=x_val, - DataY=y_val, - UnitX="Expected Peaks Centre (dSpacing A)", - YUnitLabel="Fitted Peaks Centre(TOF, us)") - ws2 = CreateWorkspace(DataX=x_val, DataY=y2_val) - - output_ws = "engggui_difc_zero_peaks_bank_" + str(i) - if Ads.doesExist(output_ws): - DeleteWorkspace(output_ws) - - AppendSpectra(ws1, ws2, OutputWorkspace=output_ws) - DeleteWorkspace(ws1) - DeleteWorkspace(ws2) - - difc_zero_ws = Ads.retrieve(output_ws) - # Create plot - difc_zero_plot = plot([difc_zero_ws], [0, 1], - plot_kwargs={ - "linestyle": "--", - "marker": "o", - "markersize": "3" - }) - difc_zero_plot.gca().set_title("Engg Gui Difc Zero Peaks Bank " + str(i)) - difc_zero_plot.gca().legend(("Peaks Fitted", "DifC/TZero Fitted Straight Line")) - difc_zero_plot.gca().set_xlabel("Expected Peaks Centre(dSpacing, A)") + def _plot_difc_tzero(): + bank_1_ws = Ads.retrieve("engggui_difc_zero_peaks_bank_1") + bank_2_ws = Ads.retrieve("engggui_difc_zero_peaks_bank_2") + # Create plot + fig = plt.figure() + gs = gridspec.GridSpec(1, 2) + plot_bank_1 = fig.add_subplot(gs[0], projection="mantid") + plot_bank_2 = fig.add_subplot(gs[1], projection="mantid") + + for ax, ws, bank in zip([plot_bank_1, plot_bank_2], [bank_1_ws, bank_2_ws], [1, 2]): + ax.plot(ws, wkspIndex=0, linestyle="--", marker="o", markersize="3") + ax.plot(ws, wkspIndex=1, linestyle="--", marker="o", markersize="3") + ax.set_title("Engg Gui Difc Zero Peaks Bank " + str(bank)) + ax.legend(("Peaks Fitted", "DifC/TZero Fitted Straight Line")) + ax.set_xlabel("Expected Peaks Centre(dSpacing, A)") + fig.show() @staticmethod - def load_ceria(ceria_run_no): + def load_sample(sample_run_no): try: - return Load(Filename=ceria_run_no, OutputWorkspace="engggui_calibration_sample_ws") + return Load(Filename=sample_run_no, OutputWorkspace="engggui_calibration_sample_ws") except Exception as e: logger.error("Error while loading calibration sample data. " "Could not run the algorithm Load successfully for the calibration sample " - "(run number: " + str(ceria_run_no) + "). Error description: " + str(e) + + "(run number: " + str(sample_run_no) + "). Error description: " + str(e) + " Please check also the previous log messages for details.") raise RuntimeError - def run_calibration(self, ceria_ws, van_integration, van_curves): + def run_calibration(self, sample_ws, van_integration, van_curves): """ Runs the main Engineering calibration algorithm. - :param ceria_ws: The workspace with the ceria data. + :param sample_ws: The workspace with the sample data. :param van_integration: The integration values from the vanadium corrections :param van_curves: The curves from the vanadium corrections. :return: The output of the algorithm. @@ -185,28 +202,28 @@ class CalibrationModel(object): output = [None] * 2 for i in range(2): table_name = self._generate_table_workspace_name(i) - output[i] = EnggCalibrate(InputWorkspace=ceria_ws, + output[i] = EnggCalibrate(InputWorkspace=sample_ws, VanIntegrationWorkspace=van_integration, VanCurvesWorkspace=van_curves, Bank=str(i + 1), FittedPeaks=table_name) return output - def create_output_files(self, calibration_dir, difc, tzero, ceria_path, vanadium_path, + def create_output_files(self, calibration_dir, difc, tzero, sample_path, vanadium_path, instrument): """ Create output files from the algorithms in the specified directory :param calibration_dir: The directory to save the files into. :param difc: DIFC values from the calibration algorithm. :param tzero: TZERO values from the calibration algorithm. - :param ceria_path: The path to the ceria data file. + :param sample_path: The path to the sample data file. :param vanadium_path: The path to the vanadium data file. :param instrument: The instrument (ENGINX or IMAT) """ if not path.exists(calibration_dir): makedirs(calibration_dir) filename = self._generate_output_file_name(vanadium_path, - ceria_path, + sample_path, instrument, bank="all") # Both Banks @@ -214,21 +231,21 @@ class CalibrationModel(object): write_ENGINX_GSAS_iparam_file(file_path, difc, tzero, - ceria_run=ceria_path, + ceria_run=sample_path, vanadium_run=vanadium_path) # North Bank file_path = calibration_dir + self._generate_output_file_name( - vanadium_path, ceria_path, instrument, bank="north") + vanadium_path, sample_path, instrument, bank="north") write_ENGINX_GSAS_iparam_file(file_path, [difc[0]], [tzero[0]], - ceria_run=ceria_path, + ceria_run=sample_path, vanadium_run=vanadium_path, template_file=NORTH_BANK_TEMPLATE_FILE, bank_names=["North"]) # South Bank file_path = calibration_dir + self._generate_output_file_name( - vanadium_path, ceria_path, instrument, bank="south") + vanadium_path, sample_path, instrument, bank="south") write_ENGINX_GSAS_iparam_file(file_path, [difc[1]], [tzero[1]], - ceria_run=ceria_path, + ceria_run=sample_path, vanadium_run=vanadium_path, template_file=SOUTH_BANK_TEMPLATE_FILE, bank_names=["South"]) @@ -258,27 +275,27 @@ class CalibrationModel(object): raise RuntimeError("Invalid file format.") words = run_numbers.split() - ceria_no = words[2] # Run numbers are stored as the 3rd and 4th word in this line. + sample_no = words[2] # Run numbers are stored as the 3rd and 4th word in this line. van_no = words[3] - return instrument, van_no, ceria_no, params_table + return instrument, van_no, sample_no, params_table @staticmethod def _generate_table_workspace_name(bank_num): return "engggui_calibration_bank_" + str(bank_num + 1) @staticmethod - def _generate_output_file_name(vanadium_path, ceria_path, instrument, bank): + def _generate_output_file_name(vanadium_path, sample_path, instrument, bank): """ - Generate an output filename in the form INSTRUMENT_VanadiumRunNo_CeriaRunCo_BANKS + Generate an output filename in the form INSTRUMENT_VanadiumRunNo_SampleRunNo_BANKS :param vanadium_path: Path to vanadium data file - :param ceria_path: Path to ceria data file + :param sample_path: Path to sample data file :param instrument: The instrument in use. :param bank: The bank being saved. - :return: The filename, the vanadium run number, and ceria run number. + :return: The filename, the vanadium run number, and sample run number. """ vanadium_no = path_handling.get_run_number_from_path(vanadium_path, instrument) - ceria_no = path_handling.get_run_number_from_path(ceria_path, instrument) - filename = instrument + "_" + vanadium_no + "_" + ceria_no + "_" + sample_no = path_handling.get_run_number_from_path(sample_path, instrument) + filename = instrument + "_" + vanadium_no + "_" + sample_no + "_" if bank == "all": filename = filename + "all_banks.prm" elif bank == "north": diff --git a/scripts/Engineering/gui/engineering_diffraction/tabs/calibration/presenter.py b/scripts/Engineering/gui/engineering_diffraction/tabs/calibration/presenter.py index eed8e9b5194d906749fd15576f192e8bdf49b275..d7b7cb19acf49c3991730b684a7dec859cfce762 100644 --- a/scripts/Engineering/gui/engineering_diffraction/tabs/calibration/presenter.py +++ b/scripts/Engineering/gui/engineering_diffraction/tabs/calibration/presenter.py @@ -7,9 +7,10 @@ # pylint: disable=invalid-name from __future__ import (absolute_import, division, print_function) -from qtpy.QtWidgets import QMessageBox +from copy import deepcopy -from Engineering.gui.engineering_diffraction.tabs.common import INSTRUMENT_DICT +from Engineering.gui.engineering_diffraction.tabs.common import INSTRUMENT_DICT, create_error_message +from Engineering.gui.engineering_diffraction.tabs.common.calibration_info import CalibrationInfo from mantidqt.utils.asynchronous import AsyncTask from mantid.simpleapi import logger from mantidqt.utils.observer_pattern import Observable @@ -22,12 +23,13 @@ class CalibrationPresenter(object): self.worker = None self.calibration_notifier = self.CalibrationNotifier(self) - self.current_calibration = {"vanadium_path": None, "ceria_path": None, "instrument": None} - self.pending_calibration = {"vanadium_path": None, "ceria_path": None, "instrument": None} + self.current_calibration = CalibrationInfo() + self.pending_calibration = CalibrationInfo() # Connect view signals to local functions. self.view.set_on_calibrate_clicked(self.on_calibrate_clicked) self.view.set_enable_controls_connection(self.set_calibrate_controls_enabled) + self.view.set_update_fields_connection(self.set_field_values) self.view.set_on_radio_new_toggled(self.set_create_new_enabled) self.view.set_on_radio_existing_toggled(self.set_load_existing_enabled) @@ -37,76 +39,70 @@ class CalibrationPresenter(object): def on_calibrate_clicked(self): plot_output = self.view.get_plot_output() - if self.view.get_new_checked(): - # Do nothing if run numbers are invalid or view is searching. - if not self.validate_run_numbers(): - self._create_error_message("Check run numbers/path is valid.") - return - if self.view.is_searching(): - self._create_error_message("Mantid is searching for the file. Please wait.") - return - vanadium_no = self.view.get_vanadium_filename() - calib_no = self.view.get_calib_filename() - self.start_calibration_worker(vanadium_no, calib_no, plot_output, self.rb_num) + if self.view.get_new_checked() and self._validate(): + vanadium_file = self.view.get_vanadium_filename() + sample_file = self.view.get_sample_filename() + self.start_calibration_worker(vanadium_file, sample_file, plot_output, self.rb_num) elif self.view.get_load_checked(): if not self.validate_path(): return filename = self.view.get_path_filename() - instrument, vanadium_no, calib_no = self.model.load_existing_gsas_parameters(filename) - self.pending_calibration["vanadium_path"] = vanadium_no - self.pending_calibration["ceria_path"] = calib_no - self.pending_calibration["instrument"] = instrument + instrument, vanadium_file, sample_file = self.model.load_existing_gsas_parameters( + filename) + self.pending_calibration.set_calibration(vanadium_file, sample_file, instrument) self.set_current_calibration() - def start_calibration_worker(self, vanadium_path, calib_path, plot_output, rb_num): + def start_calibration_worker(self, vanadium_path, sample_path, plot_output, rb_num): """ Calibrate the data in a separate thread so as to not freeze the GUI. :param vanadium_path: Path to vanadium data file. - :param calib_path: Path to calibration data file. + :param sample_path: Path to sample data file. :param plot_output: Whether to plot the output. :param rb_num: The current RB number set in the GUI. """ - self.worker = AsyncTask(self.model.create_new_calibration, (vanadium_path, calib_path), { + self.worker = AsyncTask(self.model.create_new_calibration, (vanadium_path, sample_path), { "plot_output": plot_output, "instrument": self.instrument, "rb_num": rb_num }, error_cb=self._on_error, success_cb=self._on_success) - self.pending_calibration["vanadium_path"] = vanadium_path - self.pending_calibration["ceria_path"] = calib_path - self.pending_calibration["instrument"] = self.instrument + self.pending_calibration.set_calibration(vanadium_path, sample_path, self.instrument) self.set_calibrate_controls_enabled(False) self.worker.start() - def _create_error_message(self, message): - QMessageBox.warning(self.view, "Engineering Diffraction - Error", str(message)) - def set_current_calibration(self, success_info=None): if success_info: logger.information("Thread executed in " + str(success_info.elapsed_time) + " seconds.") - self.current_calibration = self.pending_calibration + self.current_calibration = deepcopy(self.pending_calibration) self.calibration_notifier.notify_subscribers(self.current_calibration) - self.set_field_values() - self.pending_calibration = {"vanadium_path": None, "ceria_path": None, "instrument": None} + self.emit_update_fields_signal() + self.pending_calibration.clear() def set_field_values(self): - self.view.set_calib_text(self.current_calibration["ceria_path"]) - self.view.set_vanadium_text(self.current_calibration["vanadium_path"]) + self.view.set_sample_text(self.current_calibration.get_sample()) + self.view.set_vanadium_text(self.current_calibration.get_vanadium()) def set_instrument_override(self, instrument): instrument = INSTRUMENT_DICT[instrument] self.view.set_instrument_override(instrument) self.instrument = instrument - def set_rb_number(self, rb_number): - self.rb_num = rb_number + def set_rb_num(self, rb_num): + self.rb_num = rb_num - def validate_run_numbers(self): - if self.view.get_calib_valid() and self.view.get_vanadium_valid(): - return True - else: + def _validate(self): + # Do nothing if run numbers are invalid or view is searching. + if self.view.is_searching(): + create_error_message(self.view, "Mantid is searching for data files. Please wait.") return False + if not self.validate_run_numbers(): + create_error_message(self.view, "Check run numbers/path is valid.") + return False + return True + + def validate_run_numbers(self): + return self.view.get_sample_valid() and self.view.get_vanadium_valid() def validate_path(self): return self.view.get_path_valid() @@ -114,6 +110,9 @@ class CalibrationPresenter(object): def emit_enable_button_signal(self): self.view.sig_enable_controls.emit(True) + def emit_update_fields_signal(self): + self.view.sig_update_fields.emit() + def set_calibrate_controls_enabled(self, enabled): self.view.set_calibrate_button_enabled(enabled) self.view.set_check_plot_output_enabled(enabled) @@ -128,7 +127,7 @@ class CalibrationPresenter(object): def set_create_new_enabled(self, enabled): self.view.set_vanadium_enabled(enabled) - self.view.set_calib_enabled(enabled) + self.view.set_sample_enabled(enabled) if enabled: self.set_calibrate_button_text("Calibrate") self.view.set_check_plot_output_enabled(True) @@ -144,7 +143,7 @@ class CalibrationPresenter(object): self.view.set_calibrate_button_text(text) def find_files(self): - self.view.find_calib_files() + self.view.find_sample_files() self.view.find_vanadium_files() # ----------------------- diff --git a/scripts/Engineering/gui/engineering_diffraction/tabs/calibration/test/test_calib_model.py b/scripts/Engineering/gui/engineering_diffraction/tabs/calibration/test/test_calib_model.py index 51ea0044c9015f45be6135c61157a6a5ee9bb5cd..60e8d530feb89353e116cd204610f5244169e85b 100644 --- a/scripts/Engineering/gui/engineering_diffraction/tabs/calibration/test/test_calib_model.py +++ b/scripts/Engineering/gui/engineering_diffraction/tabs/calibration/test/test_calib_model.py @@ -34,20 +34,20 @@ class CalibrationModelTest(unittest.TestCase): @patch(class_path + '.update_calibration_params_table') @patch(class_path + '.create_output_files') @patch(class_path + '.run_calibration') - @patch(class_path + '.load_ceria') + @patch(class_path + '.load_sample') @patch(file_path + '.vanadium_corrections.fetch_correction_workspaces') - def test_EnggVanadiumCorrections_algorithm_is_called(self, van, load_ceria, calib, output_files, - update_table): + def test_EnggVanadiumCorrections_algorithm_is_called(self, van, load_sample, calib, + output_files, update_table): van.return_value = ("A", "B") self.model.create_new_calibration(VANADIUM_NUMBER, CERIUM_NUMBER, False, "ENGINX") van.assert_called_once() @patch(class_path + '.update_calibration_params_table') @patch(class_path + '.create_output_files') - @patch(class_path + '.load_ceria') + @patch(class_path + '.load_sample') @patch(class_path + '.run_calibration') @patch(file_path + '.vanadium_corrections.fetch_correction_workspaces') - def test_fetch_vanadium_is_called(self, van_corr, calibrate_alg, load_ceria, output_files, + def test_fetch_vanadium_is_called(self, van_corr, calibrate_alg, load_sample, output_files, update_table): van_corr.return_value = ("mocked_integration", "mocked_curves") self.model.create_new_calibration(VANADIUM_NUMBER, CERIUM_NUMBER, False, "ENGINX") @@ -55,30 +55,34 @@ class CalibrationModelTest(unittest.TestCase): @patch(class_path + '.update_calibration_params_table') @patch(class_path + '.create_output_files') - @patch(class_path + '.load_ceria') + @patch(class_path + '.load_sample') @patch(file_path + '.vanadium_corrections.fetch_correction_workspaces') @patch(class_path + '._plot_vanadium_curves') - @patch(class_path + '._plot_difc_zero') + @patch(class_path + '._generate_difc_tzero_workspace') + @patch(class_path + '._plot_difc_tzero') @patch(class_path + '.run_calibration') - def test_plotting_check(self, calib, plot_difc_zero, plot_van, van, ceria, output_files, - update_table): + def test_plotting_check(self, calib, plot_difc_zero, gen_difc, plot_van, van, sample, + output_files, update_table): van.return_value = ("A", "B") self.model.create_new_calibration(VANADIUM_NUMBER, CERIUM_NUMBER, False, "ENGINX") plot_van.assert_not_called() plot_difc_zero.assert_not_called() + gen_difc.assert_not_called() self.model.create_new_calibration(VANADIUM_NUMBER, CERIUM_NUMBER, True, "ENGINX") plot_van.assert_called_once() - self.assertEqual(plot_difc_zero.call_count, 2) + self.assertEqual(gen_difc.call_count, 2) + self.assertEqual(plot_difc_zero.call_count, 1) @patch(class_path + '.update_calibration_params_table') @patch(class_path + '.create_output_files') - @patch(class_path + '.load_ceria') + @patch(class_path + '.load_sample') @patch(file_path + '.vanadium_corrections.fetch_correction_workspaces') @patch(class_path + '._plot_vanadium_curves') - @patch(class_path + '._plot_difc_zero') + @patch(class_path + '._plot_difc_tzero') @patch(class_path + '.run_calibration') def test_present_RB_number_results_in_user_output_files(self, calib, plot_difc_zero, plot_van, - van, ceria, output_files, update_table): + van, sample, output_files, + update_table): van.return_value = ("A", "B") self.model.create_new_calibration(VANADIUM_NUMBER, CERIUM_NUMBER, @@ -89,13 +93,13 @@ class CalibrationModelTest(unittest.TestCase): @patch(class_path + '.update_calibration_params_table') @patch(class_path + '.create_output_files') - @patch(class_path + '.load_ceria') + @patch(class_path + '.load_sample') @patch(file_path + '.vanadium_corrections.fetch_correction_workspaces') @patch(class_path + '._plot_vanadium_curves') - @patch(class_path + '._plot_difc_zero') + @patch(class_path + '._plot_difc_tzero') @patch(class_path + '.run_calibration') def test_absent_run_number_results_in_no_user_output_files(self, calib, plot_difc_zero, - plot_van, van, ceria, output_files, + plot_van, van, sample, output_files, update_table): van.return_value = ("A", "B") self.model.create_new_calibration(VANADIUM_NUMBER, CERIUM_NUMBER, False, "ENGINX") @@ -103,10 +107,10 @@ class CalibrationModelTest(unittest.TestCase): @patch(class_path + '.update_calibration_params_table') @patch(class_path + '.create_output_files') - @patch(class_path + '.load_ceria') + @patch(class_path + '.load_sample') @patch(file_path + '.vanadium_corrections.fetch_correction_workspaces') @patch(class_path + '.run_calibration') - def test_calibration_params_table_is_updated(self, calibrate_alg, vanadium_alg, load_ceria, + def test_calibration_params_table_is_updated(self, calibrate_alg, vanadium_alg, load_sample, output_files, update_table): vanadium_alg.return_value = ("A", "B") self.model.create_new_calibration(VANADIUM_NUMBER, CERIUM_NUMBER, False, "ENGINX") @@ -118,19 +122,19 @@ class CalibrationModelTest(unittest.TestCase): 'Engineering.gui.engineering_diffraction.tabs.calibration.model.write_ENGINX_GSAS_iparam_file' ) def test_create_output_files(self, write_file, make_dirs, output_name): - ceria_path = "test2/test3/ENGINX20.nxs" + sample_path = "test2/test3/ENGINX20.nxs" vanadium_path = "test4/ENGINX0010.nxs" filename = "output" output_name.return_value = filename - self.model.create_output_files("test/", [0, 0], [1, 1], ceria_path, vanadium_path, + self.model.create_output_files("test/", [0, 0], [1, 1], sample_path, vanadium_path, "ENGINX") self.assertEqual(make_dirs.call_count, 1) self.assertEqual(write_file.call_count, 3) write_file.assert_called_with("test/" + filename, [0], [1], bank_names=['South'], - ceria_run=ceria_path, + ceria_run=sample_path, template_file="template_ENGINX_241391_236516_South_bank.prm", vanadium_run=vanadium_path) @@ -138,30 +142,32 @@ class CalibrationModelTest(unittest.TestCase): self.assertEqual(self.model._generate_table_workspace_name(20), "engggui_calibration_bank_21") - def test_generate_output_file_name_for_valid_bank(self): - filename = self.model._generate_output_file_name( - "test/20.raw", "test/10.raw", "ENGINX", "north") + def test_generate_output_file_name_for_north_bank(self): + filename = self.model._generate_output_file_name("test/20.raw", "test/10.raw", "ENGINX", + "north") self.assertEqual(filename, "ENGINX_20_10_bank_North.prm") - filename = self.model._generate_output_file_name( - "test/20.raw", "test/10.raw", "ENGINX", "south") + def test_generate_output_file_name_for_south_bank(self): + filename = self.model._generate_output_file_name("test/20.raw", "test/10.raw", "ENGINX", + "south") self.assertEqual(filename, "ENGINX_20_10_bank_South.prm") - filename = self.model._generate_output_file_name( - "test/20.raw", "test/10.raw", "ENGINX", "all") + def test_generate_output_file_name_for_both_banks(self): + filename = self.model._generate_output_file_name("test/20.raw", "test/10.raw", "ENGINX", + "all") self.assertEqual(filename, "ENGINX_20_10_all_banks.prm") def test_generate_output_file_name_for_invalid_bank(self): self.assertRaises(ValueError, self.model._generate_output_file_name, "test/20.raw", "test/10.raw", "ENGINX", "INVALID") - def test_generate_output_file_name_for_unconventional_filename(self): - filename = self.model._generate_output_file_name( - "test/20", "test/10.raw", "ENGINX", "north") + def test_generate_output_file_name_with_no_ext_in_filename(self): + filename = self.model._generate_output_file_name("test/20", "test/10.raw", "ENGINX", + "north") self.assertEqual(filename, "ENGINX_20_10_bank_North.prm") - filename = self.model._generate_output_file_name( - "20", "test/10.raw", "ENGINX", "north") + def test_generate_output_file_name_with_no_path_in_filename(self): + filename = self.model._generate_output_file_name("20.raw", "test/10.raw", "ENGINX", "north") self.assertEqual(filename, "ENGINX_20_10_bank_North.prm") @patch("Engineering.gui.engineering_diffraction.tabs.calibration.model.Ads") diff --git a/scripts/Engineering/gui/engineering_diffraction/tabs/calibration/test/test_calib_presenter.py b/scripts/Engineering/gui/engineering_diffraction/tabs/calibration/test/test_calib_presenter.py index d9f137f975e969f0fea5474c56346b338bf7c2d1..3c9d9d67968e0b6230d61575d45a2243c6e5d775 100644 --- a/scripts/Engineering/gui/engineering_diffraction/tabs/calibration/test/test_calib_presenter.py +++ b/scripts/Engineering/gui/engineering_diffraction/tabs/calibration/test/test_calib_presenter.py @@ -12,6 +12,7 @@ import unittest from mantid.py3compat.mock import patch, MagicMock from mantid.py3compat import mock from Engineering.gui.engineering_diffraction.tabs.calibration import model, view, presenter +from Engineering.gui.engineering_diffraction.tabs.common.calibration_info import CalibrationInfo tab_path = 'Engineering.gui.engineering_diffraction.tabs.calibration' @@ -25,33 +26,35 @@ class CalibrationPresenterTest(unittest.TestCase): @patch(tab_path + ".presenter.CalibrationPresenter.start_calibration_worker") def test_worker_started_with_right_params(self, worker_method): self.view.get_vanadium_filename.return_value = "307521" - self.view.get_calib_filename.return_value = "305738" + self.view.get_sample_filename.return_value = "305738" self.view.get_plot_output.return_value = True self.view.is_searching.return_value = False self.presenter.on_calibrate_clicked() worker_method.assert_called_with("307521", "305738", True, None) - @patch(tab_path + ".presenter.CalibrationPresenter._create_error_message") + @patch(tab_path + ".presenter.create_error_message") @patch(tab_path + ".presenter.CalibrationPresenter.start_calibration_worker") def test_worker_not_started_while_finder_is_searching(self, worker_method, err_msg): self.view.get_vanadium_filename.return_value = "307521" - self.view.get_calib_filename.return_value = "305738" + self.view.get_sample_filename.return_value = "305738" self.view.get_plot_output.return_value = True + self.view.get_load_checked.return_value = False self.view.is_searching.return_value = True self.presenter.on_calibrate_clicked() worker_method.assert_not_called() self.assertEqual(err_msg.call_count, 1) - @patch(tab_path + ".presenter.CalibrationPresenter._create_error_message") + @patch(tab_path + ".presenter.create_error_message") @patch(tab_path + ".presenter.CalibrationPresenter.validate_run_numbers") @patch(tab_path + ".presenter.CalibrationPresenter.start_calibration_worker") def test_worker_not_started_when_run_numbers_invalid(self, worker_method, validator, err_msg): self.view.get_vanadium_filename.return_value = "307521" - self.view.get_calib_filename.return_value = "305738" + self.view.get_sample_filename.return_value = "305738" self.view.get_plot_output.return_value = True self.view.is_searching.return_value = False + self.view.get_load_checked.return_value = False validator.return_value = False self.presenter.on_calibrate_clicked() @@ -81,22 +84,22 @@ class CalibrationPresenterTest(unittest.TestCase): self.assertEqual(emit.call_count, 1) def test_validation_of_run_numbers(self): - self.view.get_calib_valid.return_value = False + self.view.get_sample_valid.return_value = False self.view.get_vanadium_valid.return_value = False result = self.presenter.validate_run_numbers() self.assertFalse(result) - self.view.get_calib_valid.return_value = True + self.view.get_sample_valid.return_value = True self.view.get_vanadium_valid.return_value = False result = self.presenter.validate_run_numbers() self.assertFalse(result) - self.view.get_calib_valid.return_value = False + self.view.get_sample_valid.return_value = False self.view.get_vanadium_valid.return_value = True result = self.presenter.validate_run_numbers() self.assertFalse(result) - self.view.get_calib_valid.return_value = True + self.view.get_sample_valid.return_value = True self.view.get_vanadium_valid.return_value = True result = self.presenter.validate_run_numbers() self.assertTrue(result) @@ -115,58 +118,51 @@ class CalibrationPresenterTest(unittest.TestCase): self.view.set_instrument_override.assert_called_with("IMAT") self.assertEqual(self.presenter.instrument, "IMAT") - def test_set_current_calibration(self): + @patch(tab_path + ".presenter.CalibrationPresenter.emit_update_fields_signal") + def test_set_current_calibration(self, update_sig): self.presenter.calibration_notifier = MagicMock() - pending = { - "vanadium_path": "/test/set/path", - "ceria_path": "test/set/path/2", - "instrument": "TEST_INS" - } + pending = CalibrationInfo(vanadium_path="/test/set/path", + sample_path="test/set/path/2", + instrument="TEST_INS") + pendcpy = CalibrationInfo(vanadium_path="/test/set/path", + sample_path="test/set/path/2", + instrument="TEST_INS") self.presenter.pending_calibration = pending - current = { - "vanadium_path": "old/value", - "ceria_path": "old/cera", - "instrument": "ENGINX" - } - blank = { - "vanadium_path": None, - "ceria_path": None, - "instrument": None - } + current = CalibrationInfo(vanadium_path="old/value", + sample_path="old/cera", + instrument="ENGINX") + blank = CalibrationInfo() self.presenter.current_calibration = current - self.assertEqual(self.presenter.current_calibration, current) self.presenter.set_current_calibration() - self.assertEqual(self.presenter.current_calibration, pending) - self.assertEqual(self.presenter.pending_calibration, blank) + self.check_calibration_equal(self.presenter.current_calibration, pendcpy) + self.check_calibration_equal(self.presenter.pending_calibration, blank) self.assertEqual(self.presenter.calibration_notifier.notify_subscribers.call_count, 1) - self.assertEqual(self.view.set_vanadium_text.call_count, 1) - self.assertEqual(self.view.set_calib_text.call_count, 1) + self.assertEqual(update_sig.call_count, 1) + @patch(tab_path + ".presenter.CalibrationPresenter.emit_update_fields_signal") @patch(tab_path + ".presenter.CalibrationPresenter.validate_path") - def test_calibrate_clicked_load_valid_path(self, path): + def test_calibrate_clicked_load_valid_path(self, path, update): self.presenter.calibration_notifier = MagicMock() self.view.get_new_checked.return_value = False self.view.get_load_checked.return_value = True + path.return_value = True instrument, van, cer = ("test_ins", "test_van", "test_cer") self.model.load_existing_gsas_parameters.return_value = instrument, van, cer - current = { - "vanadium_path": "old/value", - "ceria_path": "old/cera", - "instrument": "ENGINX" - } - new = { - "vanadium_path": van, - "ceria_path": cer, - "instrument": instrument - } + current = CalibrationInfo(vanadium_path="old/value", + sample_path="old/cera", + instrument="ENGINX") + new = CalibrationInfo(vanadium_path=van, sample_path=cer, instrument=instrument) self.presenter.current_calibration = current self.presenter.on_calibrate_clicked() - self.assertEqual(self.presenter.current_calibration, new) + self.assertEqual(update.call_count, 1) + self.assertEqual(self.presenter.current_calibration.get_vanadium(), new.get_vanadium()) + self.assertEqual(self.presenter.current_calibration.get_sample(), new.get_sample()) + self.assertEqual(self.presenter.current_calibration.get_instrument(), new.get_instrument()) self.assertEqual(self.presenter.calibration_notifier.notify_subscribers.call_count, 1) @patch(tab_path + ".presenter.CalibrationPresenter.validate_path") @@ -175,11 +171,9 @@ class CalibrationPresenterTest(unittest.TestCase): self.view.get_new_checked.return_value = False self.view.get_load_checked.return_value = True path.return_value = False - current = { - "vanadium_path": "old/value", - "ceria_path": "old/cera", - "instrument": "ENGINX" - } + current = CalibrationInfo(vanadium_path="old/value", + sample_path="old/cera", + instrument="ENGINX") self.presenter.current_calibration = current self.presenter.on_calibrate_clicked() @@ -192,22 +186,22 @@ class CalibrationPresenterTest(unittest.TestCase): self.assertEqual(self.view.set_vanadium_enabled.call_count, 1) self.view.set_vanadium_enabled.assert_called_with(True) - self.assertEqual(self.view.set_calib_enabled.call_count, 1) - self.view.set_calib_enabled.assert_called_with(True) + self.assertEqual(self.view.set_sample_enabled.call_count, 1) + self.view.set_sample_enabled.assert_called_with(True) self.view.set_calibrate_button_text.assert_called_with("Calibrate") self.view.set_check_plot_output_enabled.assert_called_with(True) - self.assertEqual(self.view.find_calib_files.call_count, 1) + self.assertEqual(self.view.find_sample_files.call_count, 1) def test_create_new_enabled_false(self): self.presenter.set_create_new_enabled(False) self.assertEqual(self.view.set_vanadium_enabled.call_count, 1) self.view.set_vanadium_enabled.assert_called_with(False) - self.assertEqual(self.view.set_calib_enabled.call_count, 1) - self.view.set_calib_enabled.assert_called_with(False) + self.assertEqual(self.view.set_sample_enabled.call_count, 1) + self.view.set_sample_enabled.assert_called_with(False) self.assertEqual(self.view.set_calibrate_button_text.call_count, 0) self.assertEqual(self.view.set_check_plot_output_enabled.call_count, 0) - self.assertEqual(self.view.find_calib_files.call_count, 0) + self.assertEqual(self.view.find_sample_files.call_count, 0) def test_load_existing_enabled_true(self): self.presenter.set_load_existing_enabled(True) @@ -228,22 +222,21 @@ class CalibrationPresenterTest(unittest.TestCase): @patch(tab_path + ".presenter.AsyncTask") def test_start_calibration_worker(self, task): instrument, van, cer = ("test_ins", "test_van", "test_cer") - old_pending = { - "vanadium_path": None, - "ceria_path": None, - "instrument": None - } + old_pending = CalibrationInfo(vanadium_path=None, sample_path=None, instrument=None) self.presenter.pending_calibration = old_pending - expected_pending = { - "vanadium_path": van, - "ceria_path": cer, - "instrument": instrument - } + expected_pending = CalibrationInfo(vanadium_path=van, + sample_path=cer, + instrument=instrument) self.presenter.instrument = instrument self.presenter.start_calibration_worker(van, cer, False, None) - self.assertEqual(self.presenter.pending_calibration, expected_pending) + self.check_calibration_equal(self.presenter.pending_calibration, expected_pending) + + def check_calibration_equal(self, a, b): + self.assertEqual(a.get_vanadium(), b.get_vanadium()) + self.assertEqual(a.get_sample(), b.get_sample()) + self.assertEqual(a.get_instrument(), b.get_instrument()) if __name__ == '__main__': diff --git a/scripts/Engineering/gui/engineering_diffraction/tabs/calibration/view.py b/scripts/Engineering/gui/engineering_diffraction/tabs/calibration/view.py index ae207f36e89323d8addf513b33c3bd83d28154ed..25a9ece34c6f24bc39b33dc6673b43c3b8565cc0 100644 --- a/scripts/Engineering/gui/engineering_diffraction/tabs/calibration/view.py +++ b/scripts/Engineering/gui/engineering_diffraction/tabs/calibration/view.py @@ -15,13 +15,14 @@ Ui_calib, _ = load_ui(__file__, "calibration_tab.ui") class CalibrationView(QtWidgets.QWidget, Ui_calib): sig_enable_controls = QtCore.Signal(bool) + sig_update_fields = QtCore.Signal() def __init__(self, parent=None, instrument="ENGINX"): super(CalibrationView, self).__init__(parent) self.setupUi(self) self.setup_tabbing_order() - self.finder_calib.setLabelText("Calibration Sample #") - self.finder_calib.setInstrumentOverride(instrument) + self.finder_sample.setLabelText("Calibration Sample #") + self.finder_sample.setInstrumentOverride(instrument) self.finder_vanadium.setLabelText("Vanadium #") self.finder_vanadium.setInstrumentOverride(instrument) @@ -36,11 +37,11 @@ class CalibrationView(QtWidgets.QWidget, Ui_calib): def set_on_text_changed(self, slot): self.finder_vanadium.fileTextChanged.connect(slot) - self.finder_calib.fileTextChanged.connect(slot) + self.finder_sample.fileTextChanged.connect(slot) def set_on_finding_files_finished(self, slot): self.finder_vanadium.fileFindingFinished.connect(slot) - self.finder_calib.fileFindingFinished.connect(slot) + self.finder_sample.fileFindingFinished.connect(slot) def set_on_calibrate_clicked(self, slot): self.button_calibrate.clicked.connect(slot) @@ -54,6 +55,9 @@ class CalibrationView(QtWidgets.QWidget, Ui_calib): def set_enable_controls_connection(self, slot): self.sig_enable_controls.connect(slot) + def set_update_fields_connection(self, slot): + self.sig_update_fields.connect(slot) + # ================= # Component Setters # ================= @@ -66,13 +70,13 @@ class CalibrationView(QtWidgets.QWidget, Ui_calib): def set_instrument_override(self, instrument): self.finder_vanadium.setInstrumentOverride(instrument) - self.finder_calib.setInstrumentOverride(instrument) + self.finder_sample.setInstrumentOverride(instrument) def set_vanadium_enabled(self, set_to): self.finder_vanadium.setEnabled(set_to) - def set_calib_enabled(self, set_to): - self.finder_calib.setEnabled(set_to) + def set_sample_enabled(self, set_to): + self.finder_sample.setEnabled(set_to) def set_path_enabled(self, set_to): self.finder_path.setEnabled(set_to) @@ -80,8 +84,8 @@ class CalibrationView(QtWidgets.QWidget, Ui_calib): def set_vanadium_text(self, text): self.finder_vanadium.setText(text) - def set_calib_text(self, text): - self.finder_calib.setText(text) + def set_sample_text(self, text): + self.finder_sample.setText(text) def set_calibrate_button_text(self, text): self.button_calibrate.setText(text) @@ -96,17 +100,17 @@ class CalibrationView(QtWidgets.QWidget, Ui_calib): def get_vanadium_valid(self): return self.finder_vanadium.isValid() - def get_calib_filename(self): - return self.finder_calib.getFirstFilename() + def get_sample_filename(self): + return self.finder_sample.getFirstFilename() - def get_calib_valid(self): - return self.finder_calib.isValid() + def get_sample_valid(self): + return self.finder_sample.isValid() def get_path_filename(self): return self.finder_path.getFirstFilename() def get_path_valid(self): - return self.finder_path.isValid() + return self.finder_path.isValid() and self.finder_path.getText() def get_plot_output(self): return self.check_plotOutput.isChecked() @@ -122,14 +126,14 @@ class CalibrationView(QtWidgets.QWidget, Ui_calib): # ================= def is_searching(self): - return self.finder_calib.isSearching() or self.finder_calib.isSearching() + return self.finder_sample.isSearching() or self.finder_sample.isSearching() # ================= # Force Actions # ================= - def find_calib_files(self): - self.finder_calib.findFiles(True) + def find_sample_files(self): + self.finder_sample.findFiles(True) def find_vanadium_files(self): self.finder_vanadium.findFiles(True) @@ -140,6 +144,6 @@ class CalibrationView(QtWidgets.QWidget, Ui_calib): def setup_tabbing_order(self): self.setTabOrder(self.radio_newCalib, self.finder_vanadium) - self.setTabOrder(self.finder_vanadium, self.finder_calib) - self.setTabOrder(self.finder_calib, self.check_plotOutput) + self.setTabOrder(self.finder_vanadium, self.finder_sample) + self.setTabOrder(self.finder_sample, self.check_plotOutput) self.setTabOrder(self.check_plotOutput, self.button_calibrate) diff --git a/scripts/Engineering/gui/engineering_diffraction/tabs/common/__init__.py b/scripts/Engineering/gui/engineering_diffraction/tabs/common/__init__.py index 2dc972abc74d2c1e1d812fee5eb92e8f5213c5cd..de7c2f4de997b16bf8bc735621603badef51f4b5 100644 --- a/scripts/Engineering/gui/engineering_diffraction/tabs/common/__init__.py +++ b/scripts/Engineering/gui/engineering_diffraction/tabs/common/__init__.py @@ -1,6 +1,11 @@ """ Holds some common constants across all tabs. """ +from qtpy.QtWidgets import QMessageBox # Dictionary of indexes for instruments. INSTRUMENT_DICT = {0: "ENGINX", 1: "IMAT"} + + +def create_error_message(parent, message): + QMessageBox.warning(parent, "Engineering Diffraction - Error", str(message)) diff --git a/scripts/Engineering/gui/engineering_diffraction/tabs/common/calibration_info.py b/scripts/Engineering/gui/engineering_diffraction/tabs/common/calibration_info.py new file mode 100644 index 0000000000000000000000000000000000000000..a803e2fdb487f5e0de468cd01943b673ed306567 --- /dev/null +++ b/scripts/Engineering/gui/engineering_diffraction/tabs/common/calibration_info.py @@ -0,0 +1,46 @@ +# Mantid Repository : https://github.com/mantidproject/mantid +# +# Copyright © 2019 ISIS Rutherford Appleton Laboratory UKRI, +# NScD Oak Ridge National Laboratory, European Spallation Source +# & Institut Laue - Langevin +# SPDX - License - Identifier: GPL - 3.0 + + +from __future__ import (absolute_import, division, print_function) + + +class CalibrationInfo(object): + """ + Keeps track of the parameters that went into a calibration created by the engineering diffraction GUI. + """ + def __init__(self, vanadium_path=None, sample_path=None, instrument=None): + self.vanadium_path = vanadium_path + self.sample_path = sample_path + self.instrument = instrument + + def set_calibration(self, vanadium_path, sample_path, instrument): + """ + Set the values of the calibration. requires a complete set of calibration info to be supplied. + :param vanadium_path: Path to the vanadium data file used in the calibration. + :param sample_path: Path to the sample data file used. + :param instrument: String defining the instrument the data came from. + """ + self.vanadium_path = vanadium_path + self.sample_path = sample_path + self.instrument = instrument + + def get_vanadium(self): + return self.vanadium_path + + def get_sample(self): + return self.sample_path + + def get_instrument(self): + return self.instrument + + def clear(self): + self.vanadium_path = None + self.sample_path = None + self.instrument = None + + def is_valid(self): + return True if self.vanadium_path and self.sample_path and self.instrument else False diff --git a/scripts/Engineering/gui/engineering_diffraction/tabs/common/vanadium_corrections.py b/scripts/Engineering/gui/engineering_diffraction/tabs/common/vanadium_corrections.py index d0a2914e942a180c2020c3b0f96c0071e979232e..2f79c91c4990f585815bb7e6186540be7c01085a 100644 --- a/scripts/Engineering/gui/engineering_diffraction/tabs/common/vanadium_corrections.py +++ b/scripts/Engineering/gui/engineering_diffraction/tabs/common/vanadium_corrections.py @@ -25,11 +25,12 @@ SAVED_FILE_CURVE_SUFFIX = "_precalculated_vanadium_run_bank_curves.nxs" SAVED_FILE_INTEG_SUFFIX = "_precalculated_vanadium_run_integration.nxs" -def fetch_correction_workspaces(vanadium_path, instrument): +def fetch_correction_workspaces(vanadium_path, instrument, rb_num=""): """ Fetch workspaces from the file system or create new ones. :param vanadium_path: The path to the requested vanadium run raw data. :param instrument: The instrument the data came from. + :param rb_num: A user identifier, usually an experiment number. :return: The resultant integration and curves workspaces. """ vanadium_number = path_handling.get_run_number_from_path(vanadium_path, instrument) @@ -38,6 +39,12 @@ def fetch_correction_workspaces(vanadium_path, instrument): try: integ_workspace = Load(Filename=integ_path, OutputWorkspace=INTEGRATED_WORKSPACE_NAME) curves_workspace = Load(Filename=curves_path, OutputWorkspace=CURVES_WORKSPACE_NAME) + if rb_num: + user_integ, user_curves = _generate_saved_workspace_file_paths(vanadium_number, + rb_num=rb_num) + if not path.exists(user_integ) and not path.exists(user_curves): + _save_correction_files(integ_workspace, user_integ, curves_workspace, + user_curves) return integ_workspace, curves_workspace except RuntimeError as e: logger.error( @@ -45,6 +52,10 @@ def fetch_correction_workspaces(vanadium_path, instrument): + str(e)) integ_workspace, curves_workspace = _calculate_vanadium_correction(vanadium_path) _save_correction_files(integ_workspace, integ_path, curves_workspace, curves_path) + if rb_num: + user_integ, user_curves = _generate_saved_workspace_file_paths(vanadium_number, + rb_num=rb_num) + _save_correction_files(integ_workspace, user_integ, curves_workspace, user_curves) return integ_workspace, curves_workspace @@ -92,15 +103,20 @@ def _save_correction_files(integration_workspace, integration_path, curves_works return -def _generate_saved_workspace_file_paths(vanadium_number): +def _generate_saved_workspace_file_paths(vanadium_number, rb_num=""): """ Generate file paths based on a vanadium run number. :param vanadium_number: The number of the vanadium run. + :param rb_num: User identifier, usually an experiment number. :return: The full path to the file. """ integrated_filename = vanadium_number + SAVED_FILE_INTEG_SUFFIX curves_filename = vanadium_number + SAVED_FILE_CURVE_SUFFIX - vanadium_dir = path.join(path_handling.OUT_FILES_ROOT_DIR, VANADIUM_DIRECTORY_NAME) + if rb_num: + vanadium_dir = path.join(path_handling.OUT_FILES_ROOT_DIR, "User", rb_num, + VANADIUM_DIRECTORY_NAME) + else: + vanadium_dir = path.join(path_handling.OUT_FILES_ROOT_DIR, VANADIUM_DIRECTORY_NAME) if not path.exists(vanadium_dir): makedirs(vanadium_dir) return path.join(vanadium_dir, integrated_filename), path.join(vanadium_dir, curves_filename) diff --git a/scripts/Engineering/gui/engineering_diffraction/tabs/focus/model.py b/scripts/Engineering/gui/engineering_diffraction/tabs/focus/model.py index f41d7f526c49caa343e28e55244d8f71298a8c47..9b67f260aba5d3b9110b2f22c71bc8c6509d1b0a 100644 --- a/scripts/Engineering/gui/engineering_diffraction/tabs/focus/model.py +++ b/scripts/Engineering/gui/engineering_diffraction/tabs/focus/model.py @@ -8,24 +8,25 @@ from __future__ import (absolute_import, division, print_function) from os import path +from matplotlib import gridspec +import matplotlib.pyplot as plt from Engineering.gui.engineering_diffraction.tabs.common import vanadium_corrections, path_handling -from mantid.simpleapi import EnggFocus, Load, logger, AnalysisDataService as Ads, SaveNexus -from mantidqt.plotting.functions import plot +from mantid.simpleapi import EnggFocus, Load, logger, AnalysisDataService as Ads, SaveNexus, SaveGSS, SaveFocusedXYE SAMPLE_RUN_WORKSPACE_NAME = "engggui_focusing_input_ws" FOCUSED_OUTPUT_WORKSPACE_NAME = "engggui_focusing_output_ws_bank_" class FocusModel(object): - def focus_run(self, sample_path, banks, plot_output, instrument, rb_number): + def focus_run(self, sample_path, banks, plot_output, instrument, rb_num): """ Focus some data using the current calibration. :param sample_path: The path to the data to be focused. :param banks: The banks that should be focused. :param plot_output: True if the output should be plotted. :param instrument: The instrument that the data came from. - :param rb_number: The experiment number, used to create directories. Can be None + :param rb_num: The experiment number, used to create directories. Can be None """ if not Ads.doesExist(vanadium_corrections.INTEGRATED_WORKSPACE_NAME) and not Ads.doesExist( vanadium_corrections.CURVES_WORKSPACE_NAME): @@ -33,16 +34,17 @@ class FocusModel(object): integration_workspace = Ads.retrieve(vanadium_corrections.INTEGRATED_WORKSPACE_NAME) curves_workspace = Ads.retrieve(vanadium_corrections.CURVES_WORKSPACE_NAME) sample_workspace = self._load_focus_sample_run(sample_path) + output_workspaces = [] for name in banks: output_workspace_name = FOCUSED_OUTPUT_WORKSPACE_NAME + str(name) self._run_focus(sample_workspace, output_workspace_name, integration_workspace, curves_workspace, name) - # Plot the output - if plot_output: - self._plot_focused_workspace(output_workspace_name) + output_workspaces.append(output_workspace_name) # Save the output to the file system. - self._save_focused_output_files_as_nexus(instrument, sample_path, name, - output_workspace_name, rb_number) + self._save_output(instrument, sample_path, name, output_workspace_name, rb_num) + # Plot the output + if plot_output: + self._plot_focused_workspaces(output_workspaces) @staticmethod def _run_focus(input_workspace, output_workspace, vanadium_integration_ws, vanadium_curves_ws, @@ -70,31 +72,73 @@ class FocusModel(object): raise RuntimeError @staticmethod - def _plot_focused_workspace(focused): - focused_wsp = Ads.retrieve(focused) - plot([focused_wsp], wksp_indices=[0]) + def _plot_focused_workspaces(focused_workspaces): + fig = plt.figure() + gs = gridspec.GridSpec(1, len(focused_workspaces)) + plots = [ + fig.add_subplot(gs[i], projection="mantid") for i in range(len(focused_workspaces)) + ] - def _save_focused_output_files_as_nexus(self, instrument, sample_path, bank, sample_workspace, - rb_number): + for ax, ws_name in zip(plots, focused_workspaces): + ax.plot(Ads.retrieve(ws_name), wkspIndex=0) + ax.set_title(ws_name) + fig.show() + + def _save_output(self, instrument, sample_path, bank, sample_workspace, rb_num): """ - Save a focused workspace to the file system. Saves a separate copy to a User directory if an rb number has been - set. + Save a focused workspace to the file system. Saves separate copies to a User directory if an rb number has + been set. :param instrument: The instrument the data is from. :param sample_path: The path to the data file that was focused. :param bank: The name of the bank being saved. :param sample_workspace: The name of the workspace to be saved. - :param rb_number: Usually an experiment id, defines the name of the user directory. + :param rb_num: Usually an experiment id, defines the name of the user directory. """ + self._save_focused_output_files_as_nexus(instrument, sample_path, bank, sample_workspace, + rb_num) + self._save_focused_output_files_as_gss(instrument, sample_path, bank, sample_workspace, + rb_num) + self._save_focused_output_files_as_xye(instrument, sample_path, bank, sample_workspace, + rb_num) + + def _save_focused_output_files_as_gss(self, instrument, sample_path, bank, sample_workspace, + rb_num): + gss_output_path = path.join( + path_handling.OUT_FILES_ROOT_DIR, "Focus", + self._generate_output_file_name(instrument, sample_path, bank, ".gss")) + SaveGSS(InputWorkspace=sample_workspace, Filename=gss_output_path) + if rb_num is not None: + gss_output_path = path.join( + path_handling.OUT_FILES_ROOT_DIR, "User", rb_num, "Focus", + self._generate_output_file_name(instrument, sample_path, bank, ".gss")) + SaveGSS(InputWorkspace=sample_workspace, Filename=gss_output_path) + + def _save_focused_output_files_as_nexus(self, instrument, sample_path, bank, sample_workspace, + rb_num): nexus_output_path = path.join( path_handling.OUT_FILES_ROOT_DIR, "Focus", self._generate_output_file_name(instrument, sample_path, bank, ".nxs")) SaveNexus(InputWorkspace=sample_workspace, Filename=nexus_output_path) - if rb_number is not None: + if rb_num is not None: nexus_output_path = path.join( - path_handling.OUT_FILES_ROOT_DIR, "User", rb_number, "Focus", + path_handling.OUT_FILES_ROOT_DIR, "User", rb_num, "Focus", self._generate_output_file_name(instrument, sample_path, bank, ".nxs")) SaveNexus(InputWorkspace=sample_workspace, Filename=nexus_output_path) + def _save_focused_output_files_as_xye(self, instrument, sample_path, bank, sample_workspace, + rb_num): + xye_output_path = path.join( + path_handling.OUT_FILES_ROOT_DIR, "Focus", + self._generate_output_file_name(instrument, sample_path, bank, ".dat")) + SaveFocusedXYE(InputWorkspace=sample_workspace, Filename=xye_output_path, SplitFiles=False) + if rb_num is not None: + xye_output_path = path.join( + path_handling.OUT_FILES_ROOT_DIR, "User", rb_num, "Focus", + self._generate_output_file_name(instrument, sample_path, bank, ".dat")) + SaveFocusedXYE(InputWorkspace=sample_workspace, + Filename=xye_output_path, + SplitFiles=False) + @staticmethod def _generate_output_file_name(instrument, sample_path, bank, suffix): run_no = path_handling.get_run_number_from_path(sample_path, instrument) diff --git a/scripts/Engineering/gui/engineering_diffraction/tabs/focus/presenter.py b/scripts/Engineering/gui/engineering_diffraction/tabs/focus/presenter.py index 2305db608716b87bee694f3a50396344579a5142..6a6a3ae77f5ffb33e9171ca0183038b90c4655cc 100644 --- a/scripts/Engineering/gui/engineering_diffraction/tabs/focus/presenter.py +++ b/scripts/Engineering/gui/engineering_diffraction/tabs/focus/presenter.py @@ -7,9 +7,8 @@ # pylint: disable=invalid-name from __future__ import (absolute_import, division, print_function) -from qtpy.QtWidgets import QMessageBox - -from Engineering.gui.engineering_diffraction.tabs.common import INSTRUMENT_DICT +from Engineering.gui.engineering_diffraction.tabs.common import INSTRUMENT_DICT, create_error_message +from Engineering.gui.engineering_diffraction.tabs.common.calibration_info import CalibrationInfo from Engineering.gui.engineering_diffraction.tabs.common.vanadium_corrections import check_workspaces_exist from mantidqt.utils.asynchronous import AsyncTask from mantidqt.utils.observer_pattern import Observer @@ -28,7 +27,7 @@ class FocusPresenter(object): self.view.set_enable_controls_connection(self.set_focus_controls_enabled) # Variables from other GUI tabs. - self.current_calibration = {"vanadium_path": None, "ceria_path": None} + self.current_calibration = CalibrationInfo() self.instrument = "ENGINX" self.rb_num = None @@ -45,7 +44,7 @@ class FocusPresenter(object): :param focus_path: The path to the file containing the data to focus. :param banks: A list of banks that are to be focused. :param plot_output: True if the output should be plotted. - :param rb_num: The rb_number from the main window (often an experiment id) + :param rb_num: The RB Number from the main window (often an experiment id) """ self.worker = AsyncTask(self.model.focus_run, (focus_path, banks, plot_output, self.instrument, rb_num), @@ -59,8 +58,8 @@ class FocusPresenter(object): self.view.set_instrument_override(instrument) self.instrument = instrument - def set_rb_number(self, rb_number): - self.rb_num = rb_number + def set_rb_num(self, rb_num): + self.rb_num = rb_num def _validate(self, banks): """ @@ -68,22 +67,28 @@ class FocusPresenter(object): :param banks: A list of banks to focus. :return: True if the worker can be started safely. """ + if self.view.is_searching(): + create_error_message(self.view, "Mantid is searching for data files. Please wait.") + return False if not self.view.get_focus_valid(): + create_error_message(self.view, "Check run numbers/path is valid.") return False - if self.current_calibration["vanadium_path"] is None or not check_workspaces_exist(): - self._create_error_message( - "Load a calibration from the Calibration tab before focusing.") + if not check_workspaces_exist() or not self.current_calibration.is_valid(): + create_error_message( + self.view, "Create or Load a calibration via the Calibration tab before focusing.") return False - if self.view.is_searching(): + if self.current_calibration.get_instrument() != self.instrument: + create_error_message( + self.view, + "Please make sure the selected instrument matches instrument for the current calibration.\n" + "The instrument for the current calibration is: " + + self.current_calibration.get_instrument()) return False if len(banks) == 0: - self._create_error_message("Please select at least one bank.") + create_error_message(self.view, "Please select at least one bank.") return False return True - def _create_error_message(self, message): - QMessageBox.warning(self.view, "Engineering Diffraction - Error", str(message)) - def _on_worker_error(self, failure_info): logger.warning(str(failure_info)) self.emit_enable_button_signal() diff --git a/scripts/Engineering/gui/engineering_diffraction/tabs/focus/test/test_focus_model.py b/scripts/Engineering/gui/engineering_diffraction/tabs/focus/test/test_focus_model.py index ff406bb710deb2b7aed7f1098925ab93b4ec09dd..53b8456052a5e0b8d271ec15a95631b49e75e377 100644 --- a/scripts/Engineering/gui/engineering_diffraction/tabs/focus/test/test_focus_model.py +++ b/scripts/Engineering/gui/engineering_diffraction/tabs/focus/test/test_focus_model.py @@ -13,6 +13,7 @@ from os import path from mantid.py3compat.mock import patch from Engineering.gui.engineering_diffraction.tabs.focus import model from Engineering.gui.engineering_diffraction.tabs.common import path_handling +from Engineering.gui.engineering_diffraction.tabs.common.calibration_info import CalibrationInfo file_path = "Engineering.gui.engineering_diffraction.tabs.focus.model" @@ -20,10 +21,9 @@ file_path = "Engineering.gui.engineering_diffraction.tabs.focus.model" class FocusModelTest(unittest.TestCase): def setUp(self): self.model = model.FocusModel() - self.current_calibration = { - "vanadium_path": "/mocked/out/anyway", - "ceria_path": "this_is_mocked_out_too" - } + self.current_calibration = CalibrationInfo(vanadium_path="/mocked/out/anyway", + sample_path="this_is_mocked_out_too", + instrument="ENGINX") @patch(file_path + ".FocusModel._load_focus_sample_run") @patch(file_path + ".vanadium_corrections.Ads.doesExist") @@ -33,7 +33,7 @@ class FocusModelTest(unittest.TestCase): self.assertEqual(load.call_count, 0) @patch(file_path + ".Ads") - @patch(file_path + ".FocusModel._save_focused_output_files_as_nexus") + @patch(file_path + ".FocusModel._save_output") @patch(file_path + ".FocusModel._run_focus") @patch(file_path + ".FocusModel._load_focus_sample_run") def test_focus_run_for_each_bank(self, load_focus, run_focus, output, ads): @@ -44,27 +44,28 @@ class FocusModelTest(unittest.TestCase): self.model.focus_run("305761", banks, False, "ENGINX", "0") self.assertEqual(len(banks), run_focus.call_count) run_focus.assert_called_with("mocked_sample", - model.FOCUSED_OUTPUT_WORKSPACE_NAME + banks[-1], - "test_wsp", "test_wsp", banks[-1]) + model.FOCUSED_OUTPUT_WORKSPACE_NAME + banks[-1], "test_wsp", + "test_wsp", banks[-1]) @patch(file_path + ".Ads") - @patch(file_path + ".FocusModel._save_focused_output_files_as_nexus") - @patch(file_path + ".FocusModel._plot_focused_workspace") + @patch(file_path + ".FocusModel._save_output") + @patch(file_path + ".FocusModel._plot_focused_workspaces") @patch(file_path + ".FocusModel._run_focus") @patch(file_path + ".FocusModel._load_focus_sample_run") @patch(file_path + ".vanadium_corrections.fetch_correction_workspaces") - def test_focus_plotted_when_checked(self, fetch_van, load_focus, run_focus, plot_focus, output, ads): + def test_focus_plotted_when_checked(self, fetch_van, load_focus, run_focus, plot_focus, output, + ads): ads.doesExist.return_value = True fetch_van.return_value = ("mocked_integ", "mocked_curves") banks = ["1", "2"] load_focus.return_value = "mocked_sample" self.model.focus_run("305761", banks, True, "ENGINX", "0") - self.assertEqual(len(banks), plot_focus.call_count) + self.assertEqual(1, plot_focus.call_count) @patch(file_path + ".Ads") - @patch(file_path + ".FocusModel._save_focused_output_files_as_nexus") - @patch(file_path + ".FocusModel._plot_focused_workspace") + @patch(file_path + ".FocusModel._save_output") + @patch(file_path + ".FocusModel._plot_focused_workspaces") @patch(file_path + ".FocusModel._run_focus") @patch(file_path + ".FocusModel._load_focus_sample_run") @patch(file_path + ".vanadium_corrections.fetch_correction_workspaces") @@ -77,23 +78,31 @@ class FocusModelTest(unittest.TestCase): self.model.focus_run("305761", banks, False, "ENGINX", "0") self.assertEqual(0, plot_focus.call_count) + @patch(file_path + ".SaveFocusedXYE") + @patch(file_path + ".SaveGSS") @patch(file_path + ".SaveNexus") - def test_save_output_files_nexus_with_no_RB_number(self, save): + def test_save_output_files_with_no_RB_number(self, nexus, gss, xye): mocked_workspace = "mocked-workspace" output_file = path.join(path_handling.OUT_FILES_ROOT_DIR, "Focus", "ENGINX_123_bank_North.nxs") - self.model._save_focused_output_files_as_nexus("ENGINX", "Path/To/ENGINX000123.whatever", - "North", mocked_workspace, None) - self.assertEqual(1, save.call_count) - save.assert_called_with(Filename=output_file, InputWorkspace=mocked_workspace) + self.model._save_output("ENGINX", "Path/To/ENGINX000123.whatever", "North", + mocked_workspace, None) + self.assertEqual(1, nexus.call_count) + self.assertEqual(1, gss.call_count) + self.assertEqual(1, xye.call_count) + nexus.assert_called_with(Filename=output_file, InputWorkspace=mocked_workspace) + + @patch(file_path + ".SaveFocusedXYE") + @patch(file_path + ".SaveGSS") @patch(file_path + ".SaveNexus") - def test_save_output_files_nexus_with_RB_number(self, save): - self.model._save_focused_output_files_as_nexus("ENGINX", "Path/To/ENGINX000123.whatever", - "North", "mocked-workspace", - "An Experiment Number") - self.assertEqual(2, save.call_count) + def test_save_output_files_with_RB_number(self, nexus, gss, xye): + self.model._save_output("ENGINX", "Path/To/ENGINX000123.whatever", "North", + "mocked-workspace", "An Experiment Number") + self.assertEqual(nexus.call_count, 2) + self.assertEqual(gss.call_count, 2) + self.assertEqual(xye.call_count, 2) if __name__ == '__main__': diff --git a/scripts/Engineering/gui/engineering_diffraction/tabs/focus/test/test_focus_presenter.py b/scripts/Engineering/gui/engineering_diffraction/tabs/focus/test/test_focus_presenter.py index cd083931c24511260f93fd3322f288c6f7230905..f6e64c7bfb5e6eb22baee97784dd3387d5d72203 100644 --- a/scripts/Engineering/gui/engineering_diffraction/tabs/focus/test/test_focus_presenter.py +++ b/scripts/Engineering/gui/engineering_diffraction/tabs/focus/test/test_focus_presenter.py @@ -12,6 +12,7 @@ import unittest from mantid.py3compat import mock from mantid.py3compat.mock import patch from Engineering.gui.engineering_diffraction.tabs.focus import model, view, presenter +from Engineering.gui.engineering_diffraction.tabs.common.calibration_info import CalibrationInfo tab_path = "Engineering.gui.engineering_diffraction.tabs.focus" @@ -25,10 +26,9 @@ class FocusPresenterTest(unittest.TestCase): @patch(tab_path + ".presenter.check_workspaces_exist") @patch(tab_path + ".presenter.FocusPresenter.start_focus_worker") def test_worker_started_with_correct_params(self, worker, wsp_exists): - self.presenter.current_calibration = { - "vanadium_path": "Fake/Path", - "ceria_path": "Fake/Path" - } + self.presenter.current_calibration = CalibrationInfo(vanadium_path="Fake/Path", + sample_path="Fake/Path", + instrument="ENGINX") self.view.get_focus_filename.return_value = "305738" self.view.get_north_bank.return_value = False self.view.get_south_bank.return_value = True @@ -93,48 +93,52 @@ class FocusPresenterTest(unittest.TestCase): self.assertEqual([], self.presenter._get_banks()) - def test_validate_with_invalid_focus_path(self): + @patch(tab_path + ".presenter.create_error_message") + def test_validate_with_invalid_focus_path(self, error_message): self.view.get_focus_valid.return_value = False banks = ["North", "South"] self.assertFalse(self.presenter._validate(banks)) + self.assertEqual(error_message.call_count, 1) - @patch(tab_path + ".presenter.FocusPresenter._create_error_message") + @patch(tab_path + ".presenter.create_error_message") def test_validate_with_invalid_calibration(self, create_error): - self.presenter.current_calibration = {"vanadium_path": None, "ceria_path": None} + self.presenter.current_calibration = CalibrationInfo(vanadium_path=None, + sample_path=None, + instrument=None) + self.view.is_searching.return_value = False banks = ["North", "South"] self.presenter._validate(banks) create_error.assert_called_with( - "Load a calibration from the Calibration tab before focusing.") + self.presenter.view, + "Create or Load a calibration via the Calibration tab before focusing.") @patch(tab_path + ".presenter.check_workspaces_exist") - @patch(tab_path + ".presenter.FocusPresenter._create_error_message") + @patch(tab_path + ".presenter.create_error_message") def test_validate_while_searching(self, create_error, wsp_check): - self.presenter.current_calibration = { - "vanadium_path": "Fake/File/Path", - "ceria_path": "Fake/Path" - } + self.presenter.current_calibration = CalibrationInfo(vanadium_path="Fake/File/Path", + sample_path="Fake/Path", + instrument="ENGINX") self.view.is_searching.return_value = True wsp_check.return_value = True banks = ["North", "South"] self.assertEqual(False, self.presenter._validate(banks)) - self.assertEqual(0, create_error.call_count) + self.assertEqual(1, create_error.call_count) @patch(tab_path + ".presenter.check_workspaces_exist") - @patch(tab_path + ".presenter.FocusPresenter._create_error_message") + @patch(tab_path + ".presenter.create_error_message") def test_validate_with_no_banks_selected(self, create_error, wsp_check): - self.presenter.current_calibration = { - "vanadium_path": "Fake/Path", - "ceria_path": "Fake/Path" - } + self.presenter.current_calibration = CalibrationInfo(vanadium_path="Fake/Path", + sample_path="Fake/Path", + instrument="ENGINX") self.view.is_searching.return_value = False banks = [] wsp_check.return_value = True self.presenter._validate(banks) - create_error.assert_called_with("Please select at least one bank.") + create_error.assert_called_with(self.presenter.view, "Please select at least one bank.") if __name__ == '__main__': diff --git a/scripts/Engineering/gui/engineering_diffraction/tabs/focus/view.py b/scripts/Engineering/gui/engineering_diffraction/tabs/focus/view.py index 4cb38f864b43bfa2b52c79d200edca8d1ad6a9c6..9d6d74b8914107776306e96667d8f0043cd9a807 100644 --- a/scripts/Engineering/gui/engineering_diffraction/tabs/focus/view.py +++ b/scripts/Engineering/gui/engineering_diffraction/tabs/focus/view.py @@ -23,12 +23,20 @@ class FocusView(QtWidgets.QWidget, Ui_focus): self.finder_focus.setLabelText("Sample Run #") self.finder_focus.setInstrumentOverride(instrument) + # ================= + # Slot Connectors + # ================= + def set_on_focus_clicked(self, slot): self.button_focus.clicked.connect(slot) def set_enable_controls_connection(self, slot): self.sig_enable_controls.connect(slot) + # ================= + # Component Setters + # ================= + def set_instrument_override(self, instrument): self.finder_focus.setInstrumentOverride(instrument) @@ -38,15 +46,16 @@ class FocusView(QtWidgets.QWidget, Ui_focus): def set_plot_output_enabled(self, enabled): self.check_plotOutput.setEnabled(enabled) + # ================= + # Component Getters + # ================= + def get_focus_filename(self): return self.finder_focus.getFirstFilename() def get_focus_valid(self): return self.finder_focus.isValid() - def is_searching(self): - return self.finder_focus.isSearching() - def get_north_bank(self): return self.check_northBank.isChecked() @@ -55,3 +64,10 @@ class FocusView(QtWidgets.QWidget, Ui_focus): def get_plot_output(self): return self.check_plotOutput.isChecked() + + # ================= + # State Getters + # ================= + + def is_searching(self): + return self.finder_focus.isSearching() diff --git a/scripts/Interface/ui/sans_isis/SANSSaveOtherWindow.py b/scripts/Interface/ui/sans_isis/SANSSaveOtherWindow.py index dabd7d061ea1dfc655c2044eed5c3a6c3dca06c7..07791eb906394dc6d858df9859a2ddb272650338 100644 --- a/scripts/Interface/ui/sans_isis/SANSSaveOtherWindow.py +++ b/scripts/Interface/ui/sans_isis/SANSSaveOtherWindow.py @@ -91,9 +91,9 @@ class SANSSaveOtherDialog(QtWidgets.QDialog, Ui_SaveOtherDialog): if self.RKH_checkBox.isChecked(): save_types.append(SaveType.RKH) if self.nxcansas_checkBox.isChecked(): - save_types.append(SaveType.NXcanSAS) + save_types.append(SaveType.NX_CAN_SAS) if self.CanSAS_checkBox.isChecked(): - save_types.append(SaveType.CanSAS) + save_types.append(SaveType.CAN_SAS) return save_types def launch_file_browser(self, current_directory): diff --git a/scripts/Interface/ui/sans_isis/sans_data_processor_gui.py b/scripts/Interface/ui/sans_isis/sans_data_processor_gui.py index 7fd7aecba1fd1a93de29a81a2d26d217a6233bee..fac3ec6373c14b5980eb48e9cad0b02ac6744002 100644 --- a/scripts/Interface/ui/sans_isis/sans_data_processor_gui.py +++ b/scripts/Interface/ui/sans_isis/sans_data_processor_gui.py @@ -11,7 +11,6 @@ from __future__ import (absolute_import, division, print_function) from abc import ABCMeta, abstractmethod -from inspect import isclass from qtpy import PYQT4 from qtpy.QtCore import QRegExp from qtpy.QtGui import (QDoubleValidator, QIntValidator, QRegExpValidator) @@ -24,9 +23,10 @@ from mantidqt.widgets import jobtreeview, manageuserdirectories from six import with_metaclass from mantid.kernel import (Logger, UsageService, FeatureType) +from mantid.py3compat import Enum from reduction_gui.reduction.scripter import execute_script -from sans.common.enums import (BinningType, ReductionDimensionality, OutputMode, SaveType, SANSInstrument, - RangeStepType, ReductionMode, FitType) +from sans.common.enums import (BinningType, ReductionDimensionality, OutputMode, SaveType, RangeStepType, ReductionMode, + FitType, SANSInstrument) from sans.common.file_information import SANSFileInformationFactory from sans.gui_logic.gui_common import (get_reduction_mode_from_gui_selection, get_reduction_mode_strings_for_gui, @@ -47,6 +47,7 @@ if PYQT4: IN_MANTIDPLOT = False try: from pymantidplot import proxies + IN_MANTIDPLOT = True except ImportError: # We are not in MantidPlot e.g. testing @@ -76,9 +77,9 @@ class RunSelectorPresenterFactory(object): def _make_run_summation_settings_presenter(summation_settings_view, parent_view, instrument): if instrument != "LOQ": - binning_type = BinningType.SaveAsEventData + binning_type = BinningType.SAVE_AS_EVENT_DATA else: - binning_type = BinningType.Custom + binning_type = BinningType.CUSTOM summation_settings = SummationSettings(binning_type) summation_settings.bin_settings = DEFAULT_BIN_SETTINGS return SummationSettingsPresenter(summation_settings, @@ -217,13 +218,12 @@ class SANSDataProcessorGui(QMainWindow, self._has_monitor_5 = False # Instrument - SANSDataProcessorGui.INSTRUMENTS = ",".join([SANSInstrument.to_string(item) - for item in [SANSInstrument.SANS2D, - SANSInstrument.LOQ, - SANSInstrument.LARMOR, - SANSInstrument.ZOOM]]) + SANSDataProcessorGui.INSTRUMENTS = ",".join(item.value for item in [SANSInstrument.SANS2D, + SANSInstrument.LOQ, + SANSInstrument.LARMOR, + SANSInstrument.ZOOM]) - self.instrument = SANSInstrument.NoInstrument + self.instrument = SANSInstrument.NO_INSTRUMENT self.paste_button.setIcon(icons.get_icon("mdi.content-paste")) self.copy_button.setIcon(icons.get_icon("mdi.content-copy")) @@ -409,16 +409,16 @@ class SANSDataProcessorGui(QMainWindow, return True def _on_wavelength_step_type_changed(self): - if self.wavelength_step_type == RangeStepType.RangeLin: + if self.wavelength_step_type == RangeStepType.RANGE_LIN: self.wavelength_stacked_widget.setCurrentIndex(1) self.wavelength_step_label.setText(u'Step [\u00c5^-1]') - elif self.wavelength_step_type == RangeStepType.RangeLog: + elif self.wavelength_step_type == RangeStepType.RANGE_LOG: self.wavelength_stacked_widget.setCurrentIndex(1) self.wavelength_step_label.setText(u'Step [d\u03BB/\u03BB]') - elif self.wavelength_step_type == RangeStepType.Log: + elif self.wavelength_step_type == RangeStepType.LOG: self.wavelength_stacked_widget.setCurrentIndex(0) self.wavelength_step_label.setText(u'Step [d\u03BB/\u03BB]') - elif self.wavelength_step_type == RangeStepType.Lin: + elif self.wavelength_step_type == RangeStepType.LIN: self.wavelength_stacked_widget.setCurrentIndex(0) self.wavelength_step_label.setText(u'Step [\u00c5^-1]') @@ -491,22 +491,22 @@ class SANSDataProcessorGui(QMainWindow, """ Process runs """ - UsageService.registerFeatureUsage(FeatureType.Feature, ["ISIS SANS","Process Selected"], False) + UsageService.registerFeatureUsage(FeatureType.Feature, ["ISIS SANS", "Process Selected"], False) self._call_settings_listeners(lambda listener: listener.on_process_selected_clicked()) def _process_all_clicked(self): """ Process All button clicked """ - UsageService.registerFeatureUsage(FeatureType.Feature, ["ISIS SANS","Process All"], False) + UsageService.registerFeatureUsage(FeatureType.Feature, ["ISIS SANS", "Process All"], False) self._call_settings_listeners(lambda listener: listener.on_process_all_clicked()) def _load_clicked(self): - UsageService.registerFeatureUsage(FeatureType.Feature, ["ISIS SANS","Load"], False) + UsageService.registerFeatureUsage(FeatureType.Feature, ["ISIS SANS", "Load"], False) self._call_settings_listeners(lambda listener: listener.on_load_clicked()) def _export_table_clicked(self): - UsageService.registerFeatureUsage(FeatureType.Feature, ["ISIS SANS","Export Table"], False) + UsageService.registerFeatureUsage(FeatureType.Feature, ["ISIS SANS", "Export Table"], False) self._call_settings_listeners(lambda listener: listener.on_export_table_clicked()) def _processing_finished(self): @@ -543,23 +543,23 @@ class SANSDataProcessorGui(QMainWindow, def _remove_rows_requested_from_button(self): rows = self.get_selected_rows() - UsageService.registerFeatureUsage(FeatureType.Feature, ["ISIS SANS","Rows removed button"], False) + UsageService.registerFeatureUsage(FeatureType.Feature, ["ISIS SANS", "Rows removed button"], False) self._call_settings_listeners(lambda listener: listener.on_rows_removed(rows)) def _copy_rows_requested(self): - UsageService.registerFeatureUsage(FeatureType.Feature, ["ISIS SANS","Copy rows button"], False) + UsageService.registerFeatureUsage(FeatureType.Feature, ["ISIS SANS", "Copy rows button"], False) self._call_settings_listeners(lambda listener: listener.on_copy_rows_requested()) def _erase_rows(self): - UsageService.registerFeatureUsage(FeatureType.Feature, ["ISIS SANS","Erase rows button"], False) + UsageService.registerFeatureUsage(FeatureType.Feature, ["ISIS SANS", "Erase rows button"], False) self._call_settings_listeners(lambda listener: listener.on_erase_rows()) def _cut_rows(self): - UsageService.registerFeatureUsage(FeatureType.Feature, ["ISIS SANS","Cut rows button"], False) + UsageService.registerFeatureUsage(FeatureType.Feature, ["ISIS SANS", "Cut rows button"], False) self._call_settings_listeners(lambda listener: listener.on_cut_rows()) def _paste_rows_requested(self): - UsageService.registerFeatureUsage(FeatureType.Feature, ["ISIS SANS","Paste rows button"], False) + UsageService.registerFeatureUsage(FeatureType.Feature, ["ISIS SANS", "Paste rows button"], False) self._call_settings_listeners(lambda listener: listener.on_paste_rows_requested()) def _instrument_changed(self): @@ -596,7 +596,7 @@ class SANSDataProcessorGui(QMainWindow, def _on_save_can_clicked(self, value): self.save_can_checkBox.setChecked(value) - UsageService.registerFeatureUsage(FeatureType.Feature, ["ISIS SANS","Save Can Toggled"], False) + UsageService.registerFeatureUsage(FeatureType.Feature, ["ISIS SANS", "Save Can Toggled"], False) set_setting(self.__generic_settings, self.__save_can_key, value) def _on_reduction_dimensionality_changed(self, is_1d): @@ -632,8 +632,8 @@ class SANSDataProcessorGui(QMainWindow, def set_out_default_output_mode(self): try: - default_output_mode = OutputMode.from_string(load_property(self.__generic_settings, self.__output_mode_key)) - except RuntimeError: + default_output_mode = OutputMode(load_property(self.__generic_settings, self.__output_mode_key)) + except ValueError: pass else: self._check_output_mode(default_output_mode) @@ -646,11 +646,11 @@ class SANSDataProcessorGui(QMainWindow, 2. Via the presenter, from state :param value: An OutputMode (SANS enum) object """ - if value is OutputMode.PublishToADS: + if value is OutputMode.PUBLISH_TO_ADS: self.output_mode_memory_radio_button.setChecked(True) - elif value is OutputMode.SaveToFile: + elif value is OutputMode.SAVE_TO_FILE: self.output_mode_file_radio_button.setChecked(True) - elif value is OutputMode.Both: + elif value is OutputMode.BOTH: self.output_mode_both_radio_button.setChecked(True) # Notify the presenter @@ -669,7 +669,7 @@ class SANSDataProcessorGui(QMainWindow, Load the batch file """ - UsageService.registerFeatureUsage(FeatureType.Feature, ["ISIS SANS","Loaded Batch File"], False) + UsageService.registerFeatureUsage(FeatureType.Feature, ["ISIS SANS", "Loaded Batch File"], False) load_file(self.batch_line_edit, "*.*", self.__generic_settings, self.__batch_file_key, self.get_batch_file_path) self._call_settings_listeners(lambda listener: listener.on_batch_file_load()) @@ -775,7 +775,11 @@ class SANSDataProcessorGui(QMainWindow, def _on_reduction_mode_selection_has_changed(self): selection = self.reduction_mode_combo_box.currentText() - is_merged = selection == ReductionMode.to_string(ReductionMode.Merged) + try: + is_merged = ReductionMode(selection) is ReductionMode.MERGED + except ValueError: + is_merged = False + self.merged_settings.setEnabled(is_merged) self._call_settings_listeners(lambda listener: listener.on_reduction_mode_selection_has_changed(selection)) @@ -864,15 +868,15 @@ class SANSDataProcessorGui(QMainWindow, def _on_transmission_fit_type_has_changed(self): # Check the sample settings - fit_type_sample_as_string = self.fit_sample_fit_type_combo_box.currentText().encode('utf-8') - fit_type_sample = FitType.from_string(fit_type_sample_as_string) - is_sample_polynomial = fit_type_sample is FitType.Polynomial + fit_type_sample_as_string = self.fit_sample_fit_type_combo_box.currentText() + fit_type_sample = FitType(fit_type_sample_as_string) + is_sample_polynomial = fit_type_sample is FitType.POLYNOMIAL self.fit_sample_polynomial_order_spin_box.setEnabled(is_sample_polynomial) # Check the can settings - fit_type_can_as_string = self.fit_can_fit_type_combo_box.currentText().encode('utf-8') - fit_type_can = FitType.from_string(fit_type_can_as_string) - is_can_polynomial = fit_type_can is FitType.Polynomial + fit_type_can_as_string = self.fit_can_fit_type_combo_box.currentText() + fit_type_can = FitType(fit_type_can_as_string) + is_can_polynomial = fit_type_can is FitType.POLYNOMIAL self.fit_can_polynomial_order_spin_box.setEnabled(is_can_polynomial) def _on_transmission_target_has_changed(self): @@ -932,12 +936,12 @@ class SANSDataProcessorGui(QMainWindow, self._call_settings_listeners(lambda listener: listener.on_mask_file_add()) def _on_multi_period_selection(self): - UsageService.registerFeatureUsage(FeatureType.Feature, ["ISIS SANS","Multiple Period Toggled"], False) + UsageService.registerFeatureUsage(FeatureType.Feature, ["ISIS SANS", "Multiple Period Toggled"], False) self._call_settings_listeners( lambda listener: listener.on_multi_period_selection(self.is_multi_period_view())) def _on_sample_geometry_selection(self): - UsageService.registerFeatureUsage(FeatureType.Feature, ["ISIS SANS","Sample Geometry Toggled"], False) + UsageService.registerFeatureUsage(FeatureType.Feature, ["ISIS SANS", "Sample Geometry Toggled"], False) self._call_settings_listeners(lambda listener: listener.on_sample_geometry_selection(self.is_sample_geometry())) def _on_manage_directories(self): @@ -960,11 +964,9 @@ class SANSDataProcessorGui(QMainWindow, if isinstance(value, list): gui_element.clear() for element in value: - self._add_list_element_to_combo_box(gui_element=gui_element, element=element, - expected_type=expected_type) - elif expected_type.has_member(value): - self._set_enum_as_element_in_combo_box(gui_element=gui_element, element=value, - expected_type=expected_type) + self._add_list_element_to_combo_box(gui_element=gui_element, element=element) + elif isinstance(value, Enum): + self._set_enum_as_element_in_combo_box(gui_element=gui_element, element=value) elif isinstance(value, str): index = gui_element.findText(value) if index != -1: @@ -972,22 +974,21 @@ class SANSDataProcessorGui(QMainWindow, else: raise RuntimeError("Expected an input of type {}, but got {}".format(expected_type, type(value))) - def _add_list_element_to_combo_box(self, gui_element, element, expected_type=None): - if expected_type is not None and isclass(element) and issubclass(element, expected_type): - self._add_enum_as_element_in_combo_box(gui_element=gui_element, element=element, - expected_type=expected_type) + def _add_list_element_to_combo_box(self, gui_element, element): + if isinstance(element, Enum): + self._add_enum_as_element_in_combo_box(gui_element=gui_element, element=element) else: gui_element.addItem(element) @staticmethod - def _set_enum_as_element_in_combo_box(gui_element, element, expected_type): - value_as_string = expected_type.to_string(element) + def _set_enum_as_element_in_combo_box(gui_element, element): + value_as_string = element.value index = gui_element.findText(value_as_string) if index != -1: gui_element.setCurrentIndex(index) - def _add_enum_as_element_in_combo_box(self, gui_element, element, expected_type): - value_as_string = expected_type.to_string(element) + def _add_enum_as_element_in_combo_box(self, gui_element, element): + value_as_string = element.value gui_element.addItem(value_as_string) def get_simple_line_edit_field(self, expected_type, line_edit): @@ -1031,22 +1032,22 @@ class SANSDataProcessorGui(QMainWindow, def save_types(self): checked_save_types = [] if self.can_sas_checkbox.isChecked(): - checked_save_types.append(SaveType.CanSAS) + checked_save_types.append(SaveType.CAN_SAS) if self.nx_can_sas_checkbox.isChecked(): - checked_save_types.append(SaveType.NXcanSAS) + checked_save_types.append(SaveType.NX_CAN_SAS) if self.rkh_checkbox.isChecked(): checked_save_types.append(SaveType.RKH) # If empty, provide the NoType type if not checked_save_types: - checked_save_types = [SaveType.NoType] + checked_save_types = [SaveType.NO_TYPE] return checked_save_types @save_types.setter def save_types(self, values): for value in values: - if value is SaveType.CanSAS: + if value is SaveType.CAN_SAS: self.can_sas_checkbox.setChecked(True) - elif value is SaveType.NXcanSAS: + elif value is SaveType.NX_CAN_SAS: self.nx_can_sas_checkbox.setChecked(True) elif value is SaveType.RKH: self.rkh_checkbox.setChecked(True) @@ -1113,21 +1114,21 @@ class SANSDataProcessorGui(QMainWindow, @property def output_mode(self): if self.output_mode_memory_radio_button.isChecked(): - return OutputMode.PublishToADS + return OutputMode.PUBLISH_TO_ADS elif self.output_mode_file_radio_button.isChecked(): - return OutputMode.SaveToFile + return OutputMode.SAVE_TO_FILE elif self.output_mode_both_radio_button.isChecked(): - return OutputMode.Both + return OutputMode.BOTH else: self.gui_logger.warning( "The output format was not specified. Defaulting to saving to memory only.") - return OutputMode.PublishToADS + return OutputMode.PUBLISH_TO_ADS @output_mode.setter def output_mode(self, value): self._check_output_mode(value) try: - set_setting(self.__generic_settings, self.__output_mode_key, OutputMode.to_string(value)) + set_setting(self.__generic_settings, self.__output_mode_key, value.value) except RuntimeError: pass @@ -1153,7 +1154,8 @@ class SANSDataProcessorGui(QMainWindow, @instrument.setter def instrument(self, value): - instrument_string = SANSInstrument.to_string(value) + assert (isinstance(value, Enum)) + instrument_string = value.value self.instrument_type.setText("{}".format(instrument_string)) self._instrument_changed() @@ -1168,12 +1170,12 @@ class SANSDataProcessorGui(QMainWindow, # ------------------------------------------------------------------------------------------------------------------ @property def reduction_dimensionality(self): - return ReductionDimensionality.OneDim if self.reduction_dimensionality_1D.isChecked() \ - else ReductionDimensionality.TwoDim + return ReductionDimensionality.ONE_DIM if self.reduction_dimensionality_1D.isChecked() \ + else ReductionDimensionality.TWO_DIM @reduction_dimensionality.setter def reduction_dimensionality(self, value): - is_1d = value is ReductionDimensionality.OneDim + is_1d = value is ReductionDimensionality.ONE_DIM self.reduction_dimensionality_1D.setChecked(is_1d) self.reduction_dimensionality_2D.setChecked(not is_1d) @@ -1319,8 +1321,8 @@ class SANSDataProcessorGui(QMainWindow, # ------------------------------------------------------------------------------------------------------------------ @property def wavelength_step_type(self): - step_type_as_string = self.wavelength_step_type_combo_box.currentText().encode('utf-8') - return RangeStepType.from_string(step_type_as_string) + step_type_as_string = self.wavelength_step_type_combo_box.currentText() + return RangeStepType(step_type_as_string) @wavelength_step_type.setter def wavelength_step_type(self, value): @@ -1505,8 +1507,8 @@ class SANSDataProcessorGui(QMainWindow, @property def transmission_sample_fit_type(self): - fit_type_as_string = self.fit_sample_fit_type_combo_box.currentText().encode('utf-8') - return FitType.from_string(fit_type_as_string) + fit_type_as_string = self.fit_sample_fit_type_combo_box.currentText() + return FitType(fit_type_as_string) @transmission_sample_fit_type.setter def transmission_sample_fit_type(self, value): @@ -1518,8 +1520,8 @@ class SANSDataProcessorGui(QMainWindow, @property def transmission_can_fit_type(self): - fit_type_as_string = self.fit_can_fit_type_combo_box.currentText().encode('utf-8') - return FitType.from_string(fit_type_as_string) + fit_type_as_string = self.fit_can_fit_type_combo_box.currentText() + return FitType(fit_type_as_string) @transmission_can_fit_type.setter def transmission_can_fit_type(self, value): @@ -1671,7 +1673,7 @@ class SANSDataProcessorGui(QMainWindow, try: value = float(value_as_string) except ValueError: - value = value_as_string.encode('utf-8') + value = value_as_string return value @q_1d_min_or_rebin_string.setter @@ -1696,10 +1698,10 @@ class SANSDataProcessorGui(QMainWindow, @property def q_1d_step_type(self): - q_1d_step_type_as_string = self.q_1d_step_type_combo_box.currentText().encode('utf-8') + q_1d_step_type_as_string = self.q_1d_step_type_combo_box.currentText() # Hedge for trying to read out try: - return RangeStepType.from_string(q_1d_step_type_as_string) + return RangeStepType(q_1d_step_type_as_string) except RuntimeError: return None @@ -1717,8 +1719,7 @@ class SANSDataProcessorGui(QMainWindow, gui_element.clear() value.append(self.VARIABLE) for element in value: - self._add_list_element_to_combo_box(gui_element=gui_element, element=element, - expected_type=RangeStepType) + self._add_list_element_to_combo_box(gui_element=gui_element, element=element) @property def q_xy_max(self): @@ -1738,9 +1739,9 @@ class SANSDataProcessorGui(QMainWindow, @property def q_xy_step_type(self): - q_xy_step_type_as_string = self.q_xy_step_type_combo_box.currentText().encode('utf-8') + q_xy_step_type_as_string = self.q_xy_step_type_combo_box.currentText() try: - return RangeStepType.from_string(q_xy_step_type_as_string) + return RangeStepType(q_xy_step_type_as_string) except RuntimeError: return None @@ -1757,8 +1758,7 @@ class SANSDataProcessorGui(QMainWindow, gui_element = self.q_xy_step_type_combo_box gui_element.clear() for element in value: - self._add_list_element_to_combo_box(gui_element=gui_element, element=element, - expected_type=RangeStepType) + self._add_list_element_to_combo_box(gui_element=gui_element, element=element) # ------------------------------------------------------------------------------------------------------------------ # Gravity diff --git a/scripts/Interface/ui/sans_isis/settings_diagnostic_tab.py b/scripts/Interface/ui/sans_isis/settings_diagnostic_tab.py index 32a2880d0e879773e143df4e0cdc0ffabba70f89..0a88c566e8c6ba9126618c054c0d63b0d357b771 100644 --- a/scripts/Interface/ui/sans_isis/settings_diagnostic_tab.py +++ b/scripts/Interface/ui/sans_isis/settings_diagnostic_tab.py @@ -23,6 +23,7 @@ from mantidqt.utils.qt import load_ui from mantid import UsageService from mantid.kernel import FeatureType from sans.gui_logic.gui_common import (GENERIC_SETTINGS, JSON_SUFFIX, load_file) +from sans.state.state_base import ENUM_TYPE_TAG if PY3: unicode = str @@ -65,7 +66,6 @@ class SettingsDiagnosticTab(QtWidgets.QWidget, Ui_SettingsDiagnosticTab): # Excluded settings entries self.excluded = ["state_module", "state_name"] - self.class_type_id = "ClassTypeParameter" # Q Settings self.__generic_settings = GENERIC_SETTINGS @@ -155,7 +155,8 @@ class SettingsDiagnosticTab(QtWidgets.QWidget, Ui_SettingsDiagnosticTab): item.addChild(child) def clean_class_type(self, value): - if isinstance(value, str) and self.class_type_id in value: + # TODO the UI should not be doing logic like this + if isinstance(value, str) and ENUM_TYPE_TAG in value: # Only the last element is of interest split_values = value.split("#") return split_values[-1] diff --git a/scripts/Interface/ui/sans_isis/summation_settings_widget.py b/scripts/Interface/ui/sans_isis/summation_settings_widget.py index 573c5f92eba45e47e88c9f6fa32655c7f95615bd..35a174614f3db9fa2e1fa1469980fc910a68d8b3 100644 --- a/scripts/Interface/ui/sans_isis/summation_settings_widget.py +++ b/scripts/Interface/ui/sans_isis/summation_settings_widget.py @@ -53,11 +53,11 @@ class SummationSettingsWidget(QtWidgets.QWidget, Ui_SummationSettingsWidget): @staticmethod def _binning_type_to_index(bin_type): - if bin_type == BinningType.Custom: + if bin_type == BinningType.CUSTOM: return 0 - elif bin_type == BinningType.FromMonitors: + elif bin_type == BinningType.FROM_MONITORS: return 1 - elif bin_type == BinningType.SaveAsEventData: + elif bin_type == BinningType.SAVE_AS_EVENT_DATA: return 2 def _handle_binning_type_changed(self, index): diff --git a/scripts/SANS/sans/algorithm_detail/CreateSANSWavelengthPixelAdjustment.py b/scripts/SANS/sans/algorithm_detail/CreateSANSWavelengthPixelAdjustment.py index e39b762b295a64357aff6ebc3fd290da90b9e0df..98ef13f1568f1dcf31e202396494e9a00a032530 100644 --- a/scripts/SANS/sans/algorithm_detail/CreateSANSWavelengthPixelAdjustment.py +++ b/scripts/SANS/sans/algorithm_detail/CreateSANSWavelengthPixelAdjustment.py @@ -160,7 +160,7 @@ class CreateSANSWavelengthPixelAdjustment(object): # Crop to the required detector crop_name = "CropToComponent" - component_to_crop = DetectorType.from_string(component) + component_to_crop = DetectorType(component) component_to_crop = get_component_name(output_workspace, component_to_crop) crop_options = {"InputWorkspace": output_workspace, "OutputWorkspace": EMPTY_NAME, @@ -179,7 +179,7 @@ class CreateSANSWavelengthPixelAdjustment(object): wavelength_high = wavelength_and_pixel_adjustment_state.wavelength_high[0] wavelength_step = wavelength_and_pixel_adjustment_state.wavelength_step wavelength_step_type = -1.0 if wavelength_and_pixel_adjustment_state.wavelength_step_type \ - is RangeStepType.Log else 1.0 # noqa + is RangeStepType.LOG else 1.0 # noqa # Create a rebin string from the wavelength information wavelength_step *= wavelength_step_type diff --git a/scripts/SANS/sans/algorithm_detail/batch_execution.py b/scripts/SANS/sans/algorithm_detail/batch_execution.py index 2647e94808d8952bfa2fe260ec07f2f0f06a7d70..12fd96e3dfc1cb9b7897ae9bafc41b4033fb2376 100644 --- a/scripts/SANS/sans/algorithm_detail/batch_execution.py +++ b/scripts/SANS/sans/algorithm_detail/batch_execution.py @@ -12,7 +12,7 @@ from mantid.api import AnalysisDataService, WorkspaceGroup from sans.common.general_functions import (add_to_sample_log, create_managed_non_child_algorithm, create_unmanaged_algorithm, get_output_name, get_base_name_from_multi_period_name, get_transmission_output_name) -from sans.common.enums import (SANSDataType, SaveType, OutputMode, ISISReductionMode, DataType) +from sans.common.enums import (SANSDataType, SaveType, OutputMode, ReductionMode, DataType) from sans.common.constants import (TRANS_SUFFIX, SANS_SUFFIX, ALL_PERIODS, LAB_CAN_SUFFIX, LAB_CAN_COUNT_SUFFIX, LAB_CAN_NORM_SUFFIX, HAB_CAN_SUFFIX, HAB_CAN_COUNT_SUFFIX, HAB_CAN_NORM_SUFFIX, @@ -75,15 +75,15 @@ def single_reduction_for_batch(state, use_optimizations, output_mode, plot_resul # ------------------------------------------------------------------------------------------------------------------ # Load the data # ------------------------------------------------------------------------------------------------------------------ - workspace_to_name = {SANSDataType.SampleScatter: "SampleScatterWorkspace", - SANSDataType.SampleTransmission: "SampleTransmissionWorkspace", - SANSDataType.SampleDirect: "SampleDirectWorkspace", - SANSDataType.CanScatter: "CanScatterWorkspace", - SANSDataType.CanTransmission: "CanTransmissionWorkspace", - SANSDataType.CanDirect: "CanDirectWorkspace"} + workspace_to_name = {SANSDataType.SAMPLE_SCATTER: "SampleScatterWorkspace", + SANSDataType.SAMPLE_TRANSMISSION: "SampleTransmissionWorkspace", + SANSDataType.SAMPLE_DIRECT: "SampleDirectWorkspace", + SANSDataType.CAN_SCATTER: "CanScatterWorkspace", + SANSDataType.CAN_TRANSMISSION: "CanTransmissionWorkspace", + SANSDataType.CAN_DIRECT: "CanDirectWorkspace"} - workspace_to_monitor = {SANSDataType.SampleScatter: "SampleScatterMonitorWorkspace", - SANSDataType.CanScatter: "CanScatterMonitorWorkspace"} + workspace_to_monitor = {SANSDataType.SAMPLE_SCATTER: "SampleScatterMonitorWorkspace", + SANSDataType.CAN_SCATTER: "CanScatterMonitorWorkspace"} workspaces, monitors = provide_loaded_data(state, use_optimizations, workspace_to_name, workspace_to_monitor) @@ -191,10 +191,10 @@ def single_reduction_for_batch(state, use_optimizations, output_mode, plot_resul # 3. Both: # * This means that we need to save out the reduced data # * The data is already on the ADS, so do nothing - if output_mode is OutputMode.SaveToFile: + if output_mode is OutputMode.SAVE_TO_FILE: save_to_file(reduction_packages, save_can, additional_run_numbers, event_slice_optimisation=event_slice_optimisation) delete_reduced_workspaces(reduction_packages) - elif output_mode is OutputMode.Both: + elif output_mode is OutputMode.BOTH: save_to_file(reduction_packages, save_can, additional_run_numbers, event_slice_optimisation=event_slice_optimisation) # ----------------------------------------------------------------------- @@ -213,15 +213,15 @@ def single_reduction_for_batch(state, use_optimizations, output_mode, plot_resul def load_workspaces_from_states(state): - workspace_to_name = {SANSDataType.SampleScatter: "SampleScatterWorkspace", - SANSDataType.SampleTransmission: "SampleTransmissionWorkspace", - SANSDataType.SampleDirect: "SampleDirectWorkspace", - SANSDataType.CanScatter: "CanScatterWorkspace", - SANSDataType.CanTransmission: "CanTransmissionWorkspace", - SANSDataType.CanDirect: "CanDirectWorkspace"} + workspace_to_name = {SANSDataType.SAMPLE_SCATTER: "SampleScatterWorkspace", + SANSDataType.SAMPLE_TRANSMISSION: "SampleTransmissionWorkspace", + SANSDataType.SAMPLE_DIRECT: "SampleDirectWorkspace", + SANSDataType.CAN_SCATTER: "CanScatterWorkspace", + SANSDataType.CAN_TRANSMISSION: "CanTransmissionWorkspace", + SANSDataType.CAN_DIRECT: "CanDirectWorkspace"} - workspace_to_monitor = {SANSDataType.SampleScatter: "SampleScatterMonitorWorkspace", - SANSDataType.CanScatter: "CanScatterMonitorWorkspace"} + workspace_to_monitor = {SANSDataType.SAMPLE_SCATTER: "SampleScatterMonitorWorkspace", + SANSDataType.CAN_SCATTER: "CanScatterMonitorWorkspace"} workspaces, monitors = provide_loaded_data(state, True, workspace_to_name, workspace_to_monitor) @@ -254,19 +254,19 @@ def plot_workspace_mantidplot(reduction_package, output_graph, plotting_module): """ plotSpectrum, graph = plotting_module.plotSpectrum, plotting_module.graph - if reduction_package.reduction_mode == ISISReductionMode.All: + if reduction_package.reduction_mode == ReductionMode.ALL: graph_handle = plotSpectrum([reduction_package.reduced_hab, reduction_package.reduced_lab], 0, window=graph(output_graph), clearWindow=True) graph_handle.activeLayer().logLogAxes() - elif reduction_package.reduction_mode == ISISReductionMode.HAB: + elif reduction_package.reduction_mode == ReductionMode.HAB: graph_handle = plotSpectrum(reduction_package.reduced_hab, 0, window=graph(output_graph), clearWindow=True) graph_handle.activeLayer().logLogAxes() - elif reduction_package.reduction_mode == ISISReductionMode.LAB: + elif reduction_package.reduction_mode == ReductionMode.LAB: graph_handle = plotSpectrum(reduction_package.reduced_lab, 0, window=graph(output_graph), clearWindow=True) graph_handle.activeLayer().logLogAxes() - elif reduction_package.reduction_mode == ISISReductionMode.Merged: + elif reduction_package.reduction_mode == ReductionMode.MERGED: graph_handle = plotSpectrum([reduction_package.reduced_merged, - reduction_package.reduced_hab, reduction_package.reduced_lab], 0, + reduction_package.reduced_hab, reduction_package.reduced_lab], 0, window=graph(output_graph), clearWindow=True) graph_handle.activeLayer().logLogAxes() @@ -283,16 +283,16 @@ def plot_workspace_mantidqt(reduction_package, output_graph, plotting_module): plot_kwargs = {"scalex": True, "scaley": True} - if reduction_package.reduction_mode == ISISReductionMode.All: + if reduction_package.reduction_mode == ReductionMode.ALL: plot([reduction_package.reduced_hab, reduction_package.reduced_lab], wksp_indices=[0], overplot=True, fig=output_graph, plot_kwargs=plot_kwargs) - elif reduction_package.reduction_mode == ISISReductionMode.HAB: + elif reduction_package.reduction_mode == ReductionMode.HAB: plot([reduction_package.reduced_hab], wksp_indices=[0], overplot=True, fig=output_graph, plot_kwargs=plot_kwargs) - elif reduction_package.reduction_mode == ISISReductionMode.LAB: + elif reduction_package.reduction_mode == ReductionMode.LAB: plot([reduction_package.reduced_lab], wksp_indices=[0], overplot=True, fig=output_graph, plot_kwargs=plot_kwargs) - elif reduction_package.reduction_mode == ISISReductionMode.Merged: + elif reduction_package.reduction_mode == ReductionMode.MERGED: plot([reduction_package.reduced_merged, reduction_package.reduced_hab, reduction_package.reduced_lab], wksp_indices=[0], overplot=True, fig=output_graph, plot_kwargs=plot_kwargs) @@ -450,18 +450,19 @@ def provide_loaded_data(state, use_optimizations, workspace_to_name, workspace_t "UseCached": use_optimizations} # Set the output workspaces + set_output_workspaces_on_load_algorithm(load_options, state) load_alg = create_managed_non_child_algorithm(load_name, **load_options) load_alg.execute() # Retrieve the data - workspace_to_count = {SANSDataType.SampleScatter: "NumberOfSampleScatterWorkspaces", - SANSDataType.SampleTransmission: "NumberOfSampleTransmissionWorkspaces", - SANSDataType.SampleDirect: "NumberOfSampleDirectWorkspaces", - SANSDataType.CanScatter: "NumberOfCanScatterWorkspaces", - SANSDataType.CanTransmission: "NumberOfCanTransmissionWorkspaces", - SANSDataType.CanDirect: "NumberOfCanDirectWorkspaces"} + workspace_to_count = {SANSDataType.SAMPLE_SCATTER: "NumberOfSampleScatterWorkspaces", + SANSDataType.SAMPLE_TRANSMISSION: "NumberOfSampleTransmissionWorkspaces", + SANSDataType.SAMPLE_DIRECT: "NumberOfSampleDirectWorkspaces", + SANSDataType.CAN_SCATTER: "NumberOfCanScatterWorkspaces", + SANSDataType.CAN_TRANSMISSION: "NumberOfCanTransmissionWorkspaces", + SANSDataType.CAN_DIRECT: "NumberOfCanDirectWorkspaces"} workspaces = get_workspaces_from_load_algorithm(load_alg, workspace_to_count, workspace_to_name) monitors = get_workspaces_from_load_algorithm(load_alg, workspace_to_count, workspace_to_monitor) @@ -716,12 +717,12 @@ def create_initial_reduction_packages(state, workspaces, monitors): data_info = state.data sample_scatter_period = data_info.sample_scatter_period - requires_new_period_selection = len(workspaces[SANSDataType.SampleScatter]) > 1 \ + requires_new_period_selection = len(workspaces[SANSDataType.SAMPLE_SCATTER]) > 1 \ and sample_scatter_period == ALL_PERIODS # noqa - is_multi_period = len(workspaces[SANSDataType.SampleScatter]) > 1 + is_multi_period = len(workspaces[SANSDataType.SAMPLE_SCATTER]) > 1 - for index in range(0, len(workspaces[SANSDataType.SampleScatter])): + for index in range(0, len(workspaces[SANSDataType.SAMPLE_SCATTER])): workspaces_for_package = {} # For each workspace type, i.e sample scatter, can transmission, etc. find the correct workspace for workspace_type, workspace_list in list(workspaces.items()): @@ -805,41 +806,41 @@ def set_properties_for_reduction_algorithm(reduction_alg, reduction_package, wor setattr(reduction_package, package_attribute_name_base, workspace_name_base) def _set_lab(_reduction_alg, _reduction_package, _is_group): - _set_output_name(_reduction_alg, _reduction_package, _is_group, ISISReductionMode.LAB, + _set_output_name(_reduction_alg, _reduction_package, _is_group, ReductionMode.LAB, "OutputWorkspaceLABCan", "reduced_lab_can_name", "reduced_lab_can_base_name", multi_reduction_type, LAB_CAN_SUFFIX) # Lab Can Count workspace - this is a partial workspace - _set_output_name(_reduction_alg, _reduction_package, _is_group, ISISReductionMode.LAB, + _set_output_name(_reduction_alg, _reduction_package, _is_group, ReductionMode.LAB, "OutputWorkspaceLABCanCount", "reduced_lab_can_count_name", "reduced_lab_can_count_base_name", multi_reduction_type, LAB_CAN_COUNT_SUFFIX) # Lab Can Norm workspace - this is a partial workspace - _set_output_name(_reduction_alg, _reduction_package, _is_group, ISISReductionMode.LAB, + _set_output_name(_reduction_alg, _reduction_package, _is_group, ReductionMode.LAB, "OutputWorkspaceLABCanNorm", "reduced_lab_can_norm_name", "reduced_lab_can_norm_base_name", multi_reduction_type, LAB_CAN_NORM_SUFFIX) - _set_output_name(_reduction_alg, _reduction_package, _is_group, ISISReductionMode.LAB, + _set_output_name(_reduction_alg, _reduction_package, _is_group, ReductionMode.LAB, "OutputWorkspaceLABSample", "reduced_lab_sample_name", "reduced_lab_sample_base_name", multi_reduction_type, LAB_SAMPLE_SUFFIX) def _set_hab(_reduction_alg, _reduction_package, _is_group): # Hab Can Workspace - _set_output_name(_reduction_alg, _reduction_package, _is_group, ISISReductionMode.HAB, + _set_output_name(_reduction_alg, _reduction_package, _is_group, ReductionMode.HAB, "OutputWorkspaceHABCan", "reduced_hab_can_name", "reduced_hab_can_base_name", multi_reduction_type, HAB_CAN_SUFFIX) # Hab Can Count workspace - this is a partial workspace - _set_output_name(_reduction_alg, _reduction_package, _is_group, ISISReductionMode.HAB, + _set_output_name(_reduction_alg, _reduction_package, _is_group, ReductionMode.HAB, "OutputWorkspaceHABCanCount", "reduced_hab_can_count_name", "reduced_hab_can_count_base_name", multi_reduction_type, HAB_CAN_COUNT_SUFFIX) # Hab Can Norm workspace - this is a partial workspace - _set_output_name(_reduction_alg, _reduction_package, _is_group, ISISReductionMode.HAB, + _set_output_name(_reduction_alg, _reduction_package, _is_group, ReductionMode.HAB, "OutputWorkspaceHABCanNorm", "reduced_hab_can_norm_name", "reduced_hab_can_norm_base_name", multi_reduction_type, HAB_CAN_NORM_SUFFIX) - _set_output_name(_reduction_alg, _reduction_package, _is_group, ISISReductionMode.HAB, + _set_output_name(_reduction_alg, _reduction_package, _is_group, ReductionMode.HAB, "OutputWorkspaceHABSample", "reduced_hab_sample_name", "reduced_hab_sample_base_name", multi_reduction_type, HAB_SAMPLE_SUFFIX) @@ -883,24 +884,24 @@ def set_properties_for_reduction_algorithm(reduction_alg, reduction_package, wor reduction_alg.setProperty("WavelengthRange", is_part_of_wavelength_range_reduction) reduction_mode = reduction_package.reduction_mode - if reduction_mode is ISISReductionMode.Merged: - _set_output_name(reduction_alg, reduction_package, is_group, ISISReductionMode.Merged, + if reduction_mode is ReductionMode.MERGED: + _set_output_name(reduction_alg, reduction_package, is_group, ReductionMode.MERGED, "OutputWorkspaceMerged", "reduced_merged_name", "reduced_merged_base_name", multi_reduction_type) - _set_output_name(reduction_alg, reduction_package, is_group, ISISReductionMode.LAB, + _set_output_name(reduction_alg, reduction_package, is_group, ReductionMode.LAB, "OutputWorkspaceLAB", "reduced_lab_name", "reduced_lab_base_name", multi_reduction_type) - _set_output_name(reduction_alg, reduction_package, is_group, ISISReductionMode.HAB, + _set_output_name(reduction_alg, reduction_package, is_group, ReductionMode.HAB, "OutputWorkspaceHAB", "reduced_hab_name", "reduced_hab_base_name", multi_reduction_type) - elif reduction_mode is ISISReductionMode.LAB: - _set_output_name(reduction_alg, reduction_package, is_group, ISISReductionMode.LAB, + elif reduction_mode is ReductionMode.LAB: + _set_output_name(reduction_alg, reduction_package, is_group, ReductionMode.LAB, "OutputWorkspaceLAB", "reduced_lab_name", "reduced_lab_base_name", multi_reduction_type) - elif reduction_mode is ISISReductionMode.HAB: - _set_output_name(reduction_alg, reduction_package, is_group, ISISReductionMode.HAB, + elif reduction_mode is ReductionMode.HAB: + _set_output_name(reduction_alg, reduction_package, is_group, ReductionMode.HAB, "OutputWorkspaceHAB", "reduced_hab_name", "reduced_hab_base_name", multi_reduction_type) - elif reduction_mode is ISISReductionMode.All: - _set_output_name(reduction_alg, reduction_package, is_group, ISISReductionMode.LAB, + elif reduction_mode is ReductionMode.ALL: + _set_output_name(reduction_alg, reduction_package, is_group, ReductionMode.LAB, "OutputWorkspaceLAB", "reduced_lab_name", "reduced_lab_base_name", multi_reduction_type) - _set_output_name(reduction_alg, reduction_package, is_group, ISISReductionMode.HAB, + _set_output_name(reduction_alg, reduction_package, is_group, ReductionMode.HAB, "OutputWorkspaceHAB", "reduced_hab_name", "reduced_hab_base_name", multi_reduction_type) else: raise RuntimeError("The reduction mode {0} is not known".format(reduction_mode)) @@ -911,14 +912,14 @@ def set_properties_for_reduction_algorithm(reduction_alg, reduction_package, wor # Set the output workspaces for the can reductions -- note that these will only be set if optimizations # are enabled # Lab Can Workspace - if reduction_mode is ISISReductionMode.Merged: + if reduction_mode is ReductionMode.MERGED: _set_lab(reduction_alg, reduction_package, is_group) _set_hab(reduction_alg, reduction_package, is_group) - elif reduction_mode is ISISReductionMode.LAB: + elif reduction_mode is ReductionMode.LAB: _set_lab(reduction_alg, reduction_package, is_group) - elif reduction_mode is ISISReductionMode.HAB: + elif reduction_mode is ReductionMode.HAB: _set_hab(reduction_alg, reduction_package, is_group) - elif reduction_mode is ISISReductionMode.All: + elif reduction_mode is ReductionMode.ALL: _set_lab(reduction_alg, reduction_package, is_group) _set_hab(reduction_alg, reduction_package, is_group) else: @@ -928,16 +929,16 @@ def set_properties_for_reduction_algorithm(reduction_alg, reduction_package, wor # Set the output workspaces for the calculated and unfitted transmission # ------------------------------------------------------------------------------------------------------------------ sample_calculated_transmission, \ - sample_calculated_transmission_base = get_transmission_output_name(reduction_package.state, DataType.Sample, + sample_calculated_transmission_base = get_transmission_output_name(reduction_package.state, DataType.SAMPLE, multi_reduction_type, True) can_calculated_transmission, \ - can_calculated_transmission_base = get_transmission_output_name(reduction_package.state, DataType.Can, + can_calculated_transmission_base = get_transmission_output_name(reduction_package.state, DataType.CAN, multi_reduction_type, True) sample_unfitted_transmission, \ - sample_unfitted_transmission_base = get_transmission_output_name(reduction_package.state, DataType.Sample, + sample_unfitted_transmission_base = get_transmission_output_name(reduction_package.state, DataType.SAMPLE, multi_reduction_type, False) can_unfitted_transmission, \ - can_unfitted_transmission_base = get_transmission_output_name(reduction_package.state, DataType.Can, + can_unfitted_transmission_base = get_transmission_output_name(reduction_package.state, DataType.CAN, multi_reduction_type, False) _set_output_name_from_string(reduction_alg, reduction_package, "OutputWorkspaceCalculatedTransmission", @@ -1052,7 +1053,7 @@ def group_workspaces_if_required(reduction_package, output_mode, save_can, event # Y | N | NO # N | N | YES - if save_can and output_mode is not OutputMode.SaveToFile: + if save_can and output_mode is not OutputMode.SAVE_TO_FILE: CAN_WORKSPACE_GROUP = CAN_AND_SAMPLE_WORKSPACE else: CAN_WORKSPACE_GROUP = CAN_COUNT_AND_NORM_FOR_OPTIMIZATION @@ -1438,13 +1439,13 @@ def save_workspace_to_file(workspace_name, file_formats, file_name, additional_r "TransmissionCan": transmission_can_name}) save_options.update(additional_run_numbers) - if SaveType.Nexus in file_formats: + if SaveType.NEXUS in file_formats: save_options.update({"Nexus": True}) - if SaveType.CanSAS in file_formats: + if SaveType.CAN_SAS in file_formats: save_options.update({"CanSAS": True}) - if SaveType.NXcanSAS in file_formats: + if SaveType.NX_CAN_SAS in file_formats: save_options.update({"NXcanSAS": True}) - if SaveType.NistQxy in file_formats: + if SaveType.NIST_QXY in file_formats: save_options.update({"NistQxy": True}) if SaveType.RKH in file_formats: save_options.update({"RKH": True}) @@ -1481,6 +1482,7 @@ class ReductionPackage(object): 7. The reduced can and the reduced partial can workspaces (non have to exist, this is only for optimizations) 8. The unfitted transmission workspaces """ + def __init__(self, state, workspaces, monitors, is_part_of_multi_period_reduction=False, is_part_of_event_slice_reduction=False, is_part_of_wavelength_range_reduction=False): super(ReductionPackage, self).__init__() diff --git a/scripts/SANS/sans/algorithm_detail/calculate_sans_transmission.py b/scripts/SANS/sans/algorithm_detail/calculate_sans_transmission.py index 49cfdb0978424b5c1ea7685fb7fa10a268cf2d5f..b602f55919957c070c157e3032749325f0c1ba60 100644 --- a/scripts/SANS/sans/algorithm_detail/calculate_sans_transmission.py +++ b/scripts/SANS/sans/algorithm_detail/calculate_sans_transmission.py @@ -12,7 +12,7 @@ from sans.algorithm_detail.calculate_transmission_helper import (get_detector_id apply_flat_background_correction_to_detectors, get_region_of_interest) from sans.common.constants import EMPTY_NAME -from sans.common.enums import (RangeStepType, RebinType, FitType, DataType) +from sans.common.enums import (RangeStepType, FitType, DataType) from sans.common.general_functions import create_unmanaged_algorithm @@ -53,12 +53,11 @@ def calculate_transmission(transmission_ws, direct_ws, state_adjustment_calculat # 2. Clean transmission data + data_type = DataType(data_type_str) transmission_ws = _get_corrected_wavelength_workspace(transmission_ws, all_detector_ids, - calculate_transmission_state, data_type_str) + calculate_transmission_state, data_type=data_type) direct_ws = _get_corrected_wavelength_workspace(direct_ws, all_detector_ids, - calculate_transmission_state, data_type_str) - - data_type = DataType.from_string(data_type_str) + calculate_transmission_state, data_type=data_type) # 3. Fit output_workspace, unfitted_transmission_workspace = \ @@ -91,7 +90,7 @@ def _perform_fit(transmission_workspace, direct_workspace, wavelength_high = calculate_transmission_state.wavelength_high[0] wavelength_step = calculate_transmission_state.wavelength_step wavelength_step_type = calculate_transmission_state.wavelength_step_type - prefix = 1.0 if wavelength_step_type is RangeStepType.Lin else -1.0 + prefix = 1.0 if wavelength_step_type is RangeStepType.LIN else -1.0 wavelength_step *= prefix rebin_params = str(wavelength_low) + "," + str(wavelength_step) + "," + str(wavelength_high) @@ -115,17 +114,17 @@ def _perform_fit(transmission_workspace, direct_workspace, raise RuntimeError("No transmission monitor has been provided.") # Get the fit setting for the correct data type, ie either for the Sample of the Can - fit_type = calculate_transmission_state.fit[DataType.to_string(data_type)].fit_type - if fit_type is FitType.Logarithmic: + fit_type = calculate_transmission_state.fit[data_type].fit_type + if fit_type is FitType.LOGARITHMIC: fit_string = "Log" - elif fit_type is FitType.Polynomial: + elif fit_type is FitType.POLYNOMIAL: fit_string = "Polynomial" else: fit_string = "Linear" trans_options.update({"FitMethod": fit_string}) - if fit_type is FitType.Polynomial: - polynomial_order = calculate_transmission_state.fit[DataType.to_string(data_type)].polynomial_order + if fit_type is FitType.POLYNOMIAL: + polynomial_order = calculate_transmission_state.fit[data_type].polynomial_order trans_options.update({"PolynomialOrder": polynomial_order}) trans_alg = create_unmanaged_algorithm(trans_name, **trans_options) @@ -144,7 +143,7 @@ def _perform_fit(transmission_workspace, direct_workspace, if unfitted_transmission_workspace: unfitted_transmission_workspace.setYUnitLabel(y_unit_label_transmission_ratio) - if fit_type is FitType.NoFit: + if fit_type is FitType.NO_FIT: output_workspace = unfitted_transmission_workspace else: output_workspace = fitted_transmission_workspace @@ -183,14 +182,14 @@ def _get_detector_ids_for_transmission_calculation(transmission_workspace, calcu return detector_ids_roi, detector_id_transmission_monitor, detector_id_default_transmission_monitor -def _get_corrected_wavelength_workspace(workspace, detector_ids, calculate_transmission_state, data_type_str): +def _get_corrected_wavelength_workspace(workspace, detector_ids, calculate_transmission_state, data_type): """ Performs a prompt peak correction, a background correction, converts to wavelength and rebins. :param workspace: the workspace which is being corrected. :param detector_ids: a list of relevant detector ids :param calculate_transmission_state: a SANSStateCalculateTransmission state - :param data_type_str The component of the instrument which is to be reduced + :param data_type The component of the instrument which is to be reduced :return: a corrected workspace. """ # Extract the relevant spectra. These include @@ -257,7 +256,7 @@ def _get_corrected_wavelength_workspace(workspace, detector_ids, calculate_trans wavelength_low = calculate_transmission_state.wavelength_full_range_low wavelength_high = calculate_transmission_state.wavelength_full_range_high else: - fit_state = calculate_transmission_state.fit[data_type_str] + fit_state = calculate_transmission_state.fit[data_type] wavelength_low = fit_state.wavelength_low if fit_state.wavelength_low \ else calculate_transmission_state.wavelength_low[0] wavelength_high = fit_state.wavelength_high if fit_state.wavelength_high \ @@ -272,8 +271,8 @@ def _get_corrected_wavelength_workspace(workspace, detector_ids, calculate_trans "WavelengthLow": wavelength_low, "WavelengthHigh": wavelength_high, "WavelengthStep": wavelength_step, - "WavelengthStepType": RangeStepType.to_string(wavelength_step_type), - "RebinMode": RebinType.to_string(rebin_type)} + "WavelengthStepType": wavelength_step_type.value, + "RebinMode": rebin_type.value} convert_alg = create_unmanaged_algorithm(convert_name, **convert_options) convert_alg.setPropertyValue("OutputWorkspace", EMPTY_NAME) convert_alg.setProperty("OutputWorkspace", workspace) diff --git a/scripts/SANS/sans/algorithm_detail/calibration.py b/scripts/SANS/sans/algorithm_detail/calibration.py index d34e56381388ed3d9436cc886589470e99cdc8ea..f13e3ded50f40f8ba91634b6a20ed37ae5a7836f 100644 --- a/scripts/SANS/sans/algorithm_detail/calibration.py +++ b/scripts/SANS/sans/algorithm_detail/calibration.py @@ -37,20 +37,20 @@ def apply_calibration(calibration_file_name, workspaces, monitor_workspaces, use # Check for the sample scatter and the can scatter workspaces workspaces_to_calibrate = {} - if SANSDataType.SampleScatter in workspaces: - workspaces_to_calibrate.update({SANSDataType.SampleScatter: workspaces[SANSDataType.SampleScatter]}) - if SANSDataType.CanScatter in workspaces: - workspaces_to_calibrate.update({SANSDataType.CanScatter: workspaces[SANSDataType.CanScatter]}) + if SANSDataType.SAMPLE_SCATTER in workspaces: + workspaces_to_calibrate.update({SANSDataType.SAMPLE_SCATTER: workspaces[SANSDataType.SAMPLE_SCATTER]}) + if SANSDataType.CAN_SCATTER in workspaces: + workspaces_to_calibrate.update({SANSDataType.CAN_SCATTER: workspaces[SANSDataType.CAN_SCATTER]}) do_apply_calibration(full_file_path, workspaces_to_calibrate, use_loaded, publish_to_ads, parent_alg) # Check for the sample scatter and the can scatter workspaces monitors workspace_monitors_to_calibrate = {} - if SANSDataType.SampleScatter in monitor_workspaces: - workspace_monitors_to_calibrate.update({SANSDataType.SampleScatter: - monitor_workspaces[SANSDataType.SampleScatter]}) - if SANSDataType.CanScatter in monitor_workspaces: - workspace_monitors_to_calibrate.update({SANSDataType.CanScatter: - monitor_workspaces[SANSDataType.CanScatter]}) + if SANSDataType.SAMPLE_SCATTER in monitor_workspaces: + workspace_monitors_to_calibrate.update({SANSDataType.SAMPLE_SCATTER: + monitor_workspaces[SANSDataType.SAMPLE_SCATTER]}) + if SANSDataType.CAN_SCATTER in monitor_workspaces: + workspace_monitors_to_calibrate.update({SANSDataType.CAN_SCATTER: + monitor_workspaces[SANSDataType.CAN_SCATTER]}) do_apply_calibration(full_file_path, workspace_monitors_to_calibrate, use_loaded, publish_to_ads, parent_alg) diff --git a/scripts/SANS/sans/algorithm_detail/centre_finder_new.py b/scripts/SANS/sans/algorithm_detail/centre_finder_new.py index b8b4cf670baf95a7102f3eb2424be8b2a415b591..0dc9d962beb0ffba6bf32aa8649428335418fc01 100644 --- a/scripts/SANS/sans/algorithm_detail/centre_finder_new.py +++ b/scripts/SANS/sans/algorithm_detail/centre_finder_new.py @@ -16,7 +16,7 @@ from mantid.simpleapi import CreateEmptyTableWorkspace # Functions for the execution of a single batch iteration # ---------------------------------------------------------------------------------------------------------------------- def centre_finder_new(state, r_min = 0.06, r_max = 0.26, iterations = 10, position_1_start = 0.0, position_2_start = 0.0 - , tolerance = 0.0001251, find_direction = FindDirectionEnum.All, verbose=False, component=DetectorType.LAB): + , tolerance = 0.0001251, find_direction = FindDirectionEnum.ALL, verbose=False, component=DetectorType.LAB): """ Finds the beam centre from a good initial guess. @@ -35,15 +35,15 @@ def centre_finder_new(state, r_min = 0.06, r_max = 0.26, iterations = 10, positi # ------------------------------------------------------------------------------------------------------------------ # Load the data # ------------------------------------------------------------------------------------------------------------------ - workspace_to_name = {SANSDataType.SampleScatter: "SampleScatterWorkspace", - SANSDataType.SampleTransmission: "SampleTransmissionWorkspace", - SANSDataType.SampleDirect: "SampleDirectWorkspace", - SANSDataType.CanScatter: "CanScatterWorkspace", - SANSDataType.CanTransmission: "CanTransmissionWorkspace", - SANSDataType.CanDirect: "CanDirectWorkspace"} + workspace_to_name = {SANSDataType.SAMPLE_SCATTER: "SampleScatterWorkspace", + SANSDataType.SAMPLE_TRANSMISSION: "SampleTransmissionWorkspace", + SANSDataType.SAMPLE_DIRECT: "SampleDirectWorkspace", + SANSDataType.CAN_SCATTER: "CanScatterWorkspace", + SANSDataType.CAN_TRANSMISSION: "CanTransmissionWorkspace", + SANSDataType.CAN_DIRECT: "CanDirectWorkspace"} - workspace_to_monitor = {SANSDataType.SampleScatter: "SampleScatterMonitorWorkSpace", - SANSDataType.CanScatter: "CanScatterMonitorWorkspace"} + workspace_to_monitor = {SANSDataType.SAMPLE_SCATTER: "SampleScatterMonitorWorkSpace", + SANSDataType.CAN_SCATTER: "CanScatterMonitorWorkspace"} workspaces, monitors = provide_loaded_data(state, False, workspace_to_name, workspace_to_monitor) @@ -61,8 +61,8 @@ def centre_finder_new(state, r_min = 0.06, r_max = 0.26, iterations = 10, positi beam_centre_finder = "SANSBeamCentreFinder" beam_centre_finder_options = {"Iterations": iterations, "RMin": r_min/1000, "RMax": r_max/1000, "Position1Start": position_1_start, "Position2Start": position_2_start, - "Tolerance": tolerance, "Direction" : FindDirectionEnum.to_string(find_direction), - "Verbose": verbose, "Component": DetectorType.to_string(component)} + "Tolerance": tolerance, "Direction" : find_direction.value, + "Verbose": verbose, "Component": component.value} beam_centre_alg = create_managed_non_child_algorithm(beam_centre_finder, **beam_centre_finder_options) beam_centre_alg.setChild(False) set_properties_for_beam_centre_algorithm(beam_centre_alg, reduction_package, @@ -98,9 +98,9 @@ def centre_finder_mass(state, r_min = 0.06, max_iter=10, position_1_start = 0.0, # ------------------------------------------------------------------------------------------------------------------ # Load the data # ------------------------------------------------------------------------------------------------------------------ - workspace_to_name = {SANSDataType.SampleScatter: "SampleScatterWorkspace"} + workspace_to_name = {SANSDataType.SAMPLE_SCATTER: "SampleScatterWorkspace"} - workspace_to_monitor = {SANSDataType.SampleScatter: "SampleScatterMonitorWorkSpace"} + workspace_to_monitor = {SANSDataType.SAMPLE_SCATTER: "SampleScatterMonitorWorkSpace"} workspaces, monitors = provide_loaded_data(state, False, workspace_to_name, workspace_to_monitor) @@ -116,7 +116,7 @@ def centre_finder_mass(state, r_min = 0.06, max_iter=10, position_1_start = 0.0, # ------------------------------------------------------------------------------------------------------------------ beam_centre_finder = "SANSBeamCentreFinderMassMethod" beam_centre_finder_options = {"RMin": r_min/1000, "Centre1": position_1_start, - "Centre2": position_2_start, "Tolerance": tolerance, "Component": DetectorType.to_string(component)} + "Centre2": position_2_start, "Tolerance": tolerance, "Component": component.value} beam_centre_alg = create_managed_non_child_algorithm(beam_centre_finder, **beam_centre_finder_options) beam_centre_alg.setChild(False) diff --git a/scripts/SANS/sans/algorithm_detail/convert_to_q.py b/scripts/SANS/sans/algorithm_detail/convert_to_q.py index 6218f239c270ff28b9f5f260b790c699840be7ed..1167607415bf0bfdb492fd3e5b6f4bb0cc9185b5 100644 --- a/scripts/SANS/sans/algorithm_detail/convert_to_q.py +++ b/scripts/SANS/sans/algorithm_detail/convert_to_q.py @@ -32,7 +32,7 @@ def convert_workspace(workspace, state_convert_to_q, output_summed_parts=False, # Perform either a 1D reduction or a 2D reduction reduction_dimensionality = state_convert_to_q.reduction_dimensionality - if reduction_dimensionality is ReductionDimensionality.OneDim: + if reduction_dimensionality is ReductionDimensionality.ONE_DIM: output_workspace, sum_of_counts_workspace, sum_of_norms_workspace = \ _run_q_1d(workspace, output_summed_parts, conv_to_q_state=state_convert_to_q, pixel_adj_ws=pixel_adj_workspace, wavelength_adj_ws=wavelength_adj_workspace, @@ -112,7 +112,7 @@ def _run_q_2d(workspace, output_summed_parts, state_convert_to_q, # Extract relevant settings max_q_xy = state_convert_to_q.q_xy_max - log_binning = True if state_convert_to_q.q_xy_step_type is RangeStepType.Log else False + log_binning = True if state_convert_to_q.q_xy_step_type is RangeStepType.LOG else False delta_q = state_convert_to_q.q_xy_step radius_cutoff = state_convert_to_q.radius_cutoff / 1000. # Qxy expects the radius cutoff to be in mm wavelength_cutoff = state_convert_to_q.wavelength_cutoff diff --git a/scripts/SANS/sans/algorithm_detail/crop_helper.py b/scripts/SANS/sans/algorithm_detail/crop_helper.py index 91690f5afe5bf0c80004d1390be213899b6f22df..660d0a64099aaa3d306cd487ff9d4078f4a78b2a 100644 --- a/scripts/SANS/sans/algorithm_detail/crop_helper.py +++ b/scripts/SANS/sans/algorithm_detail/crop_helper.py @@ -18,5 +18,5 @@ def get_component_name(workspace, detector_type): instrument = workspace.getInstrument() instrument_name = instrument.getName().strip() instrument_name = instrument_name.upper() - instrument = SANSInstrument.from_string(instrument_name) + instrument = SANSInstrument[instrument_name] return convert_instrument_and_detector_type_to_bank_name(instrument, detector_type) diff --git a/scripts/SANS/sans/algorithm_detail/load_data.py b/scripts/SANS/sans/algorithm_detail/load_data.py index 7d20ac7983eef246cad36342bddfe315ea09ee8a..205473d3d943bf4e87d1e0db5018822fcf090409 100644 --- a/scripts/SANS/sans/algorithm_detail/load_data.py +++ b/scripts/SANS/sans/algorithm_detail/load_data.py @@ -59,7 +59,7 @@ from sans.common.file_information import (SANSFileInformationFactory, FileType, from sans.common.constants import (EMPTY_NAME, SANS_SUFFIX, TRANS_SUFFIX, MONITOR_SUFFIX, CALIBRATION_WORKSPACE_TAG, SANS_FILE_TAG, OUTPUT_WORKSPACE_GROUP, OUTPUT_MONITOR_WORKSPACE, OUTPUT_MONITOR_WORKSPACE_GROUP) -from sans.common.enums import (SANSFacility, SANSInstrument, SANSDataType) +from sans.common.enums import (SANSFacility, SANSDataType, SANSInstrument) from sans.common.general_functions import (create_child_algorithm) from sans.common.log_tagger import (set_tag, has_tag, get_tag) from sans.state.data import (StateData) @@ -86,28 +86,28 @@ def get_file_and_period_information_from_data(data): period_information = dict() if data.sample_scatter: update_file_information(file_information, file_information_factory, - SANSDataType.SampleScatter, data.sample_scatter) - period_information.update({SANSDataType.SampleScatter: data.sample_scatter_period}) + SANSDataType.SAMPLE_SCATTER, data.sample_scatter) + period_information.update({SANSDataType.SAMPLE_SCATTER: data.sample_scatter_period}) if data.sample_transmission: update_file_information(file_information, file_information_factory, - SANSDataType.SampleTransmission, data.sample_transmission) - period_information.update({SANSDataType.SampleTransmission: data.sample_transmission_period}) + SANSDataType.SAMPLE_TRANSMISSION, data.sample_transmission) + period_information.update({SANSDataType.SAMPLE_TRANSMISSION: data.sample_transmission_period}) if data.sample_direct: update_file_information(file_information, file_information_factory, - SANSDataType.SampleDirect, data.sample_direct) - period_information.update({SANSDataType.SampleDirect: data.sample_direct_period}) + SANSDataType.SAMPLE_DIRECT, data.sample_direct) + period_information.update({SANSDataType.SAMPLE_DIRECT: data.sample_direct_period}) if data.can_scatter: update_file_information(file_information, file_information_factory, - SANSDataType.CanScatter, data.can_scatter) - period_information.update({SANSDataType.CanScatter: data.can_scatter_period}) + SANSDataType.CAN_SCATTER, data.can_scatter) + period_information.update({SANSDataType.CAN_SCATTER: data.can_scatter_period}) if data.can_transmission: update_file_information(file_information, file_information_factory, - SANSDataType.CanTransmission, data.can_transmission) - period_information.update({SANSDataType.CanTransmission: data.can_transmission_period}) + SANSDataType.CAN_TRANSMISSION, data.can_transmission) + period_information.update({SANSDataType.CAN_TRANSMISSION: data.can_transmission_period}) if data.can_direct: update_file_information(file_information, file_information_factory, - SANSDataType.CanDirect, data.can_direct) - period_information.update({SANSDataType.CanDirect: data.can_direct_period}) + SANSDataType.CAN_DIRECT, data.can_direct) + period_information.update({SANSDataType.CAN_DIRECT: data.can_direct_period}) return file_information, period_information @@ -119,8 +119,8 @@ def is_transmission_type(to_check): :param to_check: A SANSDataType object. :return: true if the SANSDataType object is a transmission object (transmission or direct) else false. """ - return ((to_check is SANSDataType.SampleTransmission) or (to_check is SANSDataType.SampleDirect) or - (to_check is SANSDataType.CanTransmission) or (to_check is SANSDataType.CanDirect)) + return ((to_check is SANSDataType.SAMPLE_TRANSMISSION) or (to_check is SANSDataType.SAMPLE_DIRECT) or + (to_check is SANSDataType.CAN_TRANSMISSION) or (to_check is SANSDataType.CAN_DIRECT)) def get_expected_file_tags(file_information, is_transmission, period): @@ -637,11 +637,11 @@ def get_loader_strategy(file_information): :param file_information: a SANSFileInformation object. :return: a handle to the correct loading function/strategy. """ - if file_information.get_type() == FileType.ISISNexus: + if file_information.get_type() == FileType.ISIS_NEXUS: loader = loader_for_isis_nexus - elif file_information.get_type() == FileType.ISISRaw: + elif file_information.get_type() == FileType.ISIS_RAW: loader = loader_for_raw - elif file_information.get_type() == FileType.ISISNexusAdded: + elif file_information.get_type() == FileType.ISIS_NEXUS_ADDED: loader = loader_for_added_isis_nexus else: raise RuntimeError("SANSLoad: Cannot load SANS file of type {0}".format(str(file_information.get_type()))) @@ -709,6 +709,7 @@ class SANSLoadData(with_metaclass(ABCMeta, object)): class SANSLoadDataISIS(SANSLoadData): """Load implementation of SANSLoad for ISIS data""" + def do_execute(self, data_info, use_cached, publish_to_ads, progress, parent_alg): # Get all entries from the state file file_infos, period_infos = get_file_and_period_information_from_data(data_info) @@ -728,7 +729,7 @@ class SANSLoadDataISIS(SANSLoadData): for key, value in list(file_infos.items()): # Loading - report_message = "Loading {0}".format(SANSDataType.to_string(key)) + report_message = "Loading {0}".format(key.value) progress.report(report_message) workspace_pack, workspace_monitors_pack = load_isis(key, value, period_infos[key], @@ -755,6 +756,7 @@ class SANSLoadDataISIS(SANSLoadData): class SANSLoadDataFactory(object): """ A factory for SANSLoadData.""" + def __init__(self): super(SANSLoadDataFactory, self).__init__() @@ -764,7 +766,7 @@ class SANSLoadDataFactory(object): # Get the correct loader based on the sample scatter file from the data sub state data.validate() file_info, _ = get_file_and_period_information_from_data(data) - sample_scatter_info = file_info[SANSDataType.SampleScatter] + sample_scatter_info = file_info[SANSDataType.SAMPLE_SCATTER] return sample_scatter_info.get_facility() @staticmethod @@ -836,6 +838,7 @@ class LOQTransmissionCorrection(TransmissionCorrection): def get_transmission_correction(data_info): instrument_type = data_info.instrument + if instrument_type is SANSInstrument.LOQ: return LOQTransmissionCorrection() else: diff --git a/scripts/SANS/sans/algorithm_detail/mask_functions.py b/scripts/SANS/sans/algorithm_detail/mask_functions.py index 8daa48d842ba7f2045ea12ac778de76df89c4367..5fe8973f094bf902ce2cad0988062513ee547447 100644 --- a/scripts/SANS/sans/algorithm_detail/mask_functions.py +++ b/scripts/SANS/sans/algorithm_detail/mask_functions.py @@ -6,7 +6,9 @@ # SPDX - License - Identifier: GPL - 3.0 + from __future__ import (absolute_import, division, print_function) from collections import (namedtuple, Sequence) -from sans.common.enums import (SANSInstrument, DetectorOrientation, DetectorType) + +from mantid.py3compat import Enum +from sans.common.enums import (DetectorType, SANSInstrument) from sans.common.xml_parsing import get_named_elements_from_ipf_file @@ -15,6 +17,15 @@ detector_shape_bundle = namedtuple("detector_shape_bundle", 'rectangular_shape, geometry_bundle = namedtuple("geometry_bundle", 'shape, first_low_angle_spec_number') +class DetectorOrientation(Enum): + """ + Defines the detector orientation. + """ + HORIZONTAL = "Horizontal" + ROTATED = "Rotated" + VERTICAL = "Vertical" + + def get_geometry_information(ipf_path, detector_type): """ This function extracts geometry information for the detector benches. @@ -118,22 +129,22 @@ class SpectraBlock(object): if self._instrument is SANSInstrument.SANS2D: if self._run_number < 568: self._first_spectrum_number = 1 - self._detector_orientation = DetectorOrientation.Vertical + self._detector_orientation = DetectorOrientation.VERTICAL elif 568 <= self._run_number < 684: self._first_spectrum_number = 9 - self._detector_orientation = DetectorOrientation.Rotated + self._detector_orientation = DetectorOrientation.ROTATED else: self._first_spectrum_number = 9 - self._detector_orientation = DetectorOrientation.Horizontal + self._detector_orientation = DetectorOrientation.HORIZONTAL elif self._instrument is SANSInstrument.LARMOR: self._first_spectrum_number = 10 - self._detector_orientation = DetectorOrientation.Horizontal + self._detector_orientation = DetectorOrientation.HORIZONTAL elif self._instrument is SANSInstrument.LOQ: self._first_spectrum_number = 3 - self._detector_orientation = DetectorOrientation.Horizontal + self._detector_orientation = DetectorOrientation.HORIZONTAL elif self._instrument is SANSInstrument.ZOOM: self._first_spectrum_number = 9 - self._detector_orientation = DetectorOrientation.Horizontal + self._detector_orientation = DetectorOrientation.HORIZONTAL else: raise RuntimeError("MaskParser: Cannot handle masking request for " "instrument {0}".format(str(self._instrument))) @@ -160,16 +171,16 @@ class SpectraBlock(object): base_spectrum_number = self._first_spectrum_number output = [] - if self._detector_orientation == DetectorOrientation.Horizontal: + if self._detector_orientation == DetectorOrientation.HORIZONTAL: start_spectrum = base_spectrum_number + y_lower * detector_dimension + x_lower for y in range(0, y_dim): output.extend((start_spectrum + (y * detector_dimension) + x for x in range(0, x_dim))) - elif self._detector_orientation == DetectorOrientation.Vertical: + elif self._detector_orientation == DetectorOrientation.VERTICAL: start_spectrum = base_spectrum_number + x_lower * detector_dimension + y_lower for x in range(detector_dimension - 1, detector_dimension - x_dim - 1, -1): output.extend((start_spectrum + ((detector_dimension - x - 1) * detector_dimension) + y for y in range(0, y_dim))) - elif self._detector_orientation == DetectorOrientation.Rotated: + elif self._detector_orientation == DetectorOrientation.ROTATED: # This is the horizontal one rotated so need to map the x_low and y_low to their rotated versions start_spectrum = base_spectrum_number + y_lower * detector_dimension + x_lower max_spectrum = detector_dimension * detector_dimension + base_spectrum_number - 1 diff --git a/scripts/SANS/sans/algorithm_detail/mask_sans_workspace.py b/scripts/SANS/sans/algorithm_detail/mask_sans_workspace.py index 298966c49403ff05c7d500a37401bd7386b19ea9..0fc309f9291219ec292d56c40ac01caeea6b6ea1 100644 --- a/scripts/SANS/sans/algorithm_detail/mask_sans_workspace.py +++ b/scripts/SANS/sans/algorithm_detail/mask_sans_workspace.py @@ -14,7 +14,7 @@ from sans.common.general_functions import append_to_sans_file_tag def mask_workspace(state, component_as_string, workspace): assert (state is not dict) - component = DetectorType.from_string(component_as_string) + component = DetectorType(component_as_string) # Get the correct SANS masking strategy from create_masker masker = create_masker(state, component) diff --git a/scripts/SANS/sans/algorithm_detail/mask_workspace.py b/scripts/SANS/sans/algorithm_detail/mask_workspace.py index 6491644d8d92d4cb84164f7bde9ddad4c02e028c..39c521b3dcaf2a16da6d45c75bbe9474625b6a59 100644 --- a/scripts/SANS/sans/algorithm_detail/mask_workspace.py +++ b/scripts/SANS/sans/algorithm_detail/mask_workspace.py @@ -5,14 +5,17 @@ # & Institut Laue - Langevin # SPDX - License - Identifier: GPL - 3.0 + from __future__ import (absolute_import, division, print_function) -from six import with_metaclass + from abc import (ABCMeta, abstractmethod) + +from six import with_metaclass + +from sans.algorithm_detail.mask_functions import SpectraBlock +from sans.algorithm_detail.xml_shapes import (add_cylinder, add_outside_cylinder, create_phi_mask, create_line_mask) from sans.common.constants import EMPTY_NAME -from sans.common.enums import (SANSInstrument, DetectorType) -from sans.common.general_functions import create_unmanaged_algorithm +from sans.common.enums import SANSInstrument from sans.common.file_information import (find_full_file_path, get_instrument_paths_for_sans_file) -from sans.algorithm_detail.xml_shapes import (add_cylinder, add_outside_cylinder, create_phi_mask, create_line_mask) -from sans.algorithm_detail.mask_functions import (SpectraBlock) +from sans.common.general_functions import create_unmanaged_algorithm # ------------------------------------------------------------------ @@ -35,8 +38,8 @@ def mask_bins(mask_info, workspace, detector_type): bin_mask_general_start = mask_info.bin_mask_general_start bin_mask_general_stop = mask_info.bin_mask_general_stop # Mask the bins with the detector-specific setting - bin_mask_start = mask_info.detectors[DetectorType.to_string(detector_type)].bin_mask_start - bin_mask_stop = mask_info.detectors[DetectorType.to_string(detector_type)].bin_mask_stop + bin_mask_start = mask_info.detectors[detector_type.value].bin_mask_start + bin_mask_stop = mask_info.detectors[detector_type.value].bin_mask_stop # Combine the settings and run the binning start_mask = [] @@ -178,7 +181,7 @@ def mask_spectra(mask_info, workspace, spectra_block, detector_type): total_spectra = [] # All masks are detector-specific, hence we pull out only the relevant part - detector = mask_info.detectors[DetectorType.to_string(detector_type)] + detector = mask_info.detectors[detector_type.value] # ---------------------- # Single spectra @@ -400,6 +403,9 @@ def create_masker(state, detector_type): """ data_info = state.data instrument = data_info.instrument + + # TODO remove this shim + detector_names = state.reduction.detector_names if instrument is SANSInstrument.LARMOR or instrument is SANSInstrument.LOQ or\ instrument is SANSInstrument.SANS2D or instrument is SANSInstrument.ZOOM: # noqa diff --git a/scripts/SANS/sans/algorithm_detail/merge_reductions.py b/scripts/SANS/sans/algorithm_detail/merge_reductions.py index 8d0c87c4ef0434c8fbb7ba88ee65c033e647990c..6cf07b75591be9765fdb6f1fc96965f035549dcc 100644 --- a/scripts/SANS/sans/algorithm_detail/merge_reductions.py +++ b/scripts/SANS/sans/algorithm_detail/merge_reductions.py @@ -7,12 +7,15 @@ """ Merges two reduction types to single reduction""" from __future__ import (absolute_import, division, print_function) + from abc import (ABCMeta, abstractmethod) + from six import with_metaclass -from sans.common.general_functions import create_child_algorithm -from sans.common.enums import (SANSFacility, DataType, FitModeForMerge) -from sans.algorithm_detail.bundles import MergeBundle + import mantid.simpleapi as mantid_api +from sans.algorithm_detail.bundles import MergeBundle +from sans.common.enums import (SANSFacility, DataType) +from sans.common.general_functions import create_child_algorithm class Merger(with_metaclass(ABCMeta, object)): @@ -27,6 +30,7 @@ class ISIS1DMerger(Merger): """ Class which handles ISIS-style merges. """ + def __init__(self): super(ISIS1DMerger, self).__init__() @@ -74,7 +78,7 @@ class ISIS1DMerger(Merger): shift_factor, scale_factor, fit_mode, fit_min, fit_max, merge_mask, merge_min, merge_max = \ get_shift_and_scale_parameter(reduction_mode_vs_output_bundles) - fit_mode_as_string = FitModeForMerge.to_string(fit_mode) + fit_mode_as_string = fit_mode.value # We need to convert NoFit to None. if fit_mode_as_string == "NoFit": @@ -177,8 +181,8 @@ def get_partial_workspaces(primary_detector, secondary_detector, reduction_mode_ """ Get the partial workspaces for the primary and secondary detectors. - :param primary_detector: the primary detector (now normally ISISReductionMode.LAB) - :param secondary_detector: the secondary detector (now normally ISISReductionMode.HAB) + :param primary_detector: the primary detector (now normally ReductionMode.LAB) + :param secondary_detector: the secondary detector (now normally ReductionMode.HAB) :param reduction_mode_vs_output_bundles: a ReductionMode vs OutputBundles map :param is_data_type: the data type, i.e. if can or sample :return: the primary count workspace, the primary normalization workspace, the secondary count workspace and the @@ -223,8 +227,8 @@ def get_shift_and_scale_parameter(reduction_mode_vs_output_bundles): def is_sample(x): - return x.data_type is DataType.Sample + return x.data_type is DataType.SAMPLE def is_can(x): - return x.data_type is DataType.Can + return x.data_type is DataType.CAN diff --git a/scripts/SANS/sans/algorithm_detail/move_sans_instrument_component.py b/scripts/SANS/sans/algorithm_detail/move_sans_instrument_component.py index 571331689c3ceb4fde689e0eca6534c6c7c7b4a5..0fb34adb3fa50bead32c0a7111d9ec8253d20046 100644 --- a/scripts/SANS/sans/algorithm_detail/move_sans_instrument_component.py +++ b/scripts/SANS/sans/algorithm_detail/move_sans_instrument_component.py @@ -66,7 +66,7 @@ def _get_coordinates(move_info, component_name): # If the detector is unknown take the position from the LAB if selected_detector is None: - selected_detector = detectors[DetectorType.to_string(DetectorType.LAB)] + selected_detector = detectors[DetectorType.LAB.value] pos1 = selected_detector.sample_centre_pos1 pos2 = selected_detector.sample_centre_pos2 coordinates = [pos1, pos2] @@ -99,9 +99,9 @@ def _get_detector_for_component(move_info, component): detectors = move_info.detectors selected_detector = None if component == "HAB": - selected_detector = detectors[DetectorType.to_string(DetectorType.HAB)] + selected_detector = detectors[DetectorType.HAB.value] elif component == "LAB": - selected_detector = detectors[DetectorType.to_string(DetectorType.LAB)] + selected_detector = detectors[DetectorType.LAB.value] else: # Check if the component is part of the detector names for _, detector in list(detectors.items()): diff --git a/scripts/SANS/sans/algorithm_detail/move_workspaces.py b/scripts/SANS/sans/algorithm_detail/move_workspaces.py index d1d0de7d7cddaa191674c9be54adcc4d84251205..f8ee3bc0cb21b98ad293f011bbfb319fcc580468 100644 --- a/scripts/SANS/sans/algorithm_detail/move_workspaces.py +++ b/scripts/SANS/sans/algorithm_detail/move_workspaces.py @@ -12,7 +12,7 @@ from mantid.api import MatrixWorkspace from six import with_metaclass from abc import (ABCMeta, abstractmethod) from sans.state.move import StateMove -from sans.common.enums import (SANSInstrument, CanonicalCoordinates, DetectorType) +from sans.common.enums import CanonicalCoordinates, DetectorType, SANSInstrument from sans.common.general_functions import (create_unmanaged_algorithm, get_single_valued_logs_from_workspace, quaternion_to_angle_and_axis, sanitise_instrument_name) @@ -119,7 +119,7 @@ def move_backstop_monitor(ws, move_info, monitor_offset, monitor_spectrum_number monitor_position = comp_info.position(monitor_n) # The location is relative to the rear-detector, get this position - lab_detector = move_info.detectors[DetectorType.to_string(DetectorType.LAB)] + lab_detector = move_info.detectors[DetectorType.LAB.value] detector_name = lab_detector.detector_name lab_detector_index = comp_info.indexOfAny(detector_name) detector_position = comp_info.position(lab_detector_index) @@ -273,10 +273,10 @@ def set_components_to_original_for_isis(move_info, workspace, component): if not component: component_names = list(move_info.monitor_names.values()) - hab_key = DetectorType.to_string(DetectorType.HAB) + hab_key = DetectorType.HAB.value _reset_detector(hab_key, move_info, component_names) - lab_key = DetectorType.to_string(DetectorType.LAB) + lab_key = DetectorType.LAB.value _reset_detector(lab_key, move_info, component_names) component_names.append("some-sample-holder") @@ -321,7 +321,7 @@ def move_low_angle_bank_for_SANS2D_and_ZOOM(move_info, workspace, coordinates, u lab_detector_z = 0. # Perform x and y tilt - lab_detector = move_info.detectors[DetectorType.to_string(DetectorType.LAB)] + lab_detector = move_info.detectors[DetectorType.LAB.value] SANSMoveSANS2D.perform_x_and_y_tilts(workspace, lab_detector) lab_detector_default_sd_m = move_info.lab_detector_default_sd_m @@ -475,7 +475,7 @@ class SANSMoveSANS2D(SANSMove): hab_detector_default_sd_m = move_info.hab_detector_default_sd_m # Detector and name - hab_detector = move_info.detectors[DetectorType.to_string(DetectorType.HAB)] + hab_detector = move_info.detectors[DetectorType.HAB.value] detector_name = hab_detector.detector_name # Perform x and y tilt @@ -492,7 +492,7 @@ class SANSMoveSANS2D(SANSMove): # Add translational corrections x = coordinates[0] y = coordinates[1] - lab_detector = move_info.detectors[DetectorType.to_string(DetectorType.LAB)] + lab_detector = move_info.detectors[DetectorType.LAB.value] rotation_in_radians = math.pi * (hab_detector_rotation + hab_detector.rotation_correction)/180. x_shift = ((lab_detector_x + lab_detector.x_translation_correction - @@ -575,7 +575,7 @@ class SANSMoveLOQ(SANSMove): x_shift = center_position - x y_shift = center_position - y - detectors = [DetectorType.to_string(DetectorType.LAB), DetectorType.to_string(DetectorType.HAB)] + detectors = [DetectorType.LAB.value, DetectorType.HAB.value] for detector in detectors: # Get the detector name component_name = move_info.detectors[detector].detector_name @@ -622,13 +622,13 @@ class SANSMoveLARMOROldStyle(SANSMove): y_shift = -coordinates[1] coordinates_for_only_y = [0.0, y_shift] apply_standard_displacement(move_info, workspace, coordinates_for_only_y, - DetectorType.to_string(DetectorType.LAB)) + DetectorType.LAB.value) # Shift the low-angle bank detector in the x direction x_shift = -coordinates[0] coordinates_for_only_x = [x_shift, 0.0] apply_standard_displacement(move_info, workspace, coordinates_for_only_x, - DetectorType.to_string(DetectorType.LAB)) + DetectorType.LAB.value) def do_move_with_elementary_displacement(self, move_info, workspace, coordinates, component): # For LOQ we only have to coordinates @@ -684,7 +684,7 @@ class SANSMoveLARMORNewStyle(SANSMove): y_shift = -coordinates[1] coordinates_for_only_y = [0.0, y_shift] apply_standard_displacement(move_info, workspace, coordinates_for_only_y, - DetectorType.to_string(DetectorType.LAB)) + DetectorType.LAB.value) # Shift the low-angle bank detector in the x direction angle = coordinates[0] @@ -697,7 +697,7 @@ class SANSMoveLARMORNewStyle(SANSMove): if log_values[bench_rot_tag] is None else log_values[bench_rot_tag] self._rotate_around_y_axis(move_info, workspace, angle, - DetectorType.to_string(DetectorType.LAB), bench_rotation) + DetectorType.LAB.value, bench_rotation) def do_move_with_elementary_displacement(self, move_info, workspace, coordinates, component): # For LOQ we only have to coordinates @@ -782,7 +782,7 @@ def create_mover(workspace): instrument = workspace.getInstrument() instrument_name = instrument.getName() instrument_name = sanitise_instrument_name(instrument_name) - instrument_type = SANSInstrument.from_string(instrument_name) + instrument_type = SANSInstrument[instrument_name] if SANSMoveLOQ.is_correct(instrument_type, run_number): mover = SANSMoveLOQ() elif SANSMoveSANS2D.is_correct(instrument_type, run_number): diff --git a/scripts/SANS/sans/algorithm_detail/normalize_to_sans_monitor.py b/scripts/SANS/sans/algorithm_detail/normalize_to_sans_monitor.py index 893d5dc87779f76dadb043da5d001d0e55f68cc1..852715a304b432bb709b26147c98ba106787dd98 100644 --- a/scripts/SANS/sans/algorithm_detail/normalize_to_sans_monitor.py +++ b/scripts/SANS/sans/algorithm_detail/normalize_to_sans_monitor.py @@ -4,14 +4,12 @@ # NScD Oak Ridge National Laboratory, European Spallation Source # & Institut Laue - Langevin # SPDX - License - Identifier: GPL - 3.0 + -# pylint: disable=invalid-name """ SANSNormalizeToMonitor algorithm calculates the normalization to the monitor.""" from __future__ import (absolute_import, division, print_function) from sans.common.constants import EMPTY_NAME -from sans.common.enums import RebinType, RangeStepType from sans.common.general_functions import create_unmanaged_algorithm @@ -167,8 +165,8 @@ def _convert_to_wavelength(workspace, normalize_to_monitor_state): "WavelengthLow": wavelength_low, "WavelengthHigh": wavelength_high, "WavelengthStep": wavelength_step, - "WavelengthStepType": RangeStepType.to_string(wavelength_step_type), - "RebinMode": RebinType.to_string(wavelength_rebin_mode)} + "WavelengthStepType": wavelength_step_type.value, + "RebinMode": wavelength_rebin_mode.value} convert_alg = create_unmanaged_algorithm(convert_name, **convert_options) convert_alg.setPropertyValue("OutputWorkspace", EMPTY_NAME) diff --git a/scripts/SANS/sans/algorithm_detail/save_workspace.py b/scripts/SANS/sans/algorithm_detail/save_workspace.py index ad71fc810789879a4a47ec612a4a634533d0d667..f7fc1c682508383053220830d08c5c3aa5116cab 100644 --- a/scripts/SANS/sans/algorithm_detail/save_workspace.py +++ b/scripts/SANS/sans/algorithm_detail/save_workspace.py @@ -48,20 +48,20 @@ def get_save_strategy(file_format_bundle, file_name, save_options, transmission_ :return: a handle to a save algorithm """ file_format = file_format_bundle.file_format - if file_format is SaveType.Nexus: + if file_format is SaveType.NEXUS: file_name = get_file_name(file_format_bundle, file_name, "", ".nxs") save_name = "SaveNexusProcessed" - elif file_format is SaveType.CanSAS: + elif file_format is SaveType.CAN_SAS: file_name = get_file_name(file_format_bundle, file_name, "", ".xml") save_name = "SaveCanSAS1D" save_options.update(transmission_workspaces) save_options.update(additional_run_numbers) - elif file_format is SaveType.NXcanSAS: + elif file_format is SaveType.NX_CAN_SAS: file_name = get_file_name(file_format_bundle, file_name, "_nxcansas", ".h5") save_name = "SaveNXcanSAS" save_options.update(transmission_workspaces) save_options.update(additional_run_numbers) - elif file_format is SaveType.NistQxy: + elif file_format is SaveType.NIST_QXY: file_name = get_file_name(file_format_bundle, file_name, "_nistqxy", ".dat") save_name = "SaveNISTDAT" elif file_format is SaveType.RKH: diff --git a/scripts/SANS/sans/algorithm_detail/scale_sans_workspace.py b/scripts/SANS/sans/algorithm_detail/scale_sans_workspace.py index 457b1d25c0b9e773b898b80c4efdf2983cce63bb..67e74ad019c7bd52bb6f87659761c46a65900240 100644 --- a/scripts/SANS/sans/algorithm_detail/scale_sans_workspace.py +++ b/scripts/SANS/sans/algorithm_detail/scale_sans_workspace.py @@ -12,7 +12,7 @@ from __future__ import (absolute_import, division, print_function) import math from sans.common.constants import EMPTY_NAME -from sans.common.enums import SANSInstrument, SampleShape +from sans.common.enums import SampleShape, SANSInstrument from sans.common.general_functions import append_to_sans_file_tag, create_unmanaged_algorithm @@ -79,15 +79,15 @@ def _get_volume(scale_info): shape = scale_info.shape if scale_info.shape is not None else scale_info.shape_from_file # Now we calculate the volume - if shape is SampleShape.Cylinder: + if shape is SampleShape.CYLINDER: # Volume = circle area * height volume = height * math.pi radius = width / 2.0 volume *= math.pow(radius, 2) - elif shape is SampleShape.FlatPlate: + elif shape is SampleShape.FLAT_PLATE: # Flat plate sample volume = width * height * thickness - elif shape is SampleShape.Disc: + elif shape is SampleShape.DISC: # Factor of four comes from radius = width/2 # Disc - where height is not used volume = thickness * math.pi diff --git a/scripts/SANS/sans/algorithm_detail/single_execution.py b/scripts/SANS/sans/algorithm_detail/single_execution.py index 4958a416d02992be4df44d4dd571e6ed16772d81..6b389b67b75e5fc055db85f18da3511a29b34e8f 100644 --- a/scripts/SANS/sans/algorithm_detail/single_execution.py +++ b/scripts/SANS/sans/algorithm_detail/single_execution.py @@ -14,7 +14,7 @@ from sans.algorithm_detail.bundles import (EventSliceSettingBundle, OutputBundle from sans.algorithm_detail.merge_reductions import (MergeFactory, is_sample, is_can) from sans.algorithm_detail.strip_end_nans_and_infs import strip_end_nans from sans.common.constants import EMPTY_NAME -from sans.common.enums import (DataType, DetectorType, ISISReductionMode, OutputParts, TransmissionType) +from sans.common.enums import (DetectorType, ReductionMode, OutputParts, TransmissionType) from sans.common.general_functions import (create_child_algorithm, get_reduced_can_workspace_from_ads, get_transmission_workspaces_from_ads, write_hash_into_reduced_can_workspace) @@ -37,7 +37,7 @@ def run_initial_event_slice_reduction(reduction_alg, reduction_setting_bundle): reduction_alg.setProperty("Component", component) reduction_alg.setProperty("ScatterWorkspace", reduction_setting_bundle.scatter_workspace) reduction_alg.setProperty("ScatterMonitorWorkspace", reduction_setting_bundle.scatter_monitor_workspace) - reduction_alg.setProperty("DataType", DataType.to_string(reduction_setting_bundle.data_type)) + reduction_alg.setProperty("DataType", reduction_setting_bundle.data_type.value) reduction_alg.setProperty("OutputWorkspace", EMPTY_NAME) reduction_alg.setProperty("OutputMonitorWorkspace", EMPTY_NAME) @@ -83,7 +83,7 @@ def run_core_event_slice_reduction(reduction_alg, reduction_setting_bundle): reduction_alg.setProperty("DummyMaskWorkspace", reduction_setting_bundle.dummy_mask_workspace) reduction_alg.setProperty("ScatterMonitorWorkspace", reduction_setting_bundle.scatter_monitor_workspace) - reduction_alg.setProperty("DataType", DataType.to_string(reduction_setting_bundle.data_type)) + reduction_alg.setProperty("DataType", reduction_setting_bundle.data_type.value) reduction_alg.setProperty("OutputWorkspace", EMPTY_NAME) reduction_alg.setProperty("SumOfCounts", EMPTY_NAME) @@ -138,7 +138,7 @@ def run_core_reduction(reduction_alg, reduction_setting_bundle): reduction_alg.setProperty("Component", component) reduction_alg.setProperty("ScatterWorkspace", reduction_setting_bundle.scatter_workspace) reduction_alg.setProperty("ScatterMonitorWorkspace", reduction_setting_bundle.scatter_monitor_workspace) - reduction_alg.setProperty("DataType", DataType.to_string(reduction_setting_bundle.data_type)) + reduction_alg.setProperty("DataType", reduction_setting_bundle.data_type.value) if reduction_setting_bundle.transmission_workspace is not None: reduction_alg.setProperty("TransmissionWorkspace", reduction_setting_bundle.transmission_workspace) @@ -302,10 +302,10 @@ def get_component_to_reduce(reduction_setting_bundle): # Get the reduction mode reduction_mode = reduction_setting_bundle.reduction_mode - if reduction_mode is ISISReductionMode.LAB: - reduction_mode_setting = DetectorType.to_string(DetectorType.LAB) - elif reduction_mode is ISISReductionMode.HAB: - reduction_mode_setting = DetectorType.to_string(DetectorType.HAB) + if reduction_mode is ReductionMode.LAB: + reduction_mode_setting = DetectorType.LAB.value + elif reduction_mode is ReductionMode.HAB: + reduction_mode_setting = DetectorType.HAB.value else: raise RuntimeError("SingleExecution: An unknown reduction mode was selected: {}. " "Currently only HAB and LAB are supported.".format(reduction_mode)) @@ -384,22 +384,22 @@ def run_optimized_for_can(reduction_alg, reduction_setting_bundle, event_slice_o output_transmission_bundle.unfitted_transmission_workspace is not None: write_hash_into_reduced_can_workspace(state=output_transmission_bundle.state, workspace=output_transmission_bundle.calculated_transmission_workspace, - partial_type=TransmissionType.Calculated, + partial_type=TransmissionType.CALCULATED, reduction_mode=reduction_mode) write_hash_into_reduced_can_workspace(state=output_transmission_bundle.state, workspace=output_transmission_bundle.unfitted_transmission_workspace, - partial_type=TransmissionType.Unfitted, + partial_type=TransmissionType.UNFITTED, reduction_mode=reduction_mode) if (output_parts_bundle.output_workspace_count is not None and output_parts_bundle.output_workspace_norm is not None): write_hash_into_reduced_can_workspace(state=output_parts_bundle.state, workspace=output_parts_bundle.output_workspace_count, - partial_type=OutputParts.Count, + partial_type=OutputParts.COUNT, reduction_mode=reduction_mode) write_hash_into_reduced_can_workspace(state=output_parts_bundle.state, workspace=output_parts_bundle.output_workspace_norm, - partial_type=OutputParts.Norm, + partial_type=OutputParts.NORM, reduction_mode=reduction_mode) return output_bundle, output_parts_bundle, output_transmission_bundle diff --git a/scripts/SANS/sans/algorithm_detail/slice_sans_event.py b/scripts/SANS/sans/algorithm_detail/slice_sans_event.py index 26f4373976c633d0c2f913ac78debe618f73f37c..bfcf87ee56b025ce1318e714ff5d1fc8854dae7a 100644 --- a/scripts/SANS/sans/algorithm_detail/slice_sans_event.py +++ b/scripts/SANS/sans/algorithm_detail/slice_sans_event.py @@ -30,7 +30,7 @@ def slice_sans_event(state_slice, input_ws, input_ws_monitor, data_type_str="Sam 'OutputWorkspaceMonitor' : The output monitor workspace which has the correct slice factor applied to it. """ - data_type = DataType.from_string(data_type_str) + data_type = DataType(data_type_str) # This should be removed in the future when cycle 19/1 data is unlikely to be processed by users # This prevents time slicing falling over, since we wrap around and get -0 @@ -77,7 +77,7 @@ def _create_slice(workspace, slice_info, data_type): total_charge, total_time = get_charge_and_time(workspace) # If we are dealing with a Can reduction then the slice times are -1 - if data_type is DataType.Can: + if data_type is DataType.CAN: start_time = -1. end_time = -1. diff --git a/scripts/SANS/sans/algorithm_detail/xml_shapes.py b/scripts/SANS/sans/algorithm_detail/xml_shapes.py index 9386d9add2b72ef599bb47d4453786159d4516fe..db0e30ec17054fe2e6b46cad5acbb34aefb6a0f4 100644 --- a/scripts/SANS/sans/algorithm_detail/xml_shapes.py +++ b/scripts/SANS/sans/algorithm_detail/xml_shapes.py @@ -144,16 +144,16 @@ def quadrant_xml(centre,rmin,rmax,quadrant): xmlstring+= infinite_cylinder(centre, rmax, [0,0,1], cout_id) plane1Axis=None plane2Axis=None - if quadrant is MaskingQuadrant.Left: + if quadrant is MaskingQuadrant.LEFT: plane1Axis = [-1,1,0] plane2Axis = [-1,-1,0] - elif quadrant is MaskingQuadrant.Right: + elif quadrant is MaskingQuadrant.RIGHT: plane1Axis = [1,-1,0] plane2Axis = [1,1,0] - elif quadrant is MaskingQuadrant.Top: + elif quadrant is MaskingQuadrant.TOP: plane1Axis = [1,1,0] plane2Axis = [-1,1,0] - elif quadrant is MaskingQuadrant.Bottom: + elif quadrant is MaskingQuadrant.BOTTOM: plane1Axis = [-1,-1,0] plane2Axis = [1,-1,0] else: diff --git a/scripts/SANS/sans/command_interface/ISISCommandInterface.py b/scripts/SANS/sans/command_interface/ISISCommandInterface.py index c0c8bc7131b749fe42e5084e98c71298be61dd73..a5a41163238b5d8dfa9462cf7baeac644e867c34 100644 --- a/scripts/SANS/sans/command_interface/ISISCommandInterface.py +++ b/scripts/SANS/sans/command_interface/ISISCommandInterface.py @@ -20,7 +20,7 @@ from sans.command_interface.batch_csv_file_parser import BatchCsvParser from sans.common.constants import ALL_PERIODS from sans.common.file_information import (find_sans_file, find_full_file_path) from sans.common.enums import (DetectorType, FitType, RangeStepType, ReductionDimensionality, - ISISReductionMode, SANSFacility, SaveType, BatchReductionEntry, OutputMode, FindDirectionEnum) + ReductionMode, SANSFacility, SaveType, BatchReductionEntry, OutputMode, FindDirectionEnum) from sans.common.general_functions import (convert_bank_name_to_detector_type_isis, get_output_name, is_part_of_reduced_output_workspace_group) @@ -172,7 +172,7 @@ def AssignSample(sample_run, reload=True, period=ALL_PERIODS): file_name = find_sans_file(sample_run) # Set the command - data_command = DataCommand(command_id=DataCommandId.sample_scatter, file_name=file_name, period=period) + data_command = DataCommand(command_id=DataCommandId.SAMPLE_SCATTER, file_name=file_name, period=period) director.add_command(data_command) @@ -200,7 +200,7 @@ def AssignCan(can_run, reload=True, period=ALL_PERIODS): file_name = find_sans_file(can_run) # Set the command - data_command = DataCommand(command_id=DataCommandId.can_scatter, file_name=file_name, period=period) + data_command = DataCommand(command_id=DataCommandId.CAN_SCATTER, file_name=file_name, period=period) director.add_command(data_command) @@ -229,9 +229,9 @@ def TransmissionSample(sample, direct, reload=True, direct_file_name = find_sans_file(direct) # Set the command - trans_command = DataCommand(command_id=DataCommandId.sample_transmission, file_name=trans_file_name, + trans_command = DataCommand(command_id=DataCommandId.SAMPLE_TRANSMISSION, file_name=trans_file_name, period=period_t) - direct_command = DataCommand(command_id=DataCommandId.sample_direct, file_name=direct_file_name, period=period_d) + direct_command = DataCommand(command_id=DataCommandId.SAMPLE_DIRECT, file_name=direct_file_name, period=period_d) director.add_command(trans_command) director.add_command(direct_command) @@ -259,8 +259,8 @@ def TransmissionCan(can, direct, reload=True, period_t=-1, period_d=-1): direct_file_name = find_sans_file(direct) # Set the command - trans_command = DataCommand(command_id=DataCommandId.can_transmission, file_name=trans_file_name, period=period_t) - direct_command = DataCommand(command_id=DataCommandId.can_direct, file_name=direct_file_name, period=period_d) + trans_command = DataCommand(command_id=DataCommandId.CAN_TRANSMISSION, file_name=trans_file_name, period=period_t) + direct_command = DataCommand(command_id=DataCommandId.CAN_DIRECT, file_name=direct_file_name, period=period_d) director.add_command(trans_command) director.add_command(direct_command) @@ -277,7 +277,7 @@ def Clean(): """ Removes all previous settings. """ - clean_command = NParameterCommand(command_id=NParameterCommandId.clean, values=[]) + clean_command = NParameterCommand(command_id=NParameterCommandId.CLEAN, values=[]) director.add_command(clean_command) @@ -286,8 +286,8 @@ def Set1D(): Sets the reduction dimensionality to 1D """ print_message('Set1D()') - set_1d_command = NParameterCommand(command_id=NParameterCommandId.reduction_dimensionality, - values=[ReductionDimensionality.OneDim]) + set_1d_command = NParameterCommand(command_id=NParameterCommandId.REDUCTION_DIMENSIONALITY, + values=[ReductionDimensionality.ONE_DIM]) director.add_command(set_1d_command) @@ -296,8 +296,8 @@ def Set2D(): Sets the reduction dimensionality to 2D """ print_message('Set2D()') - set_2d_command = NParameterCommand(command_id=NParameterCommandId.reduction_dimensionality, - values=[ReductionDimensionality.TwoDim]) + set_2d_command = NParameterCommand(command_id=NParameterCommandId.REDUCTION_DIMENSIONALITY, + values=[ReductionDimensionality.TWO_DIM]) director.add_command(set_2d_command) @@ -305,7 +305,7 @@ def UseCompatibilityMode(): """ Sets the compatibility mode to True """ - set_2d_command = NParameterCommand(command_id=NParameterCommandId.compatibility_mode, + set_2d_command = NParameterCommand(command_id=NParameterCommandId.COMPATIBILITY_MODE, values=[True]) director.add_command(set_2d_command) @@ -323,7 +323,7 @@ def MaskFile(file_name): # Get the full file path file_name_full = find_full_file_path(file_name) - user_file_command = NParameterCommand(command_id=NParameterCommandId.user_file, values=[file_name_full]) + user_file_command = NParameterCommand(command_id=NParameterCommandId.USER_FILE, values=[file_name_full]) director.add_command(user_file_command) @@ -334,7 +334,7 @@ def Mask(details): @param details: a string that specifies masking as it would appear in a mask file """ print_message('Mask("' + details + '")') - mask_command = NParameterCommand(command_id=NParameterCommandId.mask, values=[details]) + mask_command = NParameterCommand(command_id=NParameterCommandId.MASK, values=[details]) director.add_command(mask_command) @@ -345,7 +345,7 @@ def SetSampleOffset(value): @param value: the offset in mm """ value = float(value) - sample_offset_command = NParameterCommand(command_id=NParameterCommandId.sample_offset, values=[value]) + sample_offset_command = NParameterCommand(command_id=NParameterCommandId.SAMPLE_OFFSET, values=[value]) director.add_command(sample_offset_command) @@ -359,8 +359,8 @@ def Detector(det_name): """ print_message('Detector("' + det_name + '")') detector_type = convert_bank_name_to_detector_type_isis(det_name) - reduction_mode = ISISReductionMode.HAB if detector_type is DetectorType.HAB else ISISReductionMode.LAB - detector_command = NParameterCommand(command_id=NParameterCommandId.detector, values=[reduction_mode]) + reduction_mode = ReductionMode.HAB if detector_type is DetectorType.HAB else ReductionMode.LAB + detector_command = NParameterCommand(command_id=NParameterCommandId.DETECTOR, values=[reduction_mode]) director.add_command(detector_command) @@ -368,7 +368,7 @@ def SetEventSlices(input_str): """ Sets the events slices """ - event_slices_command = NParameterCommand(command_id=NParameterCommandId.event_slices, values=input_str) + event_slices_command = NParameterCommand(command_id=NParameterCommandId.EVENT_SLICES, values=input_str) director.add_command(event_slices_command) @@ -384,7 +384,7 @@ def SetMonitorSpectrum(specNum, interp=False): """ specNum = int(specNum) is_trans = False - monitor_spectrum_command = NParameterCommand(command_id=NParameterCommandId.incident_spectrum, values=[specNum, + monitor_spectrum_command = NParameterCommand(command_id=NParameterCommandId.INCIDENT_SPECTRUM, values=[specNum, interp, is_trans]) director.add_command(monitor_spectrum_command) @@ -400,7 +400,7 @@ def SetTransSpectrum(specNum, interp=False): """ specNum = int(specNum) is_trans = True - transmission_spectrum_command = NParameterCommand(command_id=NParameterCommandId.incident_spectrum, + transmission_spectrum_command = NParameterCommand(command_id=NParameterCommandId.INCIDENT_SPECTRUM, values=[specNum, interp, is_trans]) director.add_command(transmission_spectrum_command) @@ -414,7 +414,7 @@ def Gravity(flag, extra_length=0.0): """ extra_length = float(extra_length) print_message('Gravity(' + str(flag) + ', ' + str(extra_length) + ')') - gravity_command = NParameterCommand(command_id=NParameterCommandId.gravity, values=[flag, extra_length]) + gravity_command = NParameterCommand(command_id=NParameterCommandId.GRAVITY, values=[flag, extra_length]) director.add_command(gravity_command) @@ -427,7 +427,7 @@ def SetDetectorFloodFile(filename, detector_name="REAR"): """ file_name = find_full_file_path(filename) detector_name = convert_bank_name_to_detector_type_isis(detector_name) - flood_command = NParameterCommand(command_id=NParameterCommandId.flood_file, values=[file_name, detector_name]) + flood_command = NParameterCommand(command_id=NParameterCommandId.FLOOD_FILE, values=[file_name, detector_name]) director.add_command(flood_command) @@ -445,7 +445,7 @@ def SetCorrectionFile(bank, filename): print_message("SetCorrectionFile(" + str(bank) + ', ' + filename + ')') detector_type = convert_bank_name_to_detector_type_isis(bank) file_name = find_full_file_path(filename) - flood_command = NParameterCommand(command_id=NParameterCommandId.wavelength_correction_file, + flood_command = NParameterCommand(command_id=NParameterCommandId.WAVELENGTH_CORRECTION_FILE, values=[file_name, detector_type]) director.add_command(flood_command) @@ -468,7 +468,7 @@ def SetCentre(xcoord, ycoord, bank='rear'): ycoord = float(ycoord) print_message('SetCentre(' + str(xcoord) + ', ' + str(ycoord) + ')') detector_type = convert_bank_name_to_detector_type_isis(bank) - centre_command = NParameterCommand(command_id=NParameterCommandId.centre, values=[xcoord, ycoord, detector_type]) + centre_command = NParameterCommand(command_id=NParameterCommandId.CENTRE, values=[xcoord, ycoord, detector_type]) director.add_command(centre_command) @@ -486,7 +486,7 @@ def SetPhiLimit(phimin, phimax, use_mirror=True): # a beam centre of [0,0,0] makes sense if the detector has been moved such that beam centre is at [0,0,0] phimin = float(phimin) phimax = float(phimax) - centre_command = NParameterCommand(command_id=NParameterCommandId.phi_limit, values=[phimin, phimax, use_mirror]) + centre_command = NParameterCommand(command_id=NParameterCommandId.PHI_LIMIT, values=[phimin, phimax, use_mirror]) director.add_command(centre_command) @@ -497,7 +497,7 @@ def set_save(save_algorithms, save_as_zero_error_free): @param save_algorithms: A list of SaveType enums. @param save_as_zero_error_free: True if a zero error correction should be performed. """ - save_command = NParameterCommand(command_id=NParameterCommandId.save, values=[save_algorithms, + save_command = NParameterCommand(command_id=NParameterCommandId.SAVE, values=[save_algorithms, save_as_zero_error_free]) director.add_command(save_command) @@ -531,14 +531,14 @@ def TransFit(mode, lambdamin=None, lambdamax=None, selector='BOTH'): mode = str(mode).strip().upper() if mode == "LINEAR" or mode == "STRAIGHT" or mode == "LIN": - fit_type = FitType.Linear + fit_type = FitType.LINEAR elif mode == "LOGARITHMIC" or mode == "LOG" or mode == "YLOG": - fit_type = FitType.Logarithmic + fit_type = FitType.LOGARITHMIC elif does_pattern_match(polynomial_pattern, mode): - fit_type = FitType.Polynomial + fit_type = FitType.POLYNOMIAL polynomial_order = extract_polynomial_order(mode) else: - fit_type = FitType.NoFit + fit_type = FitType.NO_FIT # Get the selected detector to which the fit settings apply selector = str(selector).strip().upper() @@ -565,7 +565,7 @@ def TransFit(mode, lambdamin=None, lambdamax=None, selector='BOTH'): # Configure fit settings polynomial_order = polynomial_order if polynomial_order is not None else 0 - fit_command = NParameterCommand(command_id=NParameterCommandId.centre, values=[fit_data, lambdamin, lambdamax, + fit_command = NParameterCommand(command_id=NParameterCommandId.CENTRE, values=[fit_data, lambdamin, lambdamax, fit_type, polynomial_order]) director.add_command(fit_command) @@ -586,7 +586,7 @@ def LimitsR(rmin, rmax, quiet=False, reducer=None): print_message('LimitsR(' + str(rmin) + ', ' + str(rmax) + ')', reducer) rmin /= 1000. rmax /= 1000. - radius_command = NParameterCommand(command_id=NParameterCommandId.mask_radius, values=[rmin, rmax]) + radius_command = NParameterCommand(command_id=NParameterCommandId.MASK_RADIUS, values=[rmin, rmax]) director.add_command(radius_command) @@ -606,9 +606,9 @@ def LimitsWav(lmin, lmax, step, bin_type): print_message('LimitsWav(' + str(lmin) + ', ' + str(lmax) + ', ' + str(step) + ', ' + bin_type + ')') rebin_string = bin_type.strip().upper() - rebin_type = RangeStepType.Log if rebin_string == "LOGARITHMIC" else RangeStepType.Lin + rebin_type = RangeStepType.LOG if rebin_string == "LOGARITHMIC" else RangeStepType.LIN - wavelength_command = NParameterCommand(command_id=NParameterCommandId.wavelength_limit, + wavelength_command = NParameterCommand(command_id=NParameterCommandId.WAVELENGTH_LIMIT, values=[lmin, lmax, step, rebin_type]) director.add_command(wavelength_command) @@ -628,10 +628,10 @@ def LimitsQXY(qmin, qmax, step, type): print_message('LimitsQXY(' + str(qmin) + ', ' + str(qmax) + ', ' + str(step) + ', ' + str(type) + ')') step_type_string = type.strip().upper() if step_type_string == "LOGARITHMIC" or step_type_string == "LOG": - step_type = RangeStepType.Log + step_type = RangeStepType.LOG else: - step_type = RangeStepType.Lin - qxy_command = NParameterCommand(command_id=NParameterCommandId.qxy_limit, values=[qmin, qmax, step, step_type]) + step_type = RangeStepType.LIN + qxy_command = NParameterCommand(command_id=NParameterCommandId.QXY_LIMIT, values=[qmin, qmax, step, step_type]) director.add_command(qxy_command) @@ -660,7 +660,7 @@ def SetFrontDetRescaleShift(scale=1.0, shift=0.0, fitScale=False, fitShift=False qMax = float(qMax) print_message('Set front detector rescale/shift values to {0} and {1}'.format(scale, shift)) - front_command = NParameterCommand(command_id=NParameterCommandId.front_detector_rescale, values=[scale, shift, + front_command = NParameterCommand(command_id=NParameterCommandId.FRONT_DETECTOR_RESCALE, values=[scale, shift, fitScale, fitShift, qMin, qMax]) director.add_command(front_command) @@ -702,7 +702,7 @@ def SetDetectorOffsets(bank, x, y, z, rot, radius, side, xtilt=0.0, ytilt=0.0): + ',' + str(y) + ',' + str(z) + ',' + str(rot) + ',' + str(radius) + ',' + str(side) + ',' + str(xtilt) + ',' + str(ytilt) + ')') detector_type = convert_bank_name_to_detector_type_isis(bank) - detector_offsets = NParameterCommand(command_id=NParameterCommandId.detector_offsets, values=[detector_type, + detector_offsets = NParameterCommand(command_id=NParameterCommandId.DETECTOR_OFFSETS, values=[detector_type, x, y, z, rot, radius, side, xtilt, ytilt]) @@ -713,7 +713,7 @@ def SetDetectorOffsets(bank, x, y, z, rot, radius, side, xtilt=0.0, ytilt=0.0): # Commands which actually kick off a reduction # -------------------------------------------- def WavRangeReduction(wav_start=None, wav_end=None, full_trans_wav=None, name_suffix=None, combineDet=None, - resetSetup=True, out_fit_settings=None, output_name=None, output_mode=OutputMode.PublishToADS, + resetSetup=True, out_fit_settings=None, output_name=None, output_mode=OutputMode.PUBLISH_TO_ADS, use_reduction_mode_as_suffix=False): """ Run reduction from loading the raw data to calculating Q. Its optional arguments allows specifics @@ -753,13 +753,13 @@ def WavRangeReduction(wav_start=None, wav_end=None, full_trans_wav=None, name_su if combineDet is None: reduction_mode = None elif combineDet == 'rear': - reduction_mode = ISISReductionMode.LAB + reduction_mode = ReductionMode.LAB elif combineDet == "front": - reduction_mode = ISISReductionMode.HAB + reduction_mode = ReductionMode.HAB elif combineDet == "merged": - reduction_mode = ISISReductionMode.Merged + reduction_mode = ReductionMode.MERGED elif combineDet == "both": - reduction_mode = ISISReductionMode.All + reduction_mode = ReductionMode.ALL else: raise RuntimeError("WavRangeReduction: The combineDet input parameter was given a value of {0}. rear, front," " both, merged and no input are allowed".format(combineDet)) @@ -770,19 +770,19 @@ def WavRangeReduction(wav_start=None, wav_end=None, full_trans_wav=None, name_su if wav_end is not None: wav_end = float(wav_end) - wavelength_command = NParameterCommand(command_id=NParameterCommandId.wavrange_settings, + wavelength_command = NParameterCommand(command_id=NParameterCommandId.WAV_RANGE_SETTINGS, values=[wav_start, wav_end, full_trans_wav, reduction_mode]) director.add_command(wavelength_command) # Save options if output_name is not None: - director.add_command(NParameterCommand(command_id=NParameterCommandId.user_specified_output_name, + director.add_command(NParameterCommand(command_id=NParameterCommandId.USER_SPECIFIED_OUTPUT_NAME, values=[output_name])) if name_suffix is not None: - director.add_command(NParameterCommand(command_id=NParameterCommandId.user_specified_output_name_suffix, + director.add_command(NParameterCommand(command_id=NParameterCommandId.USER_SPECIFIED_OUTPUT_NAME_SUFFIX, values=[name_suffix])) if use_reduction_mode_as_suffix: - director.add_command(NParameterCommand(command_id=NParameterCommandId.use_reduction_mode_as_suffix, + director.add_command(NParameterCommand(command_id=NParameterCommandId.USE_REDUCTION_MODE_AS_SUFFIX, values=[use_reduction_mode_as_suffix])) # Get the states @@ -797,11 +797,11 @@ def WavRangeReduction(wav_start=None, wav_end=None, full_trans_wav=None, name_su # ----------------------------------------------------------- reduction_mode = state.reduction.reduction_mode is_group = is_part_of_reduced_output_workspace_group(state) - if reduction_mode != ISISReductionMode.All: + if reduction_mode != ReductionMode.ALL: _, output_workspace_base_name = get_output_name(state, reduction_mode, is_group) else: - _, output_workspace_base_name_hab = get_output_name(state, ISISReductionMode.HAB, is_group) - _, output_workspace_base_name_lab = get_output_name(state, ISISReductionMode.LAB, is_group) + _, output_workspace_base_name_hab = get_output_name(state, ReductionMode.HAB, is_group) + _, output_workspace_base_name_lab = get_output_name(state, ReductionMode.LAB, is_group) output_workspace_base_name = [output_workspace_base_name_lab, output_workspace_base_name_hab] return output_workspace_base_name @@ -842,20 +842,20 @@ def BatchReduce(filename, format, plotresults=False, saveAlgs=None, verbose=Fals if key == "SaveRKH": save_algs.append(SaveType.RKH) elif key == "SaveNexus": - save_algs.append(SaveType.Nexus) + save_algs.append(SaveType.NEXUS) elif key == "SaveNistQxy": - save_algs.append(SaveType.NistQxy) + save_algs.append(SaveType.NIST_QXY) elif key == "SaveCanSAS" or key == "SaveCanSAS1D": - save_algs.append(SaveType.CanSAS) + save_algs.append(SaveType.CAN_SAS) elif key == "SaveCSV": save_algs.append(SaveType.CSV) elif key == "SaveNXcanSAS": - save_algs.append(SaveType.NXcanSAS) + save_algs.append(SaveType.NX_CAN_SAS) else: raise RuntimeError("The save format {0} is not known.".format(key)) - output_mode = OutputMode.Both + output_mode = OutputMode.BOTH else: - output_mode = OutputMode.PublishToADS + output_mode = OutputMode.PUBLISH_TO_ADS # Get the information from the csv file batch_csv_parser = BatchCsvParser(filename) @@ -865,43 +865,43 @@ def BatchReduce(filename, format, plotresults=False, saveAlgs=None, verbose=Fals for parsed_batch_entry in parsed_batch_entries: # A new user file. If a new user file is provided then this will overwrite all other settings from, # otherwise we might have cross-talk between user files. - if BatchReductionEntry.UserFile in list(parsed_batch_entry.keys()): - user_file = parsed_batch_entry[BatchReductionEntry.UserFile] + if BatchReductionEntry.USER_FILE in list(parsed_batch_entry.keys()): + user_file = parsed_batch_entry[BatchReductionEntry.USER_FILE] MaskFile(user_file) # Sample scatter - sample_scatter = parsed_batch_entry[BatchReductionEntry.SampleScatter] - sample_scatter_period = parsed_batch_entry[BatchReductionEntry.SampleScatterPeriod] + sample_scatter = parsed_batch_entry[BatchReductionEntry.SAMPLE_SCATTER] + sample_scatter_period = parsed_batch_entry[BatchReductionEntry.SAMPLE_SCATTER_PERIOD] AssignSample(sample_run=sample_scatter, period=sample_scatter_period) # Sample transmission - if (BatchReductionEntry.SampleTransmission in list(parsed_batch_entry.keys()) and - BatchReductionEntry.SampleDirect in list(parsed_batch_entry.keys())): - sample_transmission = parsed_batch_entry[BatchReductionEntry.SampleTransmission] - sample_transmission_period = parsed_batch_entry[BatchReductionEntry.SampleTransmissionPeriod] - sample_direct = parsed_batch_entry[BatchReductionEntry.SampleDirect] - sample_direct_period = parsed_batch_entry[BatchReductionEntry.SampleDirectPeriod] + if (BatchReductionEntry.SAMPLE_TRANSMISSION in list(parsed_batch_entry.keys()) and + BatchReductionEntry.SAMPLE_DIRECT in list(parsed_batch_entry.keys())): + sample_transmission = parsed_batch_entry[BatchReductionEntry.SAMPLE_TRANSMISSION] + sample_transmission_period = parsed_batch_entry[BatchReductionEntry.SAMPLE_TRANSMISSION_PERIOD] + sample_direct = parsed_batch_entry[BatchReductionEntry.SAMPLE_DIRECT] + sample_direct_period = parsed_batch_entry[BatchReductionEntry.SAMPLE_DIRECT_PERIOD] TransmissionSample(sample=sample_transmission, direct=sample_direct, period_t=sample_transmission_period, period_d=sample_direct_period) # Can scatter - if BatchReductionEntry.CanScatter in list(parsed_batch_entry.keys()): - can_scatter = parsed_batch_entry[BatchReductionEntry.CanScatter] - can_scatter_period = parsed_batch_entry[BatchReductionEntry.CanScatterPeriod] + if BatchReductionEntry.CAN_SCATTER in list(parsed_batch_entry.keys()): + can_scatter = parsed_batch_entry[BatchReductionEntry.CAN_SCATTER] + can_scatter_period = parsed_batch_entry[BatchReductionEntry.CAN_SCATTER_PERIOD] AssignCan(can_run=can_scatter, period=can_scatter_period) # Can transmission - if (BatchReductionEntry.CanTransmission in list(parsed_batch_entry.keys()) and - BatchReductionEntry.CanDirect in list(parsed_batch_entry.keys())): - can_transmission = parsed_batch_entry[BatchReductionEntry.CanTransmission] - can_transmission_period = parsed_batch_entry[BatchReductionEntry.CanTransmissionPeriod] - can_direct = parsed_batch_entry[BatchReductionEntry.CanDirect] - can_direct_period = parsed_batch_entry[BatchReductionEntry.CanDirectPeriod] + if (BatchReductionEntry.CAN_TRANSMISSION in list(parsed_batch_entry.keys()) and + BatchReductionEntry.CAN_DIRECT in list(parsed_batch_entry.keys())): + can_transmission = parsed_batch_entry[BatchReductionEntry.CAN_TRANSMISSION] + can_transmission_period = parsed_batch_entry[BatchReductionEntry.CAN_TRANSMISSION_PERIOD] + can_direct = parsed_batch_entry[BatchReductionEntry.CAN_DIRECT] + can_direct_period = parsed_batch_entry[BatchReductionEntry.CAN_DIRECT_PERIOD] TransmissionCan(can=can_transmission, direct=can_direct, period_t=can_transmission_period, period_d=can_direct_period) # Name of the output. We need to modify the name according to the setup of the old reduction mechanism - output_name = parsed_batch_entry[BatchReductionEntry.Output] + output_name = parsed_batch_entry[BatchReductionEntry.OUTPUT] # In addition to the output name the user can specify with combineDet an additional suffix (in addition to the # suffix that the user can set already -- was there previously, so we have to provide that) @@ -922,19 +922,19 @@ def BatchReduce(filename, format, plotresults=False, saveAlgs=None, verbose=Fals # 3. The last scatter transmission and direct entry (if any were set) # 4. The last can scatter ( if any was set) # 5. The last can transmission and direct entry (if any were set) - if BatchReductionEntry.UserFile in list(parsed_batch_entry.keys()): + if BatchReductionEntry.USER_FILE in list(parsed_batch_entry.keys()): director.remove_last_user_file() director.remove_last_scatter_sample() - if (BatchReductionEntry.SampleTransmission in list(parsed_batch_entry.keys()) and - BatchReductionEntry.SampleDirect in list(parsed_batch_entry.keys())): # noqa + if (BatchReductionEntry.SAMPLE_TRANSMISSION in list(parsed_batch_entry.keys()) and + BatchReductionEntry.SAMPLE_DIRECT in list(parsed_batch_entry.keys())): # noqa director.remove_last_sample_transmission_and_direct() - if BatchReductionEntry.CanScatter in list(parsed_batch_entry.keys()): + if BatchReductionEntry.CAN_SCATTER in list(parsed_batch_entry.keys()): director.remove_last_scatter_can() - if (BatchReductionEntry.CanTransmission in list(parsed_batch_entry.keys()) and - BatchReductionEntry.CanDirect in list(parsed_batch_entry.keys())): + if (BatchReductionEntry.CAN_TRANSMISSION in list(parsed_batch_entry.keys()) and + BatchReductionEntry.CAN_DIRECT in list(parsed_batch_entry.keys())): director.remove_last_can_transmission_and_direct() # Plot the results if that was requested, the flag 1 is from the old version. @@ -1017,7 +1017,7 @@ def PhiRanges(phis, plot=True): def FindBeamCentre(rlow, rupp, MaxIter=10, xstart=None, ystart=None, tolerance=1.251e-4, - find_direction=FindDirectionEnum.All, reduction_method=True): + find_direction=FindDirectionEnum.ALL, reduction_method=True): state = director.process_commands() """ Finds the beam centre position. diff --git a/scripts/SANS/sans/command_interface/batch_csv_file_parser.py b/scripts/SANS/sans/command_interface/batch_csv_file_parser.py index d5c5c37faedae7485c56ff596b70432f04ce27fb..78d3c5a4226d30944eedfcb9440ae1ff3f46c479 100644 --- a/scripts/SANS/sans/command_interface/batch_csv_file_parser.py +++ b/scripts/SANS/sans/command_interface/batch_csv_file_parser.py @@ -13,28 +13,28 @@ from sans.common.constants import ALL_PERIODS class BatchCsvParser(object): - batch_file_keywords = {"sample_sans": BatchReductionEntry.SampleScatter, - "output_as": BatchReductionEntry.Output, - "sample_trans": BatchReductionEntry.SampleTransmission, - "sample_direct_beam": BatchReductionEntry.SampleDirect, - "can_sans": BatchReductionEntry.CanScatter, - "can_trans": BatchReductionEntry.CanTransmission, - "can_direct_beam": BatchReductionEntry.CanDirect, - "user_file": BatchReductionEntry.UserFile, - "sample_thickness": BatchReductionEntry.SampleThickness, - "sample_height": BatchReductionEntry.SampleHeight, - "sample_width": BatchReductionEntry.SampleWidth} + batch_file_keywords = {"sample_sans": BatchReductionEntry.SAMPLE_SCATTER, + "output_as": BatchReductionEntry.OUTPUT, + "sample_trans": BatchReductionEntry.SAMPLE_TRANSMISSION, + "sample_direct_beam": BatchReductionEntry.SAMPLE_DIRECT, + "can_sans": BatchReductionEntry.CAN_SCATTER, + "can_trans": BatchReductionEntry.CAN_TRANSMISSION, + "can_direct_beam": BatchReductionEntry.CAN_DIRECT, + "user_file": BatchReductionEntry.USER_FILE, + "sample_thickness": BatchReductionEntry.SAMPLE_THICKNESS, + "sample_height": BatchReductionEntry.SAMPLE_HEIGHT, + "sample_width": BatchReductionEntry.SAMPLE_WIDTH} batch_file_keywords_which_are_dropped = {"background_sans": None, "background_trans": None, "background_direct_beam": None, "": None} - data_keys = {BatchReductionEntry.SampleScatter: BatchReductionEntry.SampleScatterPeriod, - BatchReductionEntry.SampleTransmission: BatchReductionEntry.SampleTransmissionPeriod, - BatchReductionEntry.SampleDirect: BatchReductionEntry.SampleDirectPeriod, - BatchReductionEntry.CanScatter: BatchReductionEntry.CanScatterPeriod, - BatchReductionEntry.CanTransmission: BatchReductionEntry.CanTransmissionPeriod, - BatchReductionEntry.CanDirect: BatchReductionEntry.CanDirectPeriod} + data_keys = {BatchReductionEntry.SAMPLE_SCATTER: BatchReductionEntry.SAMPLE_SCATTER_PERIOD, + BatchReductionEntry.SAMPLE_TRANSMISSION: BatchReductionEntry.SAMPLE_TRANSMISSION_PERIOD, + BatchReductionEntry.SAMPLE_DIRECT: BatchReductionEntry.SAMPLE_DIRECT_PERIOD, + BatchReductionEntry.CAN_SCATTER: BatchReductionEntry.CAN_SCATTER_PERIOD, + BatchReductionEntry.CAN_TRANSMISSION: BatchReductionEntry.CAN_TRANSMISSION_PERIOD, + BatchReductionEntry.CAN_DIRECT: BatchReductionEntry.CAN_DIRECT_PERIOD} def __init__(self, batch_file_name): super(BatchCsvParser, self).__init__() @@ -104,13 +104,13 @@ class BatchCsvParser(object): raise RuntimeError("The key {0} is not part of the SANS batch csv file keywords".format(key)) # Ensure that sample_scatter was set - if BatchReductionEntry.SampleScatter not in output or not output[BatchReductionEntry.SampleScatter]: + if BatchReductionEntry.SAMPLE_SCATTER not in output or not output[BatchReductionEntry.SAMPLE_SCATTER]: raise RuntimeError("The sample_scatter entry in row {0} seems to be missing.".format(row_number)) # Ensure that the transmission data for the sample is specified either completely or not at all. - has_sample_transmission = BatchReductionEntry.SampleTransmission in output and \ - output[BatchReductionEntry.SampleTransmission] # noqa - has_sample_direct_beam = BatchReductionEntry.SampleDirect in output and output[BatchReductionEntry.SampleDirect] + has_sample_transmission = BatchReductionEntry.SAMPLE_TRANSMISSION in output and \ + output[BatchReductionEntry.SAMPLE_TRANSMISSION] # noqa + has_sample_direct_beam = BatchReductionEntry.SAMPLE_DIRECT in output and output[BatchReductionEntry.SAMPLE_DIRECT] if (not has_sample_transmission and has_sample_direct_beam) or \ (has_sample_transmission and not has_sample_direct_beam): @@ -118,9 +118,9 @@ class BatchCsvParser(object): "and the direct beam run are set or none.".format(row_number)) # Ensure that the transmission data for the can is specified either completely or not at all. - has_can_transmission = BatchReductionEntry.CanTransmission in output and \ - output[BatchReductionEntry.CanTransmission] # noqa - has_can_direct_beam = BatchReductionEntry.CanDirect in output and output[BatchReductionEntry.CanDirect] + has_can_transmission = BatchReductionEntry.CAN_TRANSMISSION in output and \ + output[BatchReductionEntry.CAN_TRANSMISSION] # noqa + has_can_direct_beam = BatchReductionEntry.CAN_DIRECT in output and output[BatchReductionEntry.CAN_DIRECT] if (not has_can_transmission and has_can_direct_beam) or \ (has_can_transmission and not has_can_direct_beam): @@ -128,7 +128,7 @@ class BatchCsvParser(object): "and the direct beam run are set or none.".format(row_number)) # Ensure that can scatter is specified if the transmissions are set - has_can_scatter = BatchReductionEntry.CanScatter in output and output[BatchReductionEntry.CanScatter] + has_can_scatter = BatchReductionEntry.CAN_SCATTER in output and output[BatchReductionEntry.CAN_SCATTER] if not has_can_scatter and has_can_transmission: raise RuntimeError("The can transmission was set but not the scatter file in row {0}.".format(row_number)) return output diff --git a/scripts/SANS/sans/command_interface/command_interface_state_director.py b/scripts/SANS/sans/command_interface/command_interface_state_director.py index 6f3417ab680657668419f9ceb2438c082091b0e2..b7daed509497e24442a7153e4739bb8a776b4a76 100644 --- a/scripts/SANS/sans/command_interface/command_interface_state_director.py +++ b/scripts/SANS/sans/command_interface/command_interface_state_director.py @@ -1,20 +1,23 @@ # Mantid Repository : https://github.com/mantidproject/mantid # -# Copyright © 2018 ISIS Rutherford Appleton Laboratory UKRI, +# Copyright © 2019 ISIS Rutherford Appleton Laboratory UKRI, # NScD Oak Ridge National Laboratory, European Spallation Source # & Institut Laue - Langevin # SPDX - License - Identifier: GPL - 3.0 + from __future__ import (absolute_import, division, print_function) -from sans.common.enums import (serializable_enum, DataType) -from sans.user_file.state_director import StateDirectorISIS + +from mantid.py3compat import Enum +from sans.common.enums import DataType +from sans.common.file_information import SANSFileInformationFactory from sans.state.data import get_data_builder -from sans.user_file.user_file_parser import (UserFileParser) -from sans.user_file.user_file_reader import (UserFileReader) from sans.user_file.settings_tags import (MonId, monitor_spectrum, OtherId, SampleId, GravityId, SetId, position_entry, fit_general, FitId, monitor_file, mask_angle_entry, LimitsId, range_entry, simple_range, DetectorId, event_binning_string_values, det_fit_range, single_entry_with_detector) -from sans.common.file_information import SANSFileInformationFactory +from sans.user_file.state_director import StateDirectorISIS +from sans.user_file.user_file_parser import (UserFileParser) +from sans.user_file.user_file_reader import (UserFileReader) + # ---------------------------------------------------------------------------------------------------------------------- # Commands @@ -24,26 +27,42 @@ from sans.common.file_information import SANSFileInformationFactory # ------------------ # IDs for commands. We use here serializable_enum since enum is not available in the current Python configuration. # ------------------ -@serializable_enum("sample_scatter", "sample_transmission", "sample_direct", "can_scatter", "can_transmission", - "can_direct") -class DataCommandId(object): - pass - - -@serializable_enum("clean", "reduction_dimensionality", "compatibility_mode", # Null Parameter commands - "user_file", "mask", "sample_offset", "detector", "event_slices", # Single parameter commands - "flood_file", "wavelength_correction_file", # Single parameter commands - "user_specified_output_name", "user_specified_output_name_suffix", # Single parameter commands - "use_reduction_mode_as_suffix", # Single parameter commands - "incident_spectrum", "gravity", # Double parameter commands - "centre", "save", # Three parameter commands - "trans_fit", "phi_limit", "mask_radius", "wavelength_limit", "qxy_limit", # Four parameter commands - "wavrange_settings", # Five parameter commands - "front_detector_rescale", # Six parameter commands - "detector_offsets" # Nine parameter commands - ) +class DataCommandId(Enum): + CAN_DIRECT = "can_direct" + CAN_SCATTER = "can_scatter" + CAN_TRANSMISSION = "can_transmission" + + SAMPLE_SCATTER = "sample_scatter" + SAMPLE_TRANSMISSION = "sample_transmission" + SAMPLE_DIRECT = "sample_direct" + + class NParameterCommandId(object): - pass + CENTRE = "centre" + CLEAN = "clean" + COMPATIBILITY_MODE = "compatibility_mode" + DETECTOR = "detector" + DETECTOR_OFFSETS = "detector_offsets" + EVENT_SLICES = "event_slices" + FLOOD_FILE = "flood_file" + FRONT_DETECTOR_RESCALE = "front_detector_rescale" + GRAVITY = "gravity" + INCIDENT_SPECTRUM = "incident_spectrum" + MASK = "mask" + MASK_RADIUS = "mask_radius" + REDUCTION_DIMENSIONALITY = "reduction_dimensionality" + PHI_LIMIT = "phi_limit" + QXY_LIMIT = "qxy_limit" + SAMPLE_OFFSET = "sample_offset" + SAVE = "save" + TRANS_FIT = "trans_fit" + USER_FILE = "user_file" + USE_REDUCTION_MODE_AS_SUFFIX = "use_reduction_mode_as_suffix" + USER_SPECIFIED_OUTPUT_NAME = "user_specified_output_name" + USER_SPECIFIED_OUTPUT_NAME_SUFFIX = "user_specified_output_name_suffix" + WAVELENGTH_CORRECTION_FILE = "wavelength_correction_file" + WAVELENGTH_LIMIT = "wavelength_limit" + WAV_RANGE_SETTINGS = "wavrange_settings" class Command(object): @@ -147,7 +166,7 @@ class CommandInterfaceStateDirector(object): def _get_data_state(self): # Get the data commands data_commands = self._get_data_commands() - data_elements = self._get_elements_with_key(DataCommandId.sample_scatter, data_commands) + data_elements = self._get_elements_with_key(DataCommandId.SAMPLE_SCATTER, data_commands) data_element = data_elements[-1] file_name = data_element.file_name file_information_factory = SANSFileInformationFactory() @@ -156,17 +175,17 @@ class CommandInterfaceStateDirector(object): # Build the state data data_builder = get_data_builder(self._facility, file_information) self._set_data_element(data_builder.set_sample_scatter, data_builder.set_sample_scatter_period, - DataCommandId.sample_scatter, data_commands) + DataCommandId.SAMPLE_SCATTER, data_commands) self._set_data_element(data_builder.set_sample_transmission, data_builder.set_sample_transmission_period, - DataCommandId.sample_transmission, data_commands) + DataCommandId.SAMPLE_TRANSMISSION, data_commands) self._set_data_element(data_builder.set_sample_direct, data_builder.set_sample_direct_period, - DataCommandId.sample_direct, data_commands) + DataCommandId.SAMPLE_DIRECT, data_commands) self._set_data_element(data_builder.set_can_scatter, data_builder.set_can_scatter_period, - DataCommandId.can_scatter, data_commands) + DataCommandId.CAN_SCATTER, data_commands) self._set_data_element(data_builder.set_can_transmission, data_builder.set_can_transmission_period, - DataCommandId.can_transmission, data_commands) + DataCommandId.CAN_TRANSMISSION, data_commands) self._set_data_element(data_builder.set_can_direct, data_builder.set_can_direct_period, - DataCommandId.can_direct, data_commands) + DataCommandId.CAN_DIRECT, data_commands) return data_builder.build() @@ -253,32 +272,32 @@ class CommandInterfaceStateDirector(object): """ Sets up a mapping between command ids and the adequate processing methods which can handle the command. """ - self._method_map = {NParameterCommandId.user_file: self._process_user_file, - NParameterCommandId.mask: self._process_mask, - NParameterCommandId.incident_spectrum: self._process_incident_spectrum, - NParameterCommandId.clean: self._process_clean, - NParameterCommandId.reduction_dimensionality: self._process_reduction_dimensionality, - NParameterCommandId.sample_offset: self._process_sample_offset, - NParameterCommandId.detector: self._process_detector, - NParameterCommandId.gravity: self._process_gravity, - NParameterCommandId.centre: self._process_centre, - NParameterCommandId.trans_fit: self._process_trans_fit, - NParameterCommandId.front_detector_rescale: self._process_front_detector_rescale, - NParameterCommandId.event_slices: self._process_event_slices, - NParameterCommandId.flood_file: self._process_flood_file, - NParameterCommandId.phi_limit: self._process_phi_limit, - NParameterCommandId.wavelength_correction_file: self._process_wavelength_correction_file, - NParameterCommandId.mask_radius: self._process_mask_radius, - NParameterCommandId.wavelength_limit: self._process_wavelength_limit, - NParameterCommandId.qxy_limit: self._process_qxy_limit, - NParameterCommandId.wavrange_settings: self._process_wavrange, - NParameterCommandId.compatibility_mode: self._process_compatibility_mode, - NParameterCommandId.detector_offsets: self._process_detector_offsets, - NParameterCommandId.save: self._process_save, - NParameterCommandId.user_specified_output_name: self._process_user_specified_output_name, - NParameterCommandId.user_specified_output_name_suffix: + self._method_map = {NParameterCommandId.USER_FILE: self._process_user_file, + NParameterCommandId.MASK: self._process_mask, + NParameterCommandId.INCIDENT_SPECTRUM: self._process_incident_spectrum, + NParameterCommandId.CLEAN: self._process_clean, + NParameterCommandId.REDUCTION_DIMENSIONALITY: self._process_reduction_dimensionality, + NParameterCommandId.SAMPLE_OFFSET: self._process_sample_offset, + NParameterCommandId.DETECTOR: self._process_detector, + NParameterCommandId.GRAVITY: self._process_gravity, + NParameterCommandId.CENTRE: self._process_centre, + NParameterCommandId.TRANS_FIT: self._process_trans_fit, + NParameterCommandId.FRONT_DETECTOR_RESCALE: self._process_front_detector_rescale, + NParameterCommandId.EVENT_SLICES: self._process_event_slices, + NParameterCommandId.FLOOD_FILE: self._process_flood_file, + NParameterCommandId.PHI_LIMIT: self._process_phi_limit, + NParameterCommandId.WAVELENGTH_CORRECTION_FILE: self._process_wavelength_correction_file, + NParameterCommandId.MASK_RADIUS: self._process_mask_radius, + NParameterCommandId.WAVELENGTH_LIMIT: self._process_wavelength_limit, + NParameterCommandId.QXY_LIMIT: self._process_qxy_limit, + NParameterCommandId.WAV_RANGE_SETTINGS: self._process_wavrange, + NParameterCommandId.COMPATIBILITY_MODE: self._process_compatibility_mode, + NParameterCommandId.DETECTOR_OFFSETS: self._process_detector_offsets, + NParameterCommandId.SAVE: self._process_save, + NParameterCommandId.USER_SPECIFIED_OUTPUT_NAME: self._process_user_specified_output_name, + NParameterCommandId.USER_SPECIFIED_OUTPUT_NAME_SUFFIX: self._process_user_specified_output_name_suffix, - NParameterCommandId.use_reduction_mode_as_suffix: + NParameterCommandId.USE_REDUCTION_MODE_AS_SUFFIX: self._process_use_reduction_mode_as_suffix } @@ -351,7 +370,7 @@ class CommandInterfaceStateDirector(object): incident_monitor = command.values[0] interpolate = command.values[1] is_trans = command.values[2] - new_state_entries = {MonId.spectrum: monitor_spectrum(spectrum=incident_monitor, + new_state_entries = {MonId.SPECTRUM: monitor_spectrum(spectrum=incident_monitor, is_trans=is_trans, interpolate=interpolate)} self.add_to_processed_state_settings(new_state_entries) @@ -365,7 +384,7 @@ class CommandInterfaceStateDirector(object): index_first_clean_command = None for index in reversed(list(range(0, len(self._commands)))): element = self._commands[index] - if element.command_id == NParameterCommandId.clean: + if element.command_id == NParameterCommandId.CLEAN: index_first_clean_command = index break if index_first_clean_command is not None: @@ -380,36 +399,36 @@ class CommandInterfaceStateDirector(object): def _process_reduction_dimensionality(self, command): _ = command # noqa reduction_dimensionality = command.values[0] - new_state_entries = {OtherId.reduction_dimensionality: reduction_dimensionality} + new_state_entries = {OtherId.REDUCTION_DIMENSIONALITY: reduction_dimensionality} self.add_to_processed_state_settings(new_state_entries) def _process_sample_offset(self, command): sample_offset = command.values[0] - new_state_entries = {SampleId.offset: sample_offset} + new_state_entries = {SampleId.OFFSET: sample_offset} self.add_to_processed_state_settings(new_state_entries) def _process_detector(self, command): reduction_mode = command.values[0] - new_state_entries = {DetectorId.reduction_mode: reduction_mode} + new_state_entries = {DetectorId.REDUCTION_MODE: reduction_mode} self.add_to_processed_state_settings(new_state_entries) def _process_gravity(self, command): use_gravity = command.values[0] extra_length = command.values[1] - new_state_entries = {GravityId.on_off: use_gravity, - GravityId.extra_length: extra_length} + new_state_entries = {GravityId.ON_OFF: use_gravity, + GravityId.EXTRA_LENGTH: extra_length} self.add_to_processed_state_settings(new_state_entries) def _process_centre(self, command): pos1 = command.values[0] pos2 = command.values[1] detector_type = command.values[2] - new_state_entries = {SetId.centre: position_entry(pos1=pos1, pos2=pos2, detector_type=detector_type)} + new_state_entries = {SetId.CENTRE: position_entry(pos1=pos1, pos2=pos2, detector_type=detector_type)} self.add_to_processed_state_settings(new_state_entries) def _process_trans_fit(self, command): def fit_type_to_data_type(fit_type_to_convert): - return DataType.Can if fit_type_to_convert is FitData.Can else DataType.Sample + return DataType.CAN if fit_type_to_convert is FitData.Can else DataType.SAMPLE fit_data = command.values[0] wavelength_low = command.values[1] @@ -424,7 +443,7 @@ class CommandInterfaceStateDirector(object): new_state_entries = {} for element in data_to_fit: data_type = fit_type_to_data_type(element) - new_state_entries.update({FitId.general: fit_general(start=wavelength_low, stop=wavelength_high, + new_state_entries.update({FitId.GENERAL: fit_general(start=wavelength_low, stop=wavelength_high, fit_type=fit_type, data_type=data_type, polynomial_order=polynomial_order)}) self.add_to_processed_state_settings(new_state_entries) @@ -438,44 +457,44 @@ class CommandInterfaceStateDirector(object): q_max = command.values[5] # Set the scale and the shift - new_state_entries = {DetectorId.rescale: scale, DetectorId.shift: shift} + new_state_entries = {DetectorId.RESCALE: scale, DetectorId.SHIFT: shift} # Set the fit for the scale - new_state_entries.update({DetectorId.rescale_fit: det_fit_range(start=q_min, stop=q_max, use_fit=fit_scale)}) + new_state_entries.update({DetectorId.RESCALE_FIT: det_fit_range(start=q_min, stop=q_max, use_fit=fit_scale)}) # Set the fit for shift - new_state_entries.update({DetectorId.shift_fit: det_fit_range(start=q_min, stop=q_max, use_fit=fit_shift)}) + new_state_entries.update({DetectorId.SHIFT_FIT: det_fit_range(start=q_min, stop=q_max, use_fit=fit_shift)}) self.add_to_processed_state_settings(new_state_entries) def _process_event_slices(self, command): event_slice_value = command.values - new_state_entries = {OtherId.event_slices: event_binning_string_values(value=event_slice_value)} + new_state_entries = {OtherId.EVENT_SLICES: event_binning_string_values(value=event_slice_value)} self.add_to_processed_state_settings(new_state_entries) def _process_flood_file(self, command): file_path = command.values[0] detector_type = command.values[1] - new_state_entries = {MonId.flat: monitor_file(file_path=file_path, detector_type=detector_type)} + new_state_entries = {MonId.FLAT: monitor_file(file_path=file_path, detector_type=detector_type)} self.add_to_processed_state_settings(new_state_entries) def _process_phi_limit(self, command): phi_min = command.values[0] phi_max = command.values[1] use_phi_mirror = command.values[2] - new_state_entries = {LimitsId.angle: mask_angle_entry(min=phi_min, max=phi_max, use_mirror=use_phi_mirror)} + new_state_entries = {LimitsId.ANGLE: mask_angle_entry(min=phi_min, max=phi_max, use_mirror=use_phi_mirror)} self.add_to_processed_state_settings(new_state_entries) def _process_wavelength_correction_file(self, command): file_path = command.values[0] detector_type = command.values[1] - new_state_entries = {MonId.direct: monitor_file(file_path=file_path, detector_type=detector_type)} + new_state_entries = {MonId.DIRECT: monitor_file(file_path=file_path, detector_type=detector_type)} self.add_to_processed_state_settings(new_state_entries) def _process_mask_radius(self, command): radius_min = command.values[0] radius_max = command.values[1] - new_state_entries = {LimitsId.radius: range_entry(start=radius_min, stop=radius_max)} + new_state_entries = {LimitsId.RADIUS: range_entry(start=radius_min, stop=radius_max)} self.add_to_processed_state_settings(new_state_entries) def _process_wavelength_limit(self, command): @@ -483,7 +502,7 @@ class CommandInterfaceStateDirector(object): wavelength_high = command.values[1] wavelength_step = command.values[2] wavelength_step_type = command.values[3] - new_state_entries = {LimitsId.wavelength: simple_range(start=wavelength_low, stop=wavelength_high, + new_state_entries = {LimitsId.WAVELENGTH: simple_range(start=wavelength_low, stop=wavelength_high, step=wavelength_step, step_type=wavelength_step_type)} self.add_to_processed_state_settings(new_state_entries) @@ -498,8 +517,8 @@ class CommandInterfaceStateDirector(object): # is not nice but the command interface forces us to do so. We take a copy of the last LimitsId.wavelength # entry, we copy it and then change the desired settings. This means it has to be set at this point, else # something is wrong - if LimitsId.wavelength in self._processed_state_settings: - last_entry = self._processed_state_settings[LimitsId.wavelength][-1] + if LimitsId.WAVELENGTH in self._processed_state_settings: + last_entry = self._processed_state_settings[LimitsId.WAVELENGTH][-1] new_wavelength_low = wavelength_low if wavelength_low is not None else last_entry.start new_wavelength_high = wavelength_high if wavelength_high is not None else last_entry.stop @@ -507,18 +526,18 @@ class CommandInterfaceStateDirector(object): step_type=last_entry.step_type) if wavelength_low is not None or wavelength_high is not None: - copied_entry = {LimitsId.wavelength: new_range} + copied_entry = {LimitsId.WAVELENGTH: new_range} self.add_to_processed_state_settings(copied_entry) else: raise RuntimeError("CommandInterfaceStateDirector: Setting the lower and upper wavelength bounds is not" " possible. We require also a step and step range") if full_wavelength_range is not None: - full_wavelength_range_entry = {OtherId.use_full_wavelength_range: full_wavelength_range} + full_wavelength_range_entry = {OtherId.USE_FULL_WAVELENGTH_RANGE: full_wavelength_range} self.add_to_processed_state_settings(full_wavelength_range_entry) if reduction_mode is not None: - reduction_mode_entry = {DetectorId.reduction_mode: reduction_mode} + reduction_mode_entry = {DetectorId.REDUCTION_MODE: reduction_mode} self.add_to_processed_state_settings(reduction_mode_entry) def _process_qxy_limit(self, command): @@ -526,12 +545,12 @@ class CommandInterfaceStateDirector(object): q_max = command.values[1] q_step = command.values[2] q_step_type = command.values[3] - new_state_entries = {LimitsId.qxy: simple_range(start=q_min, stop=q_max, step=q_step, step_type=q_step_type)} + new_state_entries = {LimitsId.QXY: simple_range(start=q_min, stop=q_max, step=q_step, step_type=q_step_type)} self.add_to_processed_state_settings(new_state_entries) def _process_compatibility_mode(self, command): use_compatibility_mode = command.values[0] - new_state_entries = {OtherId.use_compatibility_mode: use_compatibility_mode} + new_state_entries = {OtherId.USE_COMPATIBILITY_MODE: use_compatibility_mode} self.add_to_processed_state_settings(new_state_entries) def _process_detector_offsets(self, command): @@ -546,18 +565,18 @@ class CommandInterfaceStateDirector(object): y_tilt = command.values[8] # Set the offsets - new_state_entries = {DetectorId.correction_x: single_entry_with_detector(entry=x, detector_type=detector_type), - DetectorId.correction_y: single_entry_with_detector(entry=y, detector_type=detector_type), - DetectorId.correction_z: single_entry_with_detector(entry=z, detector_type=detector_type), - DetectorId.correction_rotation: + new_state_entries = {DetectorId.CORRECTION_X: single_entry_with_detector(entry=x, detector_type=detector_type), + DetectorId.CORRECTION_Y: single_entry_with_detector(entry=y, detector_type=detector_type), + DetectorId.CORRECTION_Z: single_entry_with_detector(entry=z, detector_type=detector_type), + DetectorId.CORRECTION_ROTATION: single_entry_with_detector(entry=rotation, detector_type=detector_type), - DetectorId.correction_radius: + DetectorId.CORRECTION_RADIUS: single_entry_with_detector(entry=radius, detector_type=detector_type), - DetectorId.correction_translation: + DetectorId.CORRECTION_TRANSLATION: single_entry_with_detector(entry=side, detector_type=detector_type), - DetectorId.correction_x_tilt: + DetectorId.CORRECTION_X_TILT: single_entry_with_detector(entry=x_tilt, detector_type=detector_type), - DetectorId.correction_y_tilt: + DetectorId.CORRECTION_Y_TILT: single_entry_with_detector(entry=y_tilt, detector_type=detector_type), } self.add_to_processed_state_settings(new_state_entries) @@ -565,23 +584,23 @@ class CommandInterfaceStateDirector(object): def _process_save(self, command): save_algorithms = command.values[0] save_as_zero_error_free = command.values[1] - new_state_entries = {OtherId.save_types: save_algorithms, - OtherId.save_as_zero_error_free: save_as_zero_error_free} + new_state_entries = {OtherId.SAVE_TYPES: save_algorithms, + OtherId.SAVE_AS_ZERO_ERROR_FREE: save_as_zero_error_free} self.add_to_processed_state_settings(new_state_entries, treat_list_as_element=True) def _process_user_specified_output_name(self, command): user_specified_output_name = command.values[0] - new_state_entry = {OtherId.user_specified_output_name: user_specified_output_name} + new_state_entry = {OtherId.USER_SPECIFIED_OUTPUT_NAME: user_specified_output_name} self.add_to_processed_state_settings(new_state_entry) def _process_user_specified_output_name_suffix(self, command): user_specified_output_name_suffix = command.values[0] - new_state_entry = {OtherId.user_specified_output_name_suffix: user_specified_output_name_suffix} + new_state_entry = {OtherId.USER_SPECIFIED_OUTPUT_NAME_SUFFIX: user_specified_output_name_suffix} self.add_to_processed_state_settings(new_state_entry) def _process_use_reduction_mode_as_suffix(self, command): use_reduction_mode_as_suffix = command.values[0] - new_state_entry = {OtherId.use_reduction_mode_as_suffix: use_reduction_mode_as_suffix} + new_state_entry = {OtherId.USE_REDUCTION_MODE_AS_SUFFIX: use_reduction_mode_as_suffix} self.add_to_processed_state_settings(new_state_entry) def remove_last_user_file(self): @@ -590,7 +609,7 @@ class CommandInterfaceStateDirector(object): See _remove_last_element for further explanation. """ - self._remove_last_element(NParameterCommandId.user_file) + self._remove_last_element(NParameterCommandId.USER_FILE) def remove_last_scatter_sample(self): """ @@ -598,7 +617,7 @@ class CommandInterfaceStateDirector(object): See _remove_last_element for further explanation. """ - self._remove_last_element(DataCommandId.sample_scatter) + self._remove_last_element(DataCommandId.SAMPLE_SCATTER) def remove_last_sample_transmission_and_direct(self): """ @@ -606,8 +625,8 @@ class CommandInterfaceStateDirector(object): See _remove_last_element for further explanation. """ - self._remove_last_element(DataCommandId.sample_transmission) - self._remove_last_element(DataCommandId.sample_direct) + self._remove_last_element(DataCommandId.SAMPLE_TRANSMISSION) + self._remove_last_element(DataCommandId.SAMPLE_DIRECT) def remove_last_scatter_can(self): """ @@ -615,7 +634,7 @@ class CommandInterfaceStateDirector(object): See _remove_last_element for further explanation. """ - self._remove_last_element(DataCommandId.can_scatter) + self._remove_last_element(DataCommandId.CAN_SCATTER) def remove_last_can_transmission_and_direct(self): """ @@ -623,8 +642,8 @@ class CommandInterfaceStateDirector(object): See _remove_last_element for further explanation. """ - self._remove_last_element(DataCommandId.can_transmission) - self._remove_last_element(DataCommandId.can_direct) + self._remove_last_element(DataCommandId.CAN_TRANSMISSION) + self._remove_last_element(DataCommandId.CAN_DIRECT) def _remove_last_element(self, command_id): """ diff --git a/scripts/SANS/sans/common/constant_containers.py b/scripts/SANS/sans/common/constant_containers.py index 3c67f56aa739db446ab96556784af9babd751115..1a64b2d64c49eb893abd62132b2fbaf1bbbc6c80 100644 --- a/scripts/SANS/sans/common/constant_containers.py +++ b/scripts/SANS/sans/common/constant_containers.py @@ -17,7 +17,8 @@ SANSInstrument_string_as_key = {LARMOR: SANSInstrument.LARMOR, # Include NoInstrument in a dict SANSInstrument_string_as_key_NoInstrument = SANSInstrument_string_as_key.copy() -SANSInstrument_string_as_key_NoInstrument.update({"NoInstrument": SANSInstrument.NoInstrument}) +SANSInstrument_string_as_key_NoInstrument.update({"NoInstrument": SANSInstrument.NO_INSTRUMENT}) +SANSInstrument_string_as_key_NoInstrument.update({"No Instrument": SANSInstrument.NO_INSTRUMENT}) SANSInstrument_enum_as_key = {SANSInstrument.LARMOR: LARMOR, SANSInstrument.LOQ: LOQ, @@ -25,7 +26,7 @@ SANSInstrument_enum_as_key = {SANSInstrument.LARMOR: LARMOR, SANSInstrument.ZOOM: ZOOM} SANSInstrument_enum_as_key_NoInstrument = SANSInstrument_enum_as_key.copy() -SANSInstrument_enum_as_key_NoInstrument.update({SANSInstrument.NoInstrument: "NoInstrument"}) +SANSInstrument_enum_as_key_NoInstrument.update({SANSInstrument.NO_INSTRUMENT: "NoInstrument"}) SANSInstrument_string_list = [LARMOR, LOQ, SANS2D, ZOOM] SANSInstrument_enum_list = [SANSInstrument.LARMOR, SANSInstrument.LOQ, SANSInstrument.SANS2D, SANSInstrument.ZOOM] diff --git a/scripts/SANS/sans/common/enums.py b/scripts/SANS/sans/common/enums.py index 723771b6085aea25f1a67192b1b69dfa4b1785df..7e690c32e482487415617e715c5749e89d253c0c 100644 --- a/scripts/SANS/sans/common/enums.py +++ b/scripts/SANS/sans/common/enums.py @@ -1,415 +1,248 @@ # Mantid Repository : https://github.com/mantidproject/mantid # -# Copyright © 2018 ISIS Rutherford Appleton Laboratory UKRI, +# Copyright © 2019 ISIS Rutherford Appleton Laboratory UKRI, # NScD Oak Ridge National Laboratory, European Spallation Source # & Institut Laue - Langevin # SPDX - License - Identifier: GPL - 3.0 + """ The elements of this module define typed enums which are used in the SANS reduction framework.""" -# pylint: disable=too-few-public-methods, invalid-name - from __future__ import (absolute_import, division, print_function) -from inspect import isclass -from functools import partial -from six import PY2 - - -# ---------------------------------------------------------------------------------------------------------------------- -# Serializable Enum decorator -# ---------------------------------------------------------------------------------------------------------------------- -def serializable_enum(*inner_classes): - """ - Class decorator which changes the name of an inner class to include the name of the outer class. The inner class - gets a method to determine the name of the outer class. This information is needed for serialization at the - algorithm input boundary. - """ - def inner_class_builder(cls): - # Add each inner class to the outer class - for inner_class in inner_classes: - new_class = type(inner_class, (cls, ), {"outer_class_name": cls.__name__}) - # We set the module of the inner class to the module of the outer class. We have to do this since we - # are dynamically adding the inner class which gets its module name from the module where it was added, - # but not where the outer class lives. - module_of_outer_class = getattr(cls, "__module__") - setattr(new_class, "__module__", module_of_outer_class) - # Add the inner class to the outer class - setattr(cls, inner_class, new_class) - return cls - return inner_class_builder - - -# ---------------------------------------------------------------------------------------------------------------------- -# String conversion decorator -# ---------------------------------------------------------------------------------------------------------------------- -def string_convertible(cls): - """ - Class decorator to make the enum/sub-class entries string convertible. - - We do this by creating a static from_string and to_string method on the class. - IMPORTANT: It is important that the enum values are added to the class before applying this decorator. In general - the order has to be: - @string_convertible - @serializable_enum - class MyClass(object): - ... - :param cls: a reference to the class - :return: the class - """ - def to_string(elements, convert_to_string): - for key, value in list(elements.items()): - if convert_to_string is value: - return key - raise RuntimeError("Could not convert {0} to string. Unknown value.".format(convert_to_string)) - - def from_string(elements, convert_from_string): - if not PY2 and isinstance(convert_from_string, bytes): - convert_from_string = convert_from_string.decode() - for key, value in list(elements.items()): - if convert_from_string == key: - return value - raise RuntimeError("Could not convert {0} from string. Unknown value.".format(convert_from_string)) - - def has_member(elements, convert): - if not PY2 and isinstance(convert, bytes): - convert = convert.decode() - for key, value in list(elements.items()): - if convert == key or convert == value: - return True - return False - - # First get all enum/sub-class elements - convertible_elements = {} - for attribute_name, attribute_value in list(cls.__dict__.items()): - if isclass(attribute_value) and issubclass(attribute_value, cls): - convertible_elements.update({attribute_name: attribute_value}) - - # Add the new static methods to the class - partial_to_string = partial(to_string, convertible_elements) - partial_from_string = partial(from_string, convertible_elements) - partial_has_member = partial(has_member, convertible_elements) - setattr(cls, "to_string", staticmethod(partial_to_string)) - setattr(cls, "from_string", staticmethod(partial_from_string)) - setattr(cls, "has_member", staticmethod(partial_has_member)) - return cls - - -# -------------------------------- -# Instrument and facility types -# -------------------------------- -@string_convertible -@serializable_enum("LOQ", "LARMOR", "SANS2D", "ZOOM", "NoInstrument") -class SANSInstrument(object): - pass - - -@string_convertible -@serializable_enum("ISIS", "NoFacility") -class SANSFacility(object): - pass - - -# ------------------------------------ -# Data Types -# ------------------------------------ -@string_convertible -@serializable_enum("SampleScatter", "SampleTransmission", "SampleDirect", "CanScatter", "CanTransmission", "CanDirect", - "Calibration") -class SANSDataType(object): - """ - Defines the different data types which are required for the reduction. Besides the fundamental data of the - sample and the can, we can also specify a calibration. - """ - pass +from mantid.py3compat import Enum -# --------------------------- -# Coordinate Definitions (3D) -# -------------------------- -class Coordinates(object): - pass +class SANSInstrument(Enum): + NO_INSTRUMENT = "No Instrument" + LARMOR = "LARMOR" + LOQ = "LOQ" + SANS2D = "SANS2D" + ZOOM = "ZOOM" -@serializable_enum("X", "Y", "Z") -class CanonicalCoordinates(Coordinates): - pass +class SANSFacility(Enum): + NO_FACILITY = "No Facility" + ISIS = "ISIS" -# -------------------------- -# ReductionMode -# -------------------------- -@string_convertible -@serializable_enum("Merged", "All") -class ReductionMode(object): + +class SANSDataType(Enum): """ - Defines the reduction modes which should be common to all implementations, namely All and Merged. + Defines the different data types which are required for the reduction. Besides the fundamental data of the + sample and the can, we can also specify a calibration. """ - pass + CAN_DIRECT = "Can Direct" + CAN_TRANSMISSION = "Can Transmission" + CAN_SCATTER = "Can Scatter" + CALIBRATION = "Calibration" + SAMPLE_DIRECT = "Sample Direct" + SAMPLE_SCATTER = "Sample Scatter" + SAMPLE_TRANSMISSION = "Sample Transmission" -@string_convertible -@serializable_enum("HAB", "LAB") -class ISISReductionMode(ReductionMode): - """ - Defines the different reduction modes. This can be the high-angle bank, the low-angle bank - """ - pass +class CanonicalCoordinates(Enum): + X = "X" + Y = "Y" + Z = "Z" -# -------------------------- -# Reduction dimensionality -# -------------------------- -@serializable_enum("OneDim", "TwoDim") -class ReductionDimensionality(object): - """ - Defines the dimensionality for reduction. This can be either 1D or 2D - """ - pass +class ReductionMode(Enum): + NOT_SET = "Not Set" + ALL = "All" + MERGED = "Merged" + HAB = "HAB" + LAB = "LAB" + +class ReductionDimensionality(Enum): + ONE_DIM = "OneDim" + TWO_DIM = "TwoDim" -# -------------------------- -# Reduction data -# -------------------------- -@serializable_enum("Scatter", "Transmission", "Direct") -class ReductionData(object): + +class ReductionData(Enum): """ Defines the workspace type of the reduction data. For all known instances this can be scatter, transmission or direct """ - pass + DIRECT = "Direct" + SCATTER = "Scatter" + TRANSMISSION = "Transmission" -# -------------------------- -# Type of data -# -------------------------- -@string_convertible -@serializable_enum("Sample", "Can") -class DataType(object): +class DataType(Enum): """ Defines the type of reduction data. This can either the sample or only the can. """ - pass + CAN = "Can" + SAMPLE = "Sample" -# --------------------------------- -# Partial reduction output setting -# --------------------------------- -@serializable_enum("Count", "Norm") -class OutputParts(object): +class OutputParts(Enum): """ Defines the partial outputs of a reduction. They are the numerator (Count) and denominator (Norm) of a division. """ - pass + COUNT = "Count" + NORM = "Norm" -# ----------------------------------------------------- -# The fit type during merge of HAB and LAB reductions -# ----------------------------------------------------- -@string_convertible -@serializable_enum("Both", "NoFit", "ShiftOnly", "ScaleOnly") -class FitModeForMerge(object): +class FitModeForMerge(Enum): """ Defines which fit operation to use during the merge of two reductions. """ - pass - - -# -------------------------- -# Detectors -# -------------------------- -@serializable_enum("Horizontal", "Vertical", "Rotated") -class DetectorOrientation(object): - """ - Defines the detector orientation. - """ - pass + BOTH = "Both" + NO_FIT = "NoFit" + SCALE_ONLY = "ScaleOnly" + SHIFT_ONLY = "ShiftOnly" -# -------------------------- -# Detector Type -# -------------------------- -@string_convertible -@serializable_enum("HAB", "LAB") -class DetectorType(object): +class DetectorType(Enum): """ Defines the detector type """ - pass + HAB = "HAB" + LAB = "LAB" -# -------------------------- -# Transmission Type -# -------------------------- -@string_convertible -@serializable_enum("Calculated", "Unfitted") -class TransmissionType(object): +class TransmissionType(Enum): """ Defines the detector type """ - pass + CALCULATED = "Calculated" + UNFITTED = "Unfitted" -# -------------------------- -# Ranges -# -------------------------- -@string_convertible -@serializable_enum("Lin", "Log", "RangeLin", "RangeLog") -class RangeStepType(object): +class RangeStepType(Enum): """ Defines the step type of a range """ - pass + LIN = "Lin" + LOG = "Log" + NOT_SET = "NotSet" + RANGE_LIN = "RangeLin" + RANGE_LOG = "RangeLog" -# -------------------------- -# Rebin -# -------------------------- -@string_convertible -@serializable_enum("Rebin", "InterpolatingRebin") -class RebinType(object): - """ - Defines the rebin types available - """ - pass +class RebinType(Enum): + INTERPOLATING_REBIN = "InterpolatingRebin" + REBIN = "Rebin" -# -------------------------- -# SaveType -# -------------------------- -@string_convertible -@serializable_enum("Nexus", "NistQxy", "CanSAS", "RKH", "CSV", "NXcanSAS", "Nexus", "NoType") -class SaveType(object): - """ - Defines the save types available - """ - pass +class SaveType(Enum): + CAN_SAS = "CanSAS" + CSV = "CSV" + NEXUS = "Nexus" + NIST_QXY = "NistQxy" + NO_TYPE = "NoType" + NX_CAN_SAS = "NXcanSAS" + RKH = "RKH" -# ------------------------------------------ -# Fit type for the transmission calculation -# ------------------------------------------ -@string_convertible -@serializable_enum("Linear", "Logarithmic", "Polynomial", "NoFit") -class FitType(object): +class FitType(Enum): """ - Defines possible fit types + Defines possible fit types for the transmission calculation """ - pass + LINEAR = "Linear" + LOGARITHMIC = "Logarithmic" + POLYNOMIAL = "Polynomial" + NO_FIT = "NotFit" -# -------------------------- -# SampleShape -# -------------------------- -@string_convertible -@serializable_enum("Cylinder", "FlatPlate", "Disc") -class SampleShape(object): +class SampleShape(Enum): """ Defines the sample shape types """ - pass - - -def convert_int_to_shape(shape_int): - """ - Note that we convert the sample shape to an integer here. This is required for the workspace, hence we don't - use the string_convertible decorator. - """ - if shape_int == 1: - as_type = SampleShape.CylinderAxisUp - elif shape_int == 2: - as_type = SampleShape.Cuboid - elif shape_int == 3: - as_type = SampleShape.CylinderAxisAlong - else: - raise ValueError("SampleShape: Cannot convert unknown sample shape integer: {0}".format(shape_int)) - return as_type + CYLINDER = "Cylinder" + DISC = "Disc" + FLAT_PLATE = "FlatPlate" + NOT_SET = "NotSet" -# --------------------------- -# FileTypes -# --------------------------- -@serializable_enum("ISISNexus", "ISISNexusAdded", "ISISRaw", "NoFileType") -class FileType(object): - pass +class FileType(Enum): + ISIS_NEXUS = "ISISNexus" + ISIS_NEXUS_ADDED = "ISISNexusAdded" + ISIS_RAW = "ISISRaw" + NO_FILE_TYPE = "NoFileType" -# --------------------------- -# OutputMode -# --------------------------- -@string_convertible -@serializable_enum("PublishToADS", "SaveToFile", "Both") -class OutputMode(object): +class OutputMode(Enum): """ Defines the output modes of a batch reduction. """ - pass + BOTH = "Both" + PUBLISH_TO_ADS = "PublishToADS" + SAVE_TO_FILE = "SaveToFile" -# ------------------------------ -# Entries of batch reduction file -# ------------------------------- -@string_convertible -@serializable_enum("SampleScatter", "SampleTransmission", "SampleDirect", "CanScatter", "CanTransmission", "CanDirect", - "Output", "UserFile", "SampleScatterPeriod", "SampleTransmissionPeriod", "SampleDirectPeriod", - "CanScatterPeriod", "CanTransmissionPeriod", "CanDirectPeriod", "SampleThickness", "SampleHeight", - "SampleWidth") -class BatchReductionEntry(object): +class BatchReductionEntry(Enum): """ Defines the entries of a batch reduction file. """ - pass + CAN_DIRECT = "CanDirect" + CAN_DIRECT_PERIOD = "CanDirectPeriod" + + CAN_SCATTER = "CanScatter" + CAN_SCATTER_PERIOD = "CanScatterPeriod" + + CAN_TRANSMISSION = "CanTransmission" + CAN_TRANSMISSION_PERIOD = "CanTransmissionPeriod" + + OUTPUT = "Output" + + SAMPLE_DIRECT = "SampleDirect" + SAMPLE_DIRECT_PERIOD = "SampleDirectPeriod" + + SAMPLE_SCATTER = "SampleScatter" + SAMPLE_SCATTER_PERIOD = "SampleScatterPeriod" + + SAMPLE_TRANSMISSION = "SampleTransmission" + SAMPLE_TRANSMISSION_PERIOD = "SampleTransmissionPeriod" + + SAMPLE_HEIGHT = "SampleHeight" + SAMPLE_THICKNESS = "SampleThickness" + SAMPLE_WIDTH = "SampleWidth" + + USER_FILE = "UserFile" -# ------------------------------ -# Quadrants for beam centre finder -# ------------------------------- -@string_convertible -@serializable_enum("Left", "Right", "Top", "Bottom") -class MaskingQuadrant(object): +class MaskingQuadrant(Enum): """ Defines the entries of a batch reduction file. """ - pass + BOTTOM = "Bottom" + LEFT = "Left" + RIGHT = "Right" + TOP = "Top" -# ------------------------------ -# Directions for Beam centre finder -# ------------------------------- -@string_convertible -@serializable_enum("All", "Up_Down", "Left_Right") -class FindDirectionEnum(object): +class FindDirectionEnum(Enum): """ Defines the entries of a batch reduction file. """ - pass + ALL = "All" + UP_DOWN = "Up_Down" + LEFT_RIGHT = "Left_Right" -# ------------------------------ -# Integrals for diagnostic tab -# ------------------------------- -@string_convertible -@serializable_enum("Horizontal", "Vertical", "Time") -class IntegralEnum(object): +class IntegralEnum(Enum): """ Defines the entries of a batch reduction file. """ - pass + Horizontal = "Horizontal" + Time = "Time" + Vertical = "Vertical" -@string_convertible -@serializable_enum("Unprocessed", "Processed", "Error") -class RowState(object): +class RowState(Enum): """ Defines the entries of a batch reduction file. """ - pass + ERROR = "Error" + PROCESSED = "Processed" + UNPROCESSED = "Unprocessed" -# ------------------------------ -# Binning Types for AddRuns -# ------------------------------- -@string_convertible -@serializable_enum("SaveAsEventData", "Custom", "FromMonitors") -class BinningType(object): +class BinningType(Enum): """ Defines the types of binning when adding runs together """ + CUSTOM = "Custom" + FROM_MONITORS = "FromMonitors" + SAVE_AS_EVENT_DATA = "SaveAsEventData" diff --git a/scripts/SANS/sans/common/file_information.py b/scripts/SANS/sans/common/file_information.py index 885bceca537dce0c8d8f27e80a201eb0c885bd5f..3a2bf92c1b89b564cda35387112acde509083bc7 100644 --- a/scripts/SANS/sans/common/file_information.py +++ b/scripts/SANS/sans/common/file_information.py @@ -128,9 +128,9 @@ def get_extension_for_file_type(file_info): :param file_info: a SANSFileInformation object. :return: the extension a string. This can be either nxs or raw. """ - if file_info.get_type() is FileType.ISISNexus or file_info.get_type() is FileType.ISISNexusAdded: + if file_info.get_type() is FileType.ISIS_NEXUS or file_info.get_type() is FileType.ISIS_NEXUS_ADDED: extension = NXS_EXTENSION - elif file_info.get_type() is FileType.ISISRaw: + elif file_info.get_type() is FileType.ISIS_RAW: extension = RAW_EXTENSION else: raise RuntimeError("The file extension type for a file of type {0} is unknown" @@ -236,7 +236,7 @@ def get_instrument_paths_for_sans_file(file_name=None, file_information=None): # Get the instrument instrument = file_information.get_instrument() - instrument_as_string = SANSInstrument.to_string(instrument) + instrument_as_string = instrument.value # Get the idf file path # IMPORTANT NOTE: I profiled the call to ExperimentInfo.getInstrumentFilename and it dominates @@ -282,11 +282,11 @@ def convert_to_shape(shape_flag): :return: a shape object """ if shape_flag == 1: - shape = SampleShape.Cylinder + shape = SampleShape.CYLINDER elif shape_flag == 2: - shape = SampleShape.FlatPlate + shape = SampleShape.FLAT_PLATE elif shape_flag == 3: - shape = SampleShape.Disc + shape = SampleShape.DISC else: shape = None return shape @@ -412,11 +412,11 @@ def get_geometry_information_isis_nexus(file_name): thickness = float(sample[THICKNESS][0]) shape_as_string = sample[SHAPE][0].upper().decode("utf-8") if shape_as_string == CYLINDER: - shape = SampleShape.Cylinder + shape = SampleShape.CYLINDER elif shape_as_string == FLAT_PLATE: - shape = SampleShape.FlatPlate + shape = SampleShape.FLAT_PLATE elif shape_as_string == DISC: - shape = SampleShape.Disc + shape = SampleShape.DISC else: shape = None return height, width, thickness, shape @@ -885,7 +885,7 @@ class SANSFileInformationISISNexus(SANSFileInformation): super(SANSFileInformationISISNexus, self).__init__(file_name) # Setup instrument name instrument_name = get_instrument_name_for_isis_nexus(self._full_file_name) - self._instrument = SANSInstrument.from_string(instrument_name) + self._instrument = SANSInstrument[instrument_name] # Setup the facility self._facility = get_facility(self._instrument) @@ -904,7 +904,7 @@ class SANSFileInformationISISNexus(SANSFileInformation): self._height = height if height is not None else 1. self._width = width if width is not None else 1. self._thickness = thickness if thickness is not None else 1. - self._shape = shape if shape is not None else SampleShape.Disc + self._shape = shape if shape is not None else SampleShape.DISC def get_file_name(self): return self._full_file_name @@ -922,7 +922,7 @@ class SANSFileInformationISISNexus(SANSFileInformation): return self._number_of_periods def get_type(self): - return FileType.ISISNexus + return FileType.ISIS_NEXUS def is_event_mode(self): return self._is_event_mode @@ -967,7 +967,7 @@ class SANSFileInformationISISAdded(SANSFileInformation): self._height = height if height is not None else 1. self._width = width if width is not None else 1. self._thickness = thickness if thickness is not None else 1. - self._shape = shape if shape is not None else SampleShape.Disc + self._shape = shape if shape is not None else SampleShape.DISC def get_file_name(self): return self._full_file_name @@ -985,7 +985,7 @@ class SANSFileInformationISISAdded(SANSFileInformation): return self._number_of_periods def get_type(self): - return FileType.ISISNexusAdded + return FileType.ISIS_NEXUS_ADDED def is_event_mode(self): return self._is_event_mode @@ -1029,7 +1029,7 @@ class SANSFileInformationRaw(SANSFileInformation): super(SANSFileInformationRaw, self).__init__(file_name) # Setup instrument name instrument_name = get_instrument_name_for_raw(self._full_file_name) - self._instrument = SANSInstrument.from_string(instrument_name) + self._instrument = SANSInstrument[instrument_name] # Setup the facility self._facility = get_facility(self._instrument) @@ -1046,7 +1046,7 @@ class SANSFileInformationRaw(SANSFileInformation): self._height = height if height is not None else 1. self._width = width if width is not None else 1. self._thickness = thickness if thickness is not None else 1. - self._shape = shape if shape is not None else SampleShape.Disc + self._shape = shape if shape is not None else SampleShape.DISC def get_file_name(self): return self._full_file_name @@ -1064,7 +1064,7 @@ class SANSFileInformationRaw(SANSFileInformation): return self._number_of_periods def get_type(self): - return FileType.ISISRaw + return FileType.ISIS_RAW def is_event_mode(self): return False diff --git a/scripts/SANS/sans/common/general_functions.py b/scripts/SANS/sans/common/general_functions.py index f404821335573977cba8e2a6eff95d672c7c2912..d79cf3335aadc083abb21b0300851a9fd36fe26f 100644 --- a/scripts/SANS/sans/common/general_functions.py +++ b/scripts/SANS/sans/common/general_functions.py @@ -20,8 +20,8 @@ from sans.common.constant_containers import (SANSInstrument_enum_list, SANSInstr from sans.common.constants import (SANS_FILE_TAG, ALL_PERIODS, SANS2D, EMPTY_NAME, REDUCED_CAN_TAG) from sans.common.log_tagger import (get_tag, has_tag, set_tag, has_hash, get_hash_value, set_hash) -from sans.common.enums import (DetectorType, RangeStepType, ReductionDimensionality, OutputParts, ISISReductionMode, - SANSInstrument, SANSFacility, DataType, TransmissionType) +from sans.common.enums import (DetectorType, RangeStepType, ReductionDimensionality, OutputParts, ReductionMode, + SANSFacility, DataType, TransmissionType, SANSInstrument) # ------------------------------------------- # Constants @@ -614,7 +614,7 @@ def get_bins_for_rebin_setting(min_value, max_value, step_value, step_type): bins.append(lower_bound) # We can either have linear or logarithmic steps. The logarithmic step depends on the lower bound. - if step_type is RangeStepType.Lin: + if step_type is RangeStepType.LIN: step = step_value else: step = lower_bound * step_value @@ -658,7 +658,7 @@ def get_ranges_for_rebin_array(rebin_array): min_value = rebin_array[0] step_value = rebin_array[1] max_value = rebin_array[2] - step_type = RangeStepType.Lin if step_value >= 0. else RangeStepType.Log + step_type = RangeStepType.LIN if step_value >= 0. else RangeStepType.LOG step_value = abs(step_value) return get_ranges_for_rebin_setting(min_value, max_value, step_value, step_type) @@ -705,13 +705,13 @@ def get_standard_output_workspace_name(state, reduction_data_type, # 3. Detector name move = state.move detectors = move.detectors - if reduction_data_type is ISISReductionMode.Merged: + if reduction_data_type is ReductionMode.MERGED: detector_name_short = "merged" - elif reduction_data_type is ISISReductionMode.HAB: - det_name = detectors[DetectorType.to_string(DetectorType.HAB)].detector_name_short + elif reduction_data_type is ReductionMode.HAB: + det_name = detectors[DetectorType.HAB.value].detector_name_short detector_name_short = det_name if det_name is not None else "hab" - elif reduction_data_type is ISISReductionMode.LAB: - det_name = detectors[DetectorType.to_string(DetectorType.LAB)].detector_name_short + elif reduction_data_type is ReductionMode.LAB: + det_name = detectors[DetectorType.LAB.value].detector_name_short detector_name_short = det_name if det_name is not None else "lab" else: raise RuntimeError("SANSStateFunctions: Unknown reduction data type {0} cannot be used to " @@ -719,7 +719,7 @@ def get_standard_output_workspace_name(state, reduction_data_type, # 4. Dimensionality reduction = state.reduction - if reduction.reduction_dimensionality is ReductionDimensionality.OneDim: + if reduction.reduction_dimensionality is ReductionDimensionality.ONE_DIM: dimensionality_as_string = "_1D" else: dimensionality_as_string = "_2D" @@ -730,7 +730,7 @@ def get_standard_output_workspace_name(state, reduction_data_type, # 6. Phi Limits mask = state.mask - if reduction.reduction_dimensionality is ReductionDimensionality.OneDim: + if reduction.reduction_dimensionality is ReductionDimensionality.ONE_DIM: if mask.phi_min and mask.phi_max and (abs(mask.phi_max - mask.phi_min) != 180.0): phi_limits_as_string = 'Phi' + str(mask.phi_min) + '_' + str(mask.phi_max) else: @@ -760,7 +760,7 @@ def get_standard_output_workspace_name(state, reduction_data_type, return output_workspace_name, output_workspace_base_name -def get_transmission_output_name(state, data_type=DataType.Sample, multi_reduction_type=None, fitted=True): +def get_transmission_output_name(state, data_type=DataType.SAMPLE, multi_reduction_type=None, fitted=True): user_specified_output_name = state.save.user_specified_output_name data = state.data @@ -768,10 +768,10 @@ def get_transmission_output_name(state, data_type=DataType.Sample, multi_reducti short_run_number_as_string = str(short_run_number) calculated_transmission_state = state.adjustment.calculate_transmission - fit = calculated_transmission_state.fit[DataType.to_string(DataType.Sample)] + fit = calculated_transmission_state.fit[DataType.SAMPLE] wavelength_range_string = "_" + str(fit.wavelength_low) + "_" + str(fit.wavelength_high) - trans_suffix = "_trans_Sample" if data_type == DataType.Sample else "_trans_Can" + trans_suffix = "_trans_Sample" if data_type == DataType.SAMPLE else "_trans_Can" trans_suffix = trans_suffix + '_unfitted' if not fitted else trans_suffix if user_specified_output_name: @@ -880,7 +880,7 @@ def get_facility(instrument): if instrument in SANSInstrument_enum_list: return SANSFacility.ISIS else: - return SANSFacility.NoFacility + return SANSFacility.NO_FACILITY # ---------------------------------------------------------------------------------------------------------------------- @@ -937,22 +937,22 @@ def get_state_hash_for_can_reduction(state, reduction_mode, partial_type=None): # Add a tag for the reduction mode state_string = str(new_state_serialized) - if reduction_mode is ISISReductionMode.LAB: + if reduction_mode is ReductionMode.LAB: state_string += "LAB" - elif reduction_mode is ISISReductionMode.HAB: + elif reduction_mode is ReductionMode.HAB: state_string += "HAB" else: raise RuntimeError("Only LAB and HAB reduction modes are allowed at this point." " {} was provided".format(reduction_mode)) # If we are dealing with a partial output workspace, then mark it as such - if partial_type is OutputParts.Count: + if partial_type is OutputParts.COUNT: state_string += "counts" - elif partial_type is OutputParts.Norm: + elif partial_type is OutputParts.NORM: state_string += "norm" - elif partial_type is TransmissionType.Calculated: + elif partial_type is TransmissionType.CALCULATED: state_string += "calculated_transmission" - elif partial_type is TransmissionType.Unfitted: + elif partial_type is TransmissionType.UNFITTED: state_string += "unfitted_transmission" return str(get_hash_value(state_string)) @@ -978,9 +978,9 @@ def get_reduced_can_workspace_from_ads(state, output_parts, reduction_mode): reduced_can_count = None reduced_can_norm = None if output_parts: - hashed_state_count = get_state_hash_for_can_reduction(state, reduction_mode, OutputParts.Count) + hashed_state_count = get_state_hash_for_can_reduction(state, reduction_mode, OutputParts.COUNT) reduced_can_count = get_workspace_from_ads_based_on_hash(hashed_state_count) - hashed_state_norm = get_state_hash_for_can_reduction(state, reduction_mode, OutputParts.Norm) + hashed_state_norm = get_state_hash_for_can_reduction(state, reduction_mode, OutputParts.NORM) reduced_can_norm = get_workspace_from_ads_based_on_hash(hashed_state_norm) return reduced_can, reduced_can_count, reduced_can_norm @@ -993,9 +993,9 @@ def get_transmission_workspaces_from_ads(state, reduction_mode): :param reduction_mode: the reduction mode which at this point is either HAB or LAB :return: a reduced transmission can object or None. """ - hashed_state = get_state_hash_for_can_reduction(state, reduction_mode, TransmissionType.Calculated) + hashed_state = get_state_hash_for_can_reduction(state, reduction_mode, TransmissionType.CALCULATED) calculated_transmission = get_workspace_from_ads_based_on_hash(hashed_state) - hashed_state = get_state_hash_for_can_reduction(state, reduction_mode, TransmissionType.Unfitted) + hashed_state = get_state_hash_for_can_reduction(state, reduction_mode, TransmissionType.UNFITTED) unfitted_transmission = get_workspace_from_ads_based_on_hash(hashed_state) return calculated_transmission, unfitted_transmission @@ -1043,6 +1043,7 @@ def get_bank_for_spectrum_number(spectrum_number, instrument): :returns: either LAB or HAB """ detector = DetectorType.LAB + if instrument is SANSInstrument.LOQ: if 16387 <= spectrum_number <= 17784: detector = DetectorType.HAB diff --git a/scripts/SANS/sans/gui_logic/gui_common.py b/scripts/SANS/sans/gui_logic/gui_common.py index c8f67c27e137f014f0bf26259434efbf7572137d..5524bc3ae2b1e0f51a12db82a45b735b8453b34a 100644 --- a/scripts/SANS/sans/gui_logic/gui_common.py +++ b/scripts/SANS/sans/gui_logic/gui_common.py @@ -10,7 +10,7 @@ from qtpy.QtWidgets import QFileDialog from sans.common.constant_containers import (SANSInstrument_enum_as_key, SANSInstrument_string_as_key_NoInstrument, SANSInstrument_string_list) -from sans.common.enums import SANSInstrument, ISISReductionMode, DetectorType +from sans.common.enums import ReductionMode, DetectorType, SANSInstrument # ---------------------------------------------------------------------------------------------------------------------- @@ -48,15 +48,16 @@ LAB_STRINGS = {SANSInstrument.SANS2D: "rear", SANSInstrument.LOQ: "main-detector", SANSInstrument.LARMOR: "DetectorBench", SANSInstrument.ZOOM: "rear-detector", - SANSInstrument.NoInstrument: ISISReductionMode.to_string(ISISReductionMode.LAB) + SANSInstrument.NO_INSTRUMENT: ReductionMode.LAB.value } HAB_STRINGS = {SANSInstrument.SANS2D: "front", SANSInstrument.LOQ: "Hab", - SANSInstrument.NoInstrument: ISISReductionMode.to_string(ISISReductionMode.HAB)} + SANSInstrument.NO_INSTRUMENT: ReductionMode.HAB.value} -MERGED = "Merged" -ALL = "All" +ALL = ReductionMode.ALL.value +DEFAULT_HAB = ReductionMode.HAB.value +MERGED = ReductionMode.MERGED.value GENERIC_SETTINGS = "Mantid/ISISSANS" @@ -71,8 +72,8 @@ def get_detector_strings_for_gui(instrument=None): return [LAB_STRINGS[instrument]] else: - return [LAB_STRINGS[SANSInstrument.NoInstrument], - HAB_STRINGS[SANSInstrument.NoInstrument]] + return [LAB_STRINGS[SANSInstrument.NO_INSTRUMENT], + HAB_STRINGS[SANSInstrument.NO_INSTRUMENT]] def get_detector_strings_for_diagnostic_page(instrument=None): @@ -83,8 +84,8 @@ def get_detector_strings_for_diagnostic_page(instrument=None): return [LAB_STRINGS[instrument]] else: - return [LAB_STRINGS[SANSInstrument.NoInstrument], - HAB_STRINGS[SANSInstrument.NoInstrument]] + return [LAB_STRINGS[SANSInstrument.NO_INSTRUMENT], + HAB_STRINGS[SANSInstrument.NO_INSTRUMENT]] def get_reduction_mode_strings_for_gui(instrument=None): @@ -98,8 +99,8 @@ def get_reduction_mode_strings_for_gui(instrument=None): return [LAB_STRINGS[instrument]] else: - return [LAB_STRINGS[SANSInstrument.NoInstrument], - HAB_STRINGS[SANSInstrument.NoInstrument], + return [LAB_STRINGS[SANSInstrument.NO_INSTRUMENT], + HAB_STRINGS[SANSInstrument.NO_INSTRUMENT], MERGED, ALL] @@ -108,19 +109,19 @@ def get_instrument_strings_for_gui(): def get_reduction_selection(instrument): - selection = {ISISReductionMode.Merged: MERGED, - ISISReductionMode.All: ALL} + selection = {ReductionMode.MERGED: MERGED, + ReductionMode.ALL: ALL} if any (instrument is x for x in [SANSInstrument.SANS2D, SANSInstrument.LOQ]): - selection.update({ISISReductionMode.LAB: LAB_STRINGS[instrument], - ISISReductionMode.HAB: HAB_STRINGS[instrument]}) + selection.update({ReductionMode.LAB: LAB_STRINGS[instrument], + ReductionMode.HAB: HAB_STRINGS[instrument]}) elif any(instrument is x for x in [SANSInstrument.LARMOR, SANSInstrument.ZOOM]): - selection = {ISISReductionMode.LAB: LAB_STRINGS[instrument]} + selection = {ReductionMode.LAB: LAB_STRINGS[instrument]} else: - selection.update({ISISReductionMode.LAB: LAB_STRINGS[SANSInstrument.NoInstrument], - ISISReductionMode.HAB: HAB_STRINGS[SANSInstrument.NoInstrument]}) + selection.update({ReductionMode.LAB: LAB_STRINGS[SANSInstrument.NO_INSTRUMENT], + ReductionMode.HAB: HAB_STRINGS[SANSInstrument.NO_INSTRUMENT]}) return selection @@ -143,16 +144,16 @@ def get_reduction_mode_from_gui_selection(gui_selection): # TODO when we hit only Python 3 this should use casefold rather than lower case_folded_selection = gui_selection.lower() if case_folded_selection == MERGED.lower(): - return ISISReductionMode.Merged + return ReductionMode.MERGED elif case_folded_selection == ALL.lower(): - return ISISReductionMode.All + return ReductionMode.ALL elif any(case_folded_selection == lab.lower() for lab in LAB_STRINGS.values()): - return ISISReductionMode.LAB + return ReductionMode.LAB elif any(case_folded_selection == hab.lower() for hab in HAB_STRINGS.values()): - return ISISReductionMode.HAB + return ReductionMode.HAB else: raise RuntimeError("Reduction mode selection {0} is not valid.".format(gui_selection)) diff --git a/scripts/SANS/sans/gui_logic/models/batch_process_runner.py b/scripts/SANS/sans/gui_logic/models/batch_process_runner.py index 574d360b88ce90573913e2078dd9d64c62d13e10..095e8519785bc745bbeaa9eac225ce1232976606 100644 --- a/scripts/SANS/sans/gui_logic/models/batch_process_runner.py +++ b/scripts/SANS/sans/gui_logic/models/batch_process_runner.py @@ -8,7 +8,7 @@ from qtpy.QtCore import Slot, QThreadPool, Signal, QObject from sans.sans_batch import SANSBatchReduction from sans.algorithm_detail.batch_execution import load_workspaces_from_states from ui.sans_isis.worker import Worker -from sans.common.enums import ISISReductionMode +from sans.common.enums import ReductionMode class BatchProcessRunner(QObject): @@ -67,7 +67,7 @@ class BatchProcessRunner(QObject): try: out_scale_factors, out_shift_factors = \ self.batch_processor([state], use_optimizations, output_mode, plot_results, output_graph, save_can) - if state.reduction.reduction_mode == ISISReductionMode.Merged: + if state.reduction.reduction_mode == ReductionMode.MERGED: out_shift_factors = out_shift_factors[0] out_scale_factors = out_scale_factors[0] else: diff --git a/scripts/SANS/sans/gui_logic/models/beam_centre_model.py b/scripts/SANS/sans/gui_logic/models/beam_centre_model.py index 1cc6707eada529473f43add71846008c2844eb74..8916014f168f9204f1a14d4dd3a730de40344fd1 100644 --- a/scripts/SANS/sans/gui_logic/models/beam_centre_model.py +++ b/scripts/SANS/sans/gui_logic/models/beam_centre_model.py @@ -4,7 +4,7 @@ # NScD Oak Ridge National Laboratory, European Spallation Source # & Institut Laue - Langevin # SPDX - License - Identifier: GPL - 3.0 + -from sans.common.enums import (SANSInstrument, FindDirectionEnum, DetectorType) +from sans.common.enums import (FindDirectionEnum, DetectorType, SANSInstrument) from mantid.kernel import (Logger) from sans.common.file_information import get_instrument_paths_for_sans_file from sans.common.xml_parsing import get_named_elements_from_ipf_file @@ -55,6 +55,7 @@ class BeamCentreModel(object): def set_scaling(self, instrument): self.scale_1 = 1000 self.scale_2 = 1000 + if instrument == SANSInstrument.LARMOR: self.scale_1 = 1.0 @@ -69,11 +70,11 @@ class BeamCentreModel(object): centre_finder = self.SANSCentreFinder() find_direction = None if self.up_down and self.left_right: - find_direction = FindDirectionEnum.All + find_direction = FindDirectionEnum.ALL elif self.up_down: - find_direction = FindDirectionEnum.Up_Down + find_direction = FindDirectionEnum.UP_DOWN elif self.left_right: - find_direction = FindDirectionEnum.Left_Right + find_direction = FindDirectionEnum.LEFT_RIGHT else: logger = Logger("CentreFinder") logger.notice("Have chosen no find direction exiting early") diff --git a/scripts/SANS/sans/gui_logic/models/diagnostics_page_model.py b/scripts/SANS/sans/gui_logic/models/diagnostics_page_model.py index e6d7c73d81d8aaec06539037a6eab1a87161ddc6..5971b0704ab885426f206f9fd97a74bf09e3faa9 100644 --- a/scripts/SANS/sans/gui_logic/models/diagnostics_page_model.py +++ b/scripts/SANS/sans/gui_logic/models/diagnostics_page_model.py @@ -31,10 +31,10 @@ def run_integral(integral_ranges, mask, integral, detector, state): input_workspace_name = input_workspace.name() if is_multi_range: AnalysisDataService.remove(input_workspace_name + '_ranges') - input_workspace = crop_workspace(DetectorType.to_string(detector), input_workspace) + input_workspace = crop_workspace(detector.value, input_workspace) if mask: - input_workspace = apply_mask(state, input_workspace, DetectorType.to_string(detector)) + input_workspace = apply_mask(state, input_workspace, detector.value) x_dim, y_dim = get_detector_size_from_sans_file(state, detector) @@ -69,24 +69,24 @@ def parse_range(range): def load_workspace(state): - workspace_to_name = {SANSDataType.SampleScatter: "SampleScatterWorkspace", - SANSDataType.SampleTransmission: "SampleTransmissionWorkspace", - SANSDataType.SampleDirect: "SampleDirectWorkspace", - SANSDataType.CanScatter: "CanScatterWorkspace", - SANSDataType.CanTransmission: "CanTransmissionWorkspace", - SANSDataType.CanDirect: "CanDirectWorkspace"} + workspace_to_name = {SANSDataType.SAMPLE_SCATTER: "SampleScatterWorkspace", + SANSDataType.SAMPLE_TRANSMISSION: "SampleTransmissionWorkspace", + SANSDataType.SAMPLE_DIRECT: "SampleDirectWorkspace", + SANSDataType.CAN_SCATTER: "CanScatterWorkspace", + SANSDataType.CAN_TRANSMISSION: "CanTransmissionWorkspace", + SANSDataType.CAN_DIRECT: "CanDirectWorkspace"} - workspace_to_monitor = {SANSDataType.SampleScatter: "SampleScatterMonitorWorkspace", - SANSDataType.CanScatter: "CanScatterMonitorWorkspace"} + workspace_to_monitor = {SANSDataType.SAMPLE_SCATTER: "SampleScatterMonitorWorkspace", + SANSDataType.CAN_SCATTER: "CanScatterMonitorWorkspace"} workspaces, monitors = provide_loaded_data(state, False, workspace_to_name, workspace_to_monitor) - return workspaces[SANSDataType.SampleScatter] + return workspaces[SANSDataType.SAMPLE_SCATTER] def crop_workspace(component, workspace): crop_name = "CropToComponent" - component_to_crop = DetectorType.from_string(component) + component_to_crop = DetectorType(component) component_to_crop = get_component_name(workspace, component_to_crop) crop_options = {"InputWorkspace": workspace, "OutputWorkspace": EMPTY_NAME, @@ -121,8 +121,8 @@ def run_algorithm(input_workspace, range, integral, output_workspace, x_dim, y_d def generate_output_workspace_name(range, integral, mask, detector, input_workspace_name): - integral_string = IntegralEnum.to_string(integral) - detector_string = DetectorType.to_string(detector) + integral_string = integral.value + detector_string = detector.value return 'Run:{}, Range:{}, Direction:{}, Detector:{}, Mask:{}'.format(input_workspace_name, range, integral_string, diff --git a/scripts/SANS/sans/gui_logic/models/model_common.py b/scripts/SANS/sans/gui_logic/models/model_common.py index 1dd9a07b4636e6b9fdc3e5c35bda58e49c34784f..c800bfeeb3968945e8e76592fa1e683baf8dde0f 100644 --- a/scripts/SANS/sans/gui_logic/models/model_common.py +++ b/scripts/SANS/sans/gui_logic/models/model_common.py @@ -26,15 +26,15 @@ class ModelCommon(with_metaclass(ABCMeta)): @property def instrument(self): - return self.get_simple_element(element_id=DetectorId.instrument, default_value=SANSInstrument.NoInstrument) + return self.get_simple_element(element_id=DetectorId.INSTRUMENT, default_value=SANSInstrument.NO_INSTRUMENT) @instrument.setter def instrument(self, value): - self.set_simple_element(element_id=DetectorId.instrument, value=value) + self.set_simple_element(element_id=DetectorId.INSTRUMENT, value=value) def _get_incident_spectrum_info(self, default_value, attribute, is_trans): - if MonId.spectrum in self._user_file_items: - settings = self._user_file_items[MonId.spectrum] + if MonId.SPECTRUM in self._user_file_items: + settings = self._user_file_items[MonId.SPECTRUM] if is_trans: settings = [setting for setting in settings if setting.is_trans] else: @@ -45,8 +45,8 @@ class ModelCommon(with_metaclass(ABCMeta)): return default_value def _update_incident_spectrum_info(self, spectrum=None, interpolate=False, is_trans=False): - if MonId.spectrum in self._user_file_items: - settings = self._user_file_items[MonId.spectrum] + if MonId.SPECTRUM in self._user_file_items: + settings = self._user_file_items[MonId.SPECTRUM] else: # If the entry does not already exist, then add it. settings = [monitor_spectrum(spectrum=spectrum, is_trans=is_trans, interpolate=interpolate)] @@ -63,7 +63,7 @@ class ModelCommon(with_metaclass(ABCMeta)): new_settings.append(new_setting) else: new_settings.append(setting) - self._user_file_items.update({MonId.spectrum: new_settings}) + self._user_file_items.update({MonId.SPECTRUM: new_settings}) def get_simple_element(self, element_id, default_value): return self.get_simple_element_with_attribute(element_id, default_value) diff --git a/scripts/SANS/sans/gui_logic/models/settings_adjustment_model.py b/scripts/SANS/sans/gui_logic/models/settings_adjustment_model.py index baa289c76a9654bb6eb6256e5e04003ec156922c..ad5789aa973be63e1b936ff5ae312128f80eda98 100644 --- a/scripts/SANS/sans/gui_logic/models/settings_adjustment_model.py +++ b/scripts/SANS/sans/gui_logic/models/settings_adjustment_model.py @@ -37,11 +37,11 @@ class SettingsAdjustmentModel(ModelCommon): return True if self.instrument is SANSInstrument.ZOOM else False def has_transmission_fit_got_separate_settings_for_sample_and_can(self): - if FitId.general in self._user_file_items: - settings = self._user_file_items[FitId.general] + if FitId.GENERAL in self._user_file_items: + settings = self._user_file_items[FitId.GENERAL] if settings: - settings_sample = [setting for setting in settings if setting.data_type is DataType.Sample] - settings_can = [setting for setting in settings if setting.data_type is DataType.Can] + settings_sample = [setting for setting in settings if setting.data_type is DataType.SAMPLE] + settings_can = [setting for setting in settings if setting.data_type is DataType.CAN] # If we have either one or the other if settings_sample or settings_can: return True @@ -50,8 +50,8 @@ class SettingsAdjustmentModel(ModelCommon): # =================== Property helper methods ================ def _get_transmission_fit(self, data_type, attribute, default_value): - if FitId.general in self._user_file_items: - settings = self._user_file_items[FitId.general] + if FitId.GENERAL in self._user_file_items: + settings = self._user_file_items[FitId.GENERAL] # Check first if there are data type specific settings, else check if there are general settings extracted_settings = [setting for setting in settings if setting.data_type is data_type] if not extracted_settings: @@ -62,9 +62,9 @@ class SettingsAdjustmentModel(ModelCommon): return default_value def _set_transmission_fit(self, data_type, start=None, stop=None, fit_type=None, polynomial_order=None): - if FitId.general in self._user_file_items: + if FitId.GENERAL in self._user_file_items: # Gather all settings which correspond to the data type and where the data type is none - settings = self._user_file_items[FitId.general] + settings = self._user_file_items[FitId.GENERAL] settings_general = [setting for setting in settings if setting.data_type is None] settings_for_data_type = [setting for setting in settings if setting.data_type is data_type] # We check if there are data-type specific settings. @@ -79,10 +79,10 @@ class SettingsAdjustmentModel(ModelCommon): polynomial_order=setting_general.polynomial_order)) else: settings.append(fit_general(start=None, stop=None, data_type=data_type, - fit_type=FitType.NoFit, polynomial_order=2)) + fit_type=FitType.NO_FIT, polynomial_order=2)) else: settings = [fit_general(start=None, stop=None, data_type=data_type, - fit_type=FitType.NoFit, polynomial_order=2)] + fit_type=FitType.NO_FIT, polynomial_order=2)] new_settings = [] for setting in settings: @@ -97,7 +97,7 @@ class SettingsAdjustmentModel(ModelCommon): data_type=setting.data_type, polynomial_order=new_polynomial_order)) else: new_settings.append(setting) - self._user_file_items.update({FitId.general: new_settings}) + self._user_file_items.update({FitId.GENERAL: new_settings}) # ------------------------------------------------------------------------------------------------------------------ # Wavelength- and pixel-adjustment files @@ -141,39 +141,39 @@ class SettingsAdjustmentModel(ModelCommon): @property def pixel_adjustment_det_1(self): - return self._get_adjustment_file_setting(element_id=MonId.flat, detector_type=DetectorType.LAB, + return self._get_adjustment_file_setting(element_id=MonId.FLAT, detector_type=DetectorType.LAB, default_value="") @pixel_adjustment_det_1.setter def pixel_adjustment_det_1(self, value): - self._set_adjustment_file_setting(element_id=MonId.flat, detector_type=DetectorType.LAB, file_path=value) + self._set_adjustment_file_setting(element_id=MonId.FLAT, detector_type=DetectorType.LAB, file_path=value) @property def pixel_adjustment_det_2(self): - return self._get_adjustment_file_setting(element_id=MonId.flat, detector_type=DetectorType.HAB, + return self._get_adjustment_file_setting(element_id=MonId.FLAT, detector_type=DetectorType.HAB, default_value="") @pixel_adjustment_det_2.setter def pixel_adjustment_det_2(self, value): - self._set_adjustment_file_setting(element_id=MonId.flat, detector_type=DetectorType.HAB, file_path=value) + self._set_adjustment_file_setting(element_id=MonId.FLAT, detector_type=DetectorType.HAB, file_path=value) @property def wavelength_adjustment_det_1(self): - return self._get_adjustment_file_setting(element_id=MonId.direct, detector_type=DetectorType.LAB, + return self._get_adjustment_file_setting(element_id=MonId.DIRECT, detector_type=DetectorType.LAB, default_value="") @wavelength_adjustment_det_1.setter def wavelength_adjustment_det_1(self, value): - self._set_adjustment_file_setting(element_id=MonId.direct, detector_type=DetectorType.LAB, file_path=value) + self._set_adjustment_file_setting(element_id=MonId.DIRECT, detector_type=DetectorType.LAB, file_path=value) @property def wavelength_adjustment_det_2(self): - return self._get_adjustment_file_setting(element_id=MonId.direct, detector_type=DetectorType.HAB, + return self._get_adjustment_file_setting(element_id=MonId.DIRECT, detector_type=DetectorType.HAB, default_value="") @wavelength_adjustment_det_2.setter def wavelength_adjustment_det_2(self, value): - self._set_adjustment_file_setting(element_id=MonId.direct, detector_type=DetectorType.HAB, file_path=value) + self._set_adjustment_file_setting(element_id=MonId.DIRECT, detector_type=DetectorType.HAB, file_path=value) # ------------------------------------------------------------------------------------------------------------------ # Transmission Fitting @@ -181,69 +181,69 @@ class SettingsAdjustmentModel(ModelCommon): @property def transmission_sample_fit_type(self): - return self._get_transmission_fit(data_type=DataType.Sample, attribute="fit_type", default_value=FitType.NoFit) + return self._get_transmission_fit(data_type=DataType.SAMPLE, attribute="fit_type", default_value=FitType.NO_FIT) @transmission_sample_fit_type.setter def transmission_sample_fit_type(self, value): - self._set_transmission_fit(data_type=DataType.Sample, fit_type=value) + self._set_transmission_fit(data_type=DataType.SAMPLE, fit_type=value) @property def transmission_can_fit_type(self): - return self._get_transmission_fit(data_type=DataType.Can, attribute="fit_type", default_value=FitType.NoFit) + return self._get_transmission_fit(data_type=DataType.CAN, attribute="fit_type", default_value=FitType.NO_FIT) @transmission_can_fit_type.setter def transmission_can_fit_type(self, value): - self._set_transmission_fit(data_type=DataType.Can, fit_type=value) + self._set_transmission_fit(data_type=DataType.CAN, fit_type=value) @property def transmission_sample_polynomial_order(self): - return self._get_transmission_fit(data_type=DataType.Sample, attribute="polynomial_order", + return self._get_transmission_fit(data_type=DataType.SAMPLE, attribute="polynomial_order", default_value=2) @transmission_sample_polynomial_order.setter def transmission_sample_polynomial_order(self, value): - self._set_transmission_fit(data_type=DataType.Sample, polynomial_order=value) + self._set_transmission_fit(data_type=DataType.SAMPLE, polynomial_order=value) @property def transmission_can_polynomial_order(self): - return self._get_transmission_fit(data_type=DataType.Can, attribute="polynomial_order", + return self._get_transmission_fit(data_type=DataType.CAN, attribute="polynomial_order", default_value=2) @transmission_can_polynomial_order.setter def transmission_can_polynomial_order(self, value): - self._set_transmission_fit(data_type=DataType.Can, polynomial_order=value) + self._set_transmission_fit(data_type=DataType.CAN, polynomial_order=value) @property def transmission_sample_wavelength_min(self): - return self._get_transmission_fit(data_type=DataType.Sample, attribute="start", default_value="") + return self._get_transmission_fit(data_type=DataType.SAMPLE, attribute="start", default_value="") @transmission_sample_wavelength_min.setter def transmission_sample_wavelength_min(self, value): - self._set_transmission_fit(data_type=DataType.Sample, start=value) + self._set_transmission_fit(data_type=DataType.SAMPLE, start=value) @property def transmission_sample_wavelength_max(self): - return self._get_transmission_fit(data_type=DataType.Sample, attribute="stop", default_value="") + return self._get_transmission_fit(data_type=DataType.SAMPLE, attribute="stop", default_value="") @transmission_sample_wavelength_max.setter def transmission_sample_wavelength_max(self, value): - self._set_transmission_fit(data_type=DataType.Sample, stop=value) + self._set_transmission_fit(data_type=DataType.SAMPLE, stop=value) @property def transmission_can_wavelength_min(self): - return self._get_transmission_fit(data_type=DataType.Can, attribute="start", default_value="") + return self._get_transmission_fit(data_type=DataType.CAN, attribute="start", default_value="") @transmission_can_wavelength_min.setter def transmission_can_wavelength_min(self, value): - self._set_transmission_fit(data_type=DataType.Can, start=value) + self._set_transmission_fit(data_type=DataType.CAN, start=value) @property def transmission_can_wavelength_max(self): - return self._get_transmission_fit(data_type=DataType.Can, attribute="stop", default_value="") + return self._get_transmission_fit(data_type=DataType.CAN, attribute="stop", default_value="") @transmission_can_wavelength_max.setter def transmission_can_wavelength_max(self, value): - self._set_transmission_fit(data_type=DataType.Can, stop=value) + self._set_transmission_fit(data_type=DataType.CAN, stop=value) # ------------------------------------------------------------------------------------------------------------------ # Monitor normalization @@ -287,52 +287,52 @@ class SettingsAdjustmentModel(ModelCommon): @property def transmission_roi_files(self): - return self.get_simple_element(element_id=TransId.roi, default_value="") + return self.get_simple_element(element_id=TransId.ROI, default_value="") @transmission_roi_files.setter def transmission_roi_files(self, value): - self.set_simple_element(element_id=TransId.roi, value=value) + self.set_simple_element(element_id=TransId.ROI, value=value) @property def transmission_mask_files(self): - return self.get_simple_element(element_id=TransId.mask, default_value="") + return self.get_simple_element(element_id=TransId.MASK, default_value="") @transmission_mask_files.setter def transmission_mask_files(self, value): - self.set_simple_element(element_id=TransId.mask, value=value) + self.set_simple_element(element_id=TransId.MASK, value=value) @property def transmission_radius(self): - return self.get_simple_element(element_id=TransId.radius, default_value="") + return self.get_simple_element(element_id=TransId.RADIUS, default_value="") @transmission_radius.setter def transmission_radius(self, value): - self.set_simple_element(element_id=TransId.radius, value=value) + self.set_simple_element(element_id=TransId.RADIUS, value=value) @property def transmission_monitor(self): - return self.get_simple_element(element_id=TransId.spec, default_value=3) + return self.get_simple_element(element_id=TransId.SPEC, default_value=3) @transmission_monitor.setter def transmission_monitor(self, value): - self.set_simple_element(element_id=TransId.spec, value=value) + self.set_simple_element(element_id=TransId.SPEC, value=value) @property def transmission_mn_4_shift(self): # Note that this is actually part of the move operation, but is conceptually part of transmission - return self.get_simple_element(element_id=TransId.spec_4_shift, default_value="") + return self.get_simple_element(element_id=TransId.SPEC_4_SHIFT, default_value="") @transmission_mn_4_shift.setter def transmission_mn_4_shift(self, value): # Note that this is actually part of the move operation, but is conceptually part of transmission - self.set_simple_element(element_id=TransId.spec_4_shift, value=value) + self.set_simple_element(element_id=TransId.SPEC_4_SHIFT, value=value) @property def transmission_mn_5_shift(self): # Note that this is actually part of the move operation, but is conceptually part of transmission - return self.get_simple_element(element_id=TransId.spec_5_shift, default_value="") + return self.get_simple_element(element_id=TransId.SPEC_5_SHIFT, default_value="") @transmission_mn_5_shift.setter def transmission_mn_5_shift(self, value): # Note that this is actually part of the move operation, but is conceptually part of transmission - self.set_simple_element(element_id=TransId.spec_5_shift, value=value) + self.set_simple_element(element_id=TransId.SPEC_5_SHIFT, value=value) diff --git a/scripts/SANS/sans/gui_logic/models/state_gui_model.py b/scripts/SANS/sans/gui_logic/models/state_gui_model.py index 4066f661abe6a6410564e0be87cb82e21e532c8e..4e01b2c05fe273f0d8b19a5059d45642e7c97a50 100644 --- a/scripts/SANS/sans/gui_logic/models/state_gui_model.py +++ b/scripts/SANS/sans/gui_logic/models/state_gui_model.py @@ -13,7 +13,7 @@ are not available in the model associated with the data table. from __future__ import (absolute_import, division, print_function) from mantid.py3compat import ensure_str -from sans.common.enums import (ReductionDimensionality, ISISReductionMode, RangeStepType, SaveType, +from sans.common.enums import (ReductionDimensionality, ReductionMode, RangeStepType, SaveType, DetectorType) from sans.gui_logic.models.model_common import ModelCommon from sans.user_file.settings_tags import (OtherId, DetectorId, LimitsId, SetId, SampleId, GravityId, @@ -45,27 +45,27 @@ class StateGuiModel(ModelCommon): # ------------------------------------------------------------------------------------------------------------------ @property def compatibility_mode(self): - return self.get_simple_element(element_id=OtherId.use_compatibility_mode, default_value=True) + return self.get_simple_element(element_id=OtherId.USE_COMPATIBILITY_MODE, default_value=True) @compatibility_mode.setter def compatibility_mode(self, value): - self.set_simple_element(element_id=OtherId.use_compatibility_mode, value=value) + self.set_simple_element(element_id=OtherId.USE_COMPATIBILITY_MODE, value=value) @property def event_slice_optimisation(self): - return self.get_simple_element(element_id=OtherId.use_event_slice_optimisation, default_value=False) + return self.get_simple_element(element_id=OtherId.USE_EVENT_SLICE_OPTIMISATION, default_value=False) @event_slice_optimisation.setter def event_slice_optimisation(self, value): - self.set_simple_element(element_id=OtherId.use_event_slice_optimisation, value=value) + self.set_simple_element(element_id=OtherId.USE_EVENT_SLICE_OPTIMISATION, value=value) # ------------------------------------------------------------------------------------------------------------------ # Save Options # ------------------------------------------------------------------------------------------------------------------ @property def zero_error_free(self): - if OtherId.save_as_zero_error_free in self._user_file_items: - return self._user_file_items[OtherId.save_as_zero_error_free][-1] + if OtherId.SAVE_AS_ZERO_ERROR_FREE in self._user_file_items: + return self._user_file_items[OtherId.SAVE_AS_ZERO_ERROR_FREE][-1] else: # Turn on zero error free saving by default return True @@ -74,18 +74,18 @@ class StateGuiModel(ModelCommon): def zero_error_free(self, value): if value is None: return - if OtherId.save_as_zero_error_free in self._user_file_items: - del self._user_file_items[OtherId.save_as_zero_error_free] - new_state_entries = {OtherId.save_as_zero_error_free: [value]} + if OtherId.SAVE_AS_ZERO_ERROR_FREE in self._user_file_items: + del self._user_file_items[OtherId.SAVE_AS_ZERO_ERROR_FREE] + new_state_entries = {OtherId.SAVE_AS_ZERO_ERROR_FREE: [value]} self._user_file_items.update(new_state_entries) @property def save_types(self): - return self.get_simple_element(element_id=OtherId.save_types, default_value=[SaveType.NXcanSAS]) + return self.get_simple_element(element_id=OtherId.SAVE_TYPES, default_value=[SaveType.NX_CAN_SAS]) @save_types.setter def save_types(self, value): - self.set_simple_element(element_id=OtherId.save_types, value=value) + self.set_simple_element(element_id=OtherId.SAVE_TYPES, value=value) # ================================================================================================================== # ================================================================================================================== @@ -94,7 +94,7 @@ class StateGuiModel(ModelCommon): # ================================================================================================================== @property def lab_pos_1(self): - return self.get_simple_element_with_attribute(element_id=SetId.centre, default_value='', attribute="pos1") + return self.get_simple_element_with_attribute(element_id=SetId.CENTRE, default_value='', attribute="pos1") @lab_pos_1.setter def lab_pos_1(self, value): @@ -102,7 +102,7 @@ class StateGuiModel(ModelCommon): @property def lab_pos_2(self): - return self.get_simple_element_with_attribute(element_id=SetId.centre, default_value='', attribute="pos2") + return self.get_simple_element_with_attribute(element_id=SetId.CENTRE, default_value='', attribute="pos2") @lab_pos_2.setter def lab_pos_2(self, value): @@ -110,7 +110,7 @@ class StateGuiModel(ModelCommon): @property def hab_pos_1(self): - return self.get_simple_element_with_attribute(element_id=SetId.centre_HAB, default_value='', attribute="pos1") + return self.get_simple_element_with_attribute(element_id=SetId.CENTRE_HAB, default_value='', attribute="pos1") @hab_pos_1.setter def hab_pos_1(self, value): @@ -118,15 +118,15 @@ class StateGuiModel(ModelCommon): @property def hab_pos_2(self): - return self.get_simple_element_with_attribute(element_id=SetId.centre_HAB, default_value='', attribute="pos2") + return self.get_simple_element_with_attribute(element_id=SetId.CENTRE_HAB, default_value='', attribute="pos2") @hab_pos_2.setter def hab_pos_2(self, value): self._update_centre(pos_2=value) def _update_centre(self, pos_1=None, pos_2=None, detector_type=None): - if SetId.centre in self._user_file_items: - settings = self._user_file_items[SetId.centre] + if SetId.CENTRE in self._user_file_items: + settings = self._user_file_items[SetId.CENTRE] else: # If the entry does not already exist, then add it. The -1. is an illegal input which should get overridden # and if not we want it to fail. @@ -139,7 +139,7 @@ class StateGuiModel(ModelCommon): new_detector_type = detector_type if detector_type else setting.detector_type new_setting = position_entry(pos1=new_pos1, pos2=new_pos2, detector_type=new_detector_type) new_settings.append(new_setting) - self._user_file_items.update({SetId.centre: new_settings}) + self._user_file_items.update({SetId.CENTRE: new_settings}) # ================================================================================================================== # ================================================================================================================== # General TAB @@ -151,7 +151,7 @@ class StateGuiModel(ModelCommon): # ------------------------------------------------------------------------------------------------------------------ @property def event_slices(self): - return self.get_simple_element_with_attribute(element_id=OtherId.event_slices, + return self.get_simple_element_with_attribute(element_id=OtherId.EVENT_SLICES, default_value="", attribute="value") @@ -159,9 +159,9 @@ class StateGuiModel(ModelCommon): def event_slices(self, value): if not value: return - if OtherId.event_slices in self._user_file_items: - del self._user_file_items[OtherId.event_slices] - new_state_entries = {OtherId.event_slices: [event_binning_string_values(value=value)]} + if OtherId.EVENT_SLICES in self._user_file_items: + del self._user_file_items[OtherId.EVENT_SLICES] + new_state_entries = {OtherId.EVENT_SLICES: [event_binning_string_values(value=value)]} self._user_file_items.update(new_state_entries) # ------------------------------------------------------------------------------------------------------------------ @@ -169,15 +169,15 @@ class StateGuiModel(ModelCommon): # ------------------------------------------------------------------------------------------------------------------ @property def reduction_dimensionality(self): - return self.get_simple_element_with_attribute(element_id=OtherId.reduction_dimensionality, - default_value=ReductionDimensionality.OneDim) + return self.get_simple_element_with_attribute(element_id=OtherId.REDUCTION_DIMENSIONALITY, + default_value=ReductionDimensionality.ONE_DIM) @reduction_dimensionality.setter def reduction_dimensionality(self, value): - if value is ReductionDimensionality.OneDim or value is ReductionDimensionality.TwoDim: - if OtherId.reduction_dimensionality in self._user_file_items: - del self._user_file_items[OtherId.reduction_dimensionality] - new_state_entries = {OtherId.reduction_dimensionality: [value]} + if value is ReductionDimensionality.ONE_DIM or value is ReductionDimensionality.TWO_DIM: + if OtherId.REDUCTION_DIMENSIONALITY in self._user_file_items: + del self._user_file_items[OtherId.REDUCTION_DIMENSIONALITY] + new_state_entries = {OtherId.REDUCTION_DIMENSIONALITY: [value]} self._user_file_items.update(new_state_entries) else: raise ValueError("A reduction dimensionality was expected, got instead {}".format(value)) @@ -211,10 +211,10 @@ class StateGuiModel(ModelCommon): q_stop = [] settings = [] - if DetectorId.rescale_fit in self._user_file_items: - settings.extend(self._user_file_items[DetectorId.rescale_fit]) - if DetectorId.shift_fit in self._user_file_items: - settings.extend(self._user_file_items[DetectorId.shift_fit]) + if DetectorId.RESCALE_FIT in self._user_file_items: + settings.extend(self._user_file_items[DetectorId.RESCALE_FIT]) + if DetectorId.SHIFT_FIT in self._user_file_items: + settings.extend(self._user_file_items[DetectorId.SHIFT_FIT]) for setting in settings: if setting.start is not None: @@ -232,8 +232,8 @@ class StateGuiModel(ModelCommon): q_stop = [] settings = [] - if DetectorId.merge_range in self._user_file_items: - settings.extend(self._user_file_items[DetectorId.merge_range]) + if DetectorId.MERGE_RANGE in self._user_file_items: + settings.extend(self._user_file_items[DetectorId.MERGE_RANGE]) for setting in settings: if setting.start is not None: @@ -248,55 +248,55 @@ class StateGuiModel(ModelCommon): @property def reduction_mode(self): - return self.get_simple_element_with_attribute(element_id=DetectorId.reduction_mode, - default_value=ISISReductionMode.LAB) + return self.get_simple_element_with_attribute(element_id=DetectorId.REDUCTION_MODE, + default_value=ReductionMode.LAB) @reduction_mode.setter def reduction_mode(self, value): - if (value is ISISReductionMode.LAB or value is ISISReductionMode.HAB or - value is ISISReductionMode.Merged or value is ISISReductionMode.All): # noqa - if DetectorId.reduction_mode in self._user_file_items: - del self._user_file_items[DetectorId.reduction_mode] - new_state_entries = {DetectorId.reduction_mode: [value]} + if (value is ReductionMode.LAB or value is ReductionMode.HAB or + value is ReductionMode.MERGED or value is ReductionMode.ALL): # noqa + if DetectorId.REDUCTION_MODE in self._user_file_items: + del self._user_file_items[DetectorId.REDUCTION_MODE] + new_state_entries = {DetectorId.REDUCTION_MODE: [value]} self._user_file_items.update(new_state_entries) else: raise ValueError("A reduction mode was expected, got instead {}".format(value)) @property def merge_scale(self): - return self.get_simple_element(element_id=DetectorId.rescale, default_value="1.0") + return self.get_simple_element(element_id=DetectorId.RESCALE, default_value="1.0") @merge_scale.setter def merge_scale(self, value): - self.set_simple_element(element_id=DetectorId.rescale, value=value) + self.set_simple_element(element_id=DetectorId.RESCALE, value=value) @property def merge_shift(self): - return self.get_simple_element(element_id=DetectorId.shift, default_value="0.0") + return self.get_simple_element(element_id=DetectorId.SHIFT, default_value="0.0") @merge_shift.setter def merge_shift(self, value): - self.set_simple_element(element_id=DetectorId.shift, value=value) + self.set_simple_element(element_id=DetectorId.SHIFT, value=value) @property def merge_scale_fit(self): - return self.get_simple_element_with_attribute(element_id=DetectorId.rescale_fit, + return self.get_simple_element_with_attribute(element_id=DetectorId.RESCALE_FIT, default_value=False, attribute="use_fit") @merge_scale_fit.setter def merge_scale_fit(self, value): - self._update_merged_fit(element_id=DetectorId.rescale_fit, use_fit=value) + self._update_merged_fit(element_id=DetectorId.RESCALE_FIT, use_fit=value) @property def merge_shift_fit(self): - return self.get_simple_element_with_attribute(element_id=DetectorId.shift_fit, + return self.get_simple_element_with_attribute(element_id=DetectorId.SHIFT_FIT, default_value=False, attribute="use_fit") @merge_shift_fit.setter def merge_shift_fit(self, value): - self._update_merged_fit(element_id=DetectorId.shift_fit, use_fit=value) + self._update_merged_fit(element_id=DetectorId.SHIFT_FIT, use_fit=value) @property def merge_q_range_start(self): @@ -306,9 +306,9 @@ class StateGuiModel(ModelCommon): @merge_q_range_start.setter def merge_q_range_start(self, value): # Update for the shift - self._update_merged_fit(element_id=DetectorId.shift_fit, q_start=value) + self._update_merged_fit(element_id=DetectorId.SHIFT_FIT, q_start=value) # Update for the scale - self._update_merged_fit(element_id=DetectorId.rescale_fit, q_start=value) + self._update_merged_fit(element_id=DetectorId.RESCALE_FIT, q_start=value) @property def merge_q_range_stop(self): @@ -318,19 +318,19 @@ class StateGuiModel(ModelCommon): @merge_q_range_stop.setter def merge_q_range_stop(self, value): # Update for the shift - self._update_merged_fit(element_id=DetectorId.shift_fit, q_stop=value) + self._update_merged_fit(element_id=DetectorId.SHIFT_FIT, q_stop=value) # Update for the scale - self._update_merged_fit(element_id=DetectorId.rescale_fit, q_stop=value) + self._update_merged_fit(element_id=DetectorId.RESCALE_FIT, q_stop=value) @property def merge_mask(self): - return self.get_simple_element_with_attribute(element_id=DetectorId.merge_range, + return self.get_simple_element_with_attribute(element_id=DetectorId.MERGE_RANGE, default_value=False, attribute="use_fit") @merge_mask.setter def merge_mask(self, value): - self._update_merged_fit(element_id=DetectorId.merge_range, use_fit=value) + self._update_merged_fit(element_id=DetectorId.MERGE_RANGE, use_fit=value) @property def merge_max(self): @@ -339,7 +339,7 @@ class StateGuiModel(ModelCommon): @merge_max.setter def merge_max(self, value): - self._update_merged_fit(element_id=DetectorId.merge_range, q_stop=value) + self._update_merged_fit(element_id=DetectorId.MERGE_RANGE, q_stop=value) @property def merge_min(self): @@ -348,18 +348,18 @@ class StateGuiModel(ModelCommon): @merge_min.setter def merge_min(self, value): - self._update_merged_fit(element_id=DetectorId.merge_range, q_start=value) + self._update_merged_fit(element_id=DetectorId.MERGE_RANGE, q_start=value) # ------------------------------------------------------------------------------------------------------------------ # Event binning for compatibility mode # ------------------------------------------------------------------------------------------------------------------ @property def event_binning(self): - return self.get_simple_element(element_id=LimitsId.events_binning, default_value="") + return self.get_simple_element(element_id=LimitsId.EVENTS_BINNING, default_value="") @event_binning.setter def event_binning(self, value): - self.set_simple_element(element_id=LimitsId.events_binning, value=value) + self.set_simple_element(element_id=LimitsId.EVENTS_BINNING, value=value) # ------------------------------------------------------------------------------------------------------------------ # Wavelength properties @@ -371,12 +371,12 @@ class StateGuiModel(ModelCommon): # This is not something that needs to be known at this point, but it is good to know. # ------------------------------------------------------------------------------------------------------------------ def _update_wavelength(self, min_value=None, max_value=None, step=None, step_type=None, wavelength_range=None): - if LimitsId.wavelength in self._user_file_items: - settings = self._user_file_items[LimitsId.wavelength] + if LimitsId.WAVELENGTH in self._user_file_items: + settings = self._user_file_items[LimitsId.WAVELENGTH] else: # If the entry does not already exist, then add it. The -1. is an illegal input which should get overridden # and if not we want it to fail. - settings = [simple_range(start=-1., stop=-1., step=-1., step_type=RangeStepType.Lin)] + settings = [simple_range(start=-1., stop=-1., step=-1., step_type=RangeStepType.LIN)] new_settings = [] for setting in settings: @@ -386,11 +386,11 @@ class StateGuiModel(ModelCommon): new_step_type = step_type if step_type else setting.step_type new_setting = simple_range(start=new_min, stop=new_max, step=new_step, step_type=new_step_type) new_settings.append(new_setting) - self._user_file_items.update({LimitsId.wavelength: new_settings}) + self._user_file_items.update({LimitsId.WAVELENGTH: new_settings}) if wavelength_range: - if OtherId.wavelength_range in self._user_file_items: - settings = self._user_file_items[OtherId.wavelength_range] + if OtherId.WAVELENGTH_RANGE in self._user_file_items: + settings = self._user_file_items[OtherId.WAVELENGTH_RANGE] else: settings = [""] @@ -398,11 +398,11 @@ class StateGuiModel(ModelCommon): for setting in settings: new_range = wavelength_range if wavelength_range else setting new_settings.append(new_range) - self._user_file_items.update({OtherId.wavelength_range: new_settings}) + self._user_file_items.update({OtherId.WAVELENGTH_RANGE: new_settings}) @property def wavelength_step_type(self): - return self.get_simple_element_with_attribute(element_id=LimitsId.wavelength, default_value=RangeStepType.Lin, + return self.get_simple_element_with_attribute(element_id=LimitsId.WAVELENGTH, default_value=RangeStepType.LIN, attribute="step_type") @wavelength_step_type.setter @@ -411,7 +411,7 @@ class StateGuiModel(ModelCommon): @property def wavelength_min(self): - return self.get_simple_element_with_attribute(element_id=LimitsId.wavelength, + return self.get_simple_element_with_attribute(element_id=LimitsId.WAVELENGTH, default_value="", attribute="start") @@ -421,7 +421,7 @@ class StateGuiModel(ModelCommon): @property def wavelength_max(self): - return self.get_simple_element_with_attribute(element_id=LimitsId.wavelength, + return self.get_simple_element_with_attribute(element_id=LimitsId.WAVELENGTH, default_value="", attribute="stop") @@ -431,7 +431,7 @@ class StateGuiModel(ModelCommon): @property def wavelength_step(self): - return self.get_simple_element_with_attribute(element_id=LimitsId.wavelength, + return self.get_simple_element_with_attribute(element_id=LimitsId.WAVELENGTH, default_value="", attribute="step") @@ -441,7 +441,7 @@ class StateGuiModel(ModelCommon): @property def wavelength_range(self): - return self.get_simple_element(element_id=OtherId.wavelength_range, default_value="") + return self.get_simple_element(element_id=OtherId.WAVELENGTH_RANGE, default_value="") @wavelength_range.setter def wavelength_range(self, value): @@ -453,14 +453,14 @@ class StateGuiModel(ModelCommon): # ------------------------------------------------------------------------------------------------------------------ @property def absolute_scale(self): - return self.get_simple_element_with_attribute(element_id=SetId.scales, + return self.get_simple_element_with_attribute(element_id=SetId.SCALES, default_value="", attribute="s") @absolute_scale.setter def absolute_scale(self, value): - if SetId.scales in self._user_file_items: - settings = self._user_file_items[SetId.scales] + if SetId.SCALES in self._user_file_items: + settings = self._user_file_items[SetId.SCALES] else: settings = [set_scales_entry(s=100., a=0., b=0., c=0., d=0.)] @@ -468,50 +468,50 @@ class StateGuiModel(ModelCommon): for setting in settings: s_parameter = value if value else setting.s new_settings.append(set_scales_entry(s=s_parameter, a=0., b=0., c=0., d=0.)) - self._user_file_items.update({SetId.scales: new_settings}) + self._user_file_items.update({SetId.SCALES: new_settings}) @property def sample_height(self): - return self.get_simple_element(element_id=OtherId.sample_height, default_value="") + return self.get_simple_element(element_id=OtherId.SAMPLE_HEIGHT, default_value="") @sample_height.setter def sample_height(self, value): - self.set_simple_element(element_id=OtherId.sample_height, value=value) + self.set_simple_element(element_id=OtherId.SAMPLE_HEIGHT, value=value) @property def sample_width(self): - return self.get_simple_element(element_id=OtherId.sample_width, default_value="") + return self.get_simple_element(element_id=OtherId.SAMPLE_WIDTH, default_value="") @sample_width.setter def sample_width(self, value): - self.set_simple_element(element_id=OtherId.sample_width, value=value) + self.set_simple_element(element_id=OtherId.SAMPLE_WIDTH, value=value) @property def sample_thickness(self): - return self.get_simple_element(element_id=OtherId.sample_thickness, default_value="") + return self.get_simple_element(element_id=OtherId.SAMPLE_THICKNESS, default_value="") @sample_thickness.setter def sample_thickness(self, value): - self.set_simple_element(element_id=OtherId.sample_thickness, value=value) + self.set_simple_element(element_id=OtherId.SAMPLE_THICKNESS, value=value) @property def sample_shape(self): - return self.get_simple_element(element_id=OtherId.sample_shape, default_value=None) + return self.get_simple_element(element_id=OtherId.SAMPLE_SHAPE, default_value=None) @sample_shape.setter def sample_shape(self, value): # We only set the value if it is not None. Note that it can be None if the sample shape selection # is "Read from file" if value is not None: - self.set_simple_element(element_id=OtherId.sample_shape, value=value) + self.set_simple_element(element_id=OtherId.SAMPLE_SHAPE, value=value) @property def z_offset(self): - return self.get_simple_element(element_id=SampleId.offset, default_value="") + return self.get_simple_element(element_id=SampleId.OFFSET, default_value="") @z_offset.setter def z_offset(self, value): - self.set_simple_element(element_id=SampleId.offset, value=value) + self.set_simple_element(element_id=SampleId.OFFSET, value=value) # ================================================================================================================== # ================================================================================================================== @@ -523,7 +523,7 @@ class StateGuiModel(ModelCommon): # Q Limits # ------------------------------------------------------------------------------------------------------------------ def _set_q_1d_limits(self, min_value=None, max_value=None, rebin_string=None): - element_id = LimitsId.q + element_id = LimitsId.Q if element_id in self._user_file_items: settings = self._user_file_items[element_id] else: @@ -539,7 +539,7 @@ class StateGuiModel(ModelCommon): self._user_file_items.update({element_id: new_settings}) def _set_q_xy_limits(self, stop_value=None, step_value=None, step_type_value=None): - element_id = LimitsId.qxy + element_id = LimitsId.QXY if element_id in self._user_file_items: settings = self._user_file_items[element_id] else: @@ -556,7 +556,7 @@ class StateGuiModel(ModelCommon): @property def q_1d_rebin_string(self): - return self.get_simple_element_with_attribute(element_id=LimitsId.q, default_value="", + return self.get_simple_element_with_attribute(element_id=LimitsId.Q, default_value="", attribute="rebin_string") @q_1d_rebin_string.setter @@ -565,7 +565,7 @@ class StateGuiModel(ModelCommon): @property def q_xy_max(self): - return self.get_simple_element_with_attribute(element_id=LimitsId.qxy, default_value="", + return self.get_simple_element_with_attribute(element_id=LimitsId.QXY, default_value="", attribute="stop") @q_xy_max.setter @@ -574,7 +574,7 @@ class StateGuiModel(ModelCommon): @property def q_xy_step(self): - return self.get_simple_element_with_attribute(element_id=LimitsId.qxy, default_value="", + return self.get_simple_element_with_attribute(element_id=LimitsId.QXY, default_value="", attribute="step") @q_xy_step.setter @@ -583,7 +583,7 @@ class StateGuiModel(ModelCommon): @property def q_xy_step_type(self): - return self.get_simple_element_with_attribute(element_id=LimitsId.qxy, default_value=None, + return self.get_simple_element_with_attribute(element_id=LimitsId.QXY, default_value=None, attribute="step_type") @q_xy_step_type.setter @@ -592,121 +592,121 @@ class StateGuiModel(ModelCommon): @property def r_cut(self): - return self.get_simple_element(element_id=LimitsId.radius_cut, default_value="") + return self.get_simple_element(element_id=LimitsId.RADIUS_CUT, default_value="") @r_cut.setter def r_cut(self, value): - self.set_simple_element(element_id=LimitsId.radius_cut, value=value) + self.set_simple_element(element_id=LimitsId.RADIUS_CUT, value=value) @property def w_cut(self): - return self.get_simple_element(element_id=LimitsId.wavelength_cut, default_value="") + return self.get_simple_element(element_id=LimitsId.WAVELENGTH_CUT, default_value="") @w_cut.setter def w_cut(self, value): - self.set_simple_element(element_id=LimitsId.wavelength_cut, value=value) + self.set_simple_element(element_id=LimitsId.WAVELENGTH_CUT, value=value) # ------------------------------------------------------------------------------------------------------------------ # Gravity # ------------------------------------------------------------------------------------------------------------------ @property def gravity_on_off(self): - return self.get_simple_element(element_id=GravityId.on_off, default_value=True) + return self.get_simple_element(element_id=GravityId.ON_OFF, default_value=True) @gravity_on_off.setter def gravity_on_off(self, value): - self.set_simple_element(element_id=GravityId.on_off, value=value) + self.set_simple_element(element_id=GravityId.ON_OFF, value=value) @property def gravity_extra_length(self): - return self.get_simple_element(element_id=GravityId.extra_length, default_value="") + return self.get_simple_element(element_id=GravityId.EXTRA_LENGTH, default_value="") @gravity_extra_length.setter def gravity_extra_length(self, value): - self.set_simple_element(element_id=GravityId.extra_length, value=value) + self.set_simple_element(element_id=GravityId.EXTRA_LENGTH, value=value) # ------------------------------------------------------------------------------------------------------------------ # QResolution # ------------------------------------------------------------------------------------------------------------------ @property def use_q_resolution(self): - return self.get_simple_element(element_id=QResolutionId.on, default_value=False) + return self.get_simple_element(element_id=QResolutionId.ON, default_value=False) @use_q_resolution.setter def use_q_resolution(self, value): - self.set_simple_element(element_id=QResolutionId.on, value=value) + self.set_simple_element(element_id=QResolutionId.ON, value=value) @property def q_resolution_source_a(self): - return self.get_simple_element(element_id=QResolutionId.a1, default_value="") + return self.get_simple_element(element_id=QResolutionId.A1, default_value="") @q_resolution_source_a.setter def q_resolution_source_a(self, value): - self.set_simple_element(element_id=QResolutionId.a1, value=value) + self.set_simple_element(element_id=QResolutionId.A1, value=value) @property def q_resolution_sample_a(self): - return self.get_simple_element(element_id=QResolutionId.a2, default_value="") + return self.get_simple_element(element_id=QResolutionId.A2, default_value="") @q_resolution_sample_a.setter def q_resolution_sample_a(self, value): - self.set_simple_element(element_id=QResolutionId.a2, value=value) + self.set_simple_element(element_id=QResolutionId.A2, value=value) @property def q_resolution_source_h(self): - return self.get_simple_element(element_id=QResolutionId.h1, default_value="") + return self.get_simple_element(element_id=QResolutionId.H1, default_value="") @q_resolution_source_h.setter def q_resolution_source_h(self, value): - self.set_simple_element(element_id=QResolutionId.h1, value=value) + self.set_simple_element(element_id=QResolutionId.H1, value=value) @property def q_resolution_sample_h(self): - return self.get_simple_element(element_id=QResolutionId.h2, default_value="") + return self.get_simple_element(element_id=QResolutionId.H2, default_value="") @q_resolution_sample_h.setter def q_resolution_sample_h(self, value): - self.set_simple_element(element_id=QResolutionId.h2, value=value) + self.set_simple_element(element_id=QResolutionId.H2, value=value) @property def q_resolution_source_w(self): - return self.get_simple_element(element_id=QResolutionId.w1, default_value="") + return self.get_simple_element(element_id=QResolutionId.W1, default_value="") @q_resolution_source_w.setter def q_resolution_source_w(self, value): - self.set_simple_element(element_id=QResolutionId.w1, value=value) + self.set_simple_element(element_id=QResolutionId.W1, value=value) @property def q_resolution_sample_w(self): - return self.get_simple_element(element_id=QResolutionId.w2, default_value="") + return self.get_simple_element(element_id=QResolutionId.W2, default_value="") @q_resolution_sample_w.setter def q_resolution_sample_w(self, value): - self.set_simple_element(element_id=QResolutionId.w2, value=value) + self.set_simple_element(element_id=QResolutionId.W2, value=value) @property def q_resolution_delta_r(self): - return self.get_simple_element(element_id=QResolutionId.delta_r, default_value="") + return self.get_simple_element(element_id=QResolutionId.DELTA_R, default_value="") @q_resolution_delta_r.setter def q_resolution_delta_r(self, value): - self.set_simple_element(element_id=QResolutionId.delta_r, value=value) + self.set_simple_element(element_id=QResolutionId.DELTA_R, value=value) @property def q_resolution_moderator_file(self): - return self.get_simple_element(element_id=QResolutionId.moderator, default_value="") + return self.get_simple_element(element_id=QResolutionId.MODERATOR, default_value="") @q_resolution_moderator_file.setter def q_resolution_moderator_file(self, value): - self.set_simple_element(element_id=QResolutionId.moderator, value=value) + self.set_simple_element(element_id=QResolutionId.MODERATOR, value=value) @property def q_resolution_collimation_length(self): - return self.get_simple_element(element_id=QResolutionId.collimation_length, default_value="") + return self.get_simple_element(element_id=QResolutionId.COLLIMATION_LENGTH, default_value="") @q_resolution_collimation_length.setter def q_resolution_collimation_length(self, value): - self.set_simple_element(element_id=QResolutionId.collimation_length, value=value) + self.set_simple_element(element_id=QResolutionId.COLLIMATION_LENGTH, value=value) # ================================================================================================================== # ================================================================================================================== @@ -718,8 +718,8 @@ class StateGuiModel(ModelCommon): # Phi limit # ------------------------------------------------------------------------------------------------------------------ def _set_phi_limit(self, min_value=None, max_value=None, use_mirror=None): - if LimitsId.angle in self._user_file_items: - settings = self._user_file_items[LimitsId.angle] + if LimitsId.ANGLE in self._user_file_items: + settings = self._user_file_items[LimitsId.ANGLE] else: settings = [mask_angle_entry(min=None, max=None, use_mirror=False)] @@ -729,11 +729,11 @@ class StateGuiModel(ModelCommon): new_max = max_value if max_value is not None else setting.max new_use_mirror = use_mirror if use_mirror is not None else setting.use_mirror new_settings.append(mask_angle_entry(min=new_min, max=new_max, use_mirror=new_use_mirror)) - self._user_file_items.update({LimitsId.angle: new_settings}) + self._user_file_items.update({LimitsId.ANGLE: new_settings}) @property def phi_limit_min(self): - return self.get_simple_element_with_attribute(element_id=LimitsId.angle, attribute="min", default_value="-90") + return self.get_simple_element_with_attribute(element_id=LimitsId.ANGLE, attribute="min", default_value="-90") @phi_limit_min.setter def phi_limit_min(self, value): @@ -741,7 +741,7 @@ class StateGuiModel(ModelCommon): @property def phi_limit_max(self): - return self.get_simple_element_with_attribute(element_id=LimitsId.angle, attribute="max", default_value="90") + return self.get_simple_element_with_attribute(element_id=LimitsId.ANGLE, attribute="max", default_value="90") @phi_limit_max.setter def phi_limit_max(self, value): @@ -749,7 +749,7 @@ class StateGuiModel(ModelCommon): @property def phi_limit_use_mirror(self): - return self.get_simple_element_with_attribute(element_id=LimitsId.angle, attribute="use_mirror", default_value=True) # noqa + return self.get_simple_element_with_attribute(element_id=LimitsId.ANGLE, attribute="use_mirror", default_value=True) # noqa @phi_limit_use_mirror.setter def phi_limit_use_mirror(self, value): @@ -759,8 +759,8 @@ class StateGuiModel(ModelCommon): # Radius limit # ------------------------------------------------------------------------------------------------------------------ def _set_radius_limit(self, min_value=None, max_value=None): - if LimitsId.radius in self._user_file_items: - settings = self._user_file_items[LimitsId.radius] + if LimitsId.RADIUS in self._user_file_items: + settings = self._user_file_items[LimitsId.RADIUS] else: settings = [range_entry(start=None, stop=None)] @@ -769,11 +769,11 @@ class StateGuiModel(ModelCommon): new_min = min_value if min_value is not None else setting.start new_max = max_value if max_value is not None else setting.stop new_settings.append(range_entry(start=new_min, stop=new_max)) - self._user_file_items.update({LimitsId.radius: new_settings}) + self._user_file_items.update({LimitsId.RADIUS: new_settings}) @property def radius_limit_min(self): - return self.get_simple_element_with_attribute(element_id=LimitsId.radius, attribute="start", default_value="") + return self.get_simple_element_with_attribute(element_id=LimitsId.RADIUS, attribute="start", default_value="") @radius_limit_min.setter def radius_limit_min(self, value): @@ -781,7 +781,7 @@ class StateGuiModel(ModelCommon): @property def radius_limit_max(self): - return self.get_simple_element_with_attribute(element_id=LimitsId.radius, attribute="stop", default_value="") + return self.get_simple_element_with_attribute(element_id=LimitsId.RADIUS, attribute="stop", default_value="") @radius_limit_max.setter def radius_limit_max(self, value): @@ -792,17 +792,17 @@ class StateGuiModel(ModelCommon): # ------------------------------------------------------------------------------------------------------------------ @property def mask_files(self): - if MaskId.file in self._user_file_items: - return self._user_file_items[MaskId.file] + if MaskId.FILE in self._user_file_items: + return self._user_file_items[MaskId.FILE] return [] @mask_files.setter def mask_files(self, value): if value is None: return - if MaskId.file in self._user_file_items: - del self._user_file_items[MaskId.file] - new_state_entries = {MaskId.file: value} + if MaskId.FILE in self._user_file_items: + del self._user_file_items[MaskId.FILE] + new_state_entries = {MaskId.FILE: value} self._user_file_items.update(new_state_entries) # ------------------------------------------------------------------------------------------------------------------ @@ -810,8 +810,8 @@ class StateGuiModel(ModelCommon): # ------------------------------------------------------------------------------------------------------------------ @property def output_name(self): - return self.get_simple_element(element_id=OtherId.user_specified_output_name, default_value="") + return self.get_simple_element(element_id=OtherId.USER_SPECIFIED_OUTPUT_NAME, default_value="") @output_name.setter def output_name(self, value): - self.set_simple_element(element_id=OtherId.user_specified_output_name, value=value) + self.set_simple_element(element_id=OtherId.USER_SPECIFIED_OUTPUT_NAME, value=value) diff --git a/scripts/SANS/sans/gui_logic/models/summation_settings.py b/scripts/SANS/sans/gui_logic/models/summation_settings.py index 56c6616bde985c50ca60c48d745b68ff9e21e6bd..b155b1cb2aa2af9ce36151c8673deea97b29d675 100644 --- a/scripts/SANS/sans/gui_logic/models/summation_settings.py +++ b/scripts/SANS/sans/gui_logic/models/summation_settings.py @@ -81,9 +81,9 @@ class BinningFromMonitors(object): class SummationSettings(object): def __init__(self, initial_type): self._save_directory = ConfigService.getString('defaultsave.directory') - self._type_factory_dict = {BinningType.SaveAsEventData: SaveAsEventData(), - BinningType.Custom: CustomBinning(), - BinningType.FromMonitors: BinningFromMonitors()} + self._type_factory_dict = {BinningType.SAVE_AS_EVENT_DATA: SaveAsEventData(), + BinningType.CUSTOM: CustomBinning(), + BinningType.FROM_MONITORS: BinningFromMonitors()} self._settings, self._type = self._settings_from_type(initial_type) def instrument(self): @@ -146,4 +146,4 @@ class SummationSettings(object): return self._type_factory_dict[type], type except KeyError: # default - return self._type_factory_dict[BinningType.Custom], BinningType.Custom + return self._type_factory_dict[BinningType.CUSTOM], BinningType.CUSTOM diff --git a/scripts/SANS/sans/gui_logic/models/table_model.py b/scripts/SANS/sans/gui_logic/models/table_model.py index 437372ee02392f1268080416481fddc31f16b724..5dfab8947fc8c672a857011bae3e9384fea11c25 100644 --- a/scripts/SANS/sans/gui_logic/models/table_model.py +++ b/scripts/SANS/sans/gui_logic/models/table_model.py @@ -15,6 +15,7 @@ from __future__ import (absolute_import, division, print_function) import os import re +from mantid.py3compat import Enum from mantid.kernel import Logger from sans.common.constants import ALL_PERIODS from sans.common.enums import RowState, SampleShape @@ -125,7 +126,7 @@ class TableModel(object): def update_table_entry(self, row, column, value): self._table_entries[row].update_attribute(self.column_name_converter[column], value) - self._table_entries[row].update_attribute('row_state', RowState.Unprocessed) + self._table_entries[row].update_attribute('row_state', RowState.UNPROCESSED) self._table_entries[row].update_attribute('tool_tip', '') self.notify_subscribers() @@ -147,16 +148,16 @@ class TableModel(object): return SampleShapeColumnModel.get_hint_strategy() def set_row_to_processed(self, row, tool_tip): - self._table_entries[row].update_attribute('row_state', RowState.Processed) + self._table_entries[row].update_attribute('row_state', RowState.PROCESSED) self._table_entries[row].update_attribute('tool_tip', tool_tip) self.notify_subscribers() def reset_row_state(self, row): - self._table_entries[row].update_attribute('row_state', RowState.Unprocessed) + self._table_entries[row].update_attribute('row_state', RowState.UNPROCESSED) self._table_entries[row].update_attribute('tool_tip', '') def set_row_to_error(self, row, tool_tip): - self._table_entries[row].update_attribute('row_state', RowState.Error) + self._table_entries[row].update_attribute('row_state', RowState.ERROR) self._table_entries[row].update_attribute('tool_tip', tool_tip) self.notify_subscribers() @@ -319,7 +320,7 @@ class TableIndexModel(object): self.sample_shape_model = SampleShapeColumnModel() self.sample_shape = sample_shape - self.row_state = RowState.Unprocessed + self.row_state = RowState.UNPROCESSED self.tool_tip = '' self.file_information = None @@ -498,26 +499,30 @@ class SampleShapeColumnModel(object): self.sample_shape_string = "" def __call__(self, original_value): - self._get_sample_shape(original_value) + self._set_sample_shape(original_value) - def _get_sample_shape(self, original_value): - try: - original_value = SampleShape.to_string(original_value) - except RuntimeError as e: - if not isinstance(original_value, str): - raise ValueError(str(e)) - - value = original_value.strip().lower() - if value == "": - self.sample_shape = "" - self.sample_shape_string = "" - else: - for shape in SampleShapeColumnModel.SAMPLE_SHAPES: - if shape.startswith(value): - shape_enum_string = SampleShapeColumnModel.SAMPLE_SHAPES_DICT[shape] - self.sample_shape = SampleShape.from_string(shape_enum_string) - self.sample_shape_string = shape_enum_string - break + def _set_sample_shape(self, original_value): + if isinstance(original_value, Enum): + self.sample_shape = original_value + self.sample_shape_string = original_value.value + return + + user_val = original_value.strip().lower() + + # Set it to none as our fallback + self.sample_shape = SampleShape.NOT_SET + self.sample_shape_string = "" + + if not user_val: + return # If we don't return here an empty string will match with the first shape + + for shape in SampleShape: + # Case insensitive lookup + value = str(shape.value) + if user_val in value.lower() : + self.sample_shape = shape + self.sample_shape_string = shape.value + return @staticmethod def get_hint_strategy(): diff --git a/scripts/SANS/sans/gui_logic/presenter/add_runs_presenter.py b/scripts/SANS/sans/gui_logic/presenter/add_runs_presenter.py index 5cd27b8c63a7d74d05bee330a496da7f53efb113..4e318c1c7330ed22be14f5127c81200058083828 100644 --- a/scripts/SANS/sans/gui_logic/presenter/add_runs_presenter.py +++ b/scripts/SANS/sans/gui_logic/presenter/add_runs_presenter.py @@ -7,7 +7,7 @@ import os from mantid.kernel import ConfigService -from sans.common.enums import SANSInstrument +from mantid.py3compat import Enum from sans.gui_logic.gui_common import SANSGuiPropertiesHandler @@ -16,7 +16,8 @@ class AddRunsFilenameManager(object): if isinstance(inst, str): self.instrument_string = inst else: - self.instrument_string = SANSInstrument.to_string(inst) + assert(isinstance(inst, Enum)) + self.instrument_string = inst.value def make_filename(self, runs): if runs: diff --git a/scripts/SANS/sans/gui_logic/presenter/masking_table_presenter.py b/scripts/SANS/sans/gui_logic/presenter/masking_table_presenter.py index e465bb8fd74cd6e35d6f1c6cb92bf4ea0f4108ef..d189039f697c311cd03593931a59c23bfc90e0a2 100644 --- a/scripts/SANS/sans/gui_logic/presenter/masking_table_presenter.py +++ b/scripts/SANS/sans/gui_logic/presenter/masking_table_presenter.py @@ -54,9 +54,9 @@ def load_workspace(state, workspace_name): def run_mask_workspace(state, workspace_to_mask): mask_info = state.mask - detectors = [DetectorType.to_string(DetectorType.LAB), DetectorType.to_string(DetectorType.HAB)] \ - if DetectorType.to_string(DetectorType.HAB) in mask_info.detectors else\ - [DetectorType.to_string(DetectorType.LAB)] # noqa + detectors = [DetectorType.LAB.value, DetectorType.HAB.value] \ + if DetectorType.HAB.value in mask_info.detectors else\ + [DetectorType.LAB.value] # noqa for detector in detectors: mask_workspace(component_as_string=detector, workspace=workspace_to_mask, state=state) @@ -402,8 +402,8 @@ class MaskingTablePresenter(object): mask_info = state.mask masks = [] - mask_info_lab = mask_info.detectors[DetectorType.to_string(DetectorType.LAB)] - mask_info_hab = mask_info.detectors[DetectorType.to_string(DetectorType.HAB)] if DetectorType.to_string(DetectorType.HAB) in mask_info.detectors else None # noqa + mask_info_lab = mask_info.detectors[DetectorType.LAB.value] + mask_info_hab = mask_info.detectors[DetectorType.HAB.value] if DetectorType.HAB.value in mask_info.detectors else None # noqa # Add the radius mask radius_mask = self._get_radius(mask_info) diff --git a/scripts/SANS/sans/gui_logic/presenter/run_tab_presenter.py b/scripts/SANS/sans/gui_logic/presenter/run_tab_presenter.py index 3e9830f07c373a97ccecd9df9225d89b00552f36..39bf9e78e31729d745ef1dbda1aa549c630d0953 100644 --- a/scripts/SANS/sans/gui_logic/presenter/run_tab_presenter.py +++ b/scripts/SANS/sans/gui_logic/presenter/run_tab_presenter.py @@ -24,8 +24,8 @@ from mantid.kernel import Logger, ConfigService, ConfigPropertyObserver from mantid.py3compat import csv_open_type from sans.command_interface.batch_csv_file_parser import BatchCsvParser from sans.common.constants import ALL_PERIODS -from sans.common.enums import (BatchReductionEntry, ISISReductionMode, RangeStepType, RowState, SampleShape, - SANSFacility, SaveType, SANSInstrument) +from sans.common.enums import (BatchReductionEntry, ReductionMode, RangeStepType, RowState, SampleShape, + SaveType, SANSInstrument) from sans.gui_logic.gui_common import (add_dir_to_datasearch, get_reduction_mode_from_gui_selection, get_reduction_mode_strings_for_gui, get_string_for_gui_from_instrument, remove_dir_from_datasearch, SANSGuiPropertiesHandler) @@ -59,8 +59,8 @@ if PYQT4: else: from mantidqt.plotting.functions import get_plot_fig -row_state_to_colour_mapping = {RowState.Unprocessed: '#FFFFFF', RowState.Processed: '#d0f4d0', - RowState.Error: '#accbff'} +row_state_to_colour_mapping = {RowState.UNPROCESSED: '#FFFFFF', RowState.PROCESSED: '#d0f4d0', + RowState.ERROR: '#accbff'} def log_times(func): @@ -247,24 +247,24 @@ class RunTabPresenter(PresenterCommon): self._view.set_reduction_modes(reduction_mode_list) # Set the step type options for wavelength - range_step_types = [RangeStepType.to_string(RangeStepType.Lin), - RangeStepType.to_string(RangeStepType.Log), - RangeStepType.to_string(RangeStepType.RangeLog), - RangeStepType.to_string(RangeStepType.RangeLin)] + range_step_types = [RangeStepType.LIN.value, + RangeStepType.LOG.value, + RangeStepType.RANGE_LOG.value, + RangeStepType.RANGE_LIN.value] self._view.wavelength_step_type = range_step_types # Set the geometry options. This needs to include the option to read the sample shape from file. sample_shape = ["Read from file", - SampleShape.Cylinder, - SampleShape.FlatPlate, - SampleShape.Disc] + SampleShape.CYLINDER, + SampleShape.FLAT_PLATE, + SampleShape.DISC] self._view.sample_shape = sample_shape # Set the q range - self._view.q_1d_step_type = [RangeStepType.to_string(RangeStepType.Lin), - RangeStepType.to_string(RangeStepType.Log)] - self._view.q_xy_step_type = [RangeStepType.to_string(RangeStepType.Lin), - RangeStepType.to_string(RangeStepType.Log)] + self._view.q_1d_step_type = [RangeStepType.LIN.value, + RangeStepType.LOG.value] + self._view.q_xy_step_type = [RangeStepType.LIN.value, + RangeStepType.LOG.value] def _handle_output_directory_changed(self, new_directory): """ @@ -384,7 +384,7 @@ class RunTabPresenter(PresenterCommon): self._workspace_diagnostic_presenter.on_user_file_load(user_file_path) # 6. Warning if user file did not contain a recognised instrument - if self._view.instrument == SANSInstrument.NoInstrument: + if self._view.instrument == SANSInstrument.NO_INSTRUMENT: raise RuntimeError("User file did not contain a SANS Instrument.") except RuntimeError as instrument_e: @@ -398,8 +398,8 @@ class RunTabPresenter(PresenterCommon): use_error_name=True) def _on_user_file_load_failure(self, e, message, use_error_name=False): - self._setup_instrument_specific_settings(SANSInstrument.NoInstrument) - self._view.instrument = SANSInstrument.NoInstrument + self._setup_instrument_specific_settings(SANSInstrument.NO_INSTRUMENT) + self._view.instrument = SANSInstrument.NO_INSTRUMENT self._view.on_user_file_load_failure() self.display_errors(e, message, use_error_name) @@ -462,33 +462,33 @@ class RunTabPresenter(PresenterCommon): # ----Pull out the entries---- # Run numbers - sample_scatter = get_string_entry(BatchReductionEntry.SampleScatter, row) + sample_scatter = get_string_entry(BatchReductionEntry.SAMPLE_SCATTER, row) sample_scatter_period = get_string_period( - get_string_entry(BatchReductionEntry.SampleScatterPeriod, row)) - sample_transmission = get_string_entry(BatchReductionEntry.SampleTransmission, row) + get_string_entry(BatchReductionEntry.SAMPLE_SCATTER_PERIOD, row)) + sample_transmission = get_string_entry(BatchReductionEntry.SAMPLE_TRANSMISSION, row) sample_transmission_period = \ - get_string_period(get_string_entry(BatchReductionEntry.SampleTransmissionPeriod, row)) - sample_direct = get_string_entry(BatchReductionEntry.SampleDirect, row) + get_string_period(get_string_entry(BatchReductionEntry.SAMPLE_TRANSMISSION_PERIOD, row)) + sample_direct = get_string_entry(BatchReductionEntry.SAMPLE_DIRECT, row) sample_direct_period = get_string_period( - get_string_entry(BatchReductionEntry.SampleDirectPeriod, row)) - can_scatter = get_string_entry(BatchReductionEntry.CanScatter, row) + get_string_entry(BatchReductionEntry.SAMPLE_DIRECT_PERIOD, row)) + can_scatter = get_string_entry(BatchReductionEntry.CAN_SCATTER, row) can_scatter_period = get_string_period( - get_string_entry(BatchReductionEntry.CanScatterPeriod, row)) - can_transmission = get_string_entry(BatchReductionEntry.CanTransmission, row) + get_string_entry(BatchReductionEntry.CAN_SCATTER_PERIOD, row)) + can_transmission = get_string_entry(BatchReductionEntry.CAN_TRANSMISSION, row) can_transmission_period = get_string_period( - get_string_entry(BatchReductionEntry.CanScatterPeriod, row)) - can_direct = get_string_entry(BatchReductionEntry.CanDirect, row) + get_string_entry(BatchReductionEntry.CAN_SCATTER_PERIOD, row)) + can_direct = get_string_entry(BatchReductionEntry.CAN_DIRECT, row) can_direct_period = get_string_period( - get_string_entry(BatchReductionEntry.CanDirectPeriod, row)) + get_string_entry(BatchReductionEntry.CAN_DIRECT_PERIOD, row)) # Other information - output_name = get_string_entry(BatchReductionEntry.Output, row) - user_file = get_string_entry(BatchReductionEntry.UserFile, row) + output_name = get_string_entry(BatchReductionEntry.OUTPUT, row) + user_file = get_string_entry(BatchReductionEntry.USER_FILE, row) # Sample geometries - sample_thickness = get_string_entry(BatchReductionEntry.SampleThickness, row) - sample_height = get_string_entry(BatchReductionEntry.SampleHeight, row) - sample_width = get_string_entry(BatchReductionEntry.SampleWidth, row) + sample_thickness = get_string_entry(BatchReductionEntry.SAMPLE_THICKNESS, row) + sample_height = get_string_entry(BatchReductionEntry.SAMPLE_HEIGHT, row) + sample_width = get_string_entry(BatchReductionEntry.SAMPLE_WIDTH, row) # ----Form a row---- row_entry = [sample_scatter, sample_scatter_period, sample_transmission, @@ -528,7 +528,7 @@ class RunTabPresenter(PresenterCommon): def on_data_changed(self, row, column, new_value, old_value): self._table_model.update_table_entry(row, column, new_value) - self._view.change_row_color(row_state_to_colour_mapping[RowState.Unprocessed], row) + self._view.change_row_color(row_state_to_colour_mapping[RowState.UNPROCESSED], row) self._view.set_row_tooltip('', row) self._beam_centre_presenter.on_update_rows() self._masking_table_presenter.on_update_rows() @@ -629,7 +629,7 @@ class RunTabPresenter(PresenterCommon): """ if (self._view.output_mode_file_radio_button.isChecked() or self._view.output_mode_both_radio_button.isChecked()): - if self._view.save_types == [SaveType.NoType]: + if self._view.save_types == [SaveType.NO_TYPE]: raise RuntimeError("You have selected an output mode which saves to file, " "but no file types have been selected.") @@ -903,9 +903,9 @@ class RunTabPresenter(PresenterCommon): if not selection: return mode = get_reduction_mode_from_gui_selection(selection) - if mode == ISISReductionMode.HAB: + if mode == ReductionMode.HAB: self._beam_centre_presenter.update_hab_selected() - elif mode == ISISReductionMode.LAB: + elif mode == ReductionMode.LAB: self._beam_centre_presenter.update_lab_selected() else: self._beam_centre_presenter.update_all_selected() @@ -1076,7 +1076,7 @@ class RunTabPresenter(PresenterCommon): if len(elements) == 3: step_element = float(elements[1]) step = abs(step_element) - step_type = RangeStepType.Lin if step_element >= 0 else RangeStepType.Log + step_type = RangeStepType.LIN if step_element >= 0 else RangeStepType.LOG # Set on the view self._view.q_1d_min_or_rebin_string = float(elements[0]) @@ -1182,7 +1182,7 @@ class RunTabPresenter(PresenterCommon): q_1d_step = self._view.q_1d_step if q_1d_min and q_1d_max and q_1d_step and q_1d_step_type: q_1d_rebin_string = str(q_1d_min) + "," - q_1d_step_type_factor = -1. if q_1d_step_type is RangeStepType.Log else 1. + q_1d_step_type_factor = -1. if q_1d_step_type is RangeStepType.LOG else 1. q_1d_rebin_string += str(q_1d_step_type_factor * q_1d_step) + "," q_1d_rebin_string += str(q_1d_max) state_model.q_1d_rebin_string = q_1d_rebin_string @@ -1227,14 +1227,14 @@ class RunTabPresenter(PresenterCommon): # Settings # ------------------------------------------------------------------------------------------------------------------ def _setup_instrument_specific_settings(self, instrument=None): - if ConfigService["default.facility"] != SANSFacility.to_string(self._facility): - ConfigService["default.facility"] = SANSFacility.to_string(self._facility) + if ConfigService["default.facility"] != self._facility.value: + ConfigService["default.facility"] = self._facility.value self.sans_logger.notice("Facility changed to ISIS.") if not instrument: instrument = self._view.instrument - if instrument == SANSInstrument.NoInstrument: + if instrument == SANSInstrument.NO_INSTRUMENT: self._view.disable_process_buttons() else: instrument_string = get_string_for_gui_from_instrument(instrument) diff --git a/scripts/SANS/sans/gui_logic/presenter/settings_adjustment_presenter.py b/scripts/SANS/sans/gui_logic/presenter/settings_adjustment_presenter.py index 1f8221c286e61cd156a1166eb9d5ff5076e026f8..0bee1c1567d2e42d6ceb3bf509f77857e141e350 100644 --- a/scripts/SANS/sans/gui_logic/presenter/settings_adjustment_presenter.py +++ b/scripts/SANS/sans/gui_logic/presenter/settings_adjustment_presenter.py @@ -25,9 +25,9 @@ class SettingsAdjustmentPresenter(PresenterCommon): def default_gui_setup(self): # Set the fit options - fit_types = [FitType.to_string(FitType.Linear), - FitType.to_string(FitType.Logarithmic), - FitType.to_string(FitType.Polynomial)] + fit_types = [FitType.LINEAR.value, + FitType.LOGARITHMIC.value, + FitType.POLYNOMIAL.value] self._view.transmission_sample_fit_type = fit_types self._view.transmission_can_fit_type = fit_types @@ -91,8 +91,8 @@ class SettingsAdjustmentPresenter(PresenterCommon): use_fit = self._view.transmission_sample_use_fit fit_type = self._view.transmission_sample_fit_type polynomial_order = self._view.transmission_sample_polynomial_order - state_model.transmission_sample_fit_type = fit_type if use_fit else FitType.NoFit - state_model.transmission_can_fit_type = fit_type if use_fit else FitType.NoFit + state_model.transmission_sample_fit_type = fit_type if use_fit else FitType.NO_FIT + state_model.transmission_can_fit_type = fit_type if use_fit else FitType.NO_FIT state_model.transmission_sample_polynomial_order = polynomial_order state_model.transmission_can_polynomial_order = polynomial_order @@ -109,7 +109,7 @@ class SettingsAdjustmentPresenter(PresenterCommon): use_fit_sample = self._view.transmission_sample_use_fit fit_type_sample = self._view.transmission_sample_fit_type polynomial_order_sample = self._view.transmission_sample_polynomial_order - state_model.transmission_sample_fit_type = fit_type_sample if use_fit_sample else FitType.NoFit + state_model.transmission_sample_fit_type = fit_type_sample if use_fit_sample else FitType.NO_FIT state_model.transmission_sample_polynomial_order = polynomial_order_sample # Wavelength settings @@ -123,7 +123,7 @@ class SettingsAdjustmentPresenter(PresenterCommon): use_fit_can = self._view.transmission_can_use_fit fit_type_can = self._view.transmission_can_fit_type polynomial_order_can = self._view.transmission_can_polynomial_order - state_model.transmission_can_fit_type = fit_type_can if use_fit_can else FitType.NoFit + state_model.transmission_can_fit_type = fit_type_can if use_fit_can else FitType.NO_FIT state_model.transmission_can_polynomial_order = polynomial_order_can # Wavelength settings @@ -136,15 +136,15 @@ class SettingsAdjustmentPresenter(PresenterCommon): def _set_on_view_transmission_fit_sample_settings(self): # Set transmission_sample_use_fit fit_type = self._model.transmission_sample_fit_type - use_fit = fit_type is not FitType.NoFit + use_fit = fit_type is not FitType.NO_FIT self._view.transmission_sample_use_fit = use_fit # Set the polynomial order for sample - polynomial_order = self._model.transmission_sample_polynomial_order if fit_type is FitType.Polynomial else 2 # noqa + polynomial_order = self._model.transmission_sample_polynomial_order if fit_type is FitType.POLYNOMIAL else 2 # noqa self._view.transmission_sample_polynomial_order = polynomial_order # Set the fit type for the sample - fit_type = fit_type if fit_type is not FitType.NoFit else FitType.Linear + fit_type = fit_type if fit_type is not FitType.NO_FIT else FitType.LINEAR self._view.transmission_sample_fit_type = fit_type # Set the wavelength @@ -167,15 +167,15 @@ class SettingsAdjustmentPresenter(PresenterCommon): # Set transmission_sample_can_fit fit_type_can = self._model.transmission_can_fit_type() - use_can_fit = fit_type_can is FitType.NoFit + use_can_fit = fit_type_can is FitType.NO_FIT self._view.transmission_can_use_fit = use_can_fit # Set the polynomial order for can - polynomial_order_can = self._model.transmission_can_polynomial_order if fit_type_can is FitType.Polynomial else 2 # noqa + polynomial_order_can = self._model.transmission_can_polynomial_order if fit_type_can is FitType.POLYNOMIAL else 2 # noqa self._view.transmission_can_polynomial_order = polynomial_order_can # Set the fit type for the can - fit_type_can = fit_type_can if fit_type_can is not FitType.NoFit else FitType.Linear + fit_type_can = fit_type_can if fit_type_can is not FitType.NO_FIT else FitType.LINEAR self.transmission_can_fit_type = fit_type_can # Set the wavelength diff --git a/scripts/SANS/sans/gui_logic/presenter/summation_settings_presenter.py b/scripts/SANS/sans/gui_logic/presenter/summation_settings_presenter.py index 9b6e3bb634d3e27384132dd19facaa4525a6cbcb..d3afe64ea31ba0407d4e68a51168070ebca3298a 100644 --- a/scripts/SANS/sans/gui_logic/presenter/summation_settings_presenter.py +++ b/scripts/SANS/sans/gui_logic/presenter/summation_settings_presenter.py @@ -36,11 +36,11 @@ class SummationSettingsPresenter(object): @staticmethod def _binning_type_index_to_type(index): if index == 0: - return BinningType.Custom + return BinningType.CUSTOM elif index == 1: - return BinningType.FromMonitors + return BinningType.FROM_MONITORS elif index == 2: - return BinningType.SaveAsEventData + return BinningType.SAVE_AS_EVENT_DATA def _switch_binning_type(self, type_of_binning): self._summation_settings.set_histogram_binning_type(type_of_binning) diff --git a/scripts/SANS/sans/gui_logic/sans_data_processor_gui_algorithm.py b/scripts/SANS/sans/gui_logic/sans_data_processor_gui_algorithm.py index 836a4b55769a76444ca4757980495aaf79eff9b1..ca2bd2ccfb63aa2e9456c7b3713d59e54774b260 100644 --- a/scripts/SANS/sans/gui_logic/sans_data_processor_gui_algorithm.py +++ b/scripts/SANS/sans/gui_logic/sans_data_processor_gui_algorithm.py @@ -195,7 +195,7 @@ def create_properties(show_periods=True): algorithm_property="OutputMode", description='The output mode.', show_value=False, - default=OutputMode.to_string(OutputMode.PublishToADS), + default=OutputMode.PUBLISH_TO_ADS.value, prefix='', property_type=bool), algorithm_list_entry(column_name="", @@ -296,7 +296,7 @@ def create_properties(show_periods=True): algorithm_property="OutputMode", description='The output mode.', show_value=False, - default=OutputMode.to_string(OutputMode.PublishToADS), + default=OutputMode.PUBLISH_TO_ADS.value, prefix='', property_type=bool), algorithm_list_entry(column_name="", @@ -381,7 +381,7 @@ class SANSGuiDataProcessorAlgorithm(DataProcessorAlgorithm): # 3. Get some global settings use_optimizations = self.getProperty("UseOptimizations").value output_mode_as_string = self.getProperty("OutputMode").value - output_mode = OutputMode.from_string(output_mode_as_string) + output_mode = OutputMode(output_mode_as_string) plot_results = self.getProperty('PlotResults').value output_graph = self.getProperty('OutputGraph').value diff --git a/scripts/SANS/sans/sans_batch.py b/scripts/SANS/sans/sans_batch.py index 77820a4d6711296e2cf3c4ed3b853dff3a40f34d..e8cee1b8e5d83cd7c933e6fa1745b264c7505cad 100644 --- a/scripts/SANS/sans/sans_batch.py +++ b/scripts/SANS/sans/sans_batch.py @@ -17,7 +17,7 @@ class SANSBatchReduction(object): def __init__(self): super(SANSBatchReduction, self).__init__() - def __call__(self, states, use_optimizations=True, output_mode=OutputMode.PublishToADS, plot_results = False, + def __call__(self, states, use_optimizations=True, output_mode=OutputMode.PUBLISH_TO_ADS, plot_results = False, output_graph='', save_can=False): """ This is the start of any reduction. @@ -71,8 +71,8 @@ class SANSBatchReduction(object): raise RuntimeError("The output_graph must be set if plot_results is true. The provided value is" " {0}".format(output_graph)) - if output_mode is not OutputMode.PublishToADS and output_mode is not OutputMode.SaveToFile and\ - output_mode is not OutputMode.Both: # noqa + if output_mode is not OutputMode.PUBLISH_TO_ADS and output_mode is not OutputMode.SAVE_TO_FILE and\ + output_mode is not OutputMode.BOTH: # noqa raise RuntimeError("The output mode has to be an enum of type OutputMode. The provided type is" " {0}".format(type(output_mode))) @@ -97,7 +97,7 @@ class SANSCentreFinder(object): super(SANSCentreFinder, self).__init__() def __call__(self, state, r_min = 60, r_max = 280, max_iter = 20, x_start = 0.0, y_start = 0.0, - tolerance = 1.251e-4, find_direction = FindDirectionEnum.All, reduction_method = True, verbose=False, + tolerance = 1.251e-4, find_direction = FindDirectionEnum.ALL, reduction_method = True, verbose=False, component=DetectorType.LAB): """ This is the start of the beam centre finder algorithm. diff --git a/scripts/SANS/sans/state/calculate_transmission.py b/scripts/SANS/sans/state/calculate_transmission.py index 0cbf4aa262e31418d9b8daf2c548d4e1bea74711..eb7911f07aeca46263d93ee05317085ea45ce87e 100644 --- a/scripts/SANS/sans/state/calculate_transmission.py +++ b/scripts/SANS/sans/state/calculate_transmission.py @@ -11,8 +11,11 @@ from __future__ import (absolute_import, division, print_function) import json import copy +import abc +import six + from sans.state.state_base import (StateBase, rename_descriptor_names, PositiveIntegerParameter, BoolParameter, - PositiveFloatParameter, ClassTypeParameter, FloatParameter, DictParameter, + PositiveFloatParameter, FloatParameter, DictParameter, StringListParameter, PositiveFloatWithNoneParameter, PositiveFloatListParameter) from sans.common.enums import (RebinType, RangeStepType, FitType, DataType, SANSInstrument) from sans.common.configurations import Configurations @@ -27,19 +30,19 @@ from sans.common.xml_parsing import get_named_elements_from_ipf_file # ---------------------------------------------------------------------------------------------------------------------- @rename_descriptor_names class StateTransmissionFit(StateBase): - fit_type = ClassTypeParameter(FitType) + fit_type = FitType.LOGARITHMIC polynomial_order = PositiveIntegerParameter() wavelength_low = PositiveFloatWithNoneParameter() wavelength_high = PositiveFloatWithNoneParameter() def __init__(self): super(StateTransmissionFit, self).__init__() - self.fit_type = FitType.Logarithmic + self.fit_type = FitType.LOGARITHMIC self.polynomial_order = 0 - def validate(self): # noqa + def validate(self): is_invalid = {} - if self.fit_type is FitType.Polynomial and self.polynomial_order == 0: + if self.fit_type is FitType.POLYNOMIAL and self.polynomial_order == 0: entry = validation_message("You can only select a polynomial fit if you set a polynomial order (2 to 6).", "Make sure that you select a polynomial order.", {"fit_type": self.fit_type, @@ -89,11 +92,11 @@ class StateCalculateTransmission(StateBase): # ---------------- # Wavelength rebin # ---------------- - rebin_type = ClassTypeParameter(RebinType) wavelength_low = PositiveFloatListParameter() wavelength_high = PositiveFloatListParameter() wavelength_step = PositiveFloatParameter() - wavelength_step_type = ClassTypeParameter(RangeStepType) + rebin_type = RebinType.REBIN + wavelength_step_type = RangeStepType.NOT_SET use_full_wavelength_range = BoolParameter() wavelength_full_range_low = PositiveFloatParameter() @@ -109,23 +112,18 @@ class StateCalculateTransmission(StateBase): background_TOF_roi_start = FloatParameter() background_TOF_roi_stop = FloatParameter() - # ----------------------- - # Fit - # ---------------------- - fit = DictParameter() + fit = {DataType.CAN : StateTransmissionFit(), + DataType.SAMPLE : StateTransmissionFit()} def __init__(self): super(StateCalculateTransmission, self).__init__() # The keys of this dictionaries are the spectrum number of the monitors (as a string) self.background_TOF_monitor_start = {} self.background_TOF_monitor_stop = {} - - self.fit = {DataType.to_string(DataType.Sample): StateTransmissionFit(), - DataType.to_string(DataType.Can): StateTransmissionFit()} self.use_full_wavelength_range = False # Default rebin type is a standard Rebin - self.rebin_type = RebinType.Rebin + self.rebin_type = RebinType.REBIN self.prompt_peak_correction_enabled = False @@ -177,8 +175,7 @@ class StateCalculateTransmission(StateBase): # ----------------- # Wavelength rebin # ----------------- - if one_is_none([self.wavelength_low, self.wavelength_high, self.wavelength_step, self.wavelength_step_type, - self.wavelength_step_type, self.rebin_type]): + if one_is_none([self.wavelength_low, self.wavelength_high, self.wavelength_step, self.rebin_type]): entry = validation_message("A wavelength entry has not been set.", "Make sure that all entries are set.", {"wavelength_low": self.wavelength_low, @@ -188,6 +185,12 @@ class StateCalculateTransmission(StateBase): "rebin_type": self.rebin_type}) is_invalid.update(entry) + if self.wavelength_step_type is RangeStepType.NOT_SET: + entry = validation_message("A wavelength entry has not been set.", + "Make sure that all entries are set.", + {"wavelength_step_type" : self.wavelength_step_type}) + is_invalid.update(entry) + if is_not_none_and_first_larger_than_second([self.wavelength_low, self.wavelength_high]): entry = validation_message("Incorrect wavelength bounds.", "Make sure that lower wavelength bound is smaller then upper bound.", @@ -273,11 +276,8 @@ class StateCalculateTransmission(StateBase): "background_TOF_monitor_stop": self.background_TOF_monitor_stop}) is_invalid.update(entry) - # ----- - # Fit - # ----- - self.fit[DataType.to_string(DataType.Sample)].validate() - self.fit[DataType.to_string(DataType.Can)].validate() + for fit_type in six.itervalues(self.fit): + fit_type.validate() if is_invalid: raise ValueError("StateCalculateTransmission: The provided inputs are illegal. " @@ -299,6 +299,9 @@ class StateCalculateTransmissionLOQ(StateCalculateTransmission): def validate(self): super(StateCalculateTransmissionLOQ, self).validate() + def set_rebin_type(self, val): + self.state.rebin_type = val + class StateCalculateTransmissionSANS2D(StateCalculateTransmission): def __init__(self): @@ -356,12 +359,47 @@ def set_default_monitors(calculate_transmission_info, data_info): # --------------------------------------- # State builders # --------------------------------------- -class StateCalculateTransmissionBuilderLOQ(object): +@six.add_metaclass(abc.ABCMeta) +class StateCalculateTransmissionBuilderCommon(object): + def __init__(self, state): + self.state = state + + def set_wavelength_step_type(self, val): + self.state.wavelength_step_type = val + + def set_rebin_type(self, val): + self.state.rebin_type = val + + def set_can_fit_type(self, val): + self.state.fit[DataType.CAN].fit_type = val + + def set_can_polynomial_order(self, val): + self.state.fit[DataType.CAN].polynomial_order = val + + def set_can_wavelength_low(self, val): + self.state.fit[DataType.CAN].wavelength_low = val + + def set_can_wavelength_high(self, val): + self.state.fit[DataType.CAN].wavelength_high = val + + def set_sample_fit_type(self, val): + self.state.fit[DataType.SAMPLE].fit_type = val + + def set_sample_polynomial_order(self, val): + self.state.fit[DataType.SAMPLE].polynomial_order = val + + def set_sample_wavelength_low(self, val): + self.state.fit[DataType.SAMPLE].wavelength_low = val + + def set_sample_wavelength_high(self, val): + self.state.fit[DataType.SAMPLE].wavelength_high = val + + +class StateCalculateTransmissionBuilderLOQ(StateCalculateTransmissionBuilderCommon): @automatic_setters(StateCalculateTransmissionLOQ) def __init__(self, data_info): - super(StateCalculateTransmissionBuilderLOQ, self).__init__() + super(StateCalculateTransmissionBuilderLOQ, self).__init__(state=StateCalculateTransmissionLOQ()) self._data = data_info - self.state = StateCalculateTransmissionLOQ() set_default_monitors(self.state, self._data) def build(self): @@ -369,12 +407,11 @@ class StateCalculateTransmissionBuilderLOQ(object): return copy.copy(self.state) -class StateCalculateTransmissionBuilderSANS2D(object): +class StateCalculateTransmissionBuilderSANS2D(StateCalculateTransmissionBuilderCommon): @automatic_setters(StateCalculateTransmissionSANS2D) def __init__(self, data_info): - super(StateCalculateTransmissionBuilderSANS2D, self).__init__() + super(StateCalculateTransmissionBuilderSANS2D, self).__init__(state=StateCalculateTransmissionSANS2D()) self._data = data_info - self.state = StateCalculateTransmissionSANS2D() set_default_monitors(self.state, self._data) def build(self): @@ -382,12 +419,11 @@ class StateCalculateTransmissionBuilderSANS2D(object): return copy.copy(self.state) -class StateCalculateTransmissionBuilderLARMOR(object): +class StateCalculateTransmissionBuilderLARMOR(StateCalculateTransmissionBuilderCommon): @automatic_setters(StateCalculateTransmissionLARMOR) def __init__(self, data_info): - super(StateCalculateTransmissionBuilderLARMOR, self).__init__() + super(StateCalculateTransmissionBuilderLARMOR, self).__init__(state=StateCalculateTransmissionLARMOR()) self._data = data_info - self.state = StateCalculateTransmissionLARMOR() set_default_monitors(self.state, self._data) def build(self): @@ -395,12 +431,11 @@ class StateCalculateTransmissionBuilderLARMOR(object): return copy.copy(self.state) -class StateCalculateTransmissionBuilderZOOM(object): +class StateCalculateTransmissionBuilderZOOM(StateCalculateTransmissionBuilderCommon): @automatic_setters(StateCalculateTransmissionZOOM) def __init__(self, data_info): - super(StateCalculateTransmissionBuilderZOOM, self).__init__() + super(StateCalculateTransmissionBuilderZOOM, self).__init__(state=StateCalculateTransmissionZOOM()) self._data = data_info - self.state = StateCalculateTransmissionZOOM() set_default_monitors(self.state, self._data) def build(self): @@ -413,6 +448,7 @@ class StateCalculateTransmissionBuilderZOOM(object): # ------------------------------------------ def get_calculate_transmission_builder(data_info): instrument = data_info.instrument + if instrument is SANSInstrument.LARMOR: return StateCalculateTransmissionBuilderLARMOR(data_info) elif instrument is SANSInstrument.SANS2D: diff --git a/scripts/SANS/sans/state/convert_to_q.py b/scripts/SANS/sans/state/convert_to_q.py index bf3a5bb0eb2557edbf37946c8785838695b8161b..75cdf108f85db6074cd62882a1836c55837d2726 100644 --- a/scripts/SANS/sans/state/convert_to_q.py +++ b/scripts/SANS/sans/state/convert_to_q.py @@ -12,7 +12,7 @@ from __future__ import (absolute_import, division, print_function) import json import copy from sans.state.state_base import (StateBase, rename_descriptor_names, BoolParameter, PositiveFloatParameter, - ClassTypeParameter, StringParameter) + StringParameter) from sans.common.enums import (ReductionDimensionality, RangeStepType, SANSFacility) from sans.state.state_functions import (is_pure_none_or_not_none, is_not_none_and_first_larger_than_second, validation_message) @@ -24,7 +24,7 @@ from sans.state.automatic_setters import (automatic_setters) # ---------------------------------------------------------------------------------------------------------------------- @rename_descriptor_names class StateConvertToQ(StateBase): - reduction_dimensionality = ClassTypeParameter(ReductionDimensionality) + reduction_dimensionality = ReductionDimensionality.ONE_DIM use_gravity = BoolParameter() gravity_extra_length = PositiveFloatParameter() radius_cutoff = PositiveFloatParameter() @@ -38,7 +38,7 @@ class StateConvertToQ(StateBase): # 2D settings q_xy_max = PositiveFloatParameter() q_xy_step = PositiveFloatParameter() - q_xy_step_type = ClassTypeParameter(RangeStepType) + q_xy_step_type = RangeStepType.LIN # ----------------------- # Q Resolution specific @@ -60,7 +60,7 @@ class StateConvertToQ(StateBase): def __init__(self): super(StateConvertToQ, self).__init__() - self.reduction_dimensionality = ReductionDimensionality.OneDim + self.reduction_dimensionality = ReductionDimensionality.ONE_DIM self.use_gravity = False self.gravity_extra_length = 0.0 self.use_q_resolution = False @@ -84,7 +84,7 @@ class StateConvertToQ(StateBase): "q_max": self.q_max}) is_invalid.update(entry) - if self.reduction_dimensionality is ReductionDimensionality.OneDim: + if self.reduction_dimensionality is ReductionDimensionality.ONE_DIM: if self.q_min is None or self.q_max is None: entry = validation_message("Q bounds not set for 1D reduction.", "Make sure to set the q boundaries when using a 1D reduction.", @@ -105,7 +105,7 @@ class StateConvertToQ(StateBase): is_invalid.update(entry) # QXY settings - if self.reduction_dimensionality is ReductionDimensionality.TwoDim: + if self.reduction_dimensionality is ReductionDimensionality.TWO_DIM: if self.q_xy_max is None or self.q_xy_step is None: entry = validation_message("Q bounds not set for 2D reduction.", "Make sure that the q_max value bound and the step for the 2D reduction.", @@ -168,6 +168,15 @@ class StateConvertToQBuilder(object): self.state.validate() return copy.copy(self.state) + def set_reduction_dimensionality(self, val): + self.state.reduction_dimensionality = val + + def set_wavelength_step_type(self, val): + self.state.wavelength_step_type = val + + def set_q_xy_step_type(self, val): + self.state.q_xy_step_type = val + # ------------------------------------------ # Factory method for StateConvertToQBuilder diff --git a/scripts/SANS/sans/state/data.py b/scripts/SANS/sans/state/data.py index 1930b314b42e2a03b010f1debceb7b8995822728..4a62bab729d14a6a77e95c1f2c54c148aebab029 100644 --- a/scripts/SANS/sans/state/data.py +++ b/scripts/SANS/sans/state/data.py @@ -12,8 +12,8 @@ import json import copy from sans.state.state_base import (StateBase, StringParameter, PositiveIntegerParameter, BoolParameter, - ClassTypeParameter, rename_descriptor_names) -from sans.common.enums import (SANSInstrument, SANSFacility) + rename_descriptor_names) +from sans.common.enums import SANSFacility, SANSInstrument import sans.common.constants from sans.state.state_functions import (is_pure_none_or_not_none, validation_message) from sans.state.automatic_setters import automatic_setters @@ -43,13 +43,13 @@ class StateData(StateBase): sample_scatter_run_number = PositiveIntegerParameter() sample_scatter_is_multi_period = BoolParameter() - instrument = ClassTypeParameter(SANSInstrument) - facility = ClassTypeParameter(SANSFacility) idf_file_path = StringParameter() ipf_file_path = StringParameter() - user_file = StringParameter() + instrument = SANSInstrument.NO_INSTRUMENT + facility = SANSFacility.NO_FACILITY + def __init__(self): super(StateData, self).__init__() @@ -64,8 +64,8 @@ class StateData(StateBase): # This should be reset by the builder. Setting this to NoInstrument ensure that we will trip early on, # in case this is not set, for example by not using the builders. - self.instrument = SANSInstrument.NoInstrument - self.facility = SANSFacility.NoFacility + self.instrument = SANSInstrument.NO_INSTRUMENT + self.facility = SANSFacility.NO_FACILITY self.user_file = "" def validate(self): diff --git a/scripts/SANS/sans/state/mask.py b/scripts/SANS/sans/state/mask.py index cace137608c5535147b3cf222c381faaf82dc8a5..3711f52926f11f5ae43e84ed4ecb1e91f4d94c63 100644 --- a/scripts/SANS/sans/state/mask.py +++ b/scripts/SANS/sans/state/mask.py @@ -17,7 +17,7 @@ from sans.state.state_base import (StateBase, BoolParameter, StringListParameter from sans.state.state_functions import (is_pure_none_or_not_none, validation_message, set_detector_names) from sans.state.automatic_setters import (automatic_setters) from sans.common.file_information import find_full_file_path -from sans.common.enums import (SANSInstrument, DetectorType) +from sans.common.enums import (DetectorType, SANSInstrument) from sans.common.general_functions import get_bank_for_spectrum_number @@ -262,8 +262,8 @@ class StateMaskSANS2D(StateMask): def __init__(self): super(StateMaskSANS2D, self).__init__() # Setup the detectors - self.detectors = {DetectorType.to_string(DetectorType.LAB): StateMaskDetector(), - DetectorType.to_string(DetectorType.HAB): StateMaskDetector()} + self.detectors = {DetectorType.LAB.value: StateMaskDetector(), + DetectorType.HAB.value: StateMaskDetector()} def validate(self): super(StateMaskSANS2D, self).validate() @@ -274,8 +274,8 @@ class StateMaskLOQ(StateMask): def __init__(self): super(StateMaskLOQ, self).__init__() # Setup the detectors - self.detectors = {DetectorType.to_string(DetectorType.LAB): StateMaskDetector(), - DetectorType.to_string(DetectorType.HAB): StateMaskDetector()} + self.detectors = {DetectorType.LAB.value: StateMaskDetector(), + DetectorType.HAB.value: StateMaskDetector()} def validate(self): super(StateMaskLOQ, self).validate() @@ -286,7 +286,7 @@ class StateMaskLARMOR(StateMask): def __init__(self): super(StateMaskLARMOR, self).__init__() # Setup the detectors - self.detectors = {DetectorType.to_string(DetectorType.LAB): StateMaskDetector()} + self.detectors = {DetectorType.LAB.value: StateMaskDetector()} def validate(self): super(StateMaskLARMOR, self).validate() @@ -297,7 +297,7 @@ class StateMaskZOOM(StateMask): def __init__(self): super(StateMaskZOOM, self).__init__() # Setup the detectors - self.detectors = {DetectorType.to_string(DetectorType.LAB): StateMaskDetector()} + self.detectors = {DetectorType.LAB.value: StateMaskDetector()} def validate(self): super(StateMaskZOOM, self).validate() @@ -334,12 +334,12 @@ class StateMaskBuilder(object): instrument = self._data.instrument for spectrum in single_spectra: detector = get_bank_for_spectrum_number(spectrum, instrument) - detector_mask_state = self.state.detectors[DetectorType.to_string(detector)] + detector_mask_state = self.state.detectors[detector.value] spectra = detector_mask_state.single_spectra if spectra is not None: spectra.append(spectrum) else: - self.state.detectors[DetectorType.to_string(detector)].single_spectra = [spectrum] + self.state.detectors[detector.value].single_spectra = [spectrum] def set_spectrum_range_on_detector(self, spectrum_range_start, spectrum_range_stop): """ @@ -357,20 +357,20 @@ class StateMaskBuilder(object): raise ValueError("The specified spectrum mask range S{0}{1} has spectra on more than one detector. " "Make sure that all spectra in the range are on a single detector".format(start, stop)) else: - detector_mask_state = self.state.detectors[DetectorType.to_string(detector_bank_start)] + detector_mask_state = self.state.detectors[detector_bank_start.value] spec_range_start = detector_mask_state.spectrum_range_start spec_range_stop = detector_mask_state.spectrum_range_stop # Set the start spectrum range if spec_range_start is not None: spec_range_start.append(start) else: - self.state.detectors[DetectorType.to_string(detector_bank_start)].spectrum_range_start = [start] + self.state.detectors[detector_bank_start.value].spectrum_range_start = [start] # Set the stop spectrum range if spec_range_stop is not None: spec_range_stop.append(stop) else: - self.state.detectors[DetectorType.to_string(detector_bank_start)].spectrum_range_stop = [stop] + self.state.detectors[detector_bank_start.value].spectrum_range_stop = [stop] def build(self): self.state.validate() @@ -381,6 +381,7 @@ def get_mask_builder(data_info): # The data state has most of the information that we require to define the mask. For the factory method, only # the facility/instrument is of relevance. instrument = data_info.instrument + if instrument is SANSInstrument.SANS2D: return StateMaskBuilder(data_info, StateMaskSANS2D()) elif instrument is SANSInstrument.LOQ: diff --git a/scripts/SANS/sans/state/move.py b/scripts/SANS/sans/state/move.py index fa752251a819f08f432ee6b32da4f38661f138d8..322517940278aca711bfc65a142e42c00bf7fe5d 100644 --- a/scripts/SANS/sans/state/move.py +++ b/scripts/SANS/sans/state/move.py @@ -13,9 +13,9 @@ from __future__ import (absolute_import, division, print_function) import copy import json -from sans.common.enums import (Coordinates, CanonicalCoordinates, SANSInstrument, DetectorType) +from sans.common.enums import (CanonicalCoordinates, SANSInstrument, DetectorType) from sans.state.automatic_setters import automatic_setters -from sans.state.state_base import (StateBase, FloatParameter, DictParameter, ClassTypeParameter, +from sans.state.state_base import (StateBase, FloatParameter, DictParameter, StringWithNoneParameter, rename_descriptor_names) from sans.state.state_functions import (validation_message, set_detector_names, set_monitor_names) @@ -83,10 +83,11 @@ class StateMoveDetector(StateBase): @rename_descriptor_names class StateMove(StateBase): sample_offset = FloatParameter() - sample_offset_direction = ClassTypeParameter(Coordinates) detectors = DictParameter() monitor_names = DictParameter() + sample_offset_direction = CanonicalCoordinates.Z + def __init__(self): super(StateMove, self).__init__() @@ -119,8 +120,8 @@ class StateMoveLOQ(StateMove): self.monitor_names = {} # Setup the detectors - self.detectors = {DetectorType.to_string(DetectorType.LAB): StateMoveDetector(), - DetectorType.to_string(DetectorType.HAB): StateMoveDetector()} + self.detectors = {DetectorType.LAB.value: StateMoveDetector(), + DetectorType.HAB.value: StateMoveDetector()} def validate(self): # No validation of the descriptors on this level, let potential exceptions from detectors "bubble" up @@ -165,8 +166,8 @@ class StateMoveSANS2D(StateMove): self.monitor_4_offset = 0.0 # Setup the detectors - self.detectors = {DetectorType.to_string(DetectorType.LAB): StateMoveDetector(), - DetectorType.to_string(DetectorType.HAB): StateMoveDetector()} + self.detectors = {DetectorType.LAB.value: StateMoveDetector(), + DetectorType.HAB.value: StateMoveDetector()} def validate(self): super(StateMoveSANS2D, self).validate() @@ -186,7 +187,7 @@ class StateMoveLARMOR(StateMove): self.monitor_names = {} # Setup the detectors - self.detectors = {DetectorType.to_string(DetectorType.LAB): StateMoveDetector()} + self.detectors = {DetectorType.LAB.value: StateMoveDetector()} def validate(self): super(StateMoveLARMOR, self).validate() @@ -210,7 +211,7 @@ class StateMoveZOOM(StateMove): self.monitor_5_offset = 0.0 # Setup the detectors - self.detectors = {DetectorType.to_string(DetectorType.LAB): StateMoveDetector()} + self.detectors = {DetectorType.LAB.value: StateMoveDetector()} def validate(self): super(StateMoveZOOM, self).validate() @@ -329,6 +330,7 @@ def get_move_builder(data_info): # The data state has most of the information that we require to define the move. For the factory method, only # the instrument is of relevance. instrument = data_info.instrument + if instrument is SANSInstrument.LOQ: return StateMoveLOQBuilder(data_info) elif instrument is SANSInstrument.SANS2D: diff --git a/scripts/SANS/sans/state/normalize_to_monitor.py b/scripts/SANS/sans/state/normalize_to_monitor.py index b9348b0a0a2f99a6651e4ea475b932ff1d5be94b..9e0a68c1dac5b026b166654a6b07dab778b351fd 100644 --- a/scripts/SANS/sans/state/normalize_to_monitor.py +++ b/scripts/SANS/sans/state/normalize_to_monitor.py @@ -12,7 +12,7 @@ from __future__ import (absolute_import, division, print_function) import json import copy from sans.state.state_base import (StateBase, rename_descriptor_names, PositiveIntegerParameter, - PositiveFloatParameter, FloatParameter, ClassTypeParameter, DictParameter, + PositiveFloatParameter, FloatParameter, DictParameter, PositiveFloatWithNoneParameter, BoolParameter, PositiveFloatListParameter) from sans.state.automatic_setters import (automatic_setters) from sans.common.enums import (RebinType, RangeStepType, SANSInstrument) @@ -30,11 +30,11 @@ class StateNormalizeToMonitor(StateBase): prompt_peak_correction_max = PositiveFloatWithNoneParameter() prompt_peak_correction_enabled = BoolParameter() - rebin_type = ClassTypeParameter(RebinType) + rebin_type = RebinType.REBIN wavelength_low = PositiveFloatListParameter() wavelength_high = PositiveFloatListParameter() wavelength_step = PositiveFloatParameter() - wavelength_step_type = ClassTypeParameter(RangeStepType) + wavelength_step_type = RangeStepType.NOT_SET background_TOF_general_start = FloatParameter() background_TOF_general_stop = FloatParameter() @@ -50,7 +50,7 @@ class StateNormalizeToMonitor(StateBase): self.prompt_peak_correction_enabled = False # Default rebin type is a standard Rebin - self.rebin_type = RebinType.Rebin + self.rebin_type = RebinType.REBIN def validate(self): is_invalid = {} @@ -192,6 +192,12 @@ class StateNormalizeToMonitorBuilder(object): self.state.validate() return copy.copy(self.state) + def set_wavelength_step_type(self, val): + self.state.wavelength_step_type = val + + def set_rebin_type(self, val): + self.state.rebin_type = val + class StateNormalizeToMonitorBuilderLOQ(object): @automatic_setters(StateNormalizeToMonitorLOQ) @@ -205,9 +211,16 @@ class StateNormalizeToMonitorBuilderLOQ(object): self.state.validate() return copy.copy(self.state) + def set_wavelength_step_type(self, val): + self.state.wavelength_step_type = val + + def set_rebin_type(self, val): + self.state.rebin_type = val + def get_normalize_to_monitor_builder(data_info): instrument = data_info.instrument + if instrument is SANSInstrument.LARMOR or instrument is SANSInstrument.SANS2D or instrument is SANSInstrument.ZOOM: return StateNormalizeToMonitorBuilder(data_info) elif instrument is SANSInstrument.LOQ: diff --git a/scripts/SANS/sans/state/reduction_mode.py b/scripts/SANS/sans/state/reduction_mode.py index d2f09ca722c9a68e0547cc3005c93e5d12f2970d..6a98e407841cde44f688a95e7cbcb81ff4c9c2c2 100644 --- a/scripts/SANS/sans/state/reduction_mode.py +++ b/scripts/SANS/sans/state/reduction_mode.py @@ -4,21 +4,23 @@ # NScD Oak Ridge National Laboratory, European Spallation Source # & Institut Laue - Langevin # SPDX - License - Identifier: GPL - 3.0 + -# pylint: disable=too-few-public-methods """ Defines the state of the reduction.""" from __future__ import (absolute_import, division, print_function) -from abc import (ABCMeta, abstractmethod) -from six import (with_metaclass) + import copy import json -from sans.state.state_base import (StateBase, ClassTypeParameter, FloatParameter, DictParameter, - FloatWithNoneParameter, rename_descriptor_names, BoolParameter) -from sans.common.enums import (ReductionMode, ISISReductionMode, ReductionDimensionality, FitModeForMerge, +from abc import (ABCMeta, abstractmethod) + +from six import (with_metaclass) + +from sans.common.enums import (ReductionMode, ReductionDimensionality, FitModeForMerge, SANSFacility, DetectorType) from sans.common.xml_parsing import get_named_elements_from_ipf_file from sans.state.automatic_setters import (automatic_setters) +from sans.state.state_base import (StateBase, FloatParameter, DictParameter, + FloatWithNoneParameter, rename_descriptor_names, BoolParameter) # ---------------------------------------------------------------------------------------------------------------------- @@ -40,14 +42,15 @@ class StateReductionBase(with_metaclass(ABCMeta, object)): @rename_descriptor_names class StateReductionMode(StateReductionBase, StateBase): - reduction_mode = ClassTypeParameter(ReductionMode) - reduction_dimensionality = ClassTypeParameter(ReductionDimensionality) + reduction_mode = ReductionMode.NOT_SET + + reduction_dimensionality = ReductionDimensionality.ONE_DIM merge_max = FloatWithNoneParameter() merge_min = FloatWithNoneParameter() merge_mask = BoolParameter() # Fitting - merge_fit_mode = ClassTypeParameter(FitModeForMerge) + merge_fit_mode = FitModeForMerge.NO_FIT merge_shift = FloatParameter() merge_scale = FloatParameter() merge_range_min = FloatWithNoneParameter() @@ -58,13 +61,13 @@ class StateReductionMode(StateReductionBase, StateBase): def __init__(self): super(StateReductionMode, self).__init__() - self.reduction_mode = ISISReductionMode.LAB - self.reduction_dimensionality = ReductionDimensionality.OneDim + self.reduction_mode = ReductionMode.LAB + self.reduction_dimensionality = ReductionDimensionality.ONE_DIM # Set the shifts to defaults which essentially don't do anything. self.merge_shift = 0.0 self.merge_scale = 1.0 - self.merge_fit_mode = FitModeForMerge.NoFit + self.merge_fit_mode = FitModeForMerge.NO_FIT self.merge_range_min = None self.merge_range_max = None self.merge_max = None @@ -72,20 +75,20 @@ class StateReductionMode(StateReductionBase, StateBase): self.merge_mask = False # Set the detector names to empty strings - self.detector_names = {DetectorType.to_string(DetectorType.LAB): "", - DetectorType.to_string(DetectorType.HAB): ""} + self.detector_names = {DetectorType.LAB.value: "", + DetectorType.HAB.value: ""} def get_merge_strategy(self): - return [ISISReductionMode.LAB, ISISReductionMode.HAB] + return [ReductionMode.LAB, ReductionMode.HAB] def get_all_reduction_modes(self): - return [ISISReductionMode.LAB, ISISReductionMode.HAB] + return [ReductionMode.LAB, ReductionMode.HAB] def get_detector_name_for_reduction_mode(self, reduction_mode): - if reduction_mode is ISISReductionMode.LAB: - bank_type = DetectorType.to_string(DetectorType.LAB) - elif reduction_mode is ISISReductionMode.HAB: - bank_type = DetectorType.to_string(DetectorType.HAB) + if reduction_mode is ReductionMode.LAB: + bank_type = DetectorType.LAB.value + elif reduction_mode is ReductionMode.HAB: + bank_type = DetectorType.HAB.value else: raise RuntimeError("SANStateReductionISIS: There is no detector available for the" " reduction mode {0}.".format(reduction_mode)) @@ -108,8 +111,8 @@ class StateReductionMode(StateReductionBase, StateBase): def setup_detectors_from_ipf(reduction_info, data_info): ipf_file_path = data_info.ipf_file_path - detector_names = {DetectorType.to_string(DetectorType.LAB): "low-angle-detector-name", - DetectorType.to_string(DetectorType.HAB): "high-angle-detector-name"} + detector_names = {DetectorType.LAB.value: "low-angle-detector-name", + DetectorType.HAB.value: "high-angle-detector-name"} names_to_search = [] names_to_search.extend(list(detector_names.values())) @@ -128,10 +131,19 @@ def setup_detectors_from_ipf(reduction_info, data_info): class StateReductionModeBuilder(object): @automatic_setters(StateReductionMode, exclusions=["detector_names"]) def __init__(self, data_info): - super(StateReductionModeBuilder, self).__init__() self.state = StateReductionMode() setup_detectors_from_ipf(self.state, data_info) + # TODO this whole class is a shim around state, so we should remove it at a later date + def set_reduction_mode(self, val): + self.state.reduction_mode = val + + def set_reduction_dimensionality(self, val): + self.state.reduction_dimensionality = val + + def set_merge_fit_mode(self, val): + self.state.merge_fit_mode = val + def build(self): self.state.validate() return copy.copy(self.state) diff --git a/scripts/SANS/sans/state/save.py b/scripts/SANS/sans/state/save.py index cd7659b5dfe6406758cb746e5a69d054fd930ba2..c97334e799a9c892b446817571fdc1ba1a9764ef 100644 --- a/scripts/SANS/sans/state/save.py +++ b/scripts/SANS/sans/state/save.py @@ -11,7 +11,7 @@ from __future__ import (absolute_import, division, print_function) import copy from sans.state.state_base import (StateBase, BoolParameter, StringParameter, StringWithNoneParameter, - ClassTypeListParameter, rename_descriptor_names) + rename_descriptor_names) from sans.common.enums import (SaveType, SANSFacility) from sans.state.automatic_setters import (automatic_setters) @@ -22,7 +22,7 @@ from sans.state.automatic_setters import (automatic_setters) @rename_descriptor_names class StateSave(StateBase): zero_free_correction = BoolParameter() - file_format = ClassTypeListParameter(SaveType) + file_format = SaveType.NO_TYPE # Settings for the output name user_specified_output_name = StringWithNoneParameter() @@ -50,6 +50,9 @@ class StateSaveBuilder(object): self.state.validate() return copy.copy(self.state) + def set_file_format(self, val): + self.state.file_format = val + def get_save_builder(data_info): # The data state has most of the information that we require to define the save. For the factory method, only diff --git a/scripts/SANS/sans/state/scale.py b/scripts/SANS/sans/state/scale.py index 84356c519958a026106e11a3b6f2359ea56dfafb..fb191a4449da2eb7f29814e9826ec9330e0635d8 100644 --- a/scripts/SANS/sans/state/scale.py +++ b/scripts/SANS/sans/state/scale.py @@ -8,7 +8,7 @@ from __future__ import (absolute_import, division, print_function) import copy -from sans.state.state_base import (StateBase, rename_descriptor_names, PositiveFloatParameter, ClassTypeParameter) +from sans.state.state_base import (StateBase, rename_descriptor_names, PositiveFloatParameter) from sans.common.enums import (SampleShape, SANSFacility) from sans.state.automatic_setters import (automatic_setters) @@ -18,24 +18,21 @@ from sans.state.automatic_setters import (automatic_setters) # ---------------------------------------------------------------------------------------------------------------------- @rename_descriptor_names class StateScale(StateBase): - shape = ClassTypeParameter(SampleShape) + shape = None + thickness = PositiveFloatParameter() width = PositiveFloatParameter() height = PositiveFloatParameter() scale = PositiveFloatParameter() # Geometry from the file - shape_from_file = ClassTypeParameter(SampleShape) + shape_from_file = SampleShape.DISC thickness_from_file = PositiveFloatParameter() width_from_file = PositiveFloatParameter() height_from_file = PositiveFloatParameter() def __init__(self): super(StateScale, self).__init__() - - # The default geometry - self.shape_from_file = SampleShape.Disc - # The default values are 1mm self.thickness_from_file = 1. self.width_from_file = 1. @@ -48,7 +45,7 @@ class StateScale(StateBase): # ---------------------------------------------------------------------------------------------------------------------- # Builder # ---------------------------------------------------------------------------------------------------------------------- -def set_geometry_from_file(state, date_info, file_information): +def set_geometry_from_file(state, file_information): # Get the geometry state.height_from_file = file_information.get_height() state.width_from_file = file_information.get_width() @@ -58,15 +55,21 @@ def set_geometry_from_file(state, date_info, file_information): class StateScaleBuilder(object): @automatic_setters(StateScale, exclusions=[]) - def __init__(self, data_info, file_information): + def __init__(self, file_information): super(StateScaleBuilder, self).__init__() self.state = StateScale() - set_geometry_from_file(self.state, data_info, file_information) + set_geometry_from_file(self.state, file_information) def build(self): self.state.validate() return copy.copy(self.state) + def set_shape(self, val): + self.state.shape = val + + def set_shape_from_file(self, val): + self.state.shape_from_file(val) + # --------------------------------------- # Factory method for SANStateScaleBuilder @@ -76,7 +79,7 @@ def get_scale_builder(data_info, file_information=None): # the facility/instrument is of relevance. facility = data_info.facility if facility is SANSFacility.ISIS: - return StateScaleBuilder(data_info, file_information) + return StateScaleBuilder(file_information) else: raise NotImplementedError("StateScaleBuilder: Could not find any valid scale builder for the " "specified StateData object {0}".format(str(data_info))) diff --git a/scripts/SANS/sans/state/state_base.py b/scripts/SANS/sans/state/state_base.py index 7723369488aff615f13c2843bd2f8e6befe20269..d4d01039b740727ac6c3f13e47e5dcd017149ff0 100644 --- a/scripts/SANS/sans/state/state_base.py +++ b/scripts/SANS/sans/state/state_base.py @@ -4,22 +4,28 @@ # NScD Oak Ridge National Laboratory, European Spallation Source # & Institut Laue - Langevin # SPDX - License - Identifier: GPL - 3.0 + -#pylint: disable=too-few-public-methods, invalid-name +# pylint: disable=too-few-public-methods, invalid-name """ Fundamental classes and Descriptors for the State mechanism.""" from __future__ import (absolute_import, division, print_function) -from abc import (ABCMeta, abstractmethod) + import copy import inspect +from abc import (ABCMeta, abstractmethod) from functools import (partial) +from importlib import import_module + from six import string_types, with_metaclass from mantid.kernel import (PropertyManager, std_vector_dbl, std_vector_str, std_vector_int, std_vector_long) +from mantid.py3compat import Enum # --------------------------------------------------------------- # Validator functions # --------------------------------------------------------------- + + def is_not_none(value): return value is not None @@ -209,24 +215,6 @@ class DictFloatsParameter(TypedParameter): value, type(value))) -class ClassTypeParameter(TypedParameter): - """ - This TypedParameter variant allows for storing a class type. - - This could be for example something from the SANSType module, e.g. CanonicalCoordinates.X - It is something that is used frequently with the main of moving away from using strings where types - should be used instead. - """ - def __init__(self, class_type): - super(ClassTypeParameter, self).__init__(class_type, is_not_none) - - def _type_check(self, value): - if not issubclass(value, self.parameter_type): - raise TypeError("Trying to set {0} which expects a value of type {1}." - " Got a value of {2} which is of type: {3}".format(self.name, self.parameter_type, - value, type(value))) - - class FloatWithNoneParameter(TypedParameter): def __init__(self): super(FloatWithNoneParameter, self).__init__(float) @@ -276,7 +264,6 @@ class PositiveFloatListParameter(TypedParameter): super(PositiveFloatListParameter, self).__init__(list, all_list_elements_are_float_and_positive_and_not_empty) def _type_check(self, value): - if not isinstance(value, self.parameter_type) or not all_list_elements_are_float_and_not_empty(value): raise TypeError("Trying to set {0} which expects a value of type {1}." " Got a value of {2} which is of type: {3}".format(self.name, self.parameter_type, @@ -306,12 +293,6 @@ class PositiveIntegerListParameter(TypedParameter): value, type(value))) -class ClassTypeListParameter(TypedParameter): - def __init__(self, class_type): - typed_comparison = partial(all_list_elements_are_of_class_type_and_not_empty, comparison_type=class_type) - super(ClassTypeListParameter, self).__init__(list, typed_comparison) - - # ------------------------------------------------ # StateBase # ------------------------------------------------ @@ -330,6 +311,9 @@ class StateBase(with_metaclass(ABCMeta, object)): def validate(self): pass + def convert_to_dict(self): + return convert_state_to_dict(self) + def rename_descriptor_names(cls): """ @@ -355,12 +339,13 @@ def rename_descriptor_names(cls): # # During serialization we place identifier tags into the serialized object, e.g. we add a specifier if the item # is a State type at all and if so which state it is. +ENUM_TYPE_TAG = "EnumTag#" +INT_TAG = "int" STATE_NAME = "state_name" STATE_MODULE = "state_module" SEPARATOR_SERIAL = "#" -class_type_parameter_id = "ClassTypeParameterID#" MODULE = "__module__" @@ -389,19 +374,7 @@ def get_module_and_class_name(instance): def provide_class_from_module_and_class_name(module_name, class_name): - # Importlib seems to be missing on RHEL6, hence we resort to __import__ - try: - from importlib import import_module - module = import_module(module_name) - except ImportError: - if "." in module_name: - _, mod_name = module_name.rsplit(".", 1) - else: - mod_name = None - if not mod_name: - module = __import__(module_name) - else: - module = __import__(module_name, fromlist=[mod_name]) + module = import_module(module_name) return getattr(module, class_name) @@ -411,20 +384,18 @@ def provide_class(instance): return provide_class_from_module_and_class_name(module_name, class_name) -def is_class_type_parameter(value): - return isinstance(value, string_types) and class_type_parameter_id in value +def is_enum_type_parameter(value): + return isinstance(value, string_types) and ENUM_TYPE_TAG in value -def is_vector_with_class_type_parameter(value): - is_vector_with_class_type = True - contains_str = is_string_vector(value) - if contains_str: - for element in value: - if not is_class_type_parameter(element): - is_vector_with_class_type = False - else: - is_vector_with_class_type = False - return is_vector_with_class_type +def is_enum_list_parameter(value): + if isinstance(value, string_types): + return False + + try: + return all(ENUM_TYPE_TAG in s for s in value) + except TypeError: + return False def get_module_and_class_name_from_encoded_string(encoder, value): @@ -432,10 +403,6 @@ def get_module_and_class_name_from_encoded_string(encoder, value): return without_encoder.split(SEPARATOR_SERIAL) -def create_module_and_class_name_from_encoded_string(class_type_id, module_name, class_name): - return class_type_id + module_name + SEPARATOR_SERIAL + class_name - - def create_sub_state(value): # We are dealing with a sub state. We first have to create it and then populate it sub_state_class = provide_class(value) @@ -446,22 +413,27 @@ def create_sub_state(value): def get_descriptor_values(instance): - # Get all descriptor names which are TypedParameter of instance's type - descriptor_names = [] - descriptor_types = {} - for descriptor_name, descriptor_object in inspect.getmembers(type(instance)): - if inspect.isdatadescriptor(descriptor_object) and isinstance(descriptor_object, TypedParameter): - descriptor_names.append(descriptor_name) - descriptor_types.update({descriptor_name: descriptor_object}) + # Get all user defined attributes + member_variables = inspect.getmembers(type(instance), lambda x: not(inspect.isroutine(x))) + # Remove anything starting with a '_' or '__' - i.e. private attributes such as weak ref + member_variables = [item for item in member_variables if not item[0].startswith('_')] + + # Property manager is a fake property that wraps the serializing method (i.e. this) + # so trying to pack it causes inf recursion + member_variables = [item for item in member_variables if "property_manager" not in item[0]] + + # We only need names + member_variables = [item[0] for item in member_variables] # Get the descriptor values from the instance descriptor_values = {} - for key in descriptor_names: + for key in member_variables: if hasattr(instance, key): value = getattr(instance, key) if value is not None: descriptor_values.update({key: value}) - return descriptor_values, descriptor_types + + return descriptor_values def get_class_descriptor_types(instance): @@ -480,40 +452,28 @@ def convert_state_to_dict(instance): :param instance: the instance which is to be converted :return: a serialized state object in the form of a dict """ - descriptor_values, descriptor_types = get_descriptor_values(instance) + attribute = get_descriptor_values(instance) # Add the descriptors to a dict state_dict = dict() - for key, value in list(descriptor_values.items()): - # If the value is a SANSBaseState then create a dict from it - # If the value is a dict, then we need to check what the sub types are - # If the value is a ClassTypeParameter, then we need to encode it - # If the value is a list of ClassTypeParameters, then we need to encode each element in the list - if isinstance(value, StateBase): - sub_state_dict = value.property_manager - value = sub_state_dict - elif isinstance(value, dict): - # If we have a dict, then we need to watch out since a value in the dict might be a State - sub_dictionary = {} - for key_sub, val_sub in list(value.items()): - if isinstance(val_sub, StateBase): - sub_dictionary_value = val_sub.property_manager - else: - sub_dictionary_value = val_sub - sub_dictionary.update({key_sub: sub_dictionary_value}) - value = sub_dictionary - elif isinstance(descriptor_types[key], ClassTypeParameter): - value = get_serialized_class_type_parameter(value) - elif isinstance(descriptor_types[key], ClassTypeListParameter): - if value: - # If there are entries in the list, then convert them individually and place them into a list. - # The list will contain a sequence of serialized ClassTypeParameters - serialized_value = [] - for element in value: - serialized_element = get_serialized_class_type_parameter(element) - serialized_value.append(serialized_element) - value = serialized_value - - state_dict.update({key: value}) + + # Don't do anything if primitive type that Mantid can serialize + primative_types = (int, str, bool, float, TypedParameter) + + for attr_name, attr_val in attribute.items(): + if isinstance(attr_val, StateBase): + # If the value is a SANSBaseState then create a dict from it + attr_val = attr_val.property_manager + elif isinstance(attr_val, dict): + attr_val = serialize_dict(attr_val) + elif isinstance(attr_val, Enum) or isinstance(attr_val, list) and all(isinstance(x, Enum) for x in attr_val): + attr_val = serialize_enum(attr_val) + elif isinstance(attr_val, primative_types) \ + or isinstance(attr_val, list) and all(isinstance(x, primative_types) for x in attr_val): + pass # A primative type or list of primitives don't need anything special + else: + raise ValueError("Cannot serialize {0}".format(attr_val)) + + state_dict.update({attr_name: attr_val}) # Add information about the current state object, such as in which module it lives and what its name is module_name, class_name = get_module_and_class_name(instance) state_dict.update({STATE_MODULE: module_name}) @@ -521,6 +481,24 @@ def convert_state_to_dict(instance): return state_dict +def serialize_dict(value): + # If we have a dict, then we need to watch out since a value in the dict might be a State + sub_dictionary = {} + for key_sub, val_sub in list(value.items()): + # We have to handle the key being an enum too + if isinstance(key_sub, Enum): + key_sub = serialize_enum(key_sub) + + if isinstance(val_sub, StateBase): + val_sub = val_sub.property_manager + elif isinstance(val_sub, Enum): + val_sub = serialize_enum(val_sub) + + sub_dictionary.update({key_sub: val_sub}) + + return sub_dictionary + + def set_state_from_property_manager(instance, property_manager): """ Set the State object from the information stored on a property manager object. This is the deserialization step. @@ -528,6 +506,7 @@ def set_state_from_property_manager(instance, property_manager): :param instance: the instance which is to be set with a values of the property manager :param property_manager: the property manager with the stored setting """ + def _set_element(inst, k_element, v_element): if k_element != STATE_NAME and k_element != STATE_MODULE: setattr(inst, k_element, v_element) @@ -535,46 +514,20 @@ def set_state_from_property_manager(instance, property_manager): keys = list(property_manager.keys()) for key in keys: value = property_manager.getProperty(key).value - # There are four scenarios that need to be considered + # There are some special scenarios that need to be considered # 1. ParameterManager 1: This indicates (most often) that we are dealing with a new state -> create it and # apply recursion # 2. ParameterManager 2: In some cases the ParameterManager object is actually a map rather than a state -> # populate the state - # 3. String with special meaning: Admittedly this is a hack, but we are limited by the input property types - # of Mantid algorithms, which can be string, int, float and containers of these - # types (and PropertyManagerProperties). We need a wider range of types, such - # as ClassTypeParameters. These are encoded (as good as possible) in a string - # 4. Vector of strings with special meaning: See point 3) - # 5. Vector for float: This needs to handle Mantid's float array - # 6. Vector for string: This needs to handle Mantid's string array - # 7. Vector for int: This needs to handle Mantid's integer array - # 8. Normal values: all is fine, just populate them + if type(value) is PropertyManager and is_state(value): sub_state = create_sub_state(value) setattr(instance, key, sub_state) elif type(value) is PropertyManager: - # We must be dealing with an actual dict descriptor - sub_dict_keys = list(value.keys()) - dict_element = {} - # We need to watch out if a value of the dictionary is a sub state - for sub_dict_key in sub_dict_keys: - sub_dict_value = value.getProperty(sub_dict_key).value - if type(sub_dict_value) == PropertyManager and is_state(sub_dict_value): - sub_state = create_sub_state(sub_dict_value) - sub_dict_value_to_insert = sub_state - else: - sub_dict_value_to_insert = sub_dict_value - dict_element.update({sub_dict_key: sub_dict_value_to_insert}) - setattr(instance, key, dict_element) - elif is_class_type_parameter(value): - class_type_parameter = get_deserialized_class_type_parameter(value) - _set_element(instance, key, class_type_parameter) - elif is_vector_with_class_type_parameter(value): - class_type_list = [] - for element in value: - class_type_parameter = get_deserialized_class_type_parameter(element) - class_type_list.append(class_type_parameter) - _set_element(instance, key, class_type_list) + deserialize_dict(instance, key, value) + elif is_enum_type_parameter(value) or is_enum_list_parameter(value): + enum_type_parameter = deserialize_enum(value) + _set_element(instance, key, enum_type_parameter) elif is_float_vector(value): float_list_value = list(value) _set_element(instance, key, float_list_value) @@ -588,24 +541,60 @@ def set_state_from_property_manager(instance, property_manager): _set_element(instance, key, value) -def get_serialized_class_type_parameter(value): - # The module will only know about the outer class name, therefore we need - # 1. The module name - # 2. The name of the outer class - # 3. The name of the actual class - module_name, class_name = get_module_and_class_name(value) - outer_class_name = value.outer_class_name - class_name = outer_class_name + SEPARATOR_SERIAL + class_name - return create_module_and_class_name_from_encoded_string(class_type_parameter_id, module_name, class_name) - - -def get_deserialized_class_type_parameter(value): - # We need to first get the outer class from the module - module_name, outer_class_name, class_name = \ - get_module_and_class_name_from_encoded_string(class_type_parameter_id, value) - outer_class_type_parameter = provide_class_from_module_and_class_name(module_name, outer_class_name) - # From the outer class we can then retrieve the inner class which normally defines the users selection - return getattr(outer_class_type_parameter, class_name) +def deserialize_dict(instance, key, value): + # We must be dealing with an actual dict descriptor + sub_dict_keys = list(value.keys()) + dict_element = {} + # We need to watch out if a value of the dictionary is a sub state + for sub_dict_key in sub_dict_keys: + sub_dict_value = value.getProperty(sub_dict_key).value + if type(sub_dict_value) == PropertyManager and is_state(sub_dict_value): + sub_state = create_sub_state(sub_dict_value) + sub_dict_value_to_insert = sub_state + elif is_enum_type_parameter(sub_dict_value): + sub_dict_value_to_insert = deserialize_enum(sub_dict_value) + else: + sub_dict_value_to_insert = sub_dict_value + + if is_enum_type_parameter(sub_dict_key): + sub_dict_key = deserialize_enum(sub_dict_key) + + dict_element.update({sub_dict_key: sub_dict_value_to_insert}) + setattr(instance, key, dict_element) + + +def serialize_enum(value): + to_parse = value if isinstance(value, list) else [value] + serialized = [] + for val in to_parse: + assert (isinstance(val, Enum)) + module_name, class_name = get_module_and_class_name(val) + selected_val = val.value + + # Some devs use int for enums too so handle that + if isinstance(selected_val, int): + selected_val = INT_TAG + str(selected_val) + + serialized.append(ENUM_TYPE_TAG + module_name + SEPARATOR_SERIAL + class_name + SEPARATOR_SERIAL + selected_val) + + return serialized[0] if len(serialized) == 1 else serialized + + +def deserialize_enum(value): + # Mantid returns a std::vec type which needs to decay to a list + to_parse = [value] if isinstance(value, string_types) else [i for i in value] + parsed = [] + + for serialized_str in to_parse: + module_name, class_name, selection = get_module_and_class_name_from_encoded_string(ENUM_TYPE_TAG, + serialized_str) + enum_class = provide_class_from_module_and_class_name(module_name, class_name) + + selection = int(selection.replace(INT_TAG, '')) if INT_TAG in selection else selection + parsed_val = enum_class(selection) + parsed.append(parsed_val) + + return parsed[0] if len(parsed) == 1 else parsed def create_deserialized_sans_state_from_property_manager(property_manager): diff --git a/scripts/SANS/sans/state/state_functions.py b/scripts/SANS/sans/state/state_functions.py index aa45b9375e5b86ad824d0117f2b67eecdfc70045..23914165bdef00fd8b196d1b1b54afe3adb0af64 100644 --- a/scripts/SANS/sans/state/state_functions.py +++ b/scripts/SANS/sans/state/state_functions.py @@ -81,8 +81,8 @@ def set_detector_names(state, ipf_path, invalid_detector_types=None): if invalid_detector_types is None: invalid_detector_types = [] - lab_keyword = DetectorType.to_string(DetectorType.LAB) - hab_keyword = DetectorType.to_string(DetectorType.HAB) + lab_keyword = DetectorType.LAB.value + hab_keyword = DetectorType.HAB.value detector_names = {lab_keyword: "low-angle-detector-name", hab_keyword: "high-angle-detector-name"} detector_names_short = {lab_keyword: "low-angle-detector-short-name", @@ -96,7 +96,7 @@ def set_detector_names(state, ipf_path, invalid_detector_types=None): for detector_type in state.detectors: try: - if DetectorType.from_string(detector_type) in invalid_detector_types: + if DetectorType(detector_type) in invalid_detector_types: continue detector_name_tag = detector_names[detector_type] detector_name_short_tag = detector_names_short[detector_type] diff --git a/scripts/SANS/sans/state/wavelength.py b/scripts/SANS/sans/state/wavelength.py index d691bfd8470dee06e6d2236b923ce8fe70fbc4b8..1cbda44331e73b7c0cf599960f164ff552d1e94d 100644 --- a/scripts/SANS/sans/state/wavelength.py +++ b/scripts/SANS/sans/state/wavelength.py @@ -10,7 +10,7 @@ from __future__ import (absolute_import, division, print_function) import json import copy -from sans.state.state_base import (StateBase, PositiveFloatParameter, ClassTypeParameter, rename_descriptor_names, +from sans.state.state_base import (StateBase, PositiveFloatParameter, rename_descriptor_names, PositiveFloatListParameter) from sans.common.enums import (RebinType, RangeStepType, SANSFacility) from sans.state.state_functions import (is_not_none_and_first_larger_than_second, one_is_none, validation_message) @@ -22,15 +22,15 @@ from sans.state.automatic_setters import (automatic_setters) # ---------------------------------------------------------------------------------------------------------------------- @rename_descriptor_names class StateWavelength(StateBase): - rebin_type = ClassTypeParameter(RebinType) + rebin_type = RebinType.REBIN wavelength_low = PositiveFloatListParameter() wavelength_high = PositiveFloatListParameter() wavelength_step = PositiveFloatParameter() - wavelength_step_type = ClassTypeParameter(RangeStepType) + wavelength_step_type = RangeStepType.NOT_SET def __init__(self): super(StateWavelength, self).__init__() - self.rebin_type = RebinType.Rebin + self.rebin_type = RebinType.REBIN def validate(self): is_invalid = dict() @@ -68,6 +68,12 @@ class StateWavelengthBuilder(object): self.state.validate() return copy.copy(self.state) + def set_wavelength_step_type(self, val): + self.state.wavelength_step_type = val + + def set_rebin_type(self, val): + self.state.rebin_type = val + def get_wavelength_builder(data_info): facility = data_info.facility diff --git a/scripts/SANS/sans/state/wavelength_and_pixel_adjustment.py b/scripts/SANS/sans/state/wavelength_and_pixel_adjustment.py index 86c29c8185b14314e18194a2d7364f64ade6e5bc..75eb298ce4d473c91994b68e0c5b7f2eb6774805 100644 --- a/scripts/SANS/sans/state/wavelength_and_pixel_adjustment.py +++ b/scripts/SANS/sans/state/wavelength_and_pixel_adjustment.py @@ -12,7 +12,7 @@ from __future__ import (absolute_import, division, print_function) import json import copy from sans.state.state_base import (StateBase, rename_descriptor_names, StringParameter, - ClassTypeParameter, PositiveFloatParameter, DictParameter, PositiveFloatListParameter) + PositiveFloatParameter, DictParameter, PositiveFloatListParameter) from sans.state.state_functions import (is_not_none_and_first_larger_than_second, one_is_none, validation_message) from sans.common.enums import (RangeStepType, DetectorType, SANSFacility) from sans.state.automatic_setters import (automatic_setters) @@ -44,7 +44,7 @@ class StateWavelengthAndPixelAdjustment(StateBase): wavelength_low = PositiveFloatListParameter() wavelength_high = PositiveFloatListParameter() wavelength_step = PositiveFloatParameter() - wavelength_step_type = ClassTypeParameter(RangeStepType) + wavelength_step_type = RangeStepType.NOT_SET adjustment_files = DictParameter() @@ -52,8 +52,8 @@ class StateWavelengthAndPixelAdjustment(StateBase): def __init__(self): super(StateWavelengthAndPixelAdjustment, self).__init__() - self.adjustment_files = {DetectorType.to_string(DetectorType.LAB): StateAdjustmentFiles(), - DetectorType.to_string(DetectorType.HAB): StateAdjustmentFiles()} + self.adjustment_files = {DetectorType.LAB.value: StateAdjustmentFiles(), + DetectorType.HAB.value: StateAdjustmentFiles()} def validate(self): is_invalid = {} @@ -67,6 +67,12 @@ class StateWavelengthAndPixelAdjustment(StateBase): "wavelength_step_type": self.wavelength_step_type}) is_invalid.update(entry) + if self.wavelength_step_type is RangeStepType.NOT_SET: + entry = validation_message("A wavelength entry has not been set.", + "Make sure that all entries are set.", + {"wavelength_step_type": self.wavelength_step_type}) + is_invalid.update(entry) + if is_not_none_and_first_larger_than_second([self.wavelength_low, self.wavelength_high]): entry = validation_message("Incorrect wavelength bounds.", "Make sure that lower wavelength bound is smaller then upper bound.", @@ -75,8 +81,8 @@ class StateWavelengthAndPixelAdjustment(StateBase): is_invalid.update(entry) try: - self.adjustment_files[DetectorType.to_string(DetectorType.LAB)].validate() - self.adjustment_files[DetectorType.to_string(DetectorType.HAB)].validate() + self.adjustment_files[DetectorType.LAB.value].validate() + self.adjustment_files[DetectorType.HAB.value].validate() except ValueError as e: is_invalid.update({"adjustment_files": str(e)}) @@ -100,6 +106,9 @@ class StateWavelengthAndPixelAdjustmentBuilder(object): self.state.validate() return copy.copy(self.state) + def set_wavelength_step_type(self, val): + self.state.wavelength_step_type = val + def get_wavelength_and_pixel_adjustment_builder(data_info): facility = data_info.facility diff --git a/scripts/SANS/sans/test_helper/file_information_mock.py b/scripts/SANS/sans/test_helper/file_information_mock.py index 3aea9d97e84dee34b84adfdf75c226b1b4508cb1..66b2252f8761a61935dcc855cbc066399bc9cc2f 100644 --- a/scripts/SANS/sans/test_helper/file_information_mock.py +++ b/scripts/SANS/sans/test_helper/file_information_mock.py @@ -11,7 +11,7 @@ from sans.common.enums import (SANSFacility, SANSInstrument, FileType, SampleSha class SANSFileInformationMock(SANSFileInformation): def __init__(self, instrument=SANSInstrument.LOQ, facility=SANSFacility.ISIS, run_number=00000, file_name='file_name', - height=8.0, width=8.0, thickness=1.0, shape=SampleShape.FlatPlate, date='2012-10-22T22:41:27', periods=1, + height=8.0, width=8.0, thickness=1.0, shape=SampleShape.FLAT_PLATE, date='2012-10-22T22:41:27', periods=1, event_mode=True, added_data=False): super(SANSFileInformationMock, self).__init__(file_name) self._instrument = instrument @@ -42,7 +42,7 @@ class SANSFileInformationMock(SANSFileInformation): return self._periods def get_type(self): - return FileType.ISISNexus + return FileType.ISIS_NEXUS def is_event_mode(self): return self._event_mode diff --git a/scripts/SANS/sans/test_helper/mock_objects.py b/scripts/SANS/sans/test_helper/mock_objects.py index 618cd25978731a8307424fd0d4d6f232c36939f4..f23babbb31d406e2107854a2e065cfd7bfefa6c5 100644 --- a/scripts/SANS/sans/test_helper/mock_objects.py +++ b/scripts/SANS/sans/test_helper/mock_objects.py @@ -204,10 +204,10 @@ def create_mock_view(user_file_path, batch_file_path=None, row_user_file_path="" _q_1d_step = mock.PropertyMock(return_value=.001) type(view).q_1d_step = _q_1d_step - _q_1d_step_type = mock.PropertyMock(return_value=RangeStepType.Lin) + _q_1d_step_type = mock.PropertyMock(return_value=RangeStepType.LIN) type(view)._q_1d_step_type = _q_1d_step_type - _output_mode = mock.PropertyMock(return_value=OutputMode.PublishToADS) + _output_mode = mock.PropertyMock(return_value=OutputMode.PUBLISH_TO_ADS) type(view).output_mode = _output_mode _wavelength_range = mock.PropertyMock(return_value='') @@ -231,7 +231,7 @@ def create_mock_view2(user_file_path, batch_file_path=None): view._on_user_file_load = mock.MagicMock(side_effect=on_load_user_file_mock) view._on_batch_file_load = mock.MagicMock(side_effect=on_load_batch_file_mock) - _output_mode = mock.PropertyMock(return_value=OutputMode.PublishToADS) + _output_mode = mock.PropertyMock(return_value=OutputMode.PUBLISH_TO_ADS) type(view).output_mode = _output_mode return view diff --git a/scripts/SANS/sans/test_helper/test_director.py b/scripts/SANS/sans/test_helper/test_director.py index 107baacb7ff8bd757c886d2e5c15367686eea81b..44266e0661f34bb9e83af83dc70999859161f31f 100644 --- a/scripts/SANS/sans/test_helper/test_director.py +++ b/scripts/SANS/sans/test_helper/test_director.py @@ -21,13 +21,14 @@ from sans.state.wavelength_and_pixel_adjustment import get_wavelength_and_pixel_ from sans.state.adjustment import get_adjustment_builder from sans.state.convert_to_q import get_convert_to_q_builder -from sans.common.enums import (SANSFacility, ISISReductionMode, ReductionDimensionality, +from sans.common.enums import (SANSFacility, ReductionMode, ReductionDimensionality, FitModeForMerge, RebinType, RangeStepType, SaveType, FitType, SampleShape, SANSInstrument) from sans.test_helper.file_information_mock import SANSFileInformationMock class TestDirector(object): """ The purpose of this builder is to create a valid state object for tests""" + def __init__(self): super(TestDirector, self).__init__() self.data_state = None @@ -78,12 +79,12 @@ class TestDirector(object): # Build the SANSStateReduction if self.reduction_state is None: reduction_builder = get_reduction_mode_builder(self.data_state) - reduction_builder.set_reduction_mode(ISISReductionMode.Merged) - reduction_builder.set_reduction_dimensionality(ReductionDimensionality.OneDim) - reduction_builder.set_merge_fit_mode(FitModeForMerge.Both) + reduction_builder.set_reduction_dimensionality(ReductionDimensionality.ONE_DIM) + reduction_builder.set_merge_fit_mode(FitModeForMerge.BOTH) reduction_builder.set_merge_shift(324.2) reduction_builder.set_merge_scale(3420.98) self.reduction_state = reduction_builder.build() + self.reduction_state.reduction_mode = ReductionMode.MERGED # Build the SANSStateSliceEvent if self.slice_state is None: @@ -105,21 +106,21 @@ class TestDirector(object): wavelength_builder.set_wavelength_low([1.0]) wavelength_builder.set_wavelength_high([10.0]) wavelength_builder.set_wavelength_step(2.0) - wavelength_builder.set_wavelength_step_type(RangeStepType.Lin) - wavelength_builder.set_rebin_type(RebinType.Rebin) + wavelength_builder.set_wavelength_step_type(RangeStepType.LIN) + wavelength_builder.set_rebin_type(RebinType.REBIN) self.wavelength_state = wavelength_builder.build() # Build the SANSStateSave if self.save_state is None: save_builder = get_save_builder(self.data_state) save_builder.set_user_specified_output_name("test_file_name") - save_builder.set_file_format([SaveType.Nexus]) + save_builder.set_file_format([SaveType.NEXUS]) self.save_state = save_builder.build() # Build the SANSStateScale if self.scale_state is None: scale_builder = get_scale_builder(self.data_state, file_information) - scale_builder.set_shape(SampleShape.FlatPlate) + scale_builder.set_shape(SampleShape.FLAT_PLATE) scale_builder.set_width(1.0) scale_builder.set_height(2.0) scale_builder.set_thickness(3.0) @@ -133,8 +134,8 @@ class TestDirector(object): normalize_to_monitor_builder.set_wavelength_low([1.0]) normalize_to_monitor_builder.set_wavelength_high([10.0]) normalize_to_monitor_builder.set_wavelength_step(2.0) - normalize_to_monitor_builder.set_wavelength_step_type(RangeStepType.Lin) - normalize_to_monitor_builder.set_rebin_type(RebinType.Rebin) + normalize_to_monitor_builder.set_wavelength_step_type(RangeStepType.LIN) + normalize_to_monitor_builder.set_rebin_type(RebinType.REBIN) normalize_to_monitor_builder.set_background_TOF_general_start(1000.) normalize_to_monitor_builder.set_background_TOF_general_stop(2000.) normalize_to_monitor_builder.set_incident_monitor(1) @@ -147,19 +148,19 @@ class TestDirector(object): calculate_transmission_builder.set_wavelength_low([1.0]) calculate_transmission_builder.set_wavelength_high([10.0]) calculate_transmission_builder.set_wavelength_step(2.0) - calculate_transmission_builder.set_wavelength_step_type(RangeStepType.Lin) - calculate_transmission_builder.set_rebin_type(RebinType.Rebin) + calculate_transmission_builder.set_wavelength_step_type(RangeStepType.LIN) + calculate_transmission_builder.set_rebin_type(RebinType.REBIN) calculate_transmission_builder.set_background_TOF_general_start(1000.) calculate_transmission_builder.set_background_TOF_general_stop(2000.) - calculate_transmission_builder.set_Sample_fit_type(FitType.Linear) - calculate_transmission_builder.set_Sample_polynomial_order(0) - calculate_transmission_builder.set_Sample_wavelength_low(1.0) - calculate_transmission_builder.set_Sample_wavelength_high(10.0) - calculate_transmission_builder.set_Can_fit_type(FitType.Polynomial) - calculate_transmission_builder.set_Can_polynomial_order(3) - calculate_transmission_builder.set_Can_wavelength_low(10.0) - calculate_transmission_builder.set_Can_wavelength_high(20.0) + calculate_transmission_builder.set_sample_fit_type(FitType.LINEAR) + calculate_transmission_builder.set_sample_polynomial_order(0) + calculate_transmission_builder.set_sample_wavelength_low(1.0) + calculate_transmission_builder.set_sample_wavelength_high(10.0) + calculate_transmission_builder.set_can_fit_type(FitType.POLYNOMIAL) + calculate_transmission_builder.set_can_polynomial_order(3) + calculate_transmission_builder.set_can_wavelength_low(10.0) + calculate_transmission_builder.set_can_wavelength_high(20.0) calculate_transmission = calculate_transmission_builder.build() # Wavelength and pixel adjustment @@ -167,7 +168,7 @@ class TestDirector(object): wavelength_and_pixel_builder.set_wavelength_low([1.0]) wavelength_and_pixel_builder.set_wavelength_high([10.0]) wavelength_and_pixel_builder.set_wavelength_step(2.0) - wavelength_and_pixel_builder.set_wavelength_step_type(RangeStepType.Lin) + wavelength_and_pixel_builder.set_wavelength_step_type(RangeStepType.LIN) wavelength_and_pixel = wavelength_and_pixel_builder.build() # Adjustment @@ -180,7 +181,7 @@ class TestDirector(object): # SANSStateConvertToQ if self.convert_to_q_state is None: convert_to_q_builder = get_convert_to_q_builder(self.data_state) - convert_to_q_builder.set_reduction_dimensionality(ReductionDimensionality.OneDim) + convert_to_q_builder.set_reduction_dimensionality(ReductionDimensionality.ONE_DIM) convert_to_q_builder.set_use_gravity(False) convert_to_q_builder.set_radius_cutoff(0.002) convert_to_q_builder.set_wavelength_cutoff(12.) diff --git a/scripts/SANS/sans/user_file/settings_tags.py b/scripts/SANS/sans/user_file/settings_tags.py index abebed1c8c998fae3d64fbdbbe952280933a9d0b..ea9fda014700bdbba49018fcc35eb0e5f8b4279a 100644 --- a/scripts/SANS/sans/user_file/settings_tags.py +++ b/scripts/SANS/sans/user_file/settings_tags.py @@ -5,9 +5,10 @@ # & Institut Laue - Langevin # SPDX - License - Identifier: GPL - 3.0 + from __future__ import (absolute_import, division, print_function) + from collections import namedtuple -from sans.common.enums import serializable_enum +from mantid.py3compat import Enum # ---------------------------------------------------------------------------------------------------------------------- # Named tuples for passing around data in a structured way, a bit like a plain old c-struct. @@ -51,99 +52,147 @@ monitor_file = namedtuple('monitor_file', 'file_path, detector_type') det_fit_range = namedtuple('det_fit_range', 'start, stop, use_fit') -# ------------------------------------------------------------------ -# --- State director keys ------------------------------------------ -# ------------------------------------------------------------------ - - -# --- DET -@serializable_enum("reduction_mode", "rescale", "shift", "rescale_fit", "shift_fit", "correction_x", "correction_y", - "correction_z", "correction_rotation", "correction_radius", "correction_translation", - "correction_x_tilt", "correction_y_tilt", "merge_range", "instrument") -class DetectorId(object): - pass - - -# --- LIMITS -@serializable_enum("angle", "events_binning", "events_binning_range", "radius_cut", "wavelength_cut", "radius", "q", - "qxy", "wavelength") -class LimitsId(object): - pass - - -# --- MASK -@serializable_enum("line", "time", "time_detector", "clear_detector_mask", "clear_time_mask", "single_spectrum_mask", - "spectrum_range_mask", "vertical_single_strip_mask", "vertical_range_strip_mask", "file", - "horizontal_single_strip_mask", "horizontal_range_strip_mask", "block", "block_cross") -class MaskId(object): - pass - - -# --- SAMPLE -@serializable_enum("path", "offset") -class SampleId(object): - pass - - -# --- SET -@serializable_enum("scales", "centre", "centre_HAB") -class SetId(object): - pass - - -# --- TRANS -@serializable_enum("spec", "spec_4_shift", "spec_5_shift", "radius", "roi", "mask", "sample_workspace", "can_workspace") -class TransId(object): - pass - - -# --- TUBECALIBFILE -@serializable_enum("file") -class TubeCalibrationFileId(object): - pass - - -# -- QRESOLUTION -@serializable_enum("on", "delta_r", "collimation_length", "a1", "a2", "h1", "w1", "h2", "w2", "moderator") -class QResolutionId(object): - pass - +class DetectorId(Enum): + CORRECTION_X = "correction_x" + CORRECTION_X_TILT = "correction_x_tilt" + CORRECTION_Y = "correction_y" + CORRECTION_Y_TILT = "correction_y_tilt" + CORRECTION_Z = "correction_z" + CORRECTION_RADIUS = "correction_radius" + CORRECTION_ROTATION = "correction_rotation" + CORRECTION_TRANSLATION = "correction_translation" + MERGE_RANGE = "merge_range" + INSTRUMENT = "instrument" + REDUCTION_MODE = "reduction_mode" + RESCALE = "rescale" + RESCALE_FIT = "rescale_fit" + SHIFT = "shift" + SHIFT_FIT = "shift_fit" + + +class LimitsId(Enum): + ANGLE = "angle" + EVENTS_BINNING = "events_binning" + EVENTS_BINNING_RANGE = "events_binning_range" + RADIUS = "radius" + RADIUS_CUT = "radius_cut" + Q = "q" + QXY = "qxy" + WAVELENGTH = "wavelength" + WAVELENGTH_CUT = "wavelength_cut" + + +class MaskId(Enum): + BLOCK = "block" + BLOCK_CROSS = "block_cross" + CLEAR_DETECTOR_MASK = "clear_detector_mask" + CLEAR_TIME_MASK = "clear_time_mask" + FILE = "file" + LINE = "line" + HORIZONTAL_SINGLE_STRIP_MASK = "horizontal_single_strip_mask" + HORIZONTAL_RANGE_STRIP_MASK = "horizontal_range_strip_mask" + TIME = "time" + TIME_DETECTOR = "time_detector" + SINGLE_SPECTRUM_MASK = "single_spectrum_mask" + SPECTRUM_RANGE_MASK = "spectrum_range_mask" + VERTICAL_SINGLE_STRIP_MASK = "vertical_single_strip_mask" + VERTICAL_RANGE_STRIP_MASK = "vertical_range_strip_mask" + + +class SampleId(Enum): + PATH = "path" + OFFSET = "offset" + + +class SetId(Enum): + CENTRE = "centre" + CENTRE_HAB = "centre_HAB" + SCALES = "scales" + + +class TransId(Enum): + CAN_WORKSPACE = "can_workspace" + RADIUS = "radius" + ROI = "roi" + MASK = "mask" + SAMPLE_WORKSPACE = "sample_workspace" + SPEC = "spec" + SPEC_4_SHIFT = "spec_4_shift" + SPEC_5_SHIFT = "spec_5_shift" + + +class TubeCalibrationFileId(Enum): + FILE = "file" + + +class QResolutionId(Enum): + A1 = "a1" + A2 = "a2" + COLLIMATION_LENGTH = "collimation_length" + DELTA_R = "delta_r" + H1 = "h1" + H2 = "h2" + MODERATOR = "moderator" + ON = "on" + W1 = "w1" + W2 = "w2" + + +class FitId(Enum): + CLEAR = "clear" + GENERAL = "general" + MONITOR_TIMES = "monitor_times" + + +class GravityId(Enum): + EXTRA_LENGTH = "extra_length" + ON_OFF = "on_off" + + +class MonId(Enum): + DIRECT = "direct" + FLAT = "flat" + HAB = "hab" + INTERPOLATE = "interpolate" + LENGTH = "length" + SPECTRUM = "spectrum" + SPECTRUM_TRANS = "spectrum_trans" + + +class PrintId(Enum): + PRINT_LINE = "print_line" -# --- FIT -@serializable_enum("clear", "monitor_times", "general") -class FitId(object): - pass +class BackId(Enum): + ALL_MONITORS = "all_monitors" + MONITOR_OFF = "monitor_off" + SINGLE_MONITORS = "single_monitors" + TRANS = "trans" -# --- GRAVITY -@serializable_enum("on_off", "extra_length") -class GravityId(object): - pass +class OtherId(Enum): + EVENT_SLICES = "event_slices" -# --- MON -@serializable_enum("length", "direct", "flat", "hab", "spectrum", "spectrum_trans", "interpolate") -class MonId(object): - pass + MERGE_MASK = "merge_mask" + MERGE_MIN = "merge_min" + MERGE_MAX = "merge_max" + REDUCTION_DIMENSIONALITY = "reduction_dimensionality" -# --- PRINT -@serializable_enum("print_line") -class PrintId(object): - pass + SAVE_AS_ZERO_ERROR_FREE = "save_as_zero_error_free" + SAVE_TYPES = "save_types" + SAMPLE_HEIGHT = "sample_height" + SAMPLE_WIDTH = "sample_width" + SAMPLE_THICKNESS = "sample_thickness" + SAMPLE_SHAPE = "sample_shape" -# -- BACK -@serializable_enum("all_monitors", "single_monitors", "monitor_off", "trans") -class BackId(object): - pass + USE_COMPATIBILITY_MODE = "use_compatibility_mode" + USE_EVENT_SLICE_OPTIMISATION = "use_event_slice_optimisation" + USE_FULL_WAVELENGTH_RANGE = "use_full_wavelength_range" + USE_REDUCTION_MODE_AS_SUFFIX = "use_reduction_mode_as_suffix" + USER_SPECIFIED_OUTPUT_NAME = "user_specified_output_name" + USER_SPECIFIED_OUTPUT_NAME_SUFFIX = "user_specified_output_name_suffix" -# -- OTHER - not settable in user file -@serializable_enum("reduction_dimensionality", "use_full_wavelength_range", "event_slices", - "use_compatibility_mode", "save_types", "save_as_zero_error_free", "user_specified_output_name", - "user_specified_output_name_suffix", "use_reduction_mode_as_suffix", "sample_width", "sample_height", - "sample_thickness", "sample_shape", "merge_mask", "merge_min", "merge_max", "wavelength_range", - "use_event_slice_optimisation") -class OtherId(object): - pass + WAVELENGTH_RANGE = "wavelength_range" diff --git a/scripts/SANS/sans/user_file/state_director.py b/scripts/SANS/sans/user_file/state_director.py index 18ba38f3ac444e8ee737f722676eec5bfa103079..635189b63010618f620dfa859fec02d8b53eb573 100644 --- a/scripts/SANS/sans/user_file/state_director.py +++ b/scripts/SANS/sans/user_file/state_director.py @@ -48,9 +48,9 @@ def log_non_existing_field(field): def convert_detector(detector_type): if detector_type is DetectorType.HAB: - detector_type_as_string = DetectorType.to_string(DetectorType.HAB) + detector_type_as_string = DetectorType.HAB.value elif detector_type is DetectorType.LAB: - detector_type_as_string = DetectorType.to_string(DetectorType.LAB) + detector_type_as_string = DetectorType.LAB.value else: raise RuntimeError("UserFileStateDirector: Cannot convert detector {0}".format(detector_type)) return detector_type_as_string @@ -86,10 +86,10 @@ def convert_mm_to_m(value): def set_background_tof_general(builder, user_file_items): # The general background settings - if BackId.all_monitors in user_file_items: - back_all_monitors = user_file_items[BackId.all_monitors] + if BackId.ALL_MONITORS in user_file_items: + back_all_monitors = user_file_items[BackId.ALL_MONITORS] # Should the user have chosen several values, then the last element is selected - check_if_contains_only_one_element(back_all_monitors, BackId.all_monitors) + check_if_contains_only_one_element(back_all_monitors, BackId.ALL_MONITORS) back_all_monitors = back_all_monitors[-1] builder.set_background_TOF_general_start(back_all_monitors.start) builder.set_background_TOF_general_stop(back_all_monitors.stop) @@ -98,16 +98,16 @@ def set_background_tof_general(builder, user_file_items): def set_background_tof_monitor(builder, user_file_items): # The monitor off switches. Get all monitors which should not have an individual background setting monitor_exclusion_list = [] - if BackId.monitor_off in user_file_items: - back_monitor_off = user_file_items[BackId.monitor_off] + if BackId.MONITOR_OFF in user_file_items: + back_monitor_off = user_file_items[BackId.MONITOR_OFF] monitor_exclusion_list = list(back_monitor_off.values()) # Get all individual monitor background settings. But ignore those settings where there was an explicit # off setting. Those monitors were collected in the monitor_exclusion_list collection - if BackId.single_monitors in user_file_items: + if BackId.SINGLE_MONITORS in user_file_items: background_tof_monitor_start = {} background_tof_monitor_stop = {} - back_single_monitors = user_file_items[BackId.single_monitors] + back_single_monitors = user_file_items[BackId.SINGLE_MONITORS] for element in back_single_monitors: monitor = element.monitor if monitor not in monitor_exclusion_list: @@ -119,21 +119,21 @@ def set_background_tof_monitor(builder, user_file_items): def set_wavelength_limits(builder, user_file_items): - if LimitsId.wavelength in user_file_items: - wavelength_limits = user_file_items[LimitsId.wavelength] - check_if_contains_only_one_element(wavelength_limits, LimitsId.wavelength) + if LimitsId.WAVELENGTH in user_file_items: + wavelength_limits = user_file_items[LimitsId.WAVELENGTH] + check_if_contains_only_one_element(wavelength_limits, LimitsId.WAVELENGTH) wavelength_limits = wavelength_limits[-1] - if wavelength_limits.step_type in [RangeStepType.RangeLin, RangeStepType.RangeLog]: - wavelength_range = user_file_items[OtherId.wavelength_range] - check_if_contains_only_one_element(wavelength_range, OtherId.wavelength_range) + if wavelength_limits.step_type in [RangeStepType.RANGE_LIN, RangeStepType.RANGE_LOG]: + wavelength_range = user_file_items[OtherId.WAVELENGTH_RANGE] + check_if_contains_only_one_element(wavelength_range, OtherId.WAVELENGTH_RANGE) wavelength_range = wavelength_range[-1] wavelength_start, wavelength_stop = get_ranges_from_event_slice_setting(wavelength_range) wavelength_start = [min(wavelength_start)] + wavelength_start wavelength_stop = [max(wavelength_stop)] + wavelength_stop - wavelength_step_type = RangeStepType.Lin if wavelength_limits.step_type is RangeStepType.RangeLin \ - else RangeStepType.Log + wavelength_step_type = RangeStepType.LIN if wavelength_limits.step_type is RangeStepType.RANGE_LIN \ + else RangeStepType.LOG builder.set_wavelength_low(wavelength_start) builder.set_wavelength_high(wavelength_stop) @@ -147,10 +147,10 @@ def set_wavelength_limits(builder, user_file_items): def set_prompt_peak_correction(builder, user_file_items): - if FitId.monitor_times in user_file_items: - fit_monitor_times = user_file_items[FitId.monitor_times] + if FitId.MONITOR_TIMES in user_file_items: + fit_monitor_times = user_file_items[FitId.MONITOR_TIMES] # Should the user have chosen several values, then the last element is selected - check_if_contains_only_one_element(fit_monitor_times, FitId.monitor_times) + check_if_contains_only_one_element(fit_monitor_times, FitId.MONITOR_TIMES) fit_monitor_times = fit_monitor_times[-1] builder.set_prompt_peak_correction_min(fit_monitor_times.start) builder.set_prompt_peak_correction_max(fit_monitor_times.stop) @@ -369,8 +369,8 @@ class StateDirectorISIS(object): # --------------------------- # Correction for X, Y, Z # --------------------------- - if DetectorId.correction_x in user_file_items: - corrections_in_x = user_file_items[DetectorId.correction_x] + if DetectorId.CORRECTION_X in user_file_items: + corrections_in_x = user_file_items[DetectorId.CORRECTION_X] for correction_x in corrections_in_x: if correction_x.detector_type is DetectorType.HAB: self._move_builder.set_HAB_x_translation_correction(convert_mm_to_m(correction_x.entry)) @@ -380,8 +380,8 @@ class StateDirectorISIS(object): raise RuntimeError("UserFileStateDirector: An unknown detector {0} was used for the" " x correction.".format(correction_x.detector_type)) - if DetectorId.correction_y in user_file_items: - corrections_in_y = user_file_items[DetectorId.correction_y] + if DetectorId.CORRECTION_Y in user_file_items: + corrections_in_y = user_file_items[DetectorId.CORRECTION_Y] for correction_y in corrections_in_y: if correction_y.detector_type is DetectorType.HAB: self._move_builder.set_HAB_y_translation_correction(convert_mm_to_m(correction_y.entry)) @@ -391,8 +391,8 @@ class StateDirectorISIS(object): raise RuntimeError("UserFileStateDirector: An unknown detector {0} was used for the" " y correction.".format(correction_y.detector_type)) - if DetectorId.correction_z in user_file_items: - corrections_in_z = user_file_items[DetectorId.correction_z] + if DetectorId.CORRECTION_Z in user_file_items: + corrections_in_z = user_file_items[DetectorId.CORRECTION_Z] for correction_z in corrections_in_z: if correction_z.detector_type is DetectorType.HAB: self._move_builder.set_HAB_z_translation_correction(convert_mm_to_m(correction_z.entry)) @@ -405,10 +405,10 @@ class StateDirectorISIS(object): # --------------------------- # Correction for Rotation # --------------------------- - if DetectorId.correction_rotation in user_file_items: - rotation_correction = user_file_items[DetectorId.correction_rotation] + if DetectorId.CORRECTION_ROTATION in user_file_items: + rotation_correction = user_file_items[DetectorId.CORRECTION_ROTATION] # Should the user have chosen several values, then the last element is selected - check_if_contains_only_one_element(rotation_correction, DetectorId.correction_rotation) + check_if_contains_only_one_element(rotation_correction, DetectorId.CORRECTION_ROTATION) rotation_correction = rotation_correction[-1] if rotation_correction.detector_type is DetectorType.HAB: self._move_builder.set_HAB_rotation_correction(rotation_correction.entry) @@ -421,8 +421,8 @@ class StateDirectorISIS(object): # --------------------------- # Correction for Radius # --------------------------- - if DetectorId.correction_radius in user_file_items: - radius_corrections = user_file_items[DetectorId.correction_radius] + if DetectorId.CORRECTION_RADIUS in user_file_items: + radius_corrections = user_file_items[DetectorId.CORRECTION_RADIUS] for radius_correction in radius_corrections: if radius_correction.detector_type is DetectorType.HAB: self._move_builder.set_HAB_radius_correction(convert_mm_to_m(radius_correction.entry)) @@ -435,8 +435,8 @@ class StateDirectorISIS(object): # --------------------------- # Correction for Translation # --------------------------- - if DetectorId.correction_translation in user_file_items: - side_corrections = user_file_items[DetectorId.correction_translation] + if DetectorId.CORRECTION_TRANSLATION in user_file_items: + side_corrections = user_file_items[DetectorId.CORRECTION_TRANSLATION] for side_correction in side_corrections: if side_correction.detector_type is DetectorType.HAB: self._move_builder.set_HAB_side_correction(convert_mm_to_m(side_correction.entry)) @@ -449,8 +449,8 @@ class StateDirectorISIS(object): # --------------------------- # Tilt # --------------------------- - if DetectorId.correction_x_tilt in user_file_items: - tilt_correction = user_file_items[DetectorId.correction_x_tilt] + if DetectorId.CORRECTION_X_TILT in user_file_items: + tilt_correction = user_file_items[DetectorId.CORRECTION_X_TILT] tilt_correction = tilt_correction[-1] if tilt_correction.detector_type is DetectorType.HAB: self._move_builder.set_HAB_x_tilt_correction(tilt_correction.entry) @@ -460,8 +460,8 @@ class StateDirectorISIS(object): raise RuntimeError("UserFileStateDirector: An unknown detector {0} was used for the" " titlt correction.".format(tilt_correction.detector_type)) - if DetectorId.correction_y_tilt in user_file_items: - tilt_correction = user_file_items[DetectorId.correction_y_tilt] + if DetectorId.CORRECTION_Y_TILT in user_file_items: + tilt_correction = user_file_items[DetectorId.CORRECTION_Y_TILT] tilt_correction = tilt_correction[-1] if tilt_correction.detector_type is DetectorType.HAB: self._move_builder.set_HAB_y_tilt_correction(tilt_correction.entry) @@ -474,7 +474,7 @@ class StateDirectorISIS(object): # --------------------------- # Sample offset # --------------------------- - set_single_entry(self._move_builder, "set_sample_offset", SampleId.offset, + set_single_entry(self._move_builder, "set_sample_offset", SampleId.OFFSET, user_file_items, apply_to_value=convert_mm_to_m) # --------------------------- @@ -498,17 +498,17 @@ class StateDirectorISIS(object): else: log_non_existing_field("set_monitor_{0}_offset".format(spec_num)) - if TransId.spec_4_shift in user_file_items: - parse_shift(key_to_parse=TransId.spec_4_shift, spec_num=4) + if TransId.SPEC_4_SHIFT in user_file_items: + parse_shift(key_to_parse=TransId.SPEC_4_SHIFT, spec_num=4) - if TransId.spec_5_shift in user_file_items: - parse_shift(key_to_parse=TransId.spec_5_shift, spec_num=5) + if TransId.SPEC_5_SHIFT in user_file_items: + parse_shift(key_to_parse=TransId.SPEC_5_SHIFT, spec_num=5) # --------------------------- # Beam Centre, this can be for HAB and LAB # --------------------------- - if SetId.centre in user_file_items: - beam_centres = user_file_items[SetId.centre] + if SetId.CENTRE in user_file_items: + beam_centres = user_file_items[SetId.CENTRE] beam_centres_for_hab = [beam_centre for beam_centre in beam_centres if beam_centre.detector_type is DetectorType.HAB] beam_centres_for_lab = [beam_centre for beam_centre in beam_centres if beam_centre.detector_type @@ -538,13 +538,13 @@ class StateDirectorISIS(object): # ------------------------ # Reduction mode # ------------------------ - set_single_entry(self._reduction_builder, "set_reduction_mode", DetectorId.reduction_mode, user_file_items) + set_single_entry(self._reduction_builder, "set_reduction_mode", DetectorId.REDUCTION_MODE, user_file_items) # ------------------------------- # Shift and rescale # ------------------------------- - set_single_entry(self._reduction_builder, "set_merge_scale", DetectorId.rescale, user_file_items) - set_single_entry(self._reduction_builder, "set_merge_shift", DetectorId.shift, user_file_items) + set_single_entry(self._reduction_builder, "set_merge_scale", DetectorId.RESCALE, user_file_items) + set_single_entry(self._reduction_builder, "set_merge_shift", DetectorId.SHIFT, user_file_items) # ------------------------------- # User masking @@ -552,10 +552,10 @@ class StateDirectorISIS(object): merge_min = None merge_max = None merge_mask = False - if DetectorId.merge_range in user_file_items: - merge_range = user_file_items[DetectorId.merge_range] + if DetectorId.MERGE_RANGE in user_file_items: + merge_range = user_file_items[DetectorId.MERGE_RANGE] # Should the user have chosen several values, then the last element is selected - check_if_contains_only_one_element(merge_range, DetectorId.rescale_fit) + check_if_contains_only_one_element(merge_range, DetectorId.RESCALE_FIT) merge_range = merge_range[-1] merge_min = merge_range.start merge_max = merge_range.stop @@ -571,10 +571,10 @@ class StateDirectorISIS(object): q_range_min_scale = None q_range_max_scale = None has_rescale_fit = False - if DetectorId.rescale_fit in user_file_items: - rescale_fits = user_file_items[DetectorId.rescale_fit] + if DetectorId.RESCALE_FIT in user_file_items: + rescale_fits = user_file_items[DetectorId.RESCALE_FIT] # Should the user have chosen several values, then the last element is selected - check_if_contains_only_one_element(rescale_fits, DetectorId.rescale_fit) + check_if_contains_only_one_element(rescale_fits, DetectorId.RESCALE_FIT) rescale_fit = rescale_fits[-1] q_range_min_scale = rescale_fit.start q_range_max_scale = rescale_fit.stop @@ -583,17 +583,17 @@ class StateDirectorISIS(object): q_range_min_shift = None q_range_max_shift = None has_shift_fit = False - if DetectorId.shift_fit in user_file_items: - shift_fits = user_file_items[DetectorId.shift_fit] + if DetectorId.SHIFT_FIT in user_file_items: + shift_fits = user_file_items[DetectorId.SHIFT_FIT] # Should the user have chosen several values, then the last element is selected - check_if_contains_only_one_element(shift_fits, DetectorId.shift_fit) + check_if_contains_only_one_element(shift_fits, DetectorId.SHIFT_FIT) shift_fit = shift_fits[-1] q_range_min_shift = shift_fit.start q_range_max_shift = shift_fit.stop has_shift_fit = shift_fit.use_fit if has_rescale_fit and has_shift_fit: - self._reduction_builder.set_merge_fit_mode(FitModeForMerge.Both) + self._reduction_builder.set_merge_fit_mode(FitModeForMerge.BOTH) min_q = get_min_q_boundary(q_range_min_scale, q_range_min_shift) max_q = get_max_q_boundary(q_range_max_scale, q_range_max_shift) if min_q: @@ -601,24 +601,24 @@ class StateDirectorISIS(object): if max_q: self._reduction_builder.set_merge_range_max(max_q) elif has_rescale_fit and not has_shift_fit: - self._reduction_builder.set_merge_fit_mode(FitModeForMerge.ScaleOnly) + self._reduction_builder.set_merge_fit_mode(FitModeForMerge.SCALE_ONLY) if q_range_min_scale: self._reduction_builder.set_merge_range_min(q_range_min_scale) if q_range_max_scale: self._reduction_builder.set_merge_range_max(q_range_max_scale) elif not has_rescale_fit and has_shift_fit: - self._reduction_builder.set_merge_fit_mode(FitModeForMerge.ShiftOnly) + self._reduction_builder.set_merge_fit_mode(FitModeForMerge.SHIFT_ONLY) if q_range_min_shift: self._reduction_builder.set_merge_range_min(q_range_min_shift) if q_range_max_shift: self._reduction_builder.set_merge_range_max(q_range_max_shift) else: - self._reduction_builder.set_merge_fit_mode(FitModeForMerge.NoFit) + self._reduction_builder.set_merge_fit_mode(FitModeForMerge.NO_FIT) # ------------------------ # Reduction Dimensionality # ------------------------ - set_single_entry(self._reduction_builder, "set_reduction_dimensionality", OtherId.reduction_dimensionality, + set_single_entry(self._reduction_builder, "set_reduction_dimensionality", OtherId.REDUCTION_DIMENSIONALITY, user_file_items) def _set_up_mask_state(self, user_file_items): # noqa @@ -643,10 +643,10 @@ class StateDirectorISIS(object): # --------------------------------- # 1. Line Mask # --------------------------------- - if MaskId.line in user_file_items: - mask_lines = user_file_items[MaskId.line] + if MaskId.LINE in user_file_items: + mask_lines = user_file_items[MaskId.LINE] # If there were several arms specified then we take only the last - check_if_contains_only_one_element(mask_lines, MaskId.line) + check_if_contains_only_one_element(mask_lines, MaskId.LINE) mask_line = mask_lines[-1] # We need the width and the angle angle = mask_line.angle @@ -669,8 +669,8 @@ class StateDirectorISIS(object): # --------------------------------- # 2. General time mask # --------------------------------- - if MaskId.time in user_file_items: - mask_time_general = user_file_items[MaskId.time] + if MaskId.TIME in user_file_items: + mask_time_general = user_file_items[MaskId.TIME] start_time = [] stop_time = [] for times in mask_time_general: @@ -686,8 +686,8 @@ class StateDirectorISIS(object): # --------------------------------- # 3. Detector-bound time mask # --------------------------------- - if MaskId.time_detector in user_file_items: - mask_times = user_file_items[MaskId.time_detector] + if MaskId.TIME_DETECTOR in user_file_items: + mask_times = user_file_items[MaskId.TIME_DETECTOR] start_times_hab = [] stop_times_hab = [] start_times_lab = [] @@ -718,9 +718,9 @@ class StateDirectorISIS(object): # --------------------------------- # 4. Clear detector # --------------------------------- - if MaskId.clear_detector_mask in user_file_items: - clear_detector_mask = user_file_items[MaskId.clear_detector_mask] - check_if_contains_only_one_element(clear_detector_mask, MaskId.clear_detector_mask) + if MaskId.CLEAR_DETECTOR_MASK in user_file_items: + clear_detector_mask = user_file_items[MaskId.CLEAR_DETECTOR_MASK] + check_if_contains_only_one_element(clear_detector_mask, MaskId.CLEAR_DETECTOR_MASK) # We select the entry which was added last. clear_detector_mask = clear_detector_mask[-1] self._mask_builder.set_clear(clear_detector_mask) @@ -728,9 +728,9 @@ class StateDirectorISIS(object): # --------------------------------- # 5. Clear time # --------------------------------- - if MaskId.clear_time_mask in user_file_items: - clear_time_mask = user_file_items[MaskId.clear_time_mask] - check_if_contains_only_one_element(clear_time_mask, MaskId.clear_time_mask) + if MaskId.CLEAR_TIME_MASK in user_file_items: + clear_time_mask = user_file_items[MaskId.CLEAR_TIME_MASK] + check_if_contains_only_one_element(clear_time_mask, MaskId.CLEAR_TIME_MASK) # We select the entry which was added last. clear_time_mask = clear_time_mask[-1] self._mask_builder.set_clear_time(clear_time_mask) @@ -738,16 +738,16 @@ class StateDirectorISIS(object): # --------------------------------- # 6. Single Spectrum # --------------------------------- - if MaskId.single_spectrum_mask in user_file_items: - single_spectra = user_file_items[MaskId.single_spectrum_mask] + if MaskId.SINGLE_SPECTRUM_MASK in user_file_items: + single_spectra = user_file_items[MaskId.SINGLE_SPECTRUM_MASK] # Note that we are using an unusual setter here. Check mask.py for why we are doing this. self._mask_builder.set_single_spectra_on_detector(single_spectra) # --------------------------------- # 7. Spectrum Range # --------------------------------- - if MaskId.spectrum_range_mask in user_file_items: - spectrum_ranges = user_file_items[MaskId.spectrum_range_mask] + if MaskId.SPECTRUM_RANGE_MASK in user_file_items: + spectrum_ranges = user_file_items[MaskId.SPECTRUM_RANGE_MASK] start_range = [] stop_range = [] for spectrum_range in spectrum_ranges: @@ -763,8 +763,8 @@ class StateDirectorISIS(object): # --------------------------------- # 8. Vertical single strip # --------------------------------- - if MaskId.vertical_single_strip_mask in user_file_items: - single_vertical_strip_masks = user_file_items[MaskId.vertical_single_strip_mask] + if MaskId.VERTICAL_SINGLE_STRIP_MASK in user_file_items: + single_vertical_strip_masks = user_file_items[MaskId.VERTICAL_SINGLE_STRIP_MASK] entry_hab = [] entry_lab = [] for single_vertical_strip_mask in single_vertical_strip_masks: @@ -785,8 +785,8 @@ class StateDirectorISIS(object): # --------------------------------- # 9. Vertical range strip # --------------------------------- - if MaskId.vertical_range_strip_mask in user_file_items: - range_vertical_strip_masks = user_file_items[MaskId.vertical_range_strip_mask] + if MaskId.VERTICAL_RANGE_STRIP_MASK in user_file_items: + range_vertical_strip_masks = user_file_items[MaskId.VERTICAL_RANGE_STRIP_MASK] start_hab = [] stop_hab = [] start_lab = [] @@ -815,8 +815,8 @@ class StateDirectorISIS(object): # --------------------------------- # 10. Horizontal single strip # --------------------------------- - if MaskId.horizontal_single_strip_mask in user_file_items: - single_horizontal_strip_masks = user_file_items[MaskId.horizontal_single_strip_mask] + if MaskId.HORIZONTAL_SINGLE_STRIP_MASK in user_file_items: + single_horizontal_strip_masks = user_file_items[MaskId.HORIZONTAL_SINGLE_STRIP_MASK] entry_hab = [] entry_lab = [] for single_horizontal_strip_mask in single_horizontal_strip_masks: @@ -837,8 +837,8 @@ class StateDirectorISIS(object): # --------------------------------- # 11. Horizontal range strip # --------------------------------- - if MaskId.horizontal_range_strip_mask in user_file_items: - range_horizontal_strip_masks = user_file_items[MaskId.horizontal_range_strip_mask] + if MaskId.HORIZONTAL_RANGE_STRIP_MASK in user_file_items: + range_horizontal_strip_masks = user_file_items[MaskId.HORIZONTAL_RANGE_STRIP_MASK] start_hab = [] stop_hab = [] start_lab = [] @@ -867,8 +867,8 @@ class StateDirectorISIS(object): # --------------------------------- # 12. Block # --------------------------------- - if MaskId.block in user_file_items: - blocks = user_file_items[MaskId.block] + if MaskId.BLOCK in user_file_items: + blocks = user_file_items[MaskId.BLOCK] horizontal_start_hab = [] horizontal_stop_hab = [] vertical_start_hab = [] @@ -910,8 +910,8 @@ class StateDirectorISIS(object): # --------------------------------- # 13. Block cross # --------------------------------- - if MaskId.block_cross in user_file_items: - block_crosses = user_file_items[MaskId.block_cross] + if MaskId.BLOCK_CROSS in user_file_items: + block_crosses = user_file_items[MaskId.BLOCK_CROSS] horizontal_hab = [] vertical_hab = [] horizontal_lab = [] @@ -939,10 +939,10 @@ class StateDirectorISIS(object): # ------------------------------------------------------------ # 14. Angles --> they are specified in L/Phi # ----------------------------------------------------------- - if LimitsId.angle in user_file_items: - angles = user_file_items[LimitsId.angle] + if LimitsId.ANGLE in user_file_items: + angles = user_file_items[LimitsId.ANGLE] # Should the user have chosen several values, then the last element is selected - check_if_contains_only_one_element(angles, LimitsId.angle) + check_if_contains_only_one_element(angles, LimitsId.ANGLE) angle = angles[-1] self._mask_builder.set_phi_min(angle.min) self._mask_builder.set_phi_max(angle.max) @@ -951,17 +951,17 @@ class StateDirectorISIS(object): # ------------------------------------------------------------ # 15. Maskfiles # ----------------------------------------------------------- - if MaskId.file in user_file_items: - mask_files = user_file_items[MaskId.file] + if MaskId.FILE in user_file_items: + mask_files = user_file_items[MaskId.FILE] self._mask_builder.set_mask_files(mask_files) # ------------------------------------------------------------ # 16. Radius masks # ----------------------------------------------------------- - if LimitsId.radius in user_file_items: - radii = user_file_items[LimitsId.radius] + if LimitsId.RADIUS in user_file_items: + radii = user_file_items[LimitsId.RADIUS] # Should the user have chosen several values, then the last element is selected - check_if_contains_only_one_element(radii, LimitsId.radius) + check_if_contains_only_one_element(radii, LimitsId.RADIUS) radius = radii[-1] if radius.start > radius.stop > 0: raise RuntimeError("UserFileStateDirector: The inner radius {0} appears to be larger that the outer" @@ -976,9 +976,9 @@ class StateDirectorISIS(object): def _set_up_slice_event_state(self, user_file_items): # Setting up the slice limits is current - if OtherId.event_slices in user_file_items: - event_slices = user_file_items[OtherId.event_slices] - check_if_contains_only_one_element(event_slices, OtherId.event_slices) + if OtherId.EVENT_SLICES in user_file_items: + event_slices = user_file_items[OtherId.EVENT_SLICES] + check_if_contains_only_one_element(event_slices, OtherId.EVENT_SLICES) event_slices = event_slices[-1] # The events binning can come in three forms. # 1. As a simple range object @@ -997,60 +997,60 @@ class StateDirectorISIS(object): def _set_up_scale_state(self, user_file_items): # We only extract the first entry here, ie the s entry. Although there are other entries which a user can # specify such as a, b, c, d they seem to be - if SetId.scales in user_file_items: - scales = user_file_items[SetId.scales] - check_if_contains_only_one_element(scales, SetId.scales) + if SetId.SCALES in user_file_items: + scales = user_file_items[SetId.SCALES] + check_if_contains_only_one_element(scales, SetId.SCALES) scales = scales[-1] self._scale_builder.set_scale(scales.s) # We can also have settings for the sample geometry (Note that at the moment this is not settable via the # user file nor the command line interface - if OtherId.sample_shape in user_file_items: - sample_shape = user_file_items[OtherId.sample_shape] - check_if_contains_only_one_element(sample_shape, OtherId.sample_shape) + if OtherId.SAMPLE_SHAPE in user_file_items: + sample_shape = user_file_items[OtherId.SAMPLE_SHAPE] + check_if_contains_only_one_element(sample_shape, OtherId.SAMPLE_SHAPE) sample_shape = sample_shape[-1] self._scale_builder.set_shape(sample_shape) - if OtherId.sample_width in user_file_items: - sample_width = user_file_items[OtherId.sample_width] - check_if_contains_only_one_element(sample_width, OtherId.sample_width) + if OtherId.SAMPLE_WIDTH in user_file_items: + sample_width = user_file_items[OtherId.SAMPLE_WIDTH] + check_if_contains_only_one_element(sample_width, OtherId.SAMPLE_WIDTH) sample_width = sample_width[-1] self._scale_builder.set_width(sample_width) - if OtherId.sample_height in user_file_items: - sample_height = user_file_items[OtherId.sample_height] - check_if_contains_only_one_element(sample_height, OtherId.sample_height) + if OtherId.SAMPLE_HEIGHT in user_file_items: + sample_height = user_file_items[OtherId.SAMPLE_HEIGHT] + check_if_contains_only_one_element(sample_height, OtherId.SAMPLE_HEIGHT) sample_height = sample_height[-1] self._scale_builder.set_height(sample_height) - if OtherId.sample_thickness in user_file_items: - sample_thickness = user_file_items[OtherId.sample_thickness] - check_if_contains_only_one_element(sample_thickness, OtherId.sample_thickness) + if OtherId.SAMPLE_THICKNESS in user_file_items: + sample_thickness = user_file_items[OtherId.SAMPLE_THICKNESS] + check_if_contains_only_one_element(sample_thickness, OtherId.SAMPLE_THICKNESS) sample_thickness = sample_thickness[-1] self._scale_builder.set_thickness(sample_thickness) def _set_up_convert_to_q_state(self, user_file_items): # Get the radius cut off if any is present - set_single_entry(self._convert_to_q_builder, "set_radius_cutoff", LimitsId.radius_cut, user_file_items, + set_single_entry(self._convert_to_q_builder, "set_radius_cutoff", LimitsId.RADIUS_CUT, user_file_items, apply_to_value=convert_mm_to_m) # Get the wavelength cut off if any is present - set_single_entry(self._convert_to_q_builder, "set_wavelength_cutoff", LimitsId.wavelength_cut, + set_single_entry(self._convert_to_q_builder, "set_wavelength_cutoff", LimitsId.WAVELENGTH_CUT, user_file_items) # Get the 1D q values - if LimitsId.q in user_file_items: - limits_q = user_file_items[LimitsId.q] - check_if_contains_only_one_element(limits_q, LimitsId.q) + if LimitsId.Q in user_file_items: + limits_q = user_file_items[LimitsId.Q] + check_if_contains_only_one_element(limits_q, LimitsId.Q) limits_q = limits_q[-1] self._convert_to_q_builder.set_q_min(limits_q.min) self._convert_to_q_builder.set_q_max(limits_q.max) self._convert_to_q_builder.set_q_1d_rebin_string(limits_q.rebin_string) # Get the 2D q values - if LimitsId.qxy in user_file_items: - limits_qxy = user_file_items[LimitsId.qxy] - check_if_contains_only_one_element(limits_qxy, LimitsId.qxy) + if LimitsId.QXY in user_file_items: + limits_qxy = user_file_items[LimitsId.QXY] + check_if_contains_only_one_element(limits_qxy, LimitsId.QXY) limits_qxy = limits_qxy[-1] # Now we have to check if we have a simple pattern or a more complex pattern at hand is_complex = isinstance(limits_qxy, complex_range) @@ -1064,50 +1064,50 @@ class StateDirectorISIS(object): self._convert_to_q_builder.set_q_xy_step_type(limits_qxy.step_type) # Get the Gravity settings - set_single_entry(self._convert_to_q_builder, "set_use_gravity", GravityId.on_off, user_file_items) - set_single_entry(self._convert_to_q_builder, "set_gravity_extra_length", GravityId.extra_length, + set_single_entry(self._convert_to_q_builder, "set_use_gravity", GravityId.ON_OFF, user_file_items) + set_single_entry(self._convert_to_q_builder, "set_gravity_extra_length", GravityId.EXTRA_LENGTH, user_file_items) # Get the QResolution settings set_q_resolution_delta_r - set_single_entry(self._convert_to_q_builder, "set_use_q_resolution", QResolutionId.on, user_file_items) - set_single_entry(self._convert_to_q_builder, "set_q_resolution_delta_r", QResolutionId.delta_r, + set_single_entry(self._convert_to_q_builder, "set_use_q_resolution", QResolutionId.ON, user_file_items) + set_single_entry(self._convert_to_q_builder, "set_q_resolution_delta_r", QResolutionId.DELTA_R, user_file_items, apply_to_value=convert_mm_to_m) set_single_entry(self._convert_to_q_builder, "set_q_resolution_collimation_length", - QResolutionId.collimation_length, user_file_items) - set_single_entry(self._convert_to_q_builder, "set_q_resolution_a1", QResolutionId.a1, user_file_items, + QResolutionId.COLLIMATION_LENGTH, user_file_items) + set_single_entry(self._convert_to_q_builder, "set_q_resolution_a1", QResolutionId.A1, user_file_items, apply_to_value=convert_mm_to_m) - set_single_entry(self._convert_to_q_builder, "set_q_resolution_a2", QResolutionId.a2, user_file_items, + set_single_entry(self._convert_to_q_builder, "set_q_resolution_a2", QResolutionId.A2, user_file_items, apply_to_value=convert_mm_to_m) - set_single_entry(self._convert_to_q_builder, "set_moderator_file", QResolutionId.moderator, + set_single_entry(self._convert_to_q_builder, "set_moderator_file", QResolutionId.MODERATOR, user_file_items) - set_single_entry(self._convert_to_q_builder, "set_q_resolution_h1", QResolutionId.h1, user_file_items, + set_single_entry(self._convert_to_q_builder, "set_q_resolution_h1", QResolutionId.H1, user_file_items, apply_to_value=convert_mm_to_m) - set_single_entry(self._convert_to_q_builder, "set_q_resolution_h2", QResolutionId.h2, user_file_items, + set_single_entry(self._convert_to_q_builder, "set_q_resolution_h2", QResolutionId.H2, user_file_items, apply_to_value=convert_mm_to_m) - set_single_entry(self._convert_to_q_builder, "set_q_resolution_w1", QResolutionId.w1, user_file_items, + set_single_entry(self._convert_to_q_builder, "set_q_resolution_w1", QResolutionId.W1, user_file_items, apply_to_value=convert_mm_to_m) - set_single_entry(self._convert_to_q_builder, "set_q_resolution_w2", QResolutionId.w2, user_file_items, + set_single_entry(self._convert_to_q_builder, "set_q_resolution_w2", QResolutionId.W2, user_file_items, apply_to_value=convert_mm_to_m) # ------------------------ # Reduction Dimensionality # ------------------------ - set_single_entry(self._convert_to_q_builder, "set_reduction_dimensionality", OtherId.reduction_dimensionality, + set_single_entry(self._convert_to_q_builder, "set_reduction_dimensionality", OtherId.REDUCTION_DIMENSIONALITY, user_file_items) def _set_up_adjustment_state(self, user_file_items): # Get the wide angle correction setting - set_single_entry(self._adjustment_builder, "set_wide_angle_correction", SampleId.path, user_file_items) + set_single_entry(self._adjustment_builder, "set_wide_angle_correction", SampleId.PATH, user_file_items) def _set_up_normalize_to_monitor_state(self, user_file_items): # Extract the incident monitor and which type of rebinning to use (interpolating or normal) - if MonId.spectrum in user_file_items: - mon_spectrum = user_file_items[MonId.spectrum] + if MonId.SPECTRUM in user_file_items: + mon_spectrum = user_file_items[MonId.SPECTRUM] mon_spec = [spec for spec in mon_spectrum if not spec.is_trans] if mon_spec: mon_spec = mon_spec[-1] - rebin_type = RebinType.InterpolatingRebin if mon_spec.interpolate else RebinType.Rebin + rebin_type = RebinType.INTERPOLATING_REBIN if mon_spec.interpolate else RebinType.REBIN self._normalize_to_monitor_builder.set_rebin_type(rebin_type) # We have to check if the spectrum is None, this can be the case when the user wants to use the @@ -1129,37 +1129,37 @@ class StateDirectorISIS(object): def _set_up_calculate_transmission(self, user_file_items): # Transmission radius - set_single_entry(self._calculate_transmission_builder, "set_transmission_radius_on_detector", TransId.radius, + set_single_entry(self._calculate_transmission_builder, "set_transmission_radius_on_detector", TransId.RADIUS, user_file_items, apply_to_value=convert_mm_to_m) # List of transmission roi files - if TransId.roi in user_file_items: - trans_roi = user_file_items[TransId.roi] + if TransId.ROI in user_file_items: + trans_roi = user_file_items[TransId.ROI] self._calculate_transmission_builder.set_transmission_roi_files(trans_roi) # List of transmission mask files - if TransId.mask in user_file_items: - trans_mask = user_file_items[TransId.mask] + if TransId.MASK in user_file_items: + trans_mask = user_file_items[TransId.MASK] self._calculate_transmission_builder.set_transmission_mask_files(trans_mask) # The prompt peak correction values set_prompt_peak_correction(self._calculate_transmission_builder, user_file_items) # The transmission spectrum - if TransId.spec in user_file_items: - trans_spec = user_file_items[TransId.spec] + if TransId.SPEC in user_file_items: + trans_spec = user_file_items[TransId.SPEC] # Should the user have chosen several values, then the last element is selected - check_if_contains_only_one_element(trans_spec, TransId.spec) + check_if_contains_only_one_element(trans_spec, TransId.SPEC) trans_spec = trans_spec[-1] self._calculate_transmission_builder.set_transmission_monitor(trans_spec) # The incident monitor spectrum for transmission calculation - if MonId.spectrum in user_file_items: - mon_spectrum = user_file_items[MonId.spectrum] + if MonId.SPECTRUM in user_file_items: + mon_spectrum = user_file_items[MonId.SPECTRUM] mon_spec = [spec for spec in mon_spectrum if spec.is_trans] if mon_spec: mon_spec = mon_spec[-1] - rebin_type = RebinType.InterpolatingRebin if mon_spec.interpolate else RebinType.Rebin + rebin_type = RebinType.INTERPOLATING_REBIN if mon_spec.interpolate else RebinType.REBIN self._calculate_transmission_builder.set_rebin_type(rebin_type) # We have to check if the spectrum is None, this can be the case when the user wants to use the @@ -1174,17 +1174,17 @@ class StateDirectorISIS(object): set_background_tof_monitor(self._calculate_transmission_builder, user_file_items) # The roi-specific background settings - if BackId.trans in user_file_items: - back_trans = user_file_items[BackId.trans] + if BackId.TRANS in user_file_items: + back_trans = user_file_items[BackId.TRANS] # Should the user have chosen several values, then the last element is selected - check_if_contains_only_one_element(back_trans, BackId.trans) + check_if_contains_only_one_element(back_trans, BackId.TRANS) back_trans = back_trans[-1] self._calculate_transmission_builder.set_background_TOF_roi_start(back_trans.start) self._calculate_transmission_builder.set_background_TOF_roi_stop(back_trans.stop) # Set the fit settings - if FitId.general in user_file_items: - fit_general = user_file_items[FitId.general] + if FitId.GENERAL in user_file_items: + fit_general = user_file_items[FitId.GENERAL] # We can have settings for both the sample or the can or individually # There can be three types of settings: # 1. Clearing the fit setting @@ -1195,64 +1195,64 @@ class StateDirectorISIS(object): # As usual if there are multiple settings for a specific case, then the last in the list is used. # 1 Fit type settings - clear_settings = [item for item in fit_general if item.data_type is None and item.fit_type is FitType.NoFit] + clear_settings = [item for item in fit_general if item.data_type is None and item.fit_type is FitType.NO_FIT] if clear_settings: - check_if_contains_only_one_element(clear_settings, FitId.general) + check_if_contains_only_one_element(clear_settings, FitId.GENERAL) clear_settings = clear_settings[-1] # Will set the fitting to NoFit - self._calculate_transmission_builder.set_Sample_fit_type(clear_settings.fit_type) - self._calculate_transmission_builder.set_Can_fit_type(clear_settings.fit_type) + self._calculate_transmission_builder.set_sample_fit_type(clear_settings.fit_type) + self._calculate_transmission_builder.set_can_fit_type(clear_settings.fit_type) # 2. General settings general_settings = [item for item in fit_general if item.data_type is None and - item.fit_type is not FitType.NoFit] + item.fit_type is not FitType.NO_FIT] if general_settings: - check_if_contains_only_one_element(general_settings, FitId.general) + check_if_contains_only_one_element(general_settings, FitId.GENERAL) general_settings = general_settings[-1] - self._calculate_transmission_builder.set_Sample_fit_type(general_settings.fit_type) - self._calculate_transmission_builder.set_Sample_polynomial_order(general_settings.polynomial_order) - self._calculate_transmission_builder.set_Sample_wavelength_low(general_settings.start) - self._calculate_transmission_builder.set_Sample_wavelength_high(general_settings.stop) - self._calculate_transmission_builder.set_Can_fit_type(general_settings.fit_type) - self._calculate_transmission_builder.set_Can_polynomial_order(general_settings.polynomial_order) - self._calculate_transmission_builder.set_Can_wavelength_low(general_settings.start) - self._calculate_transmission_builder.set_Can_wavelength_high(general_settings.stop) + self._calculate_transmission_builder.set_sample_fit_type(general_settings.fit_type) + self._calculate_transmission_builder.set_sample_polynomial_order(general_settings.polynomial_order) + self._calculate_transmission_builder.set_sample_wavelength_low(general_settings.start) + self._calculate_transmission_builder.set_sample_wavelength_high(general_settings.stop) + self._calculate_transmission_builder.set_can_fit_type(general_settings.fit_type) + self._calculate_transmission_builder.set_can_polynomial_order(general_settings.polynomial_order) + self._calculate_transmission_builder.set_can_wavelength_low(general_settings.start) + self._calculate_transmission_builder.set_can_wavelength_high(general_settings.stop) # 3. Sample settings - sample_settings = [item for item in fit_general if item.data_type is DataType.Sample] + sample_settings = [item for item in fit_general if item.data_type is DataType.SAMPLE] if sample_settings: - check_if_contains_only_one_element(sample_settings, FitId.general) + check_if_contains_only_one_element(sample_settings, FitId.GENERAL) sample_settings = sample_settings[-1] - self._calculate_transmission_builder.set_Sample_fit_type(sample_settings.fit_type) - self._calculate_transmission_builder.set_Sample_polynomial_order(sample_settings.polynomial_order) - self._calculate_transmission_builder.set_Sample_wavelength_low(sample_settings.start) - self._calculate_transmission_builder.set_Sample_wavelength_high(sample_settings.stop) + self._calculate_transmission_builder.set_sample_fit_type(sample_settings.fit_type) + self._calculate_transmission_builder.set_sample_polynomial_order(sample_settings.polynomial_order) + self._calculate_transmission_builder.set_sample_wavelength_low(sample_settings.start) + self._calculate_transmission_builder.set_sample_wavelength_high(sample_settings.stop) # 4. Can settings - can_settings = [item for item in fit_general if item.data_type is DataType.Can] + can_settings = [item for item in fit_general if item.data_type is DataType.CAN] if can_settings: - check_if_contains_only_one_element(can_settings, FitId.general) + check_if_contains_only_one_element(can_settings, FitId.GENERAL) can_settings = can_settings[-1] - self._calculate_transmission_builder.set_Can_fit_type(can_settings.fit_type) - self._calculate_transmission_builder.set_Can_polynomial_order(can_settings.polynomial_order) - self._calculate_transmission_builder.set_Can_wavelength_low(can_settings.start) - self._calculate_transmission_builder.set_Can_wavelength_high(can_settings.stop) + self._calculate_transmission_builder.set_can_fit_type(can_settings.fit_type) + self._calculate_transmission_builder.set_can_polynomial_order(can_settings.polynomial_order) + self._calculate_transmission_builder.set_can_wavelength_low(can_settings.start) + self._calculate_transmission_builder.set_can_wavelength_high(can_settings.stop) # Set the wavelength default configuration set_wavelength_limits(self._calculate_transmission_builder, user_file_items) # Set the full wavelength range. Note that this can currently only be set from the ISISCommandInterface - if OtherId.use_full_wavelength_range in user_file_items: - use_full_wavelength_range = user_file_items[OtherId.use_full_wavelength_range] - check_if_contains_only_one_element(use_full_wavelength_range, OtherId.use_full_wavelength_range) + if OtherId.USE_FULL_WAVELENGTH_RANGE in user_file_items: + use_full_wavelength_range = user_file_items[OtherId.USE_FULL_WAVELENGTH_RANGE] + check_if_contains_only_one_element(use_full_wavelength_range, OtherId.USE_FULL_WAVELENGTH_RANGE) use_full_wavelength_range = use_full_wavelength_range[-1] self._calculate_transmission_builder.set_use_full_wavelength_range(use_full_wavelength_range) def _set_up_wavelength_and_pixel_adjustment(self, user_file_items): # Get the flat/flood files. There can be entries for LAB and HAB. - if MonId.flat in user_file_items: - mon_flat = user_file_items[MonId.flat] + if MonId.FLAT in user_file_items: + mon_flat = user_file_items[MonId.FLAT] hab_flat_entries = [item for item in mon_flat if item.detector_type is DetectorType.HAB] lab_flat_entries = [item for item in mon_flat if item.detector_type is DetectorType.LAB] if hab_flat_entries: @@ -1264,8 +1264,8 @@ class StateDirectorISIS(object): self._wavelength_and_pixel_adjustment_builder.set_LAB_pixel_adjustment_file(lab_flat_entry.file_path) # Get the direct files. There can be entries for LAB and HAB. - if MonId.direct in user_file_items: - mon_direct = user_file_items[MonId.direct] + if MonId.DIRECT in user_file_items: + mon_direct = user_file_items[MonId.DIRECT] hab_direct_entries = [item for item in mon_direct if item.detector_type is DetectorType.HAB] lab_direct_entries = [item for item in mon_direct if item.detector_type is DetectorType.LAB] if hab_direct_entries: @@ -1282,63 +1282,63 @@ class StateDirectorISIS(object): set_wavelength_limits(self._wavelength_and_pixel_adjustment_builder, user_file_items) def _set_up_compatibility(self, user_file_items): - if LimitsId.events_binning in user_file_items: - events_binning = user_file_items[LimitsId.events_binning] - check_if_contains_only_one_element(events_binning, LimitsId.events_binning) + if LimitsId.EVENTS_BINNING in user_file_items: + events_binning = user_file_items[LimitsId.EVENTS_BINNING] + check_if_contains_only_one_element(events_binning, LimitsId.EVENTS_BINNING) events_binning = events_binning[-1] self._compatibility_builder.set_time_rebin_string(events_binning) - if OtherId.use_compatibility_mode in user_file_items: - use_compatibility_mode = user_file_items[OtherId.use_compatibility_mode] - check_if_contains_only_one_element(use_compatibility_mode, OtherId.use_compatibility_mode) + if OtherId.USE_COMPATIBILITY_MODE in user_file_items: + use_compatibility_mode = user_file_items[OtherId.USE_COMPATIBILITY_MODE] + check_if_contains_only_one_element(use_compatibility_mode, OtherId.USE_COMPATIBILITY_MODE) use_compatibility_mode = use_compatibility_mode[-1] self._compatibility_builder.set_use_compatibility_mode(use_compatibility_mode) - if OtherId.use_event_slice_optimisation in user_file_items: - use_event_slice_optimisation = user_file_items[OtherId.use_event_slice_optimisation] - check_if_contains_only_one_element(use_event_slice_optimisation, OtherId.use_event_slice_optimisation) + if OtherId.USE_EVENT_SLICE_OPTIMISATION in user_file_items: + use_event_slice_optimisation = user_file_items[OtherId.USE_EVENT_SLICE_OPTIMISATION] + check_if_contains_only_one_element(use_event_slice_optimisation, OtherId.USE_EVENT_SLICE_OPTIMISATION) use_event_slice_optimisation = use_event_slice_optimisation[-1] self._compatibility_builder.set_use_event_slice_optimisation(use_event_slice_optimisation) def _set_up_save(self, user_file_items): - if OtherId.save_types in user_file_items: - save_types = user_file_items[OtherId.save_types] - check_if_contains_only_one_element(save_types, OtherId.save_types) + if OtherId.SAVE_TYPES in user_file_items: + save_types = user_file_items[OtherId.SAVE_TYPES] + check_if_contains_only_one_element(save_types, OtherId.SAVE_TYPES) save_types = save_types[-1] self._save_builder.set_file_format(save_types) - if OtherId.save_as_zero_error_free in user_file_items: - save_as_zero_error_free = user_file_items[OtherId.save_as_zero_error_free] - check_if_contains_only_one_element(save_as_zero_error_free, OtherId.save_as_zero_error_free) + if OtherId.SAVE_AS_ZERO_ERROR_FREE in user_file_items: + save_as_zero_error_free = user_file_items[OtherId.SAVE_AS_ZERO_ERROR_FREE] + check_if_contains_only_one_element(save_as_zero_error_free, OtherId.SAVE_AS_ZERO_ERROR_FREE) save_as_zero_error_free = save_as_zero_error_free[-1] self._save_builder.set_zero_free_correction(save_as_zero_error_free) - if OtherId.user_specified_output_name in user_file_items: - user_specified_output_name = user_file_items[OtherId.user_specified_output_name] - check_if_contains_only_one_element(user_specified_output_name, OtherId.user_specified_output_name) + if OtherId.USER_SPECIFIED_OUTPUT_NAME in user_file_items: + user_specified_output_name = user_file_items[OtherId.USER_SPECIFIED_OUTPUT_NAME] + check_if_contains_only_one_element(user_specified_output_name, OtherId.USER_SPECIFIED_OUTPUT_NAME) user_specified_output_name = user_specified_output_name[-1] self._save_builder.set_user_specified_output_name(user_specified_output_name) - if OtherId.user_specified_output_name_suffix in user_file_items: - user_specified_output_name_suffix = user_file_items[OtherId.user_specified_output_name_suffix] + if OtherId.USER_SPECIFIED_OUTPUT_NAME_SUFFIX in user_file_items: + user_specified_output_name_suffix = user_file_items[OtherId.USER_SPECIFIED_OUTPUT_NAME_SUFFIX] check_if_contains_only_one_element(user_specified_output_name_suffix, - OtherId.user_specified_output_name_suffix) + OtherId.USER_SPECIFIED_OUTPUT_NAME_SUFFIX) user_specified_output_name_suffix = user_specified_output_name_suffix[-1] self._save_builder.set_user_specified_output_name_suffix(user_specified_output_name_suffix) - if OtherId.use_reduction_mode_as_suffix in user_file_items: - use_reduction_mode_as_suffix = user_file_items[OtherId.use_reduction_mode_as_suffix] + if OtherId.USE_REDUCTION_MODE_AS_SUFFIX in user_file_items: + use_reduction_mode_as_suffix = user_file_items[OtherId.USE_REDUCTION_MODE_AS_SUFFIX] check_if_contains_only_one_element(use_reduction_mode_as_suffix, - OtherId.use_reduction_mode_as_suffix) + OtherId.USE_REDUCTION_MODE_AS_SUFFIX) use_reduction_mode_as_suffix = use_reduction_mode_as_suffix[-1] self._save_builder.set_use_reduction_mode_as_suffix(use_reduction_mode_as_suffix) def _add_information_to_data_state(self, user_file_items): # The only thing that should be set on the data is the tube calibration file which is specified in # the user file. - if TubeCalibrationFileId.file in user_file_items: - tube_calibration = user_file_items[TubeCalibrationFileId.file] - check_if_contains_only_one_element(tube_calibration, TubeCalibrationFileId.file) + if TubeCalibrationFileId.FILE in user_file_items: + tube_calibration = user_file_items[TubeCalibrationFileId.FILE] + check_if_contains_only_one_element(tube_calibration, TubeCalibrationFileId.FILE) tube_calibration = tube_calibration[-1] self._data.calibration = tube_calibration diff --git a/scripts/SANS/sans/user_file/user_file_parser.py b/scripts/SANS/sans/user_file/user_file_parser.py index bce1e1da1c9edf74f30d2e85fe761e7c98c22f60..c63efd281b21d27bede8fcf7ae5b4e5d7bd1bf05 100644 --- a/scripts/SANS/sans/user_file/user_file_parser.py +++ b/scripts/SANS/sans/user_file/user_file_parser.py @@ -12,7 +12,7 @@ import re from math import copysign -from sans.common.enums import (ISISReductionMode, DetectorType, RangeStepType, FitType, DataType, SANSInstrument) +from sans.common.enums import (ReductionMode, DetectorType, RangeStepType, FitType, DataType, SANSInstrument) from sans.user_file.settings_tags import (DetectorId, BackId, range_entry, back_single_monitor_entry, single_entry_with_detector, mask_angle_entry, LimitsId, simple_range, complex_range, MaskId, mask_block, mask_block_cross, @@ -219,24 +219,24 @@ class BackParser(UserFileComponentParser): def _extract_all_mon(self, line): all_mons_string = re.sub(self._all_mons, "", line) time_range = extract_float_range(all_mons_string) - return {BackId.all_monitors: range_entry(start=time_range[0], stop=time_range[1])} + return {BackId.ALL_MONITORS: range_entry(start=time_range[0], stop=time_range[1])} def _extract_single_mon(self, line): monitor_number = self._get_monitor_number(line) single_string = re.sub(self._times, "", line) all_mons_string = re.sub(self._single_monitor, "", single_string) time_range = extract_float_range(all_mons_string) - return {BackId.single_monitors: back_single_monitor_entry(monitor=monitor_number, start=time_range[0], + return {BackId.SINGLE_MONITORS: back_single_monitor_entry(monitor=monitor_number, start=time_range[0], stop=time_range[1])} def _extract_off(self, line): monitor_number = self._get_monitor_number(line) - return {BackId.monitor_off: monitor_number} + return {BackId.MONITOR_OFF: monitor_number} def _extract_trans(self, line): trans_string = re.sub(self._trans, "", line) time_range = extract_float_range(trans_string) - return {BackId.trans: range_entry(start=time_range[0], stop=time_range[1])} + return {BackId.TRANS: range_entry(start=time_range[0], stop=time_range[1])} def _get_monitor_number(self, line): monitor_selection = re.search(self._single_monitor, line).group(0) @@ -278,7 +278,7 @@ class InstrParser(object): raise RuntimeError("InstrParser: Unknown command for INSTR: {0}".format(line)) else: # If no exception raised - return {DetectorId.instrument: ret_val} + return {DetectorId.INSTRUMENT: ret_val} @staticmethod def get_type(): @@ -404,13 +404,13 @@ class DetParser(UserFileComponentParser): def _extract_reduction_mode(self, line): line_capital = line.upper() if line_capital in self._HAB: - return {DetectorId.reduction_mode: ISISReductionMode.HAB} + return {DetectorId.REDUCTION_MODE: ReductionMode.HAB} elif line_capital in self._LAB: - return {DetectorId.reduction_mode: ISISReductionMode.LAB} + return {DetectorId.REDUCTION_MODE: ReductionMode.LAB} elif line_capital in self._BOTH: - return {DetectorId.reduction_mode: ISISReductionMode.All} + return {DetectorId.REDUCTION_MODE: ReductionMode.ALL} elif line_capital in self._MERGE: - return {DetectorId.reduction_mode: ISISReductionMode.Merged} + return {DetectorId.REDUCTION_MODE: ReductionMode.MERGED} else: raise RuntimeError("DetParser: Could not extract line: {0}".format(line)) @@ -429,28 +429,28 @@ class DetParser(UserFileComponentParser): def _extract_detector_setting(self, qualifier, detector_type): if self._x_pattern.match(qualifier): value_string = re.sub(self._x, "", qualifier) - key = DetectorId.correction_x + key = DetectorId.CORRECTION_X elif self._y_pattern.match(qualifier): value_string = re.sub(self._y, "", qualifier) - key = DetectorId.correction_y + key = DetectorId.CORRECTION_Y elif self._z_pattern.match(qualifier): value_string = re.sub(self._z, "", qualifier) - key = DetectorId.correction_z + key = DetectorId.CORRECTION_Z elif self._rotation_pattern.match(qualifier): value_string = re.sub(self._rotation, "", qualifier) - key = DetectorId.correction_rotation + key = DetectorId.CORRECTION_ROTATION elif self._translation_pattern.match(qualifier): value_string = re.sub(self._translation, "", qualifier) - key = DetectorId.correction_translation + key = DetectorId.CORRECTION_TRANSLATION elif self._radius_pattern.match(qualifier): value_string = re.sub(self._radius, "", qualifier) - key = DetectorId.correction_radius + key = DetectorId.CORRECTION_RADIUS elif self._x_tilt_pattern.match(qualifier): value_string = re.sub(self._x_tilt, "", qualifier) - key = DetectorId.correction_x_tilt + key = DetectorId.CORRECTION_X_TILT elif self._y_tilt_pattern.match(qualifier): value_string = re.sub(self._y_tilt, "", qualifier) - key = DetectorId.correction_y_tilt + key = DetectorId.CORRECTION_Y_TILT else: raise RuntimeError("DetParser: Unknown qualifier encountered: {0}".format(qualifier)) @@ -463,11 +463,11 @@ class DetParser(UserFileComponentParser): if self._rescale_pattern.match(line) is not None: rescale_string = re.sub(self._rescale, "", line) rescale = convert_string_to_float(rescale_string) - return {DetectorId.rescale: rescale} + return {DetectorId.RESCALE: rescale} elif self._shift_pattern.match(line) is not None: shift_string = re.sub(self._shift, "", line) shift = convert_string_to_float(shift_string) - return {DetectorId.shift: shift} + return {DetectorId.SHIFT: shift} elif self._rescale_fit_pattern.match(line) is not None: rescale_fit_string = re.sub(self._rescale_fit, "", line) if rescale_fit_string: @@ -475,7 +475,7 @@ class DetParser(UserFileComponentParser): value = det_fit_range(start=rescale_fit[0], stop=rescale_fit[1], use_fit=True) else: value = det_fit_range(start=None, stop=None, use_fit=True) - return {DetectorId.rescale_fit: value} + return {DetectorId.RESCALE_FIT: value} elif self._shift_fit_pattern.match(line) is not None: shift_fit_string = re.sub(self._shift_fit, "", line) if shift_fit_string: @@ -483,7 +483,7 @@ class DetParser(UserFileComponentParser): value = det_fit_range(start=shift_fit[0], stop=shift_fit[1], use_fit=True) else: value = det_fit_range(start=None, stop=None, use_fit=True) - return {DetectorId.shift_fit: value} + return {DetectorId.SHIFT_FIT: value} elif self._merge_range_pattern.match(line) is not None: merge_range_string = re.sub(self._merge_range, "", line) if merge_range_string: @@ -491,7 +491,7 @@ class DetParser(UserFileComponentParser): value = det_fit_range(start=merge_range[0], stop=merge_range[1], use_fit=True) else: raise RuntimeError("DetParser: Could not extract line: {0}".format(line)) - return {DetectorId.merge_range: value} + return {DetectorId.MERGE_RANGE: value} else: raise RuntimeError("DetParser: Could not extract line: {0}".format(line)) @@ -667,7 +667,7 @@ class LimitParser(UserFileComponentParser): use_mirror = re.search(self._phi_no_mirror, line) is None angles_string = re.sub(self._phi, "", line) angles = extract_float_range(angles_string) - return {LimitsId.angle: mask_angle_entry(min=angles[0], max=angles[1], use_mirror=use_mirror)} + return {LimitsId.ANGLE: mask_angle_entry(min=angles[0], max=angles[1], use_mirror=use_mirror)} def _extract_event_binning(self, line): event_binning = re.sub(self._events_time, "", line) @@ -675,44 +675,44 @@ class LimitParser(UserFileComponentParser): rebin_values = extract_float_list(event_binning, separator=" ") binning_string = ",".join([str(val) for val in rebin_values]) else: - simple_pattern = self._extract_simple_pattern(event_binning, LimitsId.events_binning) - rebin_values = simple_pattern[LimitsId.events_binning] - prefix = -1. if rebin_values.step_type is RangeStepType.Log else 1. + simple_pattern = self._extract_simple_pattern(event_binning, LimitsId.EVENTS_BINNING) + rebin_values = simple_pattern[LimitsId.EVENTS_BINNING] + prefix = -1. if rebin_values.step_type is RangeStepType.LOG else 1. binning_string = str(rebin_values.start) + "," + str(prefix*rebin_values.step) + "," + \ str(rebin_values.stop) # noqa - output = {LimitsId.events_binning: binning_string} + output = {LimitsId.EVENTS_BINNING: binning_string} return output def _extract_cut_limit(self, line): if self._radius_cut_pattern.match(line) is not None: - key = LimitsId.radius_cut + key = LimitsId.RADIUS_CUT limit_value = re.sub(self._radius_cut, "", line) else: - key = LimitsId.wavelength_cut + key = LimitsId.WAVELENGTH_CUT limit_value = re.sub(self._wavelength_cut, "", line) return {key: convert_string_to_float(limit_value)} def _extract_radius_limit(self, line): radius_range_string = re.sub(self._radius, "", line) radius_range = extract_float_list(radius_range_string, separator=" ") - return {LimitsId.radius: range_entry(start=radius_range[0], stop=radius_range[1])} + return {LimitsId.RADIUS: range_entry(start=radius_range[0], stop=radius_range[1])} def _extract_q_limit(self, line): q_range = re.sub(self._q, "", line) if does_pattern_match(self._q_simple_pattern, line): - simple_output = self._extract_simple_pattern(q_range, LimitsId.q) - simple_output = simple_output[LimitsId.q] - prefix = -1.0 if simple_output.step_type is RangeStepType.Log else 1.0 + simple_output = self._extract_simple_pattern(q_range, LimitsId.Q) + simple_output = simple_output[LimitsId.Q] + prefix = -1.0 if simple_output.step_type is RangeStepType.LOG else 1.0 q_limit_output = [simple_output.start] if simple_output.step: q_limit_output.append(prefix*simple_output.step) q_limit_output.append(simple_output.stop) elif does_pattern_match(self._q_complex_pattern, line): - complex_output = self._extract_complex_pattern(q_range, LimitsId.q) - complex_output = complex_output[LimitsId.q] - prefix1 = -1.0 if complex_output.step_type1 is RangeStepType.Log else 1.0 - prefix2 = -1.0 if complex_output.step_type2 is RangeStepType.Log else 1.0 + complex_output = self._extract_complex_pattern(q_range, LimitsId.Q) + complex_output = complex_output[LimitsId.Q] + prefix1 = -1.0 if complex_output.step_type1 is RangeStepType.LOG else 1.0 + prefix2 = -1.0 if complex_output.step_type2 is RangeStepType.LOG else 1.0 q_limit_output = [complex_output.start, prefix1*complex_output.step1, complex_output.mid, prefix2*complex_output.step2, complex_output.stop] else: @@ -721,13 +721,13 @@ class LimitParser(UserFileComponentParser): # The output is a q_rebin_values object with q_min, q_max and the rebin string. rebinning_string = ",".join([str(element) for element in q_limit_output]) q_rebin = q_rebin_values(min=q_limit_output[0], max=q_limit_output[-1], rebin_string=rebinning_string) - output = {LimitsId.q: q_rebin} + output = {LimitsId.Q: q_rebin} return output def _extract_qxy_limit(self, line): qxy_range = re.sub(self._qxy, "", line) if does_pattern_match(self._qxy_simple_pattern, line): - output = self._extract_simple_pattern(qxy_range, LimitsId.qxy) + output = self._extract_simple_pattern(qxy_range, LimitsId.QXY) else: # v2 GUI cannot currently support complex QXY ranges #output = self._extract_complex_pattern(qxy_range, LimitsId.qxy) @@ -738,7 +738,7 @@ class LimitParser(UserFileComponentParser): def _extract_wavelength_limit(self, line): wavelength_range = re.sub(self._wavelength, "", line) if does_pattern_match(self._wavelength_simple_pattern, line): - output = self._extract_simple_pattern(wavelength_range, LimitsId.wavelength) + output = self._extract_simple_pattern(wavelength_range, LimitsId.WAVELENGTH) else: # This is not implemented in the old parser, hence disable here # output = self._extract_complex_pattern(wavelength_range, LimitsId.wavelength) @@ -784,8 +784,8 @@ class LimitParser(UserFileComponentParser): # Check if there is a sign on the individual steps, this shows if something had been marked as linear or log. # If there is an explicit LOG/LIN command, then this overwrites the sign - step_type1 = RangeStepType.Log if copysign(1, range_with_steps[1]) == -1 else RangeStepType.Lin - step_type2 = RangeStepType.Log if copysign(1, range_with_steps[3]) == -1 else RangeStepType.Lin + step_type1 = RangeStepType.LOG if copysign(1, range_with_steps[1]) == -1 else RangeStepType.LIN + step_type2 = RangeStepType.LOG if copysign(1, range_with_steps[3]) == -1 else RangeStepType.LIN if step_type is not None: step_type1 = step_type step_type2 = step_type @@ -807,17 +807,17 @@ class LimitParser(UserFileComponentParser): range_with_steps = extract_float_list(range_with_steps_string, " ") if step_type is not None: - prefix = -1.0 if step_type is RangeStepType.Log else 1.0 + prefix = -1.0 if step_type is RangeStepType.LOG else 1.0 for index in range(1, len(range_with_steps), 2): range_with_steps[index] *= prefix return range_with_steps - def _get_step_type(self, range_string, default=RangeStepType.Lin): + def _get_step_type(self, range_string, default=RangeStepType.LIN): range_type = default if re.search(self._log, range_string): - range_type = RangeStepType.Log + range_type = RangeStepType.LOG elif re.search(self._lin, range_string): - range_type = RangeStepType.Lin + range_type = RangeStepType.LIN return range_type @staticmethod @@ -990,15 +990,15 @@ class MaskParser(UserFileComponentParser): if self._is_vertical_range_strip_mask(block): prelim_range = self._extract_vertical_range_strip_mask(block) # Note we use the lab key word since the extraction defaults to lab - vertical_part = prelim_range[MaskId.vertical_range_strip_mask] + vertical_part = prelim_range[MaskId.VERTICAL_RANGE_STRIP_MASK] elif self._is_horizontal_range_strip_mask(block): prelim_range = self._extract_horizontal_range_strip_mask(block) # Note we use the lab key word since the extraction defaults to lab - horizontal_part = prelim_range[MaskId.horizontal_range_strip_mask] + horizontal_part = prelim_range[MaskId.HORIZONTAL_RANGE_STRIP_MASK] else: raise RuntimeError("MaskParser: Cannot handle part of block mask: {0}".format(block)) # Now that we have both parts we can assemble the output - output = {MaskId.block: mask_block(horizontal1=horizontal_part.start, horizontal2=horizontal_part.stop, + output = {MaskId.BLOCK: mask_block(horizontal1=horizontal_part.start, horizontal2=horizontal_part.stop, vertical1=vertical_part.start, vertical2=vertical_part.stop, detector_type=detector_type)} else: @@ -1006,14 +1006,14 @@ class MaskParser(UserFileComponentParser): if self._is_vertical_single_strip_mask(block): prelim_single = self._extract_vertical_single_strip_mask(block) # Note we use the lab key word since the extraction defaults to lab - vertical_part = prelim_single[MaskId.vertical_single_strip_mask] + vertical_part = prelim_single[MaskId.VERTICAL_SINGLE_STRIP_MASK] elif self._is_horizontal_single_strip_mask(block): prelim_single = self._extract_horizontal_single_strip_mask(block) # Note we use the lab key word since the extraction defaults to lab - horizontal_part = prelim_single[MaskId.horizontal_single_strip_mask] + horizontal_part = prelim_single[MaskId.HORIZONTAL_SINGLE_STRIP_MASK] else: raise RuntimeError("MaskParser: Cannot handle part of block cross mask: {0}".format(block)) - output = {MaskId.block_cross: mask_block_cross(horizontal=horizontal_part.entry, + output = {MaskId.BLOCK_CROSS: mask_block_cross(horizontal=horizontal_part.entry, vertical=vertical_part.entry, detector_type=detector_type)} return output @@ -1023,10 +1023,10 @@ class MaskParser(UserFileComponentParser): line_values = extract_float_list(line_string, " ") length_values = len(line_values) if length_values == 2: - output = {MaskId.line: mask_line(width=line_values[0], angle=line_values[1], + output = {MaskId.LINE: mask_line(width=line_values[0], angle=line_values[1], x=None, y=None)} elif length_values == 4: - output = {MaskId.line: mask_line(width=line_values[0], angle=line_values[1], + output = {MaskId.LINE: mask_line(width=line_values[0], angle=line_values[1], x=line_values[2], y=line_values[3])} else: raise ValueError("MaskParser: Line mask accepts wither 2 or 4 parameters," @@ -1038,12 +1038,12 @@ class MaskParser(UserFileComponentParser): has_hab = re.search(self._hab, line) has_lab = re.search(self._lab, line) if has_hab is not None or has_lab is not None: - key = MaskId.time_detector + key = MaskId.TIME_DETECTOR detector_type = DetectorType.HAB if has_hab is not None else DetectorType.LAB regex_string = "\s*(" + self._hab + ")\s*" if has_hab else "\s*(" + self._lab + ")\s*" min_and_max_time_range = re.sub(regex_string, "", line) else: - key = MaskId.time + key = MaskId.TIME detector_type = None min_and_max_time_range = line min_and_max_time_range = re.sub("\s*/\s*", "", min_and_max_time_range) @@ -1054,26 +1054,26 @@ class MaskParser(UserFileComponentParser): def _extract_clear_mask(self, line): clear_removed = re.sub(self._clear, "", line) - return {MaskId.clear_detector_mask: True} if clear_removed == "" else \ - {MaskId.clear_time_mask: True} + return {MaskId.CLEAR_DETECTOR_MASK: True} if clear_removed == "" else \ + {MaskId.CLEAR_TIME_MASK: True} def _extract_single_spectrum_mask(self, line): single_spectrum_string = re.sub(self._spectrum, "", line) single_spectrum = convert_string_to_integer(single_spectrum_string) - return {MaskId.single_spectrum_mask: single_spectrum} + return {MaskId.SINGLE_SPECTRUM_MASK: single_spectrum} def _extract_spectrum_range_mask(self, line): spectrum_range_string = re.sub(self._spectrum, "", line) spectrum_range_string = re.sub(self._range, " ", spectrum_range_string) spectrum_range = extract_int_range(spectrum_range_string) - return {MaskId.spectrum_range_mask: range_entry(start=spectrum_range[0], stop=spectrum_range[1])} + return {MaskId.SPECTRUM_RANGE_MASK: range_entry(start=spectrum_range[0], stop=spectrum_range[1])} def _extract_vertical_single_strip_mask(self, line): detector_type = DetectorType.HAB if re.search(self._hab, line) is not None else DetectorType.LAB single_vertical_strip_string = re.sub(self._detector, "", line) single_vertical_strip_string = re.sub(self._v, "", single_vertical_strip_string) single_vertical_strip = convert_string_to_integer(single_vertical_strip_string) - return {MaskId.vertical_single_strip_mask: single_entry_with_detector(entry=single_vertical_strip, + return {MaskId.VERTICAL_SINGLE_STRIP_MASK: single_entry_with_detector(entry=single_vertical_strip, detector_type=detector_type)} def _extract_vertical_range_strip_mask(self, line): @@ -1082,7 +1082,7 @@ class MaskParser(UserFileComponentParser): range_vertical_strip_string = re.sub(self._v, "", range_vertical_strip_string) range_vertical_strip_string = re.sub(self._range, " ", range_vertical_strip_string) range_vertical_strip = extract_int_range(range_vertical_strip_string) - return {MaskId.vertical_range_strip_mask: range_entry_with_detector(start=range_vertical_strip[0], + return {MaskId.VERTICAL_RANGE_STRIP_MASK: range_entry_with_detector(start=range_vertical_strip[0], stop=range_vertical_strip[1], detector_type=detector_type)} @@ -1091,7 +1091,7 @@ class MaskParser(UserFileComponentParser): single_horizontal_strip_string = re.sub(self._detector, "", line) single_horizontal_strip_string = re.sub(self._h, "", single_horizontal_strip_string) single_horizontal_strip = convert_string_to_integer(single_horizontal_strip_string) - return {MaskId.horizontal_single_strip_mask: single_entry_with_detector(entry=single_horizontal_strip, + return {MaskId.HORIZONTAL_SINGLE_STRIP_MASK: single_entry_with_detector(entry=single_horizontal_strip, detector_type=detector_type)} def _extract_horizontal_range_strip_mask(self, line): @@ -1100,7 +1100,7 @@ class MaskParser(UserFileComponentParser): range_horizontal_strip_string = re.sub(self._h, "", range_horizontal_strip_string) range_horizontal_strip_string = re.sub(self._range, " ", range_horizontal_strip_string) range_horizontal_strip = extract_int_range(range_horizontal_strip_string) - return {MaskId.horizontal_range_strip_mask: range_entry_with_detector(start=range_horizontal_strip[0], + return {MaskId.HORIZONTAL_RANGE_STRIP_MASK: range_entry_with_detector(start=range_horizontal_strip[0], stop=range_horizontal_strip[1], detector_type=detector_type)} @@ -1157,12 +1157,12 @@ class SampleParser(UserFileComponentParser): def _extract_sample_path(self, line): value = False if re.search(self._off, line) is not None else True - return {SampleId.path: value} + return {SampleId.PATH: value} def _extract_sample_offset(self, line): offset_string = re.sub(self._offset, "", line) offset = convert_string_to_float(offset_string) - return {SampleId.offset: offset} + return {SampleId.OFFSET: offset} @staticmethod def get_type(): @@ -1240,7 +1240,7 @@ class SetParser(UserFileComponentParser): scales = extract_float_list(scales_string, separator=" ") if len(scales) != 5: raise ValueError("SetParser: Expected 5 entries for the SCALES setting, but got {0}.".format(len(scales))) - return {SetId.scales: set_scales_entry(s=scales[0], a=scales[1], b=scales[2], c=scales[3], d=scales[4])} + return {SetId.SCALES: set_scales_entry(s=scales[0], a=scales[1], b=scales[2], c=scales[3], d=scales[4])} def _extract_centre(self, line): detector_type = DetectorType.HAB if re.search(self._hab, line) is not None else DetectorType.LAB @@ -1248,7 +1248,7 @@ class SetParser(UserFileComponentParser): centre_string = re.sub("/" + self._lab, "", centre_string) centre_string = ' '.join(centre_string.split()) centre = extract_float_list(centre_string, separator=" ") - return {SetId.centre: position_entry(pos1=centre[0], pos2=centre[1], detector_type=detector_type)} + return {SetId.CENTRE: position_entry(pos1=centre[0], pos2=centre[1], detector_type=detector_type)} def _extract_centre_HAB(self, line): detector_type = DetectorType.HAB if re.search(self._hab, line) is not None else DetectorType.LAB @@ -1256,7 +1256,7 @@ class SetParser(UserFileComponentParser): centre_string = re.sub("/" + self._hab, "", centre_string) centre_string = ' '.join(centre_string.split()) centre = extract_float_list(centre_string, separator=" ") - return {SetId.centre_HAB: position_entry(pos1=centre[0], pos2=centre[1], detector_type=detector_type)} + return {SetId.CENTRE_HAB: position_entry(pos1=centre[0], pos2=centre[1], detector_type=detector_type)} @staticmethod def get_type(): @@ -1392,20 +1392,20 @@ class TransParser(UserFileComponentParser): dist, monitor = int(split_vars[0]), int(split_vars[1]) if monitor == 5: - return {TransId.spec_5_shift: dist} + return {TransId.SPEC_5_SHIFT: dist} elif monitor >= 0: # Some instruments (i.e. LOQ) do not have monitor 4 on spectrum 4, as ZOOM # is currently the only one with monitor 5 at spectrum 5 we can make it an edge case # If a future instrument wants to use monitor 5 at a different spectrum number or # uses monitor 4 at spectrum 5 this should be updated - return {TransId.spec_4_shift: dist} + return {TransId.SPEC_4_SHIFT: dist} else: raise RuntimeError("Monitor {0} cannot be shifted".format(monitor)) def _extract_trans_spec(self, line): trans_spec_string = re.sub(self._trans_spec, "", line) trans_spec = convert_string_to_integer(trans_spec_string) - return {TransId.spec: trans_spec} + return {TransId.SPEC: trans_spec} def _extract_trans_spec_shift(self, line): # Get the transpec @@ -1423,36 +1423,36 @@ class TransParser(UserFileComponentParser): trans_spec_shift = convert_string_to_float(trans_spec_shift_string) if trans_spec == 5: - return {TransId.spec_5_shift: trans_spec_shift, TransId.spec: trans_spec} + return {TransId.SPEC_5_SHIFT: trans_spec_shift, TransId.SPEC: trans_spec} elif trans_spec >= 0: # Some instruments (i.e. LOQ) do not have monitor 4 on spectrum 4, as ZOOM # is currently the only one with monitor 5 at spectrum 5 we can make it an edge case # If a future instrument wants to use monitor 5 at a different spectrum number or # uses monitor 4 at spectrum 5 this should be updated - return {TransId.spec_4_shift: trans_spec_shift, TransId.spec: trans_spec} + return {TransId.SPEC_4_SHIFT: trans_spec_shift, TransId.SPEC: trans_spec} else: raise RuntimeError("Monitor {0} cannot be shifted".format(trans_spec)) def _extract_radius(self, line): radius_string = re.sub(self._radius, "", line) radius = convert_string_to_float(radius_string) - return {TransId.radius: radius} + return {TransId.RADIUS: radius} def _extract_roi(self, line, original_line): file_names = TransParser.extract_file_names(line, original_line, self._roi) - return {TransId.roi: file_names} + return {TransId.ROI: file_names} def _extract_mask(self, line, original_line): file_names = TransParser.extract_file_names(line, original_line, self._mask) - return {TransId.mask: file_names} + return {TransId.MASK: file_names} def _extract_sample_workspace(self, line, original_line): sample_workspace = TransParser.extract_workspace(line, original_line, self._sample_workspace) - return {TransId.sample_workspace: sample_workspace} + return {TransId.SAMPLE_WORKSPACE: sample_workspace} def _extract_can_workspace(self, line, original_line): can_workspace = TransParser.extract_workspace(line, original_line, self._can_workspace) - return {TransId.can_workspace: can_workspace} + return {TransId.CAN_WORKSPACE: can_workspace} @staticmethod def extract_workspace(line, original_line, to_remove): @@ -1507,7 +1507,7 @@ class TubeCalibFileParser(UserFileComponentParser): def _extract_tube_calib_file(line, original_line): file_name_capital = line.strip() file_name = re.search(file_name_capital, original_line, re.IGNORECASE).group(0) - return {TubeCalibrationFileId.file: file_name} + return {TubeCalibrationFileId.FILE: file_name} @staticmethod def get_type(): @@ -1644,38 +1644,38 @@ class QResolutionParser(UserFileComponentParser): def _extract_on_off(self, line): value = False if re.search(self._off, line) is not None else True - return {QResolutionId.on: value} + return {QResolutionId.ON: value} def _extract_delta_r(self, line): - return {QResolutionId.delta_r: QResolutionParser.extract_float(line, self._delta_r)} + return {QResolutionId.DELTA_R: QResolutionParser.extract_float(line, self._delta_r)} def _extract_collimation_length(self, line): - return {QResolutionId.collimation_length: QResolutionParser.extract_float(line, self._collimation_length)} + return {QResolutionId.COLLIMATION_LENGTH: QResolutionParser.extract_float(line, self._collimation_length)} def _extract_a1(self, line): - return {QResolutionId.a1: QResolutionParser.extract_float(line, self._a1)} + return {QResolutionId.A1: QResolutionParser.extract_float(line, self._a1)} def _extract_a2(self, line): - return {QResolutionId.a2: QResolutionParser.extract_float(line, self._a2)} + return {QResolutionId.A2: QResolutionParser.extract_float(line, self._a2)} def _extract_h1(self, line): - return {QResolutionId.h1: QResolutionParser.extract_float(line, self._h1)} + return {QResolutionId.H1: QResolutionParser.extract_float(line, self._h1)} def _extract_w1(self, line): - return {QResolutionId.w1: QResolutionParser.extract_float(line, self._w1)} + return {QResolutionId.W1: QResolutionParser.extract_float(line, self._w1)} def _extract_h2(self, line): - return {QResolutionId.h2: QResolutionParser.extract_float(line, self._h2)} + return {QResolutionId.H2: QResolutionParser.extract_float(line, self._h2)} def _extract_w2(self, line): - return {QResolutionId.w2: QResolutionParser.extract_float(line, self._w2)} + return {QResolutionId.W2: QResolutionParser.extract_float(line, self._w2)} def _extract_moderator(self, line, original_line): moderator_capital = re.sub(self._moderator, "", line) moderator_capital = re.sub("\"", "", moderator_capital) moderator = re.search(moderator_capital, original_line, re.IGNORECASE).group(0) # Remove quotation marks - return {QResolutionId.moderator: moderator} + return {QResolutionId.MODERATOR: moderator} @staticmethod def extract_float(line, to_remove): @@ -1766,14 +1766,14 @@ class FitParser(UserFileComponentParser): def _extract_monitor(self, line): values_string = re.sub(self._monitor, "", line) values = extract_float_range(values_string) - return {FitId.monitor_times: range_entry(start=values[0], stop=values[1])} + return {FitId.MONITOR_TIMES: range_entry(start=values[0], stop=values[1])} def _extract_general_fit(self, line): fit_type = self._get_fit_type(line) ws_type = self._get_workspace_type(line) wavelength_min, wavelength_max = self._get_wavelength(line) polynomial_order = self._get_polynomial_order(fit_type, line) - return {FitId.general: fit_general(start=wavelength_min, stop=wavelength_max, fit_type=fit_type, + return {FitId.GENERAL: fit_general(start=wavelength_min, stop=wavelength_max, fit_type=fit_type, data_type=ws_type, polynomial_order=polynomial_order)} def _get_wavelength(self, line): @@ -1781,7 +1781,7 @@ class FitParser(UserFileComponentParser): return wavelength_min, wavelength_max def _get_polynomial_order(self, fit_type, line): - if fit_type != FitType.Polynomial: + if fit_type != FitType.POLYNOMIAL: poly_order = 0 else: poly_order, _, _ = self._get_wavelength_and_polynomial(line) @@ -1824,20 +1824,20 @@ class FitParser(UserFileComponentParser): def _get_fit_type(self, line): if re.search(self._log, line) is not None: - fit_type = FitType.Logarithmic + fit_type = FitType.LOGARITHMIC elif re.search(self._lin, line) is not None: - fit_type = FitType.Linear + fit_type = FitType.LINEAR elif re.search(self._polynomial, line) is not None: - fit_type = FitType.Polynomial + fit_type = FitType.POLYNOMIAL else: raise RuntimeError("FitParser: Encountered unknown fit function: {0}".format(line)) return fit_type def _get_workspace_type(self, line): if re.search(self._sample, line) is not None: - ws_type = DataType.Sample + ws_type = DataType.SAMPLE elif re.search(self._can, line) is not None: - ws_type = DataType.Can + ws_type = DataType.CAN else: ws_type = None return ws_type @@ -1847,7 +1847,7 @@ class FitParser(UserFileComponentParser): """ With this we want to clear the fit type settings. """ - return {FitId.general: fit_general(start=None, stop=None, fit_type=FitType.NoFit, + return {FitId.GENERAL: fit_general(start=None, stop=None, fit_type=FitType.NO_FIT, data_type=None, polynomial_order=None)} @staticmethod @@ -1902,12 +1902,12 @@ class GravityParser(UserFileComponentParser): def _extract_on_off(self, line): value = re.sub(self._on, "", line).strip() == "" - return {GravityId.on_off: value} + return {GravityId.ON_OFF: value} def _extract_extra_length(self, line): extra_length_string = re.sub(self._extra_length, "", line) extra_length = convert_string_to_float(extra_length_string) - return {GravityId.extra_length: extra_length} + return {GravityId.EXTRA_LENGTH: extra_length} @staticmethod def get_type(): @@ -1951,7 +1951,7 @@ class CompatibilityParser(UserFileComponentParser): def _extract_on_off(self, line): value = re.sub(self._on, "", line).strip() == "" - return {OtherId.use_compatibility_mode: value} + return {OtherId.USE_COMPATIBILITY_MODE: value} @staticmethod def get_type(): @@ -1996,7 +1996,7 @@ class MaskFileParser(UserFileComponentParser): def extract_mask_file(line, original_line): elements_capital = extract_string_list(line) elements = [re.search(element, original_line, re.IGNORECASE).group(0) for element in elements_capital] - return {MaskId.file: elements} + return {MaskId.FILE: elements} @staticmethod def get_type(): @@ -2106,7 +2106,7 @@ class MonParser(UserFileComponentParser): if len(length_entries) != 2: raise RuntimeError("MonParser: Length setting needs 2 numeric parameters, " "but received {0}.".format(len(length_entries))) - return {MonId.length: monitor_length(length=length_entries[0], spectrum=length_entries[1], + return {MonId.LENGTH: monitor_length(length=length_entries[0], spectrum=length_entries[1], interpolate=interpolate)} def _extract_direct(self, line, original_line): @@ -2125,7 +2125,7 @@ class MonParser(UserFileComponentParser): output.append(monitor_file(file_path=file_path, detector_type=DetectorType.HAB)) if is_lab: output.append(monitor_file(file_path=file_path, detector_type=DetectorType.LAB)) - return {MonId.direct: output} + return {MonId.DIRECT: output} def _extract_flat(self, line, original_line): # If we have a HAB specified then select HAB @@ -2133,12 +2133,12 @@ class MonParser(UserFileComponentParser): # If nothing is specified then select LAB detector_type = DetectorType.HAB if re.search(self._hab, line, re.IGNORECASE) else DetectorType.LAB file_path = self._extract_file_path(line, original_line, self._flat) - return {MonId.flat: monitor_file(file_path=file_path, detector_type=detector_type)} + return {MonId.FLAT: monitor_file(file_path=file_path, detector_type=detector_type)} def _extract_hab(self, line, original_line): # This is the same as direct/front file_path = self._extract_file_path(line, original_line, self._hab_file) - return {MonId.direct: [monitor_file(file_path=file_path, detector_type=DetectorType.HAB)]} + return {MonId.DIRECT: [monitor_file(file_path=file_path, detector_type=DetectorType.HAB)]} def _extract_file_path(self, line, original_line, to_remove): direct = re.sub(self._detector, "", line) @@ -2177,7 +2177,7 @@ class MonParser(UserFileComponentParser): line = re.sub(self._spectrum, "", line) line = re.sub(self._equal, "", line) spectrum = convert_string_to_integer(line) - return {MonId.spectrum: monitor_spectrum(spectrum=spectrum, is_trans=is_trans, interpolate=interpolate)} + return {MonId.SPECTRUM: monitor_spectrum(spectrum=spectrum, is_trans=is_trans, interpolate=interpolate)} @staticmethod def get_type(): @@ -2210,7 +2210,7 @@ class PrintParser(UserFileComponentParser): else: raise RuntimeError("PrintParser: Failed to extract line {} it does not start with {}".format(line, PrintParser.Type)) - return {PrintId.print_line: setting} + return {PrintId.PRINT_LINE: setting} @staticmethod def get_type(): diff --git a/scripts/test/SANS/algorithm_detail/batch_execution_test.py b/scripts/test/SANS/algorithm_detail/batch_execution_test.py index 502a3c24c5b6395f6c49b84597de4a96a0a3fe09..ec40b9a965ac7e2e761c403db9a290f1cabcca97 100644 --- a/scripts/test/SANS/algorithm_detail/batch_execution_test.py +++ b/scripts/test/SANS/algorithm_detail/batch_execution_test.py @@ -8,8 +8,8 @@ from __future__ import (absolute_import, division, print_function) import unittest -from mantid.simpleapi import CreateSampleWorkspace from mantid.py3compat import mock +from mantid.simpleapi import CreateSampleWorkspace from sans.algorithm_detail.batch_execution import (get_all_names_to_save, get_transmission_names_to_save, ReductionPackage, select_reduction_alg, save_workspace_to_file) from sans.common.enums import SaveType @@ -268,7 +268,7 @@ class GetAllNamesToSaveTest(unittest.TestCase): ws_name = "wsName" filename = "fileName" additional_run_numbers = {} - file_types = [SaveType.Nexus, SaveType.CanSAS, SaveType.NXcanSAS, SaveType.NistQxy, SaveType.RKH, SaveType.CSV] + file_types = [SaveType.NEXUS, SaveType.CAN_SAS, SaveType.NX_CAN_SAS, SaveType.NIST_QXY, SaveType.RKH, SaveType.CSV] save_workspace_to_file(ws_name, file_types, filename, additional_run_numbers) diff --git a/scripts/test/SANS/algorithm_detail/calculate_sans_transmission_test.py b/scripts/test/SANS/algorithm_detail/calculate_sans_transmission_test.py index aabb91e4879e84c5d4ad3a6f7bf0f91fcdf66566..ca0d53e1fbeed41015f134f4b756edfc6d66dbd2 100644 --- a/scripts/test/SANS/algorithm_detail/calculate_sans_transmission_test.py +++ b/scripts/test/SANS/algorithm_detail/calculate_sans_transmission_test.py @@ -149,22 +149,22 @@ class CalculateSansTransmissionTest(unittest.TestCase): calculate_transmission_builder.set_background_TOF_roi_stop(background_TOF_roi_stop) if sample_fit_type: - calculate_transmission_builder.set_Sample_fit_type(sample_fit_type) + calculate_transmission_builder.set_sample_fit_type(sample_fit_type) if sample_polynomial_order: - calculate_transmission_builder.set_Sample_polynomial_order(sample_polynomial_order) + calculate_transmission_builder.set_sample_polynomial_order(sample_polynomial_order) if sample_wavelength_low: - calculate_transmission_builder.set_Sample_wavelength_low(sample_wavelength_low) + calculate_transmission_builder.set_sample_wavelength_low(sample_wavelength_low) if sample_wavelength_high: - calculate_transmission_builder.set_Sample_wavelength_high(sample_wavelength_high) + calculate_transmission_builder.set_sample_wavelength_high(sample_wavelength_high) if can_fit_type: - calculate_transmission_builder.set_Can_fit_type(can_fit_type) + calculate_transmission_builder.set_can_fit_type(can_fit_type) if can_polynomial_order: - calculate_transmission_builder.set_Can_polynomial_order(can_polynomial_order) + calculate_transmission_builder.set_can_polynomial_order(can_polynomial_order) if can_wavelength_low: - calculate_transmission_builder.set_Can_wavelength_low(can_wavelength_low) + calculate_transmission_builder.set_can_wavelength_low(can_wavelength_low) if can_wavelength_high: - calculate_transmission_builder.set_Can_wavelength_high(can_wavelength_high) + calculate_transmission_builder.set_can_wavelength_high(can_wavelength_high) calculate_transmission = calculate_transmission_builder.build() state.adjustment.calculate_transmission = calculate_transmission return state @@ -234,12 +234,12 @@ class CalculateSansTransmissionTest(unittest.TestCase): def test_that_calculates_transmission_for_general_background_and_no_prompt_peak(self): # Arrange - state = CalculateSansTransmissionTest._get_state(rebin_type=RebinType.Rebin, wavelength_low=2., + state = CalculateSansTransmissionTest._get_state(rebin_type=RebinType.REBIN, wavelength_low=2., wavelength_high=8., wavelength_step=2., - wavelength_step_type=RangeStepType.Lin, + wavelength_step_type=RangeStepType.LIN, background_TOF_general_start=5000., background_TOF_general_stop=10000., incident_monitor=1, - transmission_monitor=3, sample_fit_type=FitType.Linear, + transmission_monitor=3, sample_fit_type=FitType.LINEAR, sample_polynomial_order=0, sample_wavelength_low=2., sample_wavelength_high=8.) # Get a test monitor workspace with 4 bins where the first bin is the back ground @@ -270,16 +270,16 @@ class CalculateSansTransmissionTest(unittest.TestCase): fix_for_remove_bins = 1e-6 background_TOF_monitor_start = {str(incident_spectrum): 5000., str(transmission_spectrum): 5000.} background_TOF_monitor_stop = {str(incident_spectrum): 10000., str(transmission_spectrum): 10000.} - state = CalculateSansTransmissionTest._get_state(rebin_type=RebinType.Rebin, wavelength_low=2., + state = CalculateSansTransmissionTest._get_state(rebin_type=RebinType.REBIN, wavelength_low=2., wavelength_high=8., wavelength_step=2., - wavelength_step_type=RangeStepType.Lin, + wavelength_step_type=RangeStepType.LIN, prompt_peak_correction_min=15000. + fix_for_remove_bins, prompt_peak_correction_max=20000., background_TOF_monitor_start=background_TOF_monitor_start, background_TOF_monitor_stop=background_TOF_monitor_stop, incident_monitor=incident_spectrum, transmission_monitor=transmission_spectrum, - can_fit_type=FitType.Linear, + can_fit_type=FitType.LINEAR, can_polynomial_order=0, can_wavelength_low=2., can_wavelength_high=8.) # Get a test monitor workspace with 4 bins where the first bin is the back ground @@ -312,13 +312,13 @@ class CalculateSansTransmissionTest(unittest.TestCase): # This test picks the monitor detector ids based on a radius around the centre of the detector. This is much # more tricky to test here and in principle the main tests should be happening in the actual # CalculateTransmission algorithm. - state = CalculateSansTransmissionTest._get_state(rebin_type=RebinType.Rebin, wavelength_low=2., + state = CalculateSansTransmissionTest._get_state(rebin_type=RebinType.REBIN, wavelength_low=2., wavelength_high=8., wavelength_step=2., - wavelength_step_type=RangeStepType.Lin, + wavelength_step_type=RangeStepType.LIN, background_TOF_general_start=5000., background_TOF_general_stop=10000., incident_monitor=1, transmission_radius_on_detector=0.01, - sample_fit_type=FitType.Linear, + sample_fit_type=FitType.LINEAR, sample_polynomial_order=0, sample_wavelength_low=2., sample_wavelength_high=8.) # Gets the full workspace diff --git a/scripts/test/SANS/algorithm_detail/calculate_transmission_helper_test.py b/scripts/test/SANS/algorithm_detail/calculate_transmission_helper_test.py index 585e49389b4679d311e451f653ff008f66b88ec1..cbee6e15527faab9c07fb9cb96c5984d0ff0f57a 100644 --- a/scripts/test/SANS/algorithm_detail/calculate_transmission_helper_test.py +++ b/scripts/test/SANS/algorithm_detail/calculate_transmission_helper_test.py @@ -5,12 +5,13 @@ # & Institut Laue - Langevin # SPDX - License - Identifier: GPL - 3.0 + from __future__ import (absolute_import, division, print_function) -import unittest -import mantid + import os +import unittest + +from mantid.api import AnalysisDataService from mantid.kernel import config from mantid.simpleapi import (CreateSampleWorkspace, MaskDetectors, DeleteWorkspace, LoadNexusProcessed, Load, Rebin) -from mantid.api import AnalysisDataService from sans.algorithm_detail.calculate_transmission_helper import (get_masked_det_ids, get_idf_path_from_workspace, get_workspace_indices_for_monitors, diff --git a/scripts/test/SANS/algorithm_detail/centre_finder_new_test.py b/scripts/test/SANS/algorithm_detail/centre_finder_new_test.py index 9c08696fa7c9dc2d38e7f845e97ec0dc1a9f3341..796a9db5d1c48b6e2d996c0905a4c5f76c7b2a62 100644 --- a/scripts/test/SANS/algorithm_detail/centre_finder_new_test.py +++ b/scripts/test/SANS/algorithm_detail/centre_finder_new_test.py @@ -25,17 +25,17 @@ class CentreFinderNewTest(unittest.TestCase): position_1_start = 300 position_2_start = -300 tolerance = 0.001 - find_direction = FindDirectionEnum.All + find_direction = FindDirectionEnum.ALL iterations = 10 verbose = False - load_data_mock.return_value = {SANSDataType.SampleScatter: [mock.MagicMock()]}, { - SANSDataType.SampleScatter: [mock.MagicMock()]} + load_data_mock.return_value = {SANSDataType.SAMPLE_SCATTER: [mock.MagicMock()]}, { + SANSDataType.SAMPLE_SCATTER: [mock.MagicMock()]} beam_centre_finder = "SANSBeamCentreFinder" beam_centre_finder_options = {"Component":'LAB', "Iterations": iterations, "RMin": r_min / 1000, "RMax": r_max / 1000, "Position1Start": position_1_start, "Position2Start": position_2_start, - "Tolerance": tolerance, "Direction": FindDirectionEnum.to_string(find_direction), + "Tolerance": tolerance, "Direction": find_direction.value, "Verbose": verbose} centre_finder_new(self.state, r_min=r_min, r_max=r_max, iterations=iterations, position_1_start=position_1_start @@ -53,8 +53,8 @@ class CentreFinderNewTest(unittest.TestCase): tolerance = 0.001 iterations = 10 - load_data_mock.return_value = {SANSDataType.SampleScatter: [mock.MagicMock()]}, { - SANSDataType.SampleScatter: [mock.MagicMock()]} + load_data_mock.return_value = {SANSDataType.SAMPLE_SCATTER: [mock.MagicMock()]}, { + SANSDataType.SAMPLE_SCATTER: [mock.MagicMock()]} beam_centre_finder = "SANSBeamCentreFinderMassMethod" beam_centre_finder_options = {"RMin": r_min / 1000, diff --git a/scripts/test/SANS/algorithm_detail/convert_to_q_test.py b/scripts/test/SANS/algorithm_detail/convert_to_q_test.py index 6cb113767f39d9931d222173e783d7fd7fc60120..9a02ac4a89f43795beac8025af0d03534cb99b4c 100644 --- a/scripts/test/SANS/algorithm_detail/convert_to_q_test.py +++ b/scripts/test/SANS/algorithm_detail/convert_to_q_test.py @@ -38,9 +38,9 @@ class ConvertToQTest(unittest.TestCase): return workspace @staticmethod - def _get_sample_state(q_min=1., q_max=2., q_step=0.1, q_step_type=RangeStepType.Lin, + def _get_sample_state(q_min=1., q_max=2., q_step=0.1, q_step_type=RangeStepType.LIN, q_xy_max=None, q_xy_step=None, q_xy_step_type=None, - use_gravity=False, dim=ReductionDimensionality.OneDim): + use_gravity=False, dim=ReductionDimensionality.ONE_DIM): facility = SANSFacility.ISIS file_information = SANSFileInformationMock(instrument=SANSInstrument.LOQ, run_number=74044) data_builder = get_data_builder(facility, file_information) @@ -54,7 +54,7 @@ class ConvertToQTest(unittest.TestCase): convert_to_q_builder.set_wavelength_cutoff(2.) convert_to_q_builder.set_q_min(q_min) convert_to_q_builder.set_q_max(q_max) - prefix = 1. if q_step_type is RangeStepType.Lin else -1. + prefix = 1. if q_step_type is RangeStepType.LIN else -1. q_step *= prefix rebin_string = str(q_min) + "," + str(q_step) + "," + str(q_max) convert_to_q_builder.set_q_1d_rebin_string(rebin_string) @@ -77,7 +77,7 @@ class ConvertToQTest(unittest.TestCase): workspace = self._get_workspace(is_adjustment=False) adj_workspace = self._get_workspace(is_adjustment=True) - state = self._get_sample_state(q_min=1., q_max=2., q_step=0.1, q_step_type=RangeStepType.Lin) + state = self._get_sample_state(q_min=1., q_max=2., q_step=0.1, q_step_type=RangeStepType.LIN) # Act output_dict = convert_workspace(workspace=workspace, output_summed_parts=True, @@ -108,8 +108,8 @@ class ConvertToQTest(unittest.TestCase): workspace = self._get_workspace(is_adjustment=False) adj_workspace = self._get_workspace(is_adjustment=True) - state = self._get_sample_state(q_xy_max=2., q_xy_step=0.5, q_xy_step_type=RangeStepType.Lin, - dim=ReductionDimensionality.TwoDim) + state = self._get_sample_state(q_xy_max=2., q_xy_step=0.5, q_xy_step_type=RangeStepType.LIN, + dim=ReductionDimensionality.TWO_DIM) output_dict = convert_workspace(workspace=workspace, output_summed_parts=True, state_convert_to_q=state.convert_to_q, wavelength_adj_workspace=adj_workspace) diff --git a/scripts/test/SANS/algorithm_detail/create_sans_adjustment_workspaces_test.py b/scripts/test/SANS/algorithm_detail/create_sans_adjustment_workspaces_test.py index 935a608724c51dda699f452a5cbf2618f06e8d39..2530fc5dd939b8590c921e486f096880b5baff35 100644 --- a/scripts/test/SANS/algorithm_detail/create_sans_adjustment_workspaces_test.py +++ b/scripts/test/SANS/algorithm_detail/create_sans_adjustment_workspaces_test.py @@ -80,8 +80,8 @@ class CreateSANSAdjustmentWorkspacesTest(unittest.TestCase): @staticmethod def _run_test(state, sample_data, sample_monitor_data, transmission_data, direct_data, is_lab=True, is_sample=True): - data_type = DataType.to_string(DataType.Sample) if is_sample else DataType.to_string(DataType.Can) - component = DetectorType.to_string(DetectorType.LAB) if is_lab else DetectorType.to_string(DetectorType.HAB) + data_type = DataType.SAMPLE.value if is_sample else DataType.CAN.value + component = DetectorType.LAB.value if is_lab else DetectorType.HAB.value alg = CreateSANSAdjustmentWorkspaces(state_adjustment=state.adjustment, component=component, data_type=data_type) diff --git a/scripts/test/SANS/algorithm_detail/create_sans_wavelength_pixel_adjustment_test.py b/scripts/test/SANS/algorithm_detail/create_sans_wavelength_pixel_adjustment_test.py index db2646d4426262abfd9c26ee5c987e5e862e0b7a..a7118c9f349f19be65a91b37cd9d5beb44afea58 100644 --- a/scripts/test/SANS/algorithm_detail/create_sans_wavelength_pixel_adjustment_test.py +++ b/scripts/test/SANS/algorithm_detail/create_sans_wavelength_pixel_adjustment_test.py @@ -92,7 +92,7 @@ class CreateSANSWavelengthPixelAdjustmentTest(unittest.TestCase): def _run_test(transmission_workspace, norm_workspace, state, is_lab=True): state_to_send = state.adjustment.wavelength_and_pixel_adjustment - component = DetectorType.to_string(DetectorType.LAB) if is_lab else DetectorType.to_string(DetectorType.HAB) + component = DetectorType.LAB.value if is_lab else DetectorType.HAB.value alg = CreateSANSWavelengthPixelAdjustment(state_adjustment_wavelength_and_pixel=state_to_send, component=component) @@ -111,7 +111,7 @@ class CreateSANSWavelengthPixelAdjustmentTest(unittest.TestCase): state = self._get_state(wavelength_low=1., wavelength_high=11., wavelength_step=2., - wavelength_step_type=RangeStepType.Lin) + wavelength_step_type=RangeStepType.LIN) # Act wavelength_adjustment, pixel_adjustment = self._run_test(transmission_workspace, @@ -141,7 +141,7 @@ class CreateSANSWavelengthPixelAdjustmentTest(unittest.TestCase): state = CreateSANSWavelengthPixelAdjustmentTest._get_state(hab_wavelength_file=direct_file_name, wavelength_low=1., wavelength_high=11., wavelength_step=2., - wavelength_step_type=RangeStepType.Lin) + wavelength_step_type=RangeStepType.LIN) # Act wavelength_adjustment, pixel_adjustment = CreateSANSWavelengthPixelAdjustmentTest._run_test( transmission_workspace, diff --git a/scripts/test/SANS/algorithm_detail/crop_helper_test.py b/scripts/test/SANS/algorithm_detail/crop_helper_test.py index f0bc06419790b3f66cd3daed33b6d612aaf538d2..d38f63917813222ce7fc1aca3e69481740b971d7 100644 --- a/scripts/test/SANS/algorithm_detail/crop_helper_test.py +++ b/scripts/test/SANS/algorithm_detail/crop_helper_test.py @@ -5,13 +5,14 @@ # & Institut Laue - Langevin # SPDX - License - Identifier: GPL - 3.0 + from __future__ import (absolute_import, division, print_function) + import unittest -import mantid + +from mantid.api import FileFinder from sans.algorithm_detail.crop_helper import get_component_name -from sans.common.enums import DetectorType from sans.common.constants import EMPTY_NAME +from sans.common.enums import DetectorType from sans.common.general_functions import create_unmanaged_algorithm -from mantid.api import FileFinder class CropHelperTest(unittest.TestCase): diff --git a/scripts/test/SANS/algorithm_detail/merge_reductions_test.py b/scripts/test/SANS/algorithm_detail/merge_reductions_test.py index 7517da03e0ae72bcf6696f10fc38e70cce48c8d4..53385fbed48c247122f149e3dfd061b1dee25310 100644 --- a/scripts/test/SANS/algorithm_detail/merge_reductions_test.py +++ b/scripts/test/SANS/algorithm_detail/merge_reductions_test.py @@ -5,19 +5,18 @@ # & Institut Laue - Langevin # SPDX - License - Identifier: GPL - 3.0 + from __future__ import (absolute_import, division, print_function) + import unittest -import mantid -from sans.algorithm_detail.merge_reductions import (MergeFactory, ISIS1DMerger) -from sans.algorithm_detail.bundles import OutputPartsBundle +from sans.algorithm_detail.bundles import OutputPartsBundle +from sans.algorithm_detail.merge_reductions import (MergeFactory, ISIS1DMerger) +from sans.common.constants import EMPTY_NAME +from sans.common.enums import (DataType, ReductionMode) +from sans.common.enums import (ReductionDimensionality, FitModeForMerge) +from sans.common.general_functions import create_unmanaged_algorithm from sans.state.reduction_mode import StateReductionMode from sans.test_helper.test_director import TestDirector -from sans.common.enums import (ISISReductionMode, ReductionDimensionality, FitModeForMerge) -from sans.common.general_functions import create_unmanaged_algorithm -from sans.common.constants import EMPTY_NAME -from sans.common.enums import (DataType, ISISReductionMode) - class MergeReductionsTest(unittest.TestCase): @staticmethod @@ -33,11 +32,11 @@ class MergeReductionsTest(unittest.TestCase): return create_alg.getProperty('OutputWorkspace').value @staticmethod - def _get_simple_state(fit_type=FitModeForMerge.NoFit, scale=1.0, shift=0.0): + def _get_simple_state(fit_type=FitModeForMerge.NO_FIT, scale=1.0, shift=0.0): # Set the reduction parameters reduction_info = StateReductionMode() - reduction_info.reduction_mode = ISISReductionMode.Merged - reduction_info.dimensionality = ReductionDimensionality.TwoDim + reduction_info.reduction_mode = ReductionMode.MERGED + reduction_info.dimensionality = ReductionDimensionality.TWO_DIM reduction_info.merge_shift = shift reduction_info.merge_scale = scale reduction_info.merge_fit_mode = fit_type @@ -52,12 +51,12 @@ class MergeReductionsTest(unittest.TestCase): data_x_hab, data_y_hab_count, data_y_hab_norm): lab_count = MergeReductionsTest.create_1D_workspace(data_x_lab, data_y_lab_count) lab_norm = MergeReductionsTest.create_1D_workspace(data_x_lab, data_y_lab_norm) - lab_bundle = OutputPartsBundle(state=state, data_type=data_type, reduction_mode=ISISReductionMode.LAB, + lab_bundle = OutputPartsBundle(state=state, data_type=data_type, reduction_mode=ReductionMode.LAB, output_workspace_count=lab_count, output_workspace_norm=lab_norm) hab_count = MergeReductionsTest.create_1D_workspace(data_x_hab, data_y_hab_count) hab_norm = MergeReductionsTest.create_1D_workspace(data_x_hab, data_y_hab_norm) - hab_bundle = OutputPartsBundle(state=state, data_type=data_type, reduction_mode=ISISReductionMode.HAB, + hab_bundle = OutputPartsBundle(state=state, data_type=data_type, reduction_mode=ReductionMode.HAB, output_workspace_count=hab_count, output_workspace_norm=hab_norm) return lab_bundle, hab_bundle @@ -71,7 +70,7 @@ class MergeReductionsTest(unittest.TestCase): data_x_hab = list(range(0, 10)) data_y_hab_count = [3.] * 10 data_y_hab_norm = [4.] * 10 - sample_lab, sample_hab = MergeReductionsTest._create_workspaces(state, DataType.Sample, data_x_lab, + sample_lab, sample_hab = MergeReductionsTest._create_workspaces(state, DataType.SAMPLE, data_x_lab, data_y_lab_count, data_y_lab_norm, data_x_hab, data_y_hab_count, data_y_hab_norm) @@ -83,7 +82,7 @@ class MergeReductionsTest(unittest.TestCase): data_x_hab = list(range(0, 10)) data_y_hab_count = [7.] * 10 data_y_hab_norm = [8.] * 10 - can_lab, can_hab = MergeReductionsTest._create_workspaces(state, DataType.Can, data_x_lab, + can_lab, can_hab = MergeReductionsTest._create_workspaces(state, DataType.CAN, data_x_lab, data_y_lab_count, data_y_lab_norm, data_x_hab, data_y_hab_count, data_y_hab_norm) return sample_lab, sample_hab, can_lab, can_hab @@ -101,7 +100,7 @@ class MergeReductionsTest(unittest.TestCase): def test_that_can_merge_without_fitting(self): # Arrange - fit_type = FitModeForMerge.NoFit + fit_type = FitModeForMerge.NO_FIT scale_input = 32.0 shift_input = 12.65 state = self._get_simple_state(fit_type, scale_input, shift_input) @@ -110,8 +109,8 @@ class MergeReductionsTest(unittest.TestCase): sample_lab, sample_hab, can_lab, can_hab = self._provide_data(state) - bundles = {ISISReductionMode.LAB: [sample_lab, can_lab], - ISISReductionMode.HAB: [sample_hab, can_hab]} + bundles = {ReductionMode.LAB: [sample_lab, can_lab], + ReductionMode.HAB: [sample_hab, can_hab]} # Act result = merger.merge(bundles) @@ -127,7 +126,7 @@ class MergeReductionsTest(unittest.TestCase): def test_that_can_merge_fitting(self): # Arrange - fit_type = FitModeForMerge.Both + fit_type = FitModeForMerge.BOTH scale_input = 1.67 shift_input = 2.7 state = self._get_simple_state(fit_type, scale_input, shift_input) @@ -135,8 +134,8 @@ class MergeReductionsTest(unittest.TestCase): merger = merge_factory.create_merger(state) sample_lab, sample_hab, can_lab, can_hab = self._provide_data(state) - bundles = {ISISReductionMode.LAB: [sample_lab, can_lab], - ISISReductionMode.HAB: [sample_hab, can_hab]} + bundles = {ReductionMode.LAB: [sample_lab, can_lab], + ReductionMode.HAB: [sample_hab, can_hab]} # Act result = merger.merge(bundles) @@ -153,7 +152,7 @@ class MergeReductionsTest(unittest.TestCase): def test_that_can_merge_with_shift_only_fitting(self): # Arrange - fit_type = FitModeForMerge.ShiftOnly + fit_type = FitModeForMerge.SHIFT_ONLY scale_input = 1.67 shift_input = 2.7 state = self._get_simple_state(fit_type, scale_input, shift_input) @@ -161,8 +160,8 @@ class MergeReductionsTest(unittest.TestCase): merger = merge_factory.create_merger(state) sample_lab, sample_hab, can_lab, can_hab = self._provide_data(state) - bundles = {ISISReductionMode.LAB: [sample_lab, can_lab], - ISISReductionMode.HAB: [sample_hab, can_hab]} + bundles = {ReductionMode.LAB: [sample_lab, can_lab], + ReductionMode.HAB: [sample_hab, can_hab]} # Act result = merger.merge(bundles) @@ -179,7 +178,7 @@ class MergeReductionsTest(unittest.TestCase): def test_that_can_merge_with_scale_only_fitting(self): # Arrange - fit_type = FitModeForMerge.ScaleOnly + fit_type = FitModeForMerge.SCALE_ONLY scale_input = 1.67 shift_input = 2.7 state = self._get_simple_state(fit_type, scale_input, shift_input) @@ -187,8 +186,8 @@ class MergeReductionsTest(unittest.TestCase): merger = merge_factory.create_merger(state) sample_lab, sample_hab, can_lab, can_hab = self._provide_data(state) - bundles = {ISISReductionMode.LAB: [sample_lab, can_lab], - ISISReductionMode.HAB: [sample_hab, can_hab]} + bundles = {ReductionMode.LAB: [sample_lab, can_lab], + ReductionMode.HAB: [sample_hab, can_hab]} # Act result = merger.merge(bundles) diff --git a/scripts/test/SANS/algorithm_detail/move_sans_instrument_component_test.py b/scripts/test/SANS/algorithm_detail/move_sans_instrument_component_test.py index 7a7dbd24a30a728a3a81af370f94f9676b2cdfc8..2995d13af920ce55c521a323c8b53be6f7e97922 100644 --- a/scripts/test/SANS/algorithm_detail/move_sans_instrument_component_test.py +++ b/scripts/test/SANS/algorithm_detail/move_sans_instrument_component_test.py @@ -98,8 +98,8 @@ def check_that_sets_to_zero(instance, workspace, move_info, comp_name=None): # Get the components to compare if comp_name is None: component_names = list(move_info.monitor_names.values()) - hab_name = DetectorType.to_string(DetectorType.HAB) - lab_name = DetectorType.to_string(DetectorType.LAB), + hab_name = DetectorType.HAB.value + lab_name = DetectorType.LAB.value, _get_components_to_compare(hab_name, move_info, component_names) _get_components_to_compare(lab_name, move_info, component_names) component_names.append("some-sample-holder") @@ -160,7 +160,7 @@ class LOQMoveTest(unittest.TestCase): lab_x_translation_correction = 123. beam_coordinates = [45, 25] component = "main-detector-bank" - component_key = DetectorType.to_string(DetectorType.LAB) + component_key = DetectorType.LAB.value workspace = load_empty_instrument("LOQ") move_info = _get_state_move_obj(SANSInstrument.LOQ, lab_x_translation_correction) @@ -179,7 +179,7 @@ class LOQMoveTest(unittest.TestCase): # Elementary Move component_elementary_move = "HAB" - component_elementary_move_key = DetectorType.to_string(DetectorType.HAB) + component_elementary_move_key = DetectorType.HAB.value beam_coordinates_elementary_move = [120, 135] check_elementry_displacement_with_translation(self, workspace, move_info, @@ -221,7 +221,7 @@ class SANS2DMoveTest(unittest.TestCase): # Assert for initial move for low angle bank # These values are on the workspace and in the sample logs, - component_to_investigate = DetectorType.to_string(DetectorType.LAB) + component_to_investigate = DetectorType.LAB.value initial_z_position = 23.281 rear_det_z = 11.9989755859 offset = 4. @@ -235,7 +235,7 @@ class SANS2DMoveTest(unittest.TestCase): # Assert for initial move for high angle bank # These values are on the workspace and in the sample logs - component_to_investigate = DetectorType.to_string(DetectorType.HAB) + component_to_investigate = DetectorType.HAB.value initial_x_position = 1.1 x_correction = -0.187987540973 initial_z_position = 23.281 @@ -250,7 +250,7 @@ class SANS2DMoveTest(unittest.TestCase): # Act + Assert for elementary move component_elementary_move = "rear-detector" - component_elementary_move_key = DetectorType.to_string(DetectorType.LAB) + component_elementary_move_key = DetectorType.LAB.value beam_coordinates_elementary_move = [120, 135] check_elementry_displacement_with_translation(self, workspace, move_info, beam_coordinates_elementary_move, @@ -268,8 +268,8 @@ class SANS2DMoveTest(unittest.TestCase): z_translation=lab_z_translation_correction) # These values should be used instead of an explicitly specified beam centre - move_info.detectors[DetectorType.to_string(DetectorType.HAB)].sample_centre_pos1 = 26. - move_info.detectors[DetectorType.to_string(DetectorType.HAB)].sample_centre_pos2 = 98. + move_info.detectors[DetectorType.HAB.value].sample_centre_pos1 = 26. + move_info.detectors[DetectorType.HAB.value].sample_centre_pos2 = 98. # The component input is not relevant for SANS2D's initial move. All detectors are moved component = "front-detector" @@ -277,7 +277,7 @@ class SANS2DMoveTest(unittest.TestCase): move_type=MoveTypes.INITIAL_MOVE) # These values are on the workspace and in the sample logs - component_to_investigate = DetectorType.to_string(DetectorType.HAB) + component_to_investigate = DetectorType.HAB.value initial_x_position = 1.1 x_correction = -0.187987540973 initial_z_position = 23.281 @@ -300,8 +300,8 @@ class SANS2DMoveTest(unittest.TestCase): # These values should be used instead of an explicitly specified beam centre x_offset = 26. y_offset = 98. - move_info.detectors[DetectorType.to_string(DetectorType.LAB)].sample_centre_pos1 = x_offset - move_info.detectors[DetectorType.to_string(DetectorType.LAB)].sample_centre_pos2 = y_offset + move_info.detectors[DetectorType.LAB.value].sample_centre_pos1 = x_offset + move_info.detectors[DetectorType.LAB.value].sample_centre_pos2 = y_offset # The component input is not relevant for SANS2D's initial move. All detectors are moved component = None @@ -310,7 +310,7 @@ class SANS2DMoveTest(unittest.TestCase): # Assert for initial move for low angle bank # These values are on the workspace and in the sample logs, - component_to_investigate = DetectorType.to_string(DetectorType.LAB) + component_to_investigate = DetectorType.LAB.value initial_z_position = 23.281 rear_det_z = 11.9989755859 offset = 4. @@ -344,7 +344,7 @@ class LARMORMoveTest(unittest.TestCase): # Assert low angle bank for initial move # These values are on the workspace and in the sample logs - component_to_investigate = DetectorType.to_string(DetectorType.LAB) + component_to_investigate = DetectorType.LAB.value # The rotation couples the movements, hence we just insert absolute value, to have a type of regression test. expected_position = V3D(0, -38, 25.3) @@ -359,7 +359,7 @@ class ZOOMMoveTest(unittest.TestCase): def test_that_ZOOM_can_perform_move(self): beam_coordinates = [45, 25] component = "main-detector-bank" - component_key = DetectorType.to_string(DetectorType.LAB) + component_key = DetectorType.LAB.value workspace = load_empty_instrument("ZOOM") move_info = _get_state_move_obj(SANSInstrument.ZOOM) @@ -377,7 +377,7 @@ class ZOOMMoveTest(unittest.TestCase): # Elementary Move component_elementary_move = "LAB" - component_elementary_move_key = DetectorType.to_string(DetectorType.LAB) + component_elementary_move_key = DetectorType.LAB.value beam_coordinates_elementary_move = [120, 135] check_elementry_displacement_with_translation(self, workspace, move_info, diff --git a/scripts/test/SANS/algorithm_detail/move_workspaces_test.py b/scripts/test/SANS/algorithm_detail/move_workspaces_test.py index f09b6b1df6090812ec742d41f0c35b422e73b9ac..553921a04cb61f3261a252fe264b8e9f099c0214 100644 --- a/scripts/test/SANS/algorithm_detail/move_workspaces_test.py +++ b/scripts/test/SANS/algorithm_detail/move_workspaces_test.py @@ -38,7 +38,7 @@ def calculate_new_pos_rel_to_rear(ws, move_info, offset): def get_rear_detector_pos(move_info, ws): - lab_detector = move_info.detectors[DetectorType.to_string(DetectorType.LAB)] + lab_detector = move_info.detectors[DetectorType.LAB.value] detector_name = lab_detector.detector_name comp_info = ws.componentInfo() lab_detector_index = comp_info.indexOfAny(detector_name) diff --git a/scripts/test/SANS/algorithm_detail/normalize_to_sans_monitor_test.py b/scripts/test/SANS/algorithm_detail/normalize_to_sans_monitor_test.py index 8fe28ce134f1c86f0242454df9c030835366aff1..da6afc0231baa6b730ecadc703cbd0e52c15384f 100644 --- a/scripts/test/SANS/algorithm_detail/normalize_to_sans_monitor_test.py +++ b/scripts/test/SANS/algorithm_detail/normalize_to_sans_monitor_test.py @@ -68,11 +68,11 @@ class SANSNormalizeToMonitorTest(unittest.TestCase): data_state = state.data normalize_to_monitor_builder = get_normalize_to_monitor_builder(data_state) - normalize_to_monitor_builder.set_rebin_type(RebinType.Rebin) + normalize_to_monitor_builder.set_rebin_type(RebinType.REBIN) normalize_to_monitor_builder.set_wavelength_low([2.]) normalize_to_monitor_builder.set_wavelength_high([8.]) normalize_to_monitor_builder.set_wavelength_step(2.) - normalize_to_monitor_builder.set_wavelength_step_type(RangeStepType.Lin) + normalize_to_monitor_builder.set_wavelength_step_type(RangeStepType.LIN) if background_TOF_general_start: normalize_to_monitor_builder.set_background_TOF_general_start(background_TOF_general_start) if background_TOF_general_stop: diff --git a/scripts/test/SANS/algorithm_detail/scale_sans_workspace_test.py b/scripts/test/SANS/algorithm_detail/scale_sans_workspace_test.py index a1c31fc3b47f2c8bd523b10b8505ae296ac21eb8..3ab61611cb93fe01784ce0f4e377a838c9fcc0c6 100644 --- a/scripts/test/SANS/algorithm_detail/scale_sans_workspace_test.py +++ b/scripts/test/SANS/algorithm_detail/scale_sans_workspace_test.py @@ -60,7 +60,7 @@ class SANSScaleTest(unittest.TestCase): height = 2.0 scale = 7.2 state = self._get_sample_state(width=width, height=height, thickness=3.0, scale=scale, - shape=SampleShape.Cylinder) + shape=SampleShape.CYLINDER) output_workspace = scale_workspace(workspace=workspace, instrument=SANSInstrument.LOQ, state_scale=state.scale) @@ -75,7 +75,7 @@ class SANSScaleTest(unittest.TestCase): # Arrange facility = SANSFacility.ISIS file_information = SANSFileInformationMock(instrument=SANSInstrument.SANS2D, run_number=22024, height=8.0, - width=8.0, thickness=1.0, shape=SampleShape.Disc) + width=8.0, thickness=1.0, shape=SampleShape.DISC) data_builder = get_data_builder(facility, file_information) data_builder.set_sample_scatter("SANS2D00022024") data_state = data_builder.build() @@ -116,7 +116,7 @@ class SANSScaleTest(unittest.TestCase): width = 10. height = 5. thickness = 2. - scale_builder.set_shape(SampleShape.Disc) + scale_builder.set_shape(SampleShape.DISC) scale_builder.set_thickness(thickness) scale_builder.set_width(width) scale_builder.set_height(height) diff --git a/scripts/test/SANS/algorithm_detail/strip_end_nans_test.py b/scripts/test/SANS/algorithm_detail/strip_end_nans_test.py index 3d5909b0e340b712dbc835f94e85bd91dda3322b..2d091ad472c5bf57d6b22b9772d6e01f5761a83c 100644 --- a/scripts/test/SANS/algorithm_detail/strip_end_nans_test.py +++ b/scripts/test/SANS/algorithm_detail/strip_end_nans_test.py @@ -5,6 +5,7 @@ # & Institut Laue - Langevin # SPDX - License - Identifier: GPL - 3.0 + from __future__ import (absolute_import, division, print_function) + import unittest from mantid.api import AlgorithmManager, FrameworkManager diff --git a/scripts/test/SANS/command_interface/batch_csv_file_parser_test.py b/scripts/test/SANS/command_interface/batch_csv_file_parser_test.py index a5cc309d09201ee653d38e637310027c7ecfc7f1..f4d59b8032f0a46e59060abf543778b218e1ea5d 100644 --- a/scripts/test/SANS/command_interface/batch_csv_file_parser_test.py +++ b/scripts/test/SANS/command_interface/batch_csv_file_parser_test.py @@ -5,12 +5,14 @@ # & Institut Laue - Langevin # SPDX - License - Identifier: GPL - 3.0 + from __future__ import (absolute_import, division, print_function) -import unittest + import os +import unittest + import mantid -from sans.common.enums import BatchReductionEntry -from sans.common.constants import ALL_PERIODS from sans.command_interface.batch_csv_file_parser import BatchCsvParser +from sans.common.constants import ALL_PERIODS +from sans.common.enums import BatchReductionEntry class BatchCsvParserTest(unittest.TestCase): @@ -107,23 +109,23 @@ class BatchCsvParserTest(unittest.TestCase): first_line = output[0] # Should have 5 user specified entries and 3 period entries self.assertEqual(len(first_line), 8) - self.assertEqual(first_line[BatchReductionEntry.SampleScatter], "1") - self.assertEqual(first_line[BatchReductionEntry.SampleScatterPeriod], ALL_PERIODS) - self.assertEqual(first_line[BatchReductionEntry.SampleTransmission], "2") - self.assertEqual(first_line[BatchReductionEntry.SampleTransmissionPeriod], ALL_PERIODS) - self.assertEqual(first_line[BatchReductionEntry.SampleDirect], "3") - self.assertEqual(first_line[BatchReductionEntry.SampleDirectPeriod], ALL_PERIODS) - self.assertEqual(first_line[BatchReductionEntry.Output], "test_file") - self.assertEqual(first_line[BatchReductionEntry.UserFile], "user_test_file") + self.assertEqual(first_line[BatchReductionEntry.SAMPLE_SCATTER], "1") + self.assertEqual(first_line[BatchReductionEntry.SAMPLE_SCATTER_PERIOD], ALL_PERIODS) + self.assertEqual(first_line[BatchReductionEntry.SAMPLE_TRANSMISSION], "2") + self.assertEqual(first_line[BatchReductionEntry.SAMPLE_TRANSMISSION_PERIOD], ALL_PERIODS) + self.assertEqual(first_line[BatchReductionEntry.SAMPLE_DIRECT], "3") + self.assertEqual(first_line[BatchReductionEntry.SAMPLE_DIRECT_PERIOD], ALL_PERIODS) + self.assertEqual(first_line[BatchReductionEntry.OUTPUT], "test_file") + self.assertEqual(first_line[BatchReductionEntry.USER_FILE], "user_test_file") second_line = output[1] # Should have 3 user specified entries and 2 period entries self.assertEqual(len(second_line), 5) - self.assertEqual(second_line[BatchReductionEntry.SampleScatter], "1") - self.assertEqual(second_line[BatchReductionEntry.SampleScatterPeriod], ALL_PERIODS) - self.assertEqual(second_line[BatchReductionEntry.CanScatter], "2") - self.assertEqual(second_line[BatchReductionEntry.CanScatterPeriod], ALL_PERIODS) - self.assertEqual(second_line[BatchReductionEntry.Output], "test_file2") + self.assertEqual(second_line[BatchReductionEntry.SAMPLE_SCATTER], "1") + self.assertEqual(second_line[BatchReductionEntry.SAMPLE_SCATTER_PERIOD], ALL_PERIODS) + self.assertEqual(second_line[BatchReductionEntry.CAN_SCATTER], "2") + self.assertEqual(second_line[BatchReductionEntry.CAN_SCATTER_PERIOD], ALL_PERIODS) + self.assertEqual(second_line[BatchReductionEntry.OUTPUT], "test_file2") BatchCsvParserTest._remove_csv(batch_file_path) @@ -142,11 +144,11 @@ class BatchCsvParserTest(unittest.TestCase): first_line = output[0] # Should have 5 user specified entries and 3 period entries self.assertEqual(len(first_line), 5) - self.assertEqual(first_line[BatchReductionEntry.SampleScatter], "1") - self.assertEqual(first_line[BatchReductionEntry.SampleScatterPeriod], 7) - self.assertEqual(first_line[BatchReductionEntry.CanScatter], "2") - self.assertEqual(first_line[BatchReductionEntry.CanScatterPeriod], 3) - self.assertEqual(first_line[BatchReductionEntry.Output], "test_file2") + self.assertEqual(first_line[BatchReductionEntry.SAMPLE_SCATTER], "1") + self.assertEqual(first_line[BatchReductionEntry.SAMPLE_SCATTER_PERIOD], 7) + self.assertEqual(first_line[BatchReductionEntry.CAN_SCATTER], "2") + self.assertEqual(first_line[BatchReductionEntry.CAN_SCATTER_PERIOD], 3) + self.assertEqual(first_line[BatchReductionEntry.OUTPUT], "test_file2") BatchCsvParserTest._remove_csv(batch_file_path) @@ -166,23 +168,23 @@ class BatchCsvParserTest(unittest.TestCase): first_line = output[0] # Should have 5 user specified entries and 3 period entries self.assertEqual(len(first_line), 8) - self.assertEqual(first_line[BatchReductionEntry.SampleScatter], "1") - self.assertEqual(first_line[BatchReductionEntry.SampleScatterPeriod], ALL_PERIODS) - self.assertEqual(first_line[BatchReductionEntry.SampleTransmission], "2") - self.assertEqual(first_line[BatchReductionEntry.SampleTransmissionPeriod], ALL_PERIODS) - self.assertEqual(first_line[BatchReductionEntry.SampleDirect], "3") - self.assertEqual(first_line[BatchReductionEntry.SampleDirectPeriod], ALL_PERIODS) - self.assertEqual(first_line[BatchReductionEntry.Output], "test_file") - self.assertEqual(first_line[BatchReductionEntry.UserFile], "user_test_file") + self.assertEqual(first_line[BatchReductionEntry.SAMPLE_SCATTER], "1") + self.assertEqual(first_line[BatchReductionEntry.SAMPLE_SCATTER_PERIOD], ALL_PERIODS) + self.assertEqual(first_line[BatchReductionEntry.SAMPLE_TRANSMISSION], "2") + self.assertEqual(first_line[BatchReductionEntry.SAMPLE_TRANSMISSION_PERIOD], ALL_PERIODS) + self.assertEqual(first_line[BatchReductionEntry.SAMPLE_DIRECT], "3") + self.assertEqual(first_line[BatchReductionEntry.SAMPLE_DIRECT_PERIOD], ALL_PERIODS) + self.assertEqual(first_line[BatchReductionEntry.OUTPUT], "test_file") + self.assertEqual(first_line[BatchReductionEntry.USER_FILE], "user_test_file") second_line = output[1] # Should have 3 user specified entries and 2 period entries self.assertEqual(len(second_line), 5) - self.assertEqual(second_line[BatchReductionEntry.SampleScatter], "1") - self.assertEqual(second_line[BatchReductionEntry.SampleScatterPeriod], ALL_PERIODS) - self.assertEqual(second_line[BatchReductionEntry.CanScatter], "2") - self.assertEqual(second_line[BatchReductionEntry.CanScatterPeriod], ALL_PERIODS) - self.assertEqual(second_line[BatchReductionEntry.Output], "test_file2") + self.assertEqual(second_line[BatchReductionEntry.SAMPLE_SCATTER], "1") + self.assertEqual(second_line[BatchReductionEntry.SAMPLE_SCATTER_PERIOD], ALL_PERIODS) + self.assertEqual(second_line[BatchReductionEntry.CAN_SCATTER], "2") + self.assertEqual(second_line[BatchReductionEntry.CAN_SCATTER_PERIOD], ALL_PERIODS) + self.assertEqual(second_line[BatchReductionEntry.OUTPUT], "test_file2") BatchCsvParserTest._remove_csv(batch_file_path) diff --git a/scripts/test/SANS/command_interface/command_interface_state_director_test.py b/scripts/test/SANS/command_interface/command_interface_state_director_test.py index 6cf7e41674a691082f0d82764bb2811ac7fe1844..7704e3b2419d501c0d4e8cdc50a5041129e8a3f3 100644 --- a/scripts/test/SANS/command_interface/command_interface_state_director_test.py +++ b/scripts/test/SANS/command_interface/command_interface_state_director_test.py @@ -5,13 +5,14 @@ # & Institut Laue - Langevin # SPDX - License - Identifier: GPL - 3.0 + from __future__ import (absolute_import, division, print_function) + import unittest -import mantid + from sans.command_interface.command_interface_state_director import (NParameterCommand, NParameterCommandId, CommandInterfaceStateDirector, DataCommand, DataCommandId, FitData) from sans.common.enums import (SANSFacility, RebinType, DetectorType, ReductionDimensionality, - FitType, RangeStepType, ISISReductionMode, FitModeForMerge, DataType) + FitType, RangeStepType, ReductionMode, FitModeForMerge, DataType) class CommandInterfaceStateDirectorTest(unittest.TestCase): @@ -25,94 +26,94 @@ class CommandInterfaceStateDirectorTest(unittest.TestCase): command_interface = CommandInterfaceStateDirector(SANSFacility.ISIS) # User file - command = NParameterCommand(command_id=NParameterCommandId.user_file, + command = NParameterCommand(command_id=NParameterCommandId.USER_FILE, values=["test_user_file_sans2d.txt"]) self._assert_raises_nothing(command_interface.add_command, command) # Mask - command = NParameterCommand(command_id=NParameterCommandId.mask, + command = NParameterCommand(command_id=NParameterCommandId.MASK, values=["MASK/ FRONT H197>H199"]) self._assert_raises_nothing(command_interface.add_command, command) # Monitor spectrum (incident monitor for monitor normalization) - command = NParameterCommand(command_id=NParameterCommandId.incident_spectrum, + command = NParameterCommand(command_id=NParameterCommandId.INCIDENT_SPECTRUM, values=[1, True, False]) self._assert_raises_nothing(command_interface.add_command, command) # Transmission spectrum (incident monitor for transmission calculation) - command = NParameterCommand(command_id=NParameterCommandId.incident_spectrum, values=[7, False, True]) + command = NParameterCommand(command_id=NParameterCommandId.INCIDENT_SPECTRUM, values=[7, False, True]) self._assert_raises_nothing(command_interface.add_command, command) # Reduction Dimensionality One Dim - command = NParameterCommand(command_id=NParameterCommandId.reduction_dimensionality, - values=[ReductionDimensionality.OneDim]) + command = NParameterCommand(command_id=NParameterCommandId.REDUCTION_DIMENSIONALITY, + values=[ReductionDimensionality.ONE_DIM]) self._assert_raises_nothing(command_interface.add_command, command) # Reduction Dimensionality Two Dim - command = NParameterCommand(command_id=NParameterCommandId.reduction_dimensionality, - values=[ReductionDimensionality.TwoDim]) + command = NParameterCommand(command_id=NParameterCommandId.REDUCTION_DIMENSIONALITY, + values=[ReductionDimensionality.TWO_DIM]) self._assert_raises_nothing(command_interface.add_command, command) # Sample offset - command = NParameterCommand(command_id=NParameterCommandId.sample_offset, values=[23.6]) + command = NParameterCommand(command_id=NParameterCommandId.SAMPLE_OFFSET, values=[23.6]) self._assert_raises_nothing(command_interface.add_command, command) # Sample scatter data - command = DataCommand(command_id=DataCommandId.sample_scatter, file_name="SANS2D00022024", period=3) + command = DataCommand(command_id=DataCommandId.SAMPLE_SCATTER, file_name="SANS2D00022024", period=3) self._assert_raises_nothing(command_interface.add_command, command) # Detector - command = NParameterCommand(command_id=NParameterCommandId.detector, values=[ISISReductionMode.HAB]) + command = NParameterCommand(command_id=NParameterCommandId.DETECTOR, values=[ReductionMode.HAB]) self._assert_raises_nothing(command_interface.add_command, command) # Gravity - command = NParameterCommand(command_id=NParameterCommandId.gravity, values=[True, 12.4]) + command = NParameterCommand(command_id=NParameterCommandId.GRAVITY, values=[True, 12.4]) self._assert_raises_nothing(command_interface.add_command, command) # Set centre - command = NParameterCommand(command_id=NParameterCommandId.centre, values=[12.4, 23.54, DetectorType.HAB]) + command = NParameterCommand(command_id=NParameterCommandId.CENTRE, values=[12.4, 23.54, DetectorType.HAB]) self._assert_raises_nothing(command_interface.add_command, command) # # Trans fit - command = NParameterCommand(command_id=NParameterCommandId.trans_fit, values=[FitData.Can, 10.4, 12.54, - FitType.Logarithmic, 0]) + command = NParameterCommand(command_id=NParameterCommandId.TRANS_FIT, values=[FitData.Can, 10.4, 12.54, + FitType.LOGARITHMIC, 0]) self._assert_raises_nothing(command_interface.add_command, command) # Front detector rescale - command = NParameterCommand(command_id=NParameterCommandId.front_detector_rescale, values=[1.2, 2.4, True, + command = NParameterCommand(command_id=NParameterCommandId.FRONT_DETECTOR_RESCALE, values=[1.2, 2.4, True, False, None, 7.2]) self._assert_raises_nothing(command_interface.add_command, command) # Event slices - command = NParameterCommand(command_id=NParameterCommandId.event_slices, values="1-23,55:3:65") + command = NParameterCommand(command_id=NParameterCommandId.EVENT_SLICES, values="1-23,55:3:65") self._assert_raises_nothing(command_interface.add_command, command) # Flood file - command = NParameterCommand(command_id=NParameterCommandId.flood_file, values=["test", DetectorType.LAB]) + command = NParameterCommand(command_id=NParameterCommandId.FLOOD_FILE, values=["test", DetectorType.LAB]) self._assert_raises_nothing(command_interface.add_command, command) # Phi limits - command = NParameterCommand(command_id=NParameterCommandId.phi_limit, values=[12.5, 123.6, False]) + command = NParameterCommand(command_id=NParameterCommandId.PHI_LIMIT, values=[12.5, 123.6, False]) self._assert_raises_nothing(command_interface.add_command, command) # Wavelength correction file - command = NParameterCommand(command_id=NParameterCommandId.wavelength_correction_file, + command = NParameterCommand(command_id=NParameterCommandId.WAVELENGTH_CORRECTION_FILE, values=["test", DetectorType.HAB]) self._assert_raises_nothing(command_interface.add_command, command) # Radius mask - command = NParameterCommand(command_id=NParameterCommandId.mask_radius, + command = NParameterCommand(command_id=NParameterCommandId.MASK_RADIUS, values=[23.5, 234.7]) self._assert_raises_nothing(command_interface.add_command, command) # Wavelength limits - command = NParameterCommand(command_id=NParameterCommandId.wavelength_limit, - values=[1.23, 23., 1.1, RangeStepType.Lin]) + command = NParameterCommand(command_id=NParameterCommandId.WAVELENGTH_LIMIT, + values=[1.23, 23., 1.1, RangeStepType.LIN]) self._assert_raises_nothing(command_interface.add_command, command) # QXY Limits - command = NParameterCommand(command_id=NParameterCommandId.qxy_limit, - values=[1.23, 23., 1.1, RangeStepType.Lin]) + command = NParameterCommand(command_id=NParameterCommandId.QXY_LIMIT, + values=[1.23, 23., 1.1, RangeStepType.LIN]) self._assert_raises_nothing(command_interface.add_command, command) # Process all commands @@ -121,38 +122,38 @@ class CommandInterfaceStateDirectorTest(unittest.TestCase): # Assert # We check here that the elements we set up above (except for from the user file) are being applied self.assertNotEqual(state, None) - self.assertTrue(state.mask.detectors[DetectorType.to_string(DetectorType.HAB)].range_horizontal_strip_start[-1] + self.assertTrue(state.mask.detectors[DetectorType.HAB.value].range_horizontal_strip_start[-1] == 197) - self.assertTrue(state.mask.detectors[DetectorType.to_string(DetectorType.HAB)].range_horizontal_strip_stop[-1] + self.assertTrue(state.mask.detectors[DetectorType.HAB.value].range_horizontal_strip_stop[-1] == 199) self.assertEqual(state.adjustment.normalize_to_monitor.incident_monitor, 1) - self.assertEqual(state.adjustment.normalize_to_monitor.rebin_type, RebinType.InterpolatingRebin) + self.assertEqual(state.adjustment.normalize_to_monitor.rebin_type, RebinType.INTERPOLATING_REBIN) self.assertEqual(state.adjustment.calculate_transmission.incident_monitor, 7) - self.assertEqual(state.adjustment.calculate_transmission.rebin_type, RebinType.Rebin) - self.assertEqual(state.reduction.reduction_dimensionality, ReductionDimensionality.TwoDim) - self.assertEqual(state.convert_to_q.reduction_dimensionality, ReductionDimensionality.TwoDim) + self.assertEqual(state.adjustment.calculate_transmission.rebin_type, RebinType.REBIN) + self.assertEqual(state.reduction.reduction_dimensionality, ReductionDimensionality.TWO_DIM) + self.assertEqual(state.convert_to_q.reduction_dimensionality, ReductionDimensionality.TWO_DIM) self.assertEqual(state.move.sample_offset, 23.6/1000.) self.assertEqual(state.data.sample_scatter, "SANS2D00022024") self.assertEqual(state.data.sample_scatter_period, 3) - self.assertEqual(state.reduction.reduction_mode, ISISReductionMode.HAB) + self.assertEqual(state.reduction.reduction_mode, ReductionMode.HAB) self.assertTrue(state.convert_to_q.use_gravity) self.assertEqual(state.convert_to_q.gravity_extra_length, 12.4) - self.assertEqual(state.move.detectors[DetectorType.to_string(DetectorType.HAB)].sample_centre_pos1, 12.4/1000.) - self.assertTrue(state.move.detectors[DetectorType.to_string(DetectorType.HAB)].sample_centre_pos2 + self.assertEqual(state.move.detectors[DetectorType.HAB.value].sample_centre_pos1, 12.4/1000.) + self.assertTrue(state.move.detectors[DetectorType.HAB.value].sample_centre_pos2 == 23.54/1000.) - self.assertTrue(state.adjustment.calculate_transmission.fit[DataType.to_string(DataType.Can)].fit_type - is FitType.Logarithmic) - self.assertTrue(state.adjustment.calculate_transmission.fit[DataType.to_string(DataType.Can)].polynomial_order + self.assertTrue(state.adjustment.calculate_transmission.fit[DataType.CAN].fit_type + is FitType.LOGARITHMIC) + self.assertTrue(state.adjustment.calculate_transmission.fit[DataType.CAN].polynomial_order == 0) - self.assertTrue(state.adjustment.calculate_transmission.fit[DataType.to_string(DataType.Can)].wavelength_low + self.assertTrue(state.adjustment.calculate_transmission.fit[DataType.CAN].wavelength_low == 10.4) - self.assertTrue(state.adjustment.calculate_transmission.fit[DataType.to_string(DataType.Can)].wavelength_high + self.assertTrue(state.adjustment.calculate_transmission.fit[DataType.CAN].wavelength_high == 12.54) self.assertEqual(state.reduction.merge_scale, 1.2) self.assertEqual(state.reduction.merge_shift, 2.4) - self.assertEqual(state.reduction.merge_fit_mode, FitModeForMerge.ScaleOnly) + self.assertEqual(state.reduction.merge_fit_mode, FitModeForMerge.SCALE_ONLY) self.assertEqual(state.reduction.merge_range_min, None) self.assertEqual(state.reduction.merge_range_max, 7.2) @@ -166,30 +167,30 @@ class CommandInterfaceStateDirectorTest(unittest.TestCase): self.assertEqual(e1, e2) self.assertTrue(state.adjustment.wavelength_and_pixel_adjustment.adjustment_files[ - DetectorType.to_string(DetectorType.LAB)].pixel_adjustment_file == "test") + DetectorType.LAB.value].pixel_adjustment_file == "test") self.assertEqual(state.mask.phi_min, 12.5) self.assertEqual(state.mask.phi_max, 123.6) self.assertFalse(state.mask.use_mask_phi_mirror) self.assertTrue(state.adjustment.wavelength_and_pixel_adjustment.adjustment_files[ - DetectorType.to_string(DetectorType.HAB)].wavelength_adjustment_file == "test") + DetectorType.HAB.value].wavelength_adjustment_file == "test") self.assertEqual(state.mask.radius_min, 23.5 / 1000.) self.assertEqual(state.mask.radius_max, 234.7 / 1000.) self.assertEqual(state.wavelength.wavelength_low, [1.23]) self.assertEqual(state.adjustment.normalize_to_monitor.wavelength_high, [23.]) self.assertEqual(state.adjustment.wavelength_and_pixel_adjustment.wavelength_step, 1.1) - self.assertEqual(state.adjustment.calculate_transmission.wavelength_step_type, RangeStepType.Lin) + self.assertEqual(state.adjustment.calculate_transmission.wavelength_step_type, RangeStepType.LIN) self.assertEqual(state.convert_to_q.q_xy_max, 23.) self.assertEqual(state.convert_to_q.q_xy_step, 1.1) - self.assertEqual(state.convert_to_q.q_xy_step_type, RangeStepType.Lin) + self.assertEqual(state.convert_to_q.q_xy_step_type, RangeStepType.LIN) def test_that_can_remove_last_command(self): # Arrange command_interface = CommandInterfaceStateDirector(SANSFacility.ISIS) - command_interface.add_command(NParameterCommand(command_id=NParameterCommandId.user_file, + command_interface.add_command(NParameterCommand(command_id=NParameterCommandId.USER_FILE, values=["file_1.txt"])) - command_interface.add_command(NParameterCommand(command_id=NParameterCommandId.user_file, + command_interface.add_command(NParameterCommand(command_id=NParameterCommandId.USER_FILE, values=["file_2.txt"])) - command_interface.add_command(NParameterCommand(command_id=NParameterCommandId.user_file, + command_interface.add_command(NParameterCommand(command_id=NParameterCommandId.USER_FILE, values=["file_3.txt"])) # Act commands = command_interface.get_commands() diff --git a/scripts/test/SANS/common/CMakeLists.txt b/scripts/test/SANS/common/CMakeLists.txt index 89eb28deed9a0c0352dec43a4d600b73bfe10eb9..c9cb40ba1b0df7e83878048819b095ca8d33e7cf 100644 --- a/scripts/test/SANS/common/CMakeLists.txt +++ b/scripts/test/SANS/common/CMakeLists.txt @@ -1,7 +1,6 @@ # Tests for SANS set(TEST_PY_FILES - enums_test.py file_information_test.py log_tagger_test.py general_functions_test.py diff --git a/scripts/test/SANS/common/enums_test.py b/scripts/test/SANS/common/enums_test.py deleted file mode 100644 index 4ec8518b4e3b5fa54bc4b431450dc496e07fdaab..0000000000000000000000000000000000000000 --- a/scripts/test/SANS/common/enums_test.py +++ /dev/null @@ -1,57 +0,0 @@ -# Mantid Repository : https://github.com/mantidproject/mantid -# -# Copyright © 2018 ISIS Rutherford Appleton Laboratory UKRI, -# NScD Oak Ridge National Laboratory, European Spallation Source -# & Institut Laue - Langevin -# SPDX - License - Identifier: GPL - 3.0 + -from __future__ import (absolute_import, division, print_function) -import unittest - -from sans.common.enums import serializable_enum, string_convertible - - -# ----Create a test class -@string_convertible -@serializable_enum("TypeA", "TypeB", "TypeC") -class DummyClass(object): - pass - - -@string_convertible -@serializable_enum("TypeA", "TypeB", "TypeC") -class IncorrectClass(object): - pass - - -class SANSFileInformationTest(unittest.TestCase): - def test_that_can_create_enum_value_and_is_sub_class_of_base_type(self): - type_a = DummyClass.TypeA - self.assertTrue(issubclass(type_a, DummyClass)) - - def test_that_can_convert_to_string(self): - type_b = DummyClass.TypeB - self.assertEqual(DummyClass.to_string(type_b), "TypeB") - - def test_that_raises_run_time_error_if_enum_value_is_not_known(self): - self.assertRaises(RuntimeError, DummyClass.to_string, DummyClass) - - def test_that_can_convert_from_string(self): - self.assertEqual(DummyClass.from_string("TypeC"), DummyClass.TypeC) - - def test_that_raises_run_time_error_if_string_is_not_known(self): - self.assertRaises(RuntimeError, DummyClass.from_string, "TypeD") - - def test_that_has_member_handles_strings(self): - self.assertTrue(DummyClass.has_member("TypeA")) - self.assertFalse(DummyClass.has_member("TypeD")) - - def test_that_has_member_handles_enums(self): - a_variable = DummyClass.TypeA - incorrect_variable = IncorrectClass.TypeA - - self.assertTrue(DummyClass.has_member(a_variable)) - self.assertFalse(DummyClass.has_member(incorrect_variable)) - - -if __name__ == '__main__': - unittest.main() diff --git a/scripts/test/SANS/common/file_information_test.py b/scripts/test/SANS/common/file_information_test.py index 389fd54115a8ad88dd63551e81214aa8455f0b8a..5d876de5c0f5bf8dfdac060ba82e0b141cd58bac 100644 --- a/scripts/test/SANS/common/file_information_test.py +++ b/scripts/test/SANS/common/file_information_test.py @@ -30,14 +30,14 @@ class SANSFileInformationTest(unittest.TestCase): self.assertEqual(file_information.get_number_of_periods(), 1) self.assertEqual(file_information.get_date(), DateAndTime("2013-10-25T14:21:19")) self.assertEqual(file_information.get_instrument(), SANSInstrument.SANS2D) - self.assertEqual(file_information.get_type(), FileType.ISISNexus) + self.assertEqual(file_information.get_type(), FileType.ISIS_NEXUS) self.assertEqual(file_information.get_run_number(), 22024) self.assertFalse(file_information.is_event_mode()) self.assertFalse(file_information.is_added_data()) self.assertEqual(file_information.get_width(), 8.0) self.assertEqual(file_information.get_height(), 8.0) self.assertEqual(file_information.get_thickness(), 1.0) - self.assertEqual(file_information.get_shape(), SampleShape.Disc) + self.assertEqual(file_information.get_shape(), SampleShape.DISC) def test_that_can_extract_information_from_file_for_LOQ_single_period_and_raw_format(self): # Arrange @@ -52,13 +52,13 @@ class SANSFileInformationTest(unittest.TestCase): self.assertEqual(file_information.get_number_of_periods(), 1) self.assertEqual(file_information.get_date(), DateAndTime("2008-12-18T11:20:58")) self.assertEqual(file_information.get_instrument(), SANSInstrument.LOQ) - self.assertEqual(file_information.get_type(), FileType.ISISRaw) + self.assertEqual(file_information.get_type(), FileType.ISIS_RAW) self.assertEqual(file_information.get_run_number(), 48094) self.assertFalse(file_information.is_added_data()) self.assertEqual(file_information.get_width(), 8.0) self.assertEqual(file_information.get_height(), 8.0) self.assertEqual(file_information.get_thickness(), 1.0) - self.assertEqual(file_information.get_shape(), SampleShape.Disc) + self.assertEqual(file_information.get_shape(), SampleShape.DISC) def test_that_can_extract_information_from_file_for_SANS2D_multi_period_event_and_nexus_format(self): # Arrange @@ -73,14 +73,14 @@ class SANSFileInformationTest(unittest.TestCase): self.assertEqual(file_information.get_number_of_periods(), 4) self.assertEqual(file_information.get_date(), DateAndTime("2015-06-05T14:43:49")) self.assertEqual(file_information.get_instrument(), SANSInstrument.LARMOR) - self.assertEqual(file_information.get_type(), FileType.ISISNexus) + self.assertEqual(file_information.get_type(), FileType.ISIS_NEXUS) self.assertEqual(file_information.get_run_number(), 3368) self.assertTrue(file_information.is_event_mode()) self.assertFalse(file_information.is_added_data()) self.assertEqual(file_information.get_width(), 8.0) self.assertEqual(file_information.get_height(), 8.0) self.assertEqual(file_information.get_thickness(), 2.0) - self.assertEqual(file_information.get_shape(), SampleShape.FlatPlate) + self.assertEqual(file_information.get_shape(), SampleShape.FLAT_PLATE) def test_that_can_extract_information_for_added_histogram_data_and_nexus_format(self): # Arrange @@ -95,14 +95,14 @@ class SANSFileInformationTest(unittest.TestCase): self.assertEqual(file_information.get_number_of_periods(), 1) self.assertEqual(file_information.get_date(), DateAndTime("2013-10-25T14:21:19")) self.assertEqual(file_information.get_instrument(), SANSInstrument.SANS2D) - self.assertEqual(file_information.get_type(), FileType.ISISNexusAdded) + self.assertEqual(file_information.get_type(), FileType.ISIS_NEXUS_ADDED) self.assertEqual(file_information.get_run_number(), 22024) self.assertFalse(file_information.is_event_mode()) self.assertTrue(file_information.is_added_data()) self.assertEqual(file_information.get_width(), 8.0) self.assertEqual(file_information.get_height(), 8.0) self.assertEqual(file_information.get_thickness(), 1.0) - self.assertEqual(file_information.get_shape(), SampleShape.Disc) + self.assertEqual(file_information.get_shape(), SampleShape.DISC) def test_that_can_extract_information_for_LARMOR_added_event_data_and_multi_period_and_nexus_format(self): # Arrange @@ -117,14 +117,14 @@ class SANSFileInformationTest(unittest.TestCase): self.assertEqual(file_information.get_number_of_periods(), 4) self.assertEqual(file_information.get_date(), DateAndTime("2016-10-12T04:33:47")) self.assertEqual(file_information.get_instrument(), SANSInstrument.LARMOR) - self.assertEqual(file_information.get_type(), FileType.ISISNexusAdded) + self.assertEqual(file_information.get_type(), FileType.ISIS_NEXUS_ADDED) self.assertEqual(file_information.get_run_number(), 13065) self.assertTrue(file_information.is_event_mode()) self.assertTrue(file_information.is_added_data()) self.assertEqual(file_information.get_width(), 6.0) self.assertEqual(file_information.get_height(), 8.0) self.assertEqual(file_information.get_thickness(), 1.0) - self.assertEqual(file_information.get_shape(), SampleShape.FlatPlate) + self.assertEqual(file_information.get_shape(), SampleShape.FLAT_PLATE) def test_that_can_find_data_with_numbers_but_no_instrument(self): # Arrange diff --git a/scripts/test/SANS/common/general_functions_test.py b/scripts/test/SANS/common/general_functions_test.py index e530ef1a9c48b102cec7ff6ab14983d671c4a0ec..2ad5183eff77930dd29cb22587c878d8539f4f89 100644 --- a/scripts/test/SANS/common/general_functions_test.py +++ b/scripts/test/SANS/common/general_functions_test.py @@ -12,7 +12,7 @@ from mantid.api import AnalysisDataService, FrameworkManager from mantid.kernel import (V3D, Quat) from mantid.py3compat import mock from sans.common.constants import (SANS2D, LOQ, LARMOR) -from sans.common.enums import (ISISReductionMode, ReductionDimensionality, OutputParts, +from sans.common.enums import (ReductionMode, ReductionDimensionality, OutputParts, SANSInstrument, DetectorType, SANSFacility, DataType) from sans.common.general_functions import (quaternion_to_angle_and_axis, create_managed_non_child_algorithm, create_unmanaged_algorithm, add_to_sample_log, @@ -85,7 +85,7 @@ class SANSFunctionsTest(unittest.TestCase): state.data.sample_scatter_run_number = 12345 state.data.sample_scatter_period = StateData.ALL_PERIODS - state.reduction.dimensionality = ReductionDimensionality.OneDim + state.reduction.dimensionality = ReductionDimensionality.ONE_DIM state.wavelength.wavelength_low = [12.0] state.wavelength.wavelength_high = [34.0] @@ -169,7 +169,7 @@ class SANSFunctionsTest(unittest.TestCase): # Act + Assert try: - get_standard_output_workspace_name(state, ISISReductionMode.All) + get_standard_output_workspace_name(state, ReductionMode.ALL) did_raise = False except RuntimeError: did_raise = True @@ -179,7 +179,7 @@ class SANSFunctionsTest(unittest.TestCase): # Arrange state = SANSFunctionsTest._get_state() # Act - output_workspace, _ = get_standard_output_workspace_name(state, ISISReductionMode.LAB) + output_workspace, _ = get_standard_output_workspace_name(state, ReductionMode.LAB) # Assert self.assertEqual("12345_rear_1D_12.0_34.0Phi12.0_56.0_t4.57_T12.37", output_workspace) @@ -187,7 +187,7 @@ class SANSFunctionsTest(unittest.TestCase): # Arrange state = SANSFunctionsTest._get_state() # Act - output_workspace, _ = get_standard_output_workspace_name(state, ISISReductionMode.LAB, + output_workspace, _ = get_standard_output_workspace_name(state, ReductionMode.LAB, include_slice_limits=False) # Assert self.assertTrue("12345_rear_1D_12.0_34.0Phi12.0_56.0" == output_workspace) @@ -215,7 +215,7 @@ class SANSFunctionsTest(unittest.TestCase): def test_that_get_transmission_output_name_returns_correct_name_for_wavelength_slices_with_user_specified(self): state = self._get_state() state.save.user_specified_output_name = 'user_output_name' - state.reduction.reduction_mode = ISISReductionMode.Merged + state.reduction.reduction_mode = ReductionMode.MERGED multi_reduction_type = {"period": False, "event_slice": False, "wavelength_range": True} @@ -228,7 +228,7 @@ class SANSFunctionsTest(unittest.TestCase): def test_that_get_transmission_output_name_returns_correct_name_for_wavelength_slices(self): state = self._get_state() state.save.user_specified_output_name = '' - state.reduction.reduction_mode = ISISReductionMode.Merged + state.reduction.reduction_mode = ReductionMode.MERGED multi_reduction_type = {"period": False, "event_slice": False, "wavelength_range": True} @@ -243,7 +243,7 @@ class SANSFunctionsTest(unittest.TestCase): state = SANSFunctionsTest._get_state() state.save.user_specified_output_name = "test_output" # Act - output_workspace, group_output_name = get_transmission_output_name(state, data_type=DataType.Can) + output_workspace, group_output_name = get_transmission_output_name(state, data_type=DataType.CAN) # Assert self.assertEqual(output_workspace, "test_output_trans_Can") self.assertEqual(group_output_name, 'test_output_trans') @@ -253,7 +253,7 @@ class SANSFunctionsTest(unittest.TestCase): state = SANSFunctionsTest._get_state() state.save.user_specified_output_name = '' # Act - output_workspace, group_output_name = get_transmission_output_name(state, data_type=DataType.Can) + output_workspace, group_output_name = get_transmission_output_name(state, data_type=DataType.CAN) # Assert self.assertEqual(output_workspace, "12345_trans_Can_1.0_10.0") self.assertEqual(group_output_name, '12345_trans_1.0_10.0') @@ -261,14 +261,14 @@ class SANSFunctionsTest(unittest.TestCase): def test_that_get_transmission_output_name_returns_correct_name_for_wavelength_slices_for_CAN(self): state = self._get_state() state.save.user_specified_output_name = '' - state.reduction.reduction_mode = ISISReductionMode.Merged + state.reduction.reduction_mode = ReductionMode.MERGED multi_reduction_type = {"period": False, "event_slice": False, "wavelength_range": True} output_name, group_output_name = get_transmission_output_name(state, multi_reduction_type=multi_reduction_type, - data_type=DataType.Can) + data_type=DataType.CAN) self.assertEqual(output_name, '12345_trans_Can_1.0_10.0_12.0_34.0') self.assertEqual(group_output_name, '12345_trans_1.0_10.0') @@ -276,14 +276,14 @@ class SANSFunctionsTest(unittest.TestCase): def test_that_get_transmission_output_name_returns_correct_name_for_wavelength_slices_for_CAN_unfitted(self): state = self._get_state() state.save.user_specified_output_name = '' - state.reduction.reduction_mode = ISISReductionMode.Merged + state.reduction.reduction_mode = ReductionMode.MERGED multi_reduction_type = {"period": False, "event_slice": False, "wavelength_range": True} output_name, group_output_name = get_transmission_output_name(state, multi_reduction_type=multi_reduction_type, - data_type=DataType.Can, fitted=False) + data_type=DataType.CAN, fitted=False) self.assertEqual(output_name, '12345_trans_Can_unfitted_1.0_10.0') self.assertEqual(group_output_name, '12345_trans_1.0_10.0') @@ -306,15 +306,15 @@ class SANSFunctionsTest(unittest.TestCase): test_director = TestDirector() state = test_director.construct() tagged_workspace_names = {None: "test_ws", - OutputParts.Count: "test_ws_count", - OutputParts.Norm: "test_ws_norm"} + OutputParts.COUNT: "test_ws_count", + OutputParts.NORM: "test_ws_norm"} SANSFunctionsTest._prepare_workspaces(number_of_workspaces=4, tagged_workspace_names=tagged_workspace_names, state=state, - reduction_mode=ISISReductionMode.LAB) + reduction_mode=ReductionMode.LAB) # Act workspace, workspace_count, workspace_norm = get_reduced_can_workspace_from_ads(state, output_parts=True, - reduction_mode=ISISReductionMode.LAB) # noqa + reduction_mode=ReductionMode.LAB) # noqa # Assert self.assertNotEqual(workspace, None) @@ -337,11 +337,11 @@ class SANSFunctionsTest(unittest.TestCase): test_director = TestDirector() state = test_director.construct() SANSFunctionsTest._prepare_workspaces(number_of_workspaces=4, tagged_workspace_names=None, - state=state, reduction_mode=ISISReductionMode.LAB) + state=state, reduction_mode=ReductionMode.LAB) # Act workspace, workspace_count, workspace_norm = \ - get_reduced_can_workspace_from_ads(state, output_parts=False, reduction_mode=ISISReductionMode.LAB) + get_reduced_can_workspace_from_ads(state, output_parts=False, reduction_mode=ReductionMode.LAB) # Assert self.assertEqual(workspace, None) @@ -381,7 +381,7 @@ class SANSFunctionsTest(unittest.TestCase): self.assertEqual(get_facility(SANSInstrument.LOQ), SANSFacility.ISIS) self.assertEqual(get_facility(SANSInstrument.LARMOR), SANSFacility.ISIS) self.assertEqual(get_facility(SANSInstrument.ZOOM), SANSFacility.ISIS) - self.assertEqual(get_facility(SANSInstrument.NoInstrument), SANSFacility.NoFacility) + self.assertEqual(get_facility(SANSInstrument.NO_INSTRUMENT), SANSFacility.NO_FACILITY) def test_that_diagnostic_parser_produces_correct_list(self): string_to_parse = '8-11, 12:15, 5, 7:9' @@ -395,7 +395,7 @@ class SANSFunctionsTest(unittest.TestCase): state = self._get_state() state.save.user_specified_output_name = '' - output_name, group_output_name = get_output_name(state, ISISReductionMode.LAB, False) + output_name, group_output_name = get_output_name(state, ReductionMode.LAB, False) expected = "12345_rear_1D_12.0_34.0Phi12.0_56.0_t4.57_T12.37" @@ -406,9 +406,9 @@ class SANSFunctionsTest(unittest.TestCase): state = self._get_state() custom_user_name = 'user_output_name' state.save.user_specified_output_name = custom_user_name - state.reduction.reduction_mode = ISISReductionMode.LAB + state.reduction.reduction_mode = ReductionMode.LAB - output_name, group_output_name = get_output_name(state, ISISReductionMode.LAB, False) + output_name, group_output_name = get_output_name(state, ReductionMode.LAB, False) reduction_settings = "_rear_1D_12.0_34.0Phi12.0_56.0_t4.57_T12.37" @@ -421,9 +421,9 @@ class SANSFunctionsTest(unittest.TestCase): state = self._get_state() custom_user_name = 'user_output_name' state.save.user_specified_output_name = custom_user_name - state.reduction.reduction_mode = ISISReductionMode.HAB + state.reduction.reduction_mode = ReductionMode.HAB - output_name, group_output_name = get_output_name(state, ISISReductionMode.HAB, False) + output_name, group_output_name = get_output_name(state, ReductionMode.HAB, False) reduction_settings = "_front_1D_12.0_34.0Phi12.0_56.0_t4.57_T12.37" expected = custom_user_name + reduction_settings @@ -434,14 +434,14 @@ class SANSFunctionsTest(unittest.TestCase): def test_get_output_name_replaces_name_with_user_specified_name_with_appended_detector_for_All_reduction(self): state = self._get_state() custom_user_name = 'user_output_name' - state.reduction.reduction_mode = ISISReductionMode.All + state.reduction.reduction_mode = ReductionMode.ALL state.save.user_specified_output_name = custom_user_name reduction_settings = "rear_1D_12.0_34.0Phi12.0_56.0_t4.57_T12.37" expected = custom_user_name + '_' + reduction_settings - output_name, group_output_name = get_output_name(state, ISISReductionMode.LAB, False) + output_name, group_output_name = get_output_name(state, ReductionMode.LAB, False) self.assertEqual(expected, output_name) self.assertEqual(expected, group_output_name) @@ -450,12 +450,14 @@ class SANSFunctionsTest(unittest.TestCase): state = self._get_state() custom_name = 'user_output_name' state.save.user_specified_output_name = custom_name - state.reduction.reduction_mode = ISISReductionMode.Merged + state.reduction.reduction_mode = ReductionMode.MERGED reduction_settings = "_merged_1D_12.0_34.0Phi12.0_56.0_t4.57_T12.37" - output_name, group_output_name = get_output_name(state, ISISReductionMode.Merged, False) expected = custom_name + reduction_settings + + output_name, group_output_name = get_output_name(state, ReductionMode.MERGED, False) + self.assertEqual(expected, output_name) self.assertEqual(expected, group_output_name) @@ -463,12 +465,11 @@ class SANSFunctionsTest(unittest.TestCase): state = self._get_state() custom_name = 'user_output_name' state.save.user_specified_output_name = custom_name - state.reduction.reduction_mode = ISISReductionMode.Merged multi_reduction_type = {"period": False, "event_slice": True, "wavelength_range": False} - output_name, group_output_name = get_output_name(state, ISISReductionMode.Merged, True, + output_name, group_output_name = get_output_name(state, ReductionMode.MERGED, True, multi_reduction_type=multi_reduction_type) single_ws_name = custom_name + '_merged_1D_12.0_34.0Phi12.0_56.0_t4.57_T12.37_t4.57_T12.37' @@ -481,12 +482,11 @@ class SANSFunctionsTest(unittest.TestCase): state = self._get_state() custom_name = 'user_output_name' state.save.user_specified_output_name = custom_name - state.reduction.reduction_mode = ISISReductionMode.Merged multi_reduction_type = {"period": False, "event_slice": False, "wavelength_range": True} - output_name, group_output_name = get_output_name(state, ISISReductionMode.Merged, True, + output_name, group_output_name = get_output_name(state, ReductionMode.MERGED, True, multi_reduction_type=multi_reduction_type) single_ws_name = custom_name + '_merged_1D_12.0_34.0Phi12.0_56.0_t4.57_T12.37_12.0_34.0' @@ -499,12 +499,11 @@ class SANSFunctionsTest(unittest.TestCase): state = self._get_state() custom_name = 'user_output_name' state.save.user_specified_output_name = custom_name - state.reduction.reduction_mode = ISISReductionMode.Merged multi_reduction_type = {"period": True, "event_slice": False, "wavelength_range": False} - output_name, group_output_name = get_output_name(state, ISISReductionMode.Merged, True, + output_name, group_output_name = get_output_name(state, ReductionMode.MERGED, True, multi_reduction_type=multi_reduction_type) single_ws_name = custom_name + '_merged_1D_12.0_34.0Phi12.0_56.0_t4.57_T12.37_p0' @@ -517,12 +516,11 @@ class SANSFunctionsTest(unittest.TestCase): state = self._get_state() custom_name = 'user_output_name' state.save.user_specified_output_name = custom_name - state.reduction.reduction_mode = ISISReductionMode.Merged multi_reduction_type = {"period": True, "event_slice": True, "wavelength_range": True} - output_name, group_output_name = get_output_name(state, ISISReductionMode.Merged, True, + output_name, group_output_name = get_output_name(state, ReductionMode.MERGED, True, multi_reduction_type=multi_reduction_type) single_ws_name = custom_name + "_merged_1D_12.0_34.0Phi12.0_56.0_t4.57_T12.37_p0_t4.57_T12.37_12.0_34.0" @@ -535,12 +533,11 @@ class SANSFunctionsTest(unittest.TestCase): state = self._get_state() custom_name = 'user_output_name' state.save.user_specified_output_name = custom_name - state.reduction.reduction_mode = ISISReductionMode.LAB multi_reduction_type = {"period": True, "event_slice": True, "wavelength_range": True} - output_name, group_output_name = get_output_name(state, ISISReductionMode.LAB, True, + output_name, group_output_name = get_output_name(state, ReductionMode.LAB, True, multi_reduction_type=multi_reduction_type) single_ws_name = custom_name + "_rear_1D_12.0_34.0Phi12.0_56.0_t4.57_T12.37_p0_t4.57_T12.37_12.0_34.0" @@ -552,12 +549,12 @@ class SANSFunctionsTest(unittest.TestCase): def test_returned_name_for_time_sliced_merged_reduction_without_user_specified_name_correct(self): state = self._get_state() state.save.user_specified_output_name = '' - state.reduction.reduction_mode = ISISReductionMode.Merged + state.reduction.reduction_mode = ReductionMode.MERGED multi_reduction_type = {"period": False, "event_slice": True, "wavelength_range": False} - output_name, group_output_name = get_output_name(state, ISISReductionMode.Merged, True, + output_name, group_output_name = get_output_name(state, ReductionMode.MERGED, True, multi_reduction_type=multi_reduction_type) self.assertEqual(output_name, '12345_merged_1D_12.0_34.0Phi12.0_56.0_t4.57_T12.37') @@ -566,12 +563,12 @@ class SANSFunctionsTest(unittest.TestCase): def test_returned_name_for_all_multi_reduction_without_user_specified_name_correct(self): state = self._get_state() state.save.user_specified_output_name = '' - state.reduction.reduction_mode = ISISReductionMode.Merged + state.reduction.reduction_mode = ReductionMode.MERGED multi_reduction_type = {"period": True, "event_slice": True, "wavelength_range": True} - output_name, group_output_name = get_output_name(state, ISISReductionMode.Merged, True, + output_name, group_output_name = get_output_name(state, ReductionMode.MERGED, True, multi_reduction_type=multi_reduction_type) self.assertEqual(output_name, '12345_merged_1D_12.0_34.0Phi12.0_56.0_t4.57_T12.37') @@ -580,12 +577,12 @@ class SANSFunctionsTest(unittest.TestCase): def test_returned_name_for_all_multi_reduction_without_user_specified_name_correct_LAB_reduction(self): state = self._get_state() state.save.user_specified_output_name = '' - state.reduction.reduction_mode = ISISReductionMode.LAB + state.reduction.reduction_mode = ReductionMode.LAB multi_reduction_type = {"period": True, "event_slice": True, "wavelength_range": True} - output_name, group_output_name = get_output_name(state, ISISReductionMode.LAB, True, + output_name, group_output_name = get_output_name(state, ReductionMode.LAB, True, multi_reduction_type=multi_reduction_type) self.assertEqual(output_name, '12345_rear_1D_12.0_34.0Phi12.0_56.0_t4.57_T12.37') diff --git a/scripts/test/SANS/common/log_tagger_test.py b/scripts/test/SANS/common/log_tagger_test.py index 9f27f5801c6162938e74280c6788459481e8e52f..1889345740765fbbcae9f0c2e228cbbbfc5c10ff 100644 --- a/scripts/test/SANS/common/log_tagger_test.py +++ b/scripts/test/SANS/common/log_tagger_test.py @@ -5,6 +5,7 @@ # & Institut Laue - Langevin # SPDX - License - Identifier: GPL - 3.0 + from __future__ import (absolute_import, division, print_function) + import unittest from mantid.api import AlgorithmManager, FrameworkManager diff --git a/scripts/test/SANS/common/xml_parsing_test.py b/scripts/test/SANS/common/xml_parsing_test.py index b14956158c267d1f7b260e0363fc4104c1881cc3..726930933a4bd53dba13518756805940fbedc430 100644 --- a/scripts/test/SANS/common/xml_parsing_test.py +++ b/scripts/test/SANS/common/xml_parsing_test.py @@ -5,8 +5,8 @@ # & Institut Laue - Langevin # SPDX - License - Identifier: GPL - 3.0 + from __future__ import (absolute_import, division, print_function) + import unittest -import mantid from mantid.kernel import DateAndTime from sans.common.file_information import (SANSFileInformationFactory, get_instrument_paths_for_sans_file) diff --git a/scripts/test/SANS/gui_logic/add_runs_presenter_test.py b/scripts/test/SANS/gui_logic/add_runs_presenter_test.py index 144d1f73ae2c59f0eb42ca658fb41d934b8007b7..acabe4954c1f732e08c59e4662ffb08c471787f5 100644 --- a/scripts/test/SANS/gui_logic/add_runs_presenter_test.py +++ b/scripts/test/SANS/gui_logic/add_runs_presenter_test.py @@ -7,19 +7,19 @@ import os import unittest +from assert_called import assert_called +from fake_signal import FakeSignal from mantid.kernel import ConfigService from mantid.py3compat import mock -from sans.gui_logic.models.run_summation import RunSummation from sans.gui_logic.models.run_file import SummableRunFile from sans.gui_logic.models.run_selection import RunSelection +from sans.gui_logic.models.run_summation import RunSummation from sans.gui_logic.models.summation_settings import SummationSettings from sans.gui_logic.presenter.add_runs_presenter import AddRunsPagePresenter, AddRunsFilenameManager from sans.gui_logic.presenter.run_selector_presenter import RunSelectorPresenter from sans.gui_logic.presenter.summation_settings_presenter import SummationSettingsPresenter from ui.sans_isis.add_runs_page import AddRunsPage from ui.sans_isis.sans_data_processor_gui import SANSDataProcessorGui -from fake_signal import FakeSignal -from assert_called import assert_called class MockedOutAddRunsFilenameManager(AddRunsFilenameManager): diff --git a/scripts/test/SANS/gui_logic/batch_process_runner_test.py b/scripts/test/SANS/gui_logic/batch_process_runner_test.py index d8cb1f8d2b79b7eac45e9c0bca7faa3d6565a824..7c6e19f041b92909eab5af1982fa60eb66f27651 100644 --- a/scripts/test/SANS/gui_logic/batch_process_runner_test.py +++ b/scripts/test/SANS/gui_logic/batch_process_runner_test.py @@ -6,8 +6,8 @@ # SPDX - License - Identifier: GPL - 3.0 + from __future__ import (absolute_import, division, print_function) -from qtpy.QtCore import QThreadPool import unittest +from qtpy.QtCore import QThreadPool from mantid.py3compat import mock from sans.common.enums import (OutputMode) @@ -50,7 +50,7 @@ class BatchProcessRunnerTest(unittest.TestCase): self.batch_process_runner.process_states(self.states, get_thickness_for_rows_func=mock.MagicMock, get_states_func=get_states_mock, - use_optimizations=False, output_mode=OutputMode.Both, + use_optimizations=False, output_mode=OutputMode.BOTH, plot_results=False, output_graph='') QThreadPool.globalInstance().waitForDone() @@ -69,7 +69,7 @@ class BatchProcessRunnerTest(unittest.TestCase): self.batch_process_runner.process_states(self.states, get_thickness_for_rows_func=mock.MagicMock, get_states_func=get_states_mock, - use_optimizations=False, output_mode=OutputMode.Both, + use_optimizations=False, output_mode=OutputMode.BOTH, plot_results=False, output_graph='') QThreadPool.globalInstance().waitForDone() @@ -94,7 +94,7 @@ class BatchProcessRunnerTest(unittest.TestCase): self.batch_process_runner.process_states(self.states, get_thickness_for_rows_func=mock.MagicMock, get_states_func=get_states_mock, - use_optimizations=False, output_mode=OutputMode.Both, + use_optimizations=False, output_mode=OutputMode.BOTH, plot_results=False, output_graph='') QThreadPool.globalInstance().waitForDone() diff --git a/scripts/test/SANS/gui_logic/beam_centre_model_test.py b/scripts/test/SANS/gui_logic/beam_centre_model_test.py index 0a5c6f2247914a1d2271e29b2f341e666e942b10..c1956fb50f6473e8fff3fe35b8a3f4e8630a5a3b 100644 --- a/scripts/test/SANS/gui_logic/beam_centre_model_test.py +++ b/scripts/test/SANS/gui_logic/beam_centre_model_test.py @@ -9,10 +9,8 @@ from __future__ import (absolute_import, division, print_function) import unittest from mantid.py3compat import mock +from sans.common.enums import FindDirectionEnum, SANSInstrument, DetectorType from sans.gui_logic.models.beam_centre_model import BeamCentreModel -from sans.common.enums import FindDirectionEnum, SANSInstrument, DetectorType, SANSFacility -from sans.test_helper.test_director import TestDirector -from sans.state.data import get_data_builder from sans.test_helper.file_information_mock import SANSFileInformationMock @@ -75,7 +73,7 @@ class BeamCentreModelTest(unittest.TestCase): x_start=self.beam_centre_model.lab_pos_1, y_start=self.beam_centre_model.lab_pos_2, tolerance=self.beam_centre_model.tolerance, - find_direction=FindDirectionEnum.All, + find_direction=FindDirectionEnum.ALL, reduction_method=True, verbose=False, component=DetectorType.LAB) @@ -93,7 +91,7 @@ class BeamCentreModelTest(unittest.TestCase): x_start=self.result['pos1'], y_start=self.result['pos2'], tolerance=self.beam_centre_model.tolerance, - find_direction=FindDirectionEnum.All, + find_direction=FindDirectionEnum.ALL, reduction_method=True, verbose=False, component=DetectorType.LAB) @@ -103,7 +101,7 @@ class BeamCentreModelTest(unittest.TestCase): x_start=self.beam_centre_model.lab_pos_1, y_start=self.beam_centre_model.lab_pos_2, tolerance=self.beam_centre_model.tolerance, - find_direction=FindDirectionEnum.All, + find_direction=FindDirectionEnum.ALL, reduction_method=False, component=DetectorType.LAB) diff --git a/scripts/test/SANS/gui_logic/beam_centre_presenter_test.py b/scripts/test/SANS/gui_logic/beam_centre_presenter_test.py index 9f2c0a2d7e9fcb54c76ec5ac89f6684144da382b..79c21a37674b8495fa72d5fb7eccf4aca0489261 100644 --- a/scripts/test/SANS/gui_logic/beam_centre_presenter_test.py +++ b/scripts/test/SANS/gui_logic/beam_centre_presenter_test.py @@ -9,9 +9,9 @@ from __future__ import (absolute_import, division, print_function) import unittest from mantid.py3compat import mock -from sans.test_helper.mock_objects import create_mock_beam_centre_tab from sans.common.enums import SANSInstrument from sans.gui_logic.presenter.beam_centre_presenter import BeamCentrePresenter +from sans.test_helper.mock_objects import create_mock_beam_centre_tab from sans.test_helper.mock_objects import (create_run_tab_presenter_mock) @@ -85,7 +85,6 @@ class BeamCentrePresenterTest(unittest.TestCase): self.presenter._beam_centre_model.update_lab = True self.presenter._beam_centre_model.update_hab = True - self.presenter.on_processing_finished_centre_finder(result) self.assertEqual(result['pos1'], self.presenter._beam_centre_model.lab_pos_1) self.assertEqual(result['pos2'], self.presenter._beam_centre_model.lab_pos_2) diff --git a/scripts/test/SANS/gui_logic/create_state_test.py b/scripts/test/SANS/gui_logic/create_state_test.py index 5a274ccc769929b79e37faa03759be8c88e5d0c3..4c227422a524a32fc4e39b60ce9f93449aca3f18 100644 --- a/scripts/test/SANS/gui_logic/create_state_test.py +++ b/scripts/test/SANS/gui_logic/create_state_test.py @@ -7,14 +7,14 @@ from __future__ import (absolute_import, division, print_function) import unittest +from qtpy.QtCore import QCoreApplication from mantid.py3compat import mock +from sans.common.enums import (SANSInstrument, SANSFacility, SaveType) from sans.gui_logic.models.create_state import (create_states, create_gui_state_from_userfile) -from sans.common.enums import (SANSInstrument, ISISReductionMode, SANSFacility, SaveType) from sans.gui_logic.models.state_gui_model import StateGuiModel from sans.gui_logic.models.table_model import TableModel, TableIndexModel from sans.state.state import State -from qtpy.QtCore import QCoreApplication class GuiCommonTest(unittest.TestCase): @@ -81,7 +81,7 @@ class GuiCommonTest(unittest.TestCase): def test_create_gui_state_from_userfile_adds_save_format_from_gui(self): gui_state = StateGuiModel({}) - gui_state.save_types = [SaveType.NXcanSAS] + gui_state.save_types = [SaveType.NX_CAN_SAS] row_state = create_gui_state_from_userfile('MaskLOQData.txt', gui_state) diff --git a/scripts/test/SANS/gui_logic/diagnostics_page_model_test.py b/scripts/test/SANS/gui_logic/diagnostics_page_model_test.py index db01a2bc890ab9f8640fa96aac3880d4d27e0285..e87b465bf58691bd4ec5549b3c7cd56b6d96afe0 100644 --- a/scripts/test/SANS/gui_logic/diagnostics_page_model_test.py +++ b/scripts/test/SANS/gui_logic/diagnostics_page_model_test.py @@ -5,11 +5,13 @@ # & Institut Laue - Langevin # SPDX - License - Identifier: GPL - 3.0 + from __future__ import (absolute_import, division, print_function) + import unittest + from sans.common.enums import SANSFacility from sans.gui_logic.models.diagnostics_page_model import create_state -from sans.test_helper.user_file_test_helper import sample_user_file, create_user_file from sans.gui_logic.models.state_gui_model import StateGuiModel +from sans.test_helper.user_file_test_helper import sample_user_file, create_user_file from sans.user_file.user_file_reader import UserFileReader diff --git a/scripts/test/SANS/gui_logic/diagnostics_page_presenter_test.py b/scripts/test/SANS/gui_logic/diagnostics_page_presenter_test.py index e06af9e7fefe219f72319f1c06af14d683a8ada4..002f6aa61314f9c7d2767ce7bfc6e4f9cf33011f 100644 --- a/scripts/test/SANS/gui_logic/diagnostics_page_presenter_test.py +++ b/scripts/test/SANS/gui_logic/diagnostics_page_presenter_test.py @@ -9,9 +9,9 @@ from __future__ import (absolute_import, division, print_function) import unittest from mantid.py3compat import mock -from sans.test_helper.mock_objects import create_mock_diagnostics_tab from sans.common.enums import SANSInstrument, DetectorType, IntegralEnum, SANSFacility from sans.gui_logic.presenter.diagnostic_presenter import DiagnosticsPagePresenter +from sans.test_helper.mock_objects import create_mock_diagnostics_tab from sans.test_helper.mock_objects import (create_run_tab_presenter_mock) diff --git a/scripts/test/SANS/gui_logic/gui_common_test.py b/scripts/test/SANS/gui_logic/gui_common_test.py index 6ce0d45fdb001d5cd5bcd3105ffe341e5425daad..960ae34f27b74afc91d5ca9cc0eb5e0b2bc33033 100644 --- a/scripts/test/SANS/gui_logic/gui_common_test.py +++ b/scripts/test/SANS/gui_logic/gui_common_test.py @@ -9,13 +9,13 @@ from __future__ import (absolute_import, division, print_function) import unittest from mantid.py3compat import mock +from sans.common.enums import (SANSInstrument, ReductionMode) from sans.gui_logic.gui_common import (get_reduction_mode_strings_for_gui, get_reduction_selection, get_string_for_gui_from_reduction_mode, get_batch_file_dir_from_path, add_dir_to_datasearch, remove_dir_from_datasearch, SANSGuiPropertiesHandler, get_reduction_mode_from_gui_selection) -from sans.common.enums import (SANSInstrument, ISISReductionMode) class GuiCommonTest(unittest.TestCase): @@ -44,31 +44,31 @@ class GuiCommonTest(unittest.TestCase): larmor_settings = get_reduction_mode_strings_for_gui(SANSInstrument.LARMOR) self._assert_same(larmor_settings, ["DetectorBench"]) - default_settings = get_reduction_mode_strings_for_gui(SANSInstrument.NoInstrument) + default_settings = get_reduction_mode_strings_for_gui(SANSInstrument.NO_INSTRUMENT) self._assert_same(default_settings, ["LAB", "HAB", "Merged", "All"]) def test_that_gets_correct_reduction_selection(self): sans_settings = get_reduction_selection(SANSInstrument.SANS2D) - self._assert_same_map(sans_settings, {ISISReductionMode.LAB: "rear", ISISReductionMode.HAB: "front", - ISISReductionMode.Merged: "Merged", ISISReductionMode.All: "All"}) + self._assert_same_map(sans_settings, {ReductionMode.LAB: "rear", ReductionMode.HAB: "front", + ReductionMode.MERGED: "Merged", ReductionMode.ALL: "All"}) loq_settings = get_reduction_selection(SANSInstrument.LOQ) - self._assert_same_map(loq_settings, {ISISReductionMode.LAB: "main-detector", ISISReductionMode.HAB: "Hab", - ISISReductionMode.Merged: "Merged", ISISReductionMode.All: "All"}) + self._assert_same_map(loq_settings, {ReductionMode.LAB: "main-detector", ReductionMode.HAB: "Hab", + ReductionMode.MERGED: "Merged", ReductionMode.ALL: "All"}) larmor_settings = get_reduction_selection(SANSInstrument.LARMOR) - self._assert_same_map(larmor_settings, {ISISReductionMode.LAB: "DetectorBench"}) + self._assert_same_map(larmor_settings, {ReductionMode.LAB: "DetectorBench"}) - default_settings = get_reduction_selection(SANSInstrument.NoInstrument) - self._assert_same_map(default_settings, {ISISReductionMode.LAB: "LAB", ISISReductionMode.HAB: "HAB", - ISISReductionMode.Merged: "Merged", ISISReductionMode.All: "All"}) + default_settings = get_reduction_selection(SANSInstrument.NO_INSTRUMENT) + self._assert_same_map(default_settings, {ReductionMode.LAB: "LAB", ReductionMode.HAB: "HAB", + ReductionMode.MERGED: "Merged", ReductionMode.ALL: "All"}) def test_that_can_get_reduction_mode_string(self): - self.run_reduction_mode_string_case(SANSInstrument.SANS2D, ISISReductionMode.LAB, "rear") - self.run_reduction_mode_string_case(SANSInstrument.LOQ, ISISReductionMode.HAB, "Hab") - self.run_reduction_mode_string_case(SANSInstrument.LARMOR, ISISReductionMode.LAB, "DetectorBench") - self.run_reduction_mode_string_case(SANSInstrument.NoInstrument, ISISReductionMode.LAB, "LAB") - self.run_reduction_mode_string_case(SANSInstrument.NoInstrument, ISISReductionMode.HAB, "HAB") + self.run_reduction_mode_string_case(SANSInstrument.SANS2D, ReductionMode.LAB, "rear") + self.run_reduction_mode_string_case(SANSInstrument.LOQ, ReductionMode.HAB, "Hab") + self.run_reduction_mode_string_case(SANSInstrument.LARMOR, ReductionMode.LAB, "DetectorBench") + self.run_reduction_mode_string_case(SANSInstrument.NO_INSTRUMENT, ReductionMode.LAB, "LAB") + self.run_reduction_mode_string_case(SANSInstrument.NO_INSTRUMENT, ReductionMode.HAB, "HAB") def test_get_reduction_mode_from_gui_selection(self): # Terminology for SANS2D / LOQ / Larmor / Zoom / generic respectively: @@ -84,10 +84,10 @@ class GuiCommonTest(unittest.TestCase): input_str = input_str.upper() self.assertEqual(expected_outcome, get_reduction_mode_from_gui_selection(input_str)) - check_all_match(lab_strings, ISISReductionMode.LAB) - check_all_match(hab_strings, ISISReductionMode.HAB) - check_all_match(merged_strings, ISISReductionMode.Merged) - check_all_match(all_strings, ISISReductionMode.All) + check_all_match(lab_strings, ReductionMode.LAB) + check_all_match(hab_strings, ReductionMode.HAB) + check_all_match(merged_strings, ReductionMode.MERGED) + check_all_match(all_strings, ReductionMode.ALL) def test_that_batch_file_dir_returns_none_if_no_forwardslash(self): a_path = "test_batch_file_path.csv" @@ -155,5 +155,3 @@ class SANSGuiPropertiesHandlerTest(unittest.TestCase): if __name__ == '__main__': unittest.main() - - diff --git a/scripts/test/SANS/gui_logic/gui_state_director_test.py b/scripts/test/SANS/gui_logic/gui_state_director_test.py index cd46531383425eb9858ac3a77b2d1fd66b61beff..ff597c169e38297ee352cfc057924c733be7f146 100644 --- a/scripts/test/SANS/gui_logic/gui_state_director_test.py +++ b/scripts/test/SANS/gui_logic/gui_state_director_test.py @@ -6,16 +6,16 @@ # SPDX - License - Identifier: GPL - 3.0 + from __future__ import (absolute_import, division, print_function) -import unittest import os +import unittest -from sans.gui_logic.presenter.gui_state_director import GuiStateDirector -from sans.gui_logic.models.table_model import (TableModel, TableIndexModel) -from sans.gui_logic.models.state_gui_model import StateGuiModel -from sans.user_file.user_file_reader import UserFileReader from sans.common.enums import SANSFacility +from sans.gui_logic.models.state_gui_model import StateGuiModel +from sans.gui_logic.models.table_model import (TableModel, TableIndexModel) +from sans.gui_logic.presenter.gui_state_director import GuiStateDirector from sans.state.state import State from sans.test_helper.user_file_test_helper import create_user_file, sample_user_file +from sans.user_file.user_file_reader import UserFileReader class GuiStateDirectorTest(unittest.TestCase): diff --git a/scripts/test/SANS/gui_logic/masking_table_presenter_test.py b/scripts/test/SANS/gui_logic/masking_table_presenter_test.py index 375ec5293893e65e0751e37b4fbc1b399871cba9..2eceb9d76a7000df888fa435b1b78db88306f21e 100644 --- a/scripts/test/SANS/gui_logic/masking_table_presenter_test.py +++ b/scripts/test/SANS/gui_logic/masking_table_presenter_test.py @@ -10,7 +10,7 @@ import unittest from mantid.py3compat import mock from sans.gui_logic.presenter.masking_table_presenter import (MaskingTablePresenter, masking_information) -from sans.test_helper.mock_objects import (FakeParentPresenter, FakeState, create_mock_masking_table, create_run_tab_presenter_mock) +from sans.test_helper.mock_objects import (FakeState, create_mock_masking_table, create_run_tab_presenter_mock) class MaskingTablePresenterTest(unittest.TestCase): diff --git a/scripts/test/SANS/gui_logic/model_tests/settings_adjustment_model_test.py b/scripts/test/SANS/gui_logic/model_tests/settings_adjustment_model_test.py index 58cd809747c617250494e5c64bd62572f3e0b30e..5205c07e9be56ea4b91911f896c56223147e1e54 100644 --- a/scripts/test/SANS/gui_logic/model_tests/settings_adjustment_model_test.py +++ b/scripts/test/SANS/gui_logic/model_tests/settings_adjustment_model_test.py @@ -12,19 +12,19 @@ class SettingsTransmissionModelTest(unittest.TestCase): return model_under_test def test_monitor_5_reported_for_zoom(self): - user_file = {DetectorId.instrument: [SANSInstrument.ZOOM]} + user_file = {DetectorId.INSTRUMENT: [SANSInstrument.ZOOM]} model_under_test = self.create_model(user_file) self.assertTrue(model_under_test.does_instrument_support_monitor_5()) def test_monitor_5_disabled_for_no_inst(self): - user_file = {DetectorId.instrument: [SANSInstrument.NoInstrument]} + user_file = {DetectorId.INSTRUMENT: [SANSInstrument.NO_INSTRUMENT]} model_under_test = self.create_model(user_file) self.assertFalse(model_under_test.does_instrument_support_monitor_5()) def test_monitor_5_disabled_for_sans(self): - user_file = {DetectorId.instrument: [SANSInstrument.SANS2D]} + user_file = {DetectorId.INSTRUMENT: [SANSInstrument.SANS2D]} model_under_test = self.create_model(user_file) self.assertFalse(model_under_test.does_instrument_support_monitor_5()) @@ -116,19 +116,19 @@ class SettingsTransmissionModelTest(unittest.TestCase): def test_transmission_fit_defaults(self): state_gui_model = self.create_model() - self.assertEqual(state_gui_model.transmission_sample_fit_type, FitType.NoFit) - self.assertEqual(state_gui_model.transmission_can_fit_type, FitType.NoFit) + self.assertEqual(state_gui_model.transmission_sample_fit_type, FitType.NO_FIT) + self.assertEqual(state_gui_model.transmission_can_fit_type, FitType.NO_FIT) self.assertEqual(state_gui_model.transmission_sample_polynomial_order, 2) self.assertEqual(state_gui_model.transmission_can_polynomial_order, 2) def test_that_can_set_transmission_fit_options(self): state_gui_model = self.create_model() - state_gui_model.transmission_sample_fit_type = FitType.Logarithmic - state_gui_model.transmission_can_fit_type = FitType.Linear + state_gui_model.transmission_sample_fit_type = FitType.LOGARITHMIC + state_gui_model.transmission_can_fit_type = FitType.LINEAR state_gui_model.transmission_sample_polynomial_order = 2 state_gui_model.transmission_can_polynomial_order = 2 - self.assertEqual(state_gui_model.transmission_sample_fit_type, FitType.Logarithmic) - self.assertEqual(state_gui_model.transmission_can_fit_type, FitType.Linear) + self.assertEqual(state_gui_model.transmission_sample_fit_type, FitType.LOGARITHMIC) + self.assertEqual(state_gui_model.transmission_can_fit_type, FitType.LINEAR) self.assertEqual(state_gui_model.transmission_sample_polynomial_order, 2) self.assertEqual(state_gui_model.transmission_can_polynomial_order, 2) diff --git a/scripts/test/SANS/gui_logic/property_manager_service_test.py b/scripts/test/SANS/gui_logic/property_manager_service_test.py index 7c14b0fe9b48e95bced3cf4c32410e94fcd3c002..e1e0e612b0bd1e0541556a00abd7f4a92c839723 100644 --- a/scripts/test/SANS/gui_logic/property_manager_service_test.py +++ b/scripts/test/SANS/gui_logic/property_manager_service_test.py @@ -8,15 +8,13 @@ from __future__ import (absolute_import, division, print_function) import unittest -import mantid -from mantid.kernel import PropertyManagerDataService, PropertyManagerProperty from mantid.api import Algorithm - +from mantid.kernel import PropertyManagerDataService, PropertyManagerProperty +from sans.common.enums import SANSFacility, SANSInstrument from sans.gui_logic.presenter.property_manager_service import PropertyManagerService from sans.state.data import get_data_builder -from sans.test_helper.test_director import TestDirector -from sans.common.enums import SANSFacility, SANSInstrument from sans.test_helper.file_information_mock import SANSFileInformationMock +from sans.test_helper.test_director import TestDirector class FakeAlgorithm(Algorithm): diff --git a/scripts/test/SANS/gui_logic/run_selector_presenter_test.py b/scripts/test/SANS/gui_logic/run_selector_presenter_test.py index e91b7cd520141352d325774a7c6d3574f9547d38..1847d72458dff64338e608ecdd3f7ef4421142c2 100644 --- a/scripts/test/SANS/gui_logic/run_selector_presenter_test.py +++ b/scripts/test/SANS/gui_logic/run_selector_presenter_test.py @@ -6,15 +6,14 @@ # SPDX - License - Identifier: GPL - 3.0 + import unittest +from assert_called import assert_called +from fake_signal import FakeSignal from mantid.py3compat import mock -from sans.gui_logic.presenter.run_selector_presenter import RunSelectorPresenter -from sans.gui_logic.models.run_selection import RunSelection -from sans.gui_logic.models.run_finder import SummableRunFinder from sans.gui_logic.models.run_file import SummableRunFile +from sans.gui_logic.models.run_finder import SummableRunFinder +from sans.gui_logic.models.run_selection import RunSelection +from sans.gui_logic.presenter.run_selector_presenter import RunSelectorPresenter from ui.sans_isis.run_selector_widget import RunSelectorWidget -from fake_signal import FakeSignal - -from assert_called import assert_called class RunSelectorPresenterTest(unittest.TestCase): diff --git a/scripts/test/SANS/gui_logic/run_tab_presenter_test.py b/scripts/test/SANS/gui_logic/run_tab_presenter_test.py index 24020f986994969cfe42f820be338622d10f43d0..fda6911df8a2f2ea52272a0bbe03e48c08a1f853 100644 --- a/scripts/test/SANS/gui_logic/run_tab_presenter_test.py +++ b/scripts/test/SANS/gui_logic/run_tab_presenter_test.py @@ -9,53 +9,50 @@ from __future__ import (absolute_import, division, print_function) import unittest -from mantid.kernel import config from mantid.kernel import PropertyManagerDataService +from mantid.kernel import config from mantid.py3compat import mock - -from sans.gui_logic.presenter.run_tab_presenter import RunTabPresenter -from sans.common.enums import (SANSFacility, ReductionDimensionality, SaveType, ISISReductionMode, - RangeStepType, FitType, SANSInstrument, RowState) -from sans.test_helper.user_file_test_helper import (create_user_file, sample_user_file, sample_user_file_gravity_OFF, - sample_user_file_with_instrument) -from sans.test_helper.mock_objects import (create_mock_view) -from sans.test_helper.common import (remove_file) from sans.common.enums import BatchReductionEntry, SANSInstrument +from sans.common.enums import (SANSFacility, ReductionDimensionality, SaveType, ReductionMode, + RangeStepType, FitType, RowState) from sans.gui_logic.models.table_model import TableModel, TableIndexModel +from sans.gui_logic.presenter.run_tab_presenter import RunTabPresenter +from sans.test_helper.common import (remove_file) from sans.test_helper.file_information_mock import SANSFileInformationMock - - -BATCH_FILE_TEST_CONTENT_1 = [{BatchReductionEntry.SampleScatter: 1, BatchReductionEntry.SampleTransmission: 2, - BatchReductionEntry.SampleDirect: 3, BatchReductionEntry.Output: 'test_file', - BatchReductionEntry.UserFile: 'user_test_file'}, - {BatchReductionEntry.SampleScatter: 1, BatchReductionEntry.CanScatter: 2, - BatchReductionEntry.Output: 'test_file2'}] - -BATCH_FILE_TEST_CONTENT_2 = [{BatchReductionEntry.SampleScatter: 'SANS2D00022024', - BatchReductionEntry.SampleTransmission: 'SANS2D00022048', - BatchReductionEntry.SampleDirect: 'SANS2D00022048', - BatchReductionEntry.Output: 'test_file'}, - {BatchReductionEntry.SampleScatter: 'SANS2D00022024', - BatchReductionEntry.Output: 'test_file2'}] - -BATCH_FILE_TEST_CONTENT_3 = [{BatchReductionEntry.SampleScatter: 'SANS2D00022024', - BatchReductionEntry.SampleScatterPeriod: '3', - BatchReductionEntry.Output: 'test_file'}] - -BATCH_FILE_TEST_CONTENT_4 = [{BatchReductionEntry.SampleScatter: 'SANS2D00022024', - BatchReductionEntry.SampleTransmission: 'SANS2D00022048', - BatchReductionEntry.SampleDirect: 'SANS2D00022048', - BatchReductionEntry.Output: 'test_file'}, - {BatchReductionEntry.SampleScatter: 'SANS2D00022024', - BatchReductionEntry.Output: 'test_file2'}] - -BATCH_FILE_TEST_CONTENT_5 = [{BatchReductionEntry.SampleScatter: 'SANS2D00022024', - BatchReductionEntry.SampleTransmission: 'SANS2D00022048', - BatchReductionEntry.SampleDirect: 'SANS2D00022048', - BatchReductionEntry.Output: 'test_file', - BatchReductionEntry.SampleThickness: '5', - BatchReductionEntry.SampleHeight: '2', - BatchReductionEntry.SampleWidth: '8'}] +from sans.test_helper.mock_objects import (create_mock_view) +from sans.test_helper.user_file_test_helper import (create_user_file, sample_user_file, sample_user_file_gravity_OFF) + +BATCH_FILE_TEST_CONTENT_1 = [{BatchReductionEntry.SAMPLE_SCATTER: 1, BatchReductionEntry.SAMPLE_TRANSMISSION: 2, + BatchReductionEntry.SAMPLE_DIRECT: 3, BatchReductionEntry.OUTPUT: 'test_file', + BatchReductionEntry.USER_FILE: 'user_test_file'}, + {BatchReductionEntry.SAMPLE_SCATTER: 1, BatchReductionEntry.CAN_SCATTER: 2, + BatchReductionEntry.OUTPUT: 'test_file2'}] + +BATCH_FILE_TEST_CONTENT_2 = [{BatchReductionEntry.SAMPLE_SCATTER: 'SANS2D00022024', + BatchReductionEntry.SAMPLE_TRANSMISSION: 'SANS2D00022048', + BatchReductionEntry.SAMPLE_DIRECT: 'SANS2D00022048', + BatchReductionEntry.OUTPUT: 'test_file'}, + {BatchReductionEntry.SAMPLE_SCATTER: 'SANS2D00022024', + BatchReductionEntry.OUTPUT: 'test_file2'}] + +BATCH_FILE_TEST_CONTENT_3 = [{BatchReductionEntry.SAMPLE_SCATTER: 'SANS2D00022024', + BatchReductionEntry.SAMPLE_SCATTER_PERIOD: '3', + BatchReductionEntry.OUTPUT: 'test_file'}] + +BATCH_FILE_TEST_CONTENT_4 = [{BatchReductionEntry.SAMPLE_SCATTER: 'SANS2D00022024', + BatchReductionEntry.SAMPLE_TRANSMISSION: 'SANS2D00022048', + BatchReductionEntry.SAMPLE_DIRECT: 'SANS2D00022048', + BatchReductionEntry.OUTPUT: 'test_file'}, + {BatchReductionEntry.SAMPLE_SCATTER: 'SANS2D00022024', + BatchReductionEntry.OUTPUT: 'test_file2'}] + +BATCH_FILE_TEST_CONTENT_5 = [{BatchReductionEntry.SAMPLE_SCATTER: 'SANS2D00022024', + BatchReductionEntry.SAMPLE_TRANSMISSION: 'SANS2D00022048', + BatchReductionEntry.SAMPLE_DIRECT: 'SANS2D00022048', + BatchReductionEntry.OUTPUT: 'test_file', + BatchReductionEntry.SAMPLE_THICKNESS: '5', + BatchReductionEntry.SAMPLE_HEIGHT: '2', + BatchReductionEntry.SAMPLE_WIDTH: '8'}] def get_non_empty_row_mock(value): @@ -116,17 +113,17 @@ class RunTabPresenterTest(unittest.TestCase): # Assert # Note that the event slices are not set in the user file self.assertFalse(view.event_slices) - self.assertEqual(view.reduction_dimensionality, ReductionDimensionality.OneDim) - self.assertEqual(view.save_types[0], SaveType.NXcanSAS) + self.assertEqual(view.reduction_dimensionality, ReductionDimensionality.ONE_DIM) + self.assertEqual(view.save_types[0], SaveType.NX_CAN_SAS) self.assertTrue(view.zero_error_free) self.assertTrue(view.use_optimizations) - self.assertEqual(view.reduction_mode, ISISReductionMode.LAB) + self.assertEqual(view.reduction_mode, ReductionMode.LAB) self.assertEqual(view.merge_scale, 1.) self.assertEqual(view.merge_shift, 0.) self.assertFalse(view.merge_scale_fit) self.assertFalse(view.merge_shift_fit) self.assertEqual(view.event_binning, "7000.0,500.0,60000.0") - self.assertEqual(view.wavelength_step_type, RangeStepType.Lin) + self.assertEqual(view.wavelength_step_type, RangeStepType.LIN) self.assertEqual(view.wavelength_min, 1.5) self.assertEqual(view.wavelength_max, 12.5) self.assertEqual(view.wavelength_step, 0.125) @@ -142,7 +139,7 @@ class RunTabPresenterTest(unittest.TestCase): self.assertEqual(view.transmission_monitor, 4) self.assertEqual(view.transmission_mn_4_shift, -70) self.assertTrue(view.transmission_sample_use_fit) - self.assertEqual(view.transmission_sample_fit_type, FitType.Logarithmic) + self.assertEqual(view.transmission_sample_fit_type, FitType.LOGARITHMIC) self.assertEqual(view.transmission_sample_polynomial_order, 2) self.assertEqual(view.transmission_sample_wavelength_min, 1.5) self.assertEqual(view.transmission_sample_wavelength_max, 12.5) @@ -154,7 +151,7 @@ class RunTabPresenterTest(unittest.TestCase): self.assertEqual(view.q_1d_min_or_rebin_string, "0.001,0.001,0.0126,-0.08,0.2") self.assertEqual(view.q_xy_max, 0.05) self.assertEqual(view.q_xy_step, 0.001) - self.assertEqual(view.q_xy_step_type, RangeStepType.Lin) + self.assertEqual(view.q_xy_step_type, RangeStepType.LIN) self.assertTrue(view.gravity_on_off) self.assertTrue(view.use_q_resolution) self.assertEqual(view.q_resolution_sample_a, 14.) @@ -340,7 +337,7 @@ class RunTabPresenterTest(unittest.TestCase): self.assertEqual(state0.slice.start_time, None) self.assertEqual(state0.slice.end_time, None) - self.assertEqual(state0.reduction.reduction_dimensionality, ReductionDimensionality.OneDim) + self.assertEqual(state0.reduction.reduction_dimensionality, ReductionDimensionality.ONE_DIM) self.assertEqual(state0.move.detectors['LAB'].sample_centre_pos1, 0.15544999999999998) # Clean up @@ -890,7 +887,7 @@ class RunTabPresenterTest(unittest.TestCase): presenter.notify_progress(0, [0.0], [1.0]) - self.assertEqual(presenter._table_model.get_table_entry(0).row_state, RowState.Processed) + self.assertEqual(presenter._table_model.get_table_entry(0).row_state, RowState.PROCESSED) self.assertEqual(presenter._table_model.get_table_entry(0).options_column_model.get_options_string(), 'MergeScale=1.0, MergeShift=0.0') @@ -917,7 +914,7 @@ class RunTabPresenterTest(unittest.TestCase): presenter.notify_progress(0, [], []) - self.assertEqual(presenter._table_model.get_table_entry(0).row_state, RowState.Processed) + self.assertEqual(presenter._table_model.get_table_entry(0).row_state, RowState.PROCESSED) self.assertEqual(presenter._table_model.get_table_entry(0).tool_tip, '') def test_that_process_selected_does_nothing_if_no_states_selected(self): @@ -1135,7 +1132,7 @@ class RunTabPresenterTest(unittest.TestCase): presenter = RunTabPresenter(SANSFacility.ISIS) view = mock.MagicMock() - view.save_types = [SaveType.NoType] + view.save_types = [SaveType.NO_TYPE] view.output_mode_memory_radio_button.isChecked = mock.MagicMock(return_value=False) view.output_mode_file_radio_button.isChecked = mock.MagicMock(return_value=True) @@ -1148,7 +1145,7 @@ class RunTabPresenterTest(unittest.TestCase): presenter = RunTabPresenter(SANSFacility.ISIS) view = mock.MagicMock() - view.save_types = [SaveType.NoType] + view.save_types = [SaveType.NO_TYPE] view.output_mode_memory_radio_button.isChecked = mock.MagicMock(return_value=False) view.output_mode_file_radio_button.isChecked = mock.MagicMock(return_value=False) @@ -1160,7 +1157,7 @@ class RunTabPresenterTest(unittest.TestCase): def test_that_validate_output_modes_does_not_raise_if_no_file_types_selected_for_memory_mode(self): presenter = RunTabPresenter(SANSFacility.ISIS) view = mock.MagicMock() - view.save_types = [SaveType.NoType] + view.save_types = [SaveType.NO_TYPE] view.output_mode_memory_radio_button.isChecked = mock.Mock(return_value=True) view.output_mode_file_radio_button.isChecked = mock.Mock(return_value=False) @@ -1262,7 +1259,7 @@ class RunTabPresenterTest(unittest.TestCase): def _get_files_and_mock_presenter(self, content, is_multi_period=True, row_user_file_path=""): if row_user_file_path: - content[1].update({BatchReductionEntry.UserFile : row_user_file_path}) + content[1].update({BatchReductionEntry.USER_FILE : row_user_file_path}) batch_parser = mock.MagicMock() batch_parser.parse_batch_file = mock.MagicMock(return_value=content) diff --git a/scripts/test/SANS/gui_logic/sans_data_processor_gui_algorithm_test.py b/scripts/test/SANS/gui_logic/sans_data_processor_gui_algorithm_test.py index 668109130b709efbcba8d6d8e839972b7fa82d25..b035399cdd6c8adf4eb40fcbcdc3da1d5ac9be32 100644 --- a/scripts/test/SANS/gui_logic/sans_data_processor_gui_algorithm_test.py +++ b/scripts/test/SANS/gui_logic/sans_data_processor_gui_algorithm_test.py @@ -8,9 +8,9 @@ from __future__ import (absolute_import, division, print_function) import unittest +from sans.common.enums import (SANSFacility) from sans.gui_logic.sans_data_processor_gui_algorithm import (create_properties, create_option_column_properties, get_gui_algorithm_name, get_white_list, get_black_list) -from sans.common.enums import (SANSFacility) class SANSGuiDataProcessorAlgorithmTest(unittest.TestCase): diff --git a/scripts/test/SANS/gui_logic/settings_diagnostic_presenter_test.py b/scripts/test/SANS/gui_logic/settings_diagnostic_presenter_test.py index 20df8a072834ee2140a0ff99c7de26b40421826f..46a6854740190e325b68fc5c7d3a22b48a62773a 100644 --- a/scripts/test/SANS/gui_logic/settings_diagnostic_presenter_test.py +++ b/scripts/test/SANS/gui_logic/settings_diagnostic_presenter_test.py @@ -6,16 +6,15 @@ # SPDX - License - Identifier: GPL - 3.0 + from __future__ import (absolute_import, division, print_function) +import json +import os import tempfile import unittest -import os -import json - -import mantid from mantid.py3compat import mock from sans.gui_logic.presenter.settings_diagnostic_presenter import SettingsDiagnosticPresenter -from sans.test_helper.mock_objects import (create_run_tab_presenter_mock, FakeState, create_mock_settings_diagnostic_tab) +from sans.test_helper.mock_objects import (create_run_tab_presenter_mock, FakeState, + create_mock_settings_diagnostic_tab) class SettingsDiagnosticPresenterTest(unittest.TestCase): diff --git a/scripts/test/SANS/gui_logic/state_gui_model_test.py b/scripts/test/SANS/gui_logic/state_gui_model_test.py index 33b274d90f01153172505debe6af1c80fbe6b7ae..df6d63ddad2677144cf3d99dffbca7b24a5c4dea 100644 --- a/scripts/test/SANS/gui_logic/state_gui_model_test.py +++ b/scripts/test/SANS/gui_logic/state_gui_model_test.py @@ -5,11 +5,13 @@ # & Institut Laue - Langevin # SPDX - License - Identifier: GPL - 3.0 + from __future__ import (absolute_import, division, print_function) + import unittest + +from sans.common.enums import (ReductionDimensionality, ReductionMode, RangeStepType, SampleShape, SaveType, + SANSInstrument) from sans.gui_logic.models.state_gui_model import StateGuiModel -from sans.user_file.settings_tags import (OtherId, event_binning_string_values, DetectorId, det_fit_range) -from sans.common.enums import (ReductionDimensionality, ISISReductionMode, RangeStepType, SampleShape, SaveType, - FitType, SANSInstrument) +from sans.user_file.settings_tags import (OtherId, event_binning_string_values, DetectorId) from sans.user_file.settings_tags import (det_fit_range) @@ -21,7 +23,7 @@ class StateGuiModelTest(unittest.TestCase): # ================================================================================================================== def test_that_default_instrument_is_NoInstrument(self): state_gui_model = StateGuiModel({"test": [1]}) - self.assertEqual(state_gui_model.instrument, SANSInstrument.NoInstrument) + self.assertEqual(state_gui_model.instrument, SANSInstrument.NO_INSTRUMENT) # ------------------------------------------------------------------------------------------------------------------ # Compatibility Mode @@ -52,18 +54,18 @@ class StateGuiModelTest(unittest.TestCase): self.assertTrue(state_gui_model.zero_error_free) def test_that_can_zero_error_free_saving_can_be_changed(self): - state_gui_model = StateGuiModel({OtherId.save_as_zero_error_free: [True]}) + state_gui_model = StateGuiModel({OtherId.SAVE_AS_ZERO_ERROR_FREE: [True]}) state_gui_model.zero_error_free = False self.assertFalse(state_gui_model.zero_error_free) def test_that_default_save_type_is_NXcanSAS(self): state_gui_model = StateGuiModel({"test": [1]}) - self.assertEqual(state_gui_model.save_types, [SaveType.NXcanSAS]) + self.assertEqual(state_gui_model.save_types, [SaveType.NX_CAN_SAS]) def test_that_can_select_multiple_save_types(self): state_gui_model = StateGuiModel({"test": [1]}) - state_gui_model.save_types = [SaveType.RKH, SaveType.NXcanSAS] - self.assertEqual(state_gui_model.save_types, [SaveType.RKH, SaveType.NXcanSAS]) + state_gui_model.save_types = [SaveType.RKH, SaveType.NX_CAN_SAS] + self.assertEqual(state_gui_model.save_types, [SaveType.RKH, SaveType.NX_CAN_SAS]) # ================================================================================================================== # ================================================================================================================== @@ -79,11 +81,11 @@ class StateGuiModelTest(unittest.TestCase): self.assertEqual(state_gui_model.event_slices, "") def test_that_slice_event_can_be_retrieved_if_it_exists(self): - state_gui_model = StateGuiModel({OtherId.event_slices: [event_binning_string_values(value="test")]}) + state_gui_model = StateGuiModel({OtherId.EVENT_SLICES: [event_binning_string_values(value="test")]}) self.assertEqual(state_gui_model.event_slices, "test") def test_that_slice_event_can_be_updated(self): - state_gui_model = StateGuiModel({OtherId.event_slices: [event_binning_string_values(value="test")]}) + state_gui_model = StateGuiModel({OtherId.EVENT_SLICES: [event_binning_string_values(value="test")]}) state_gui_model.event_slices = "test2" self.assertEqual(state_gui_model.event_slices, "test2") @@ -92,12 +94,12 @@ class StateGuiModelTest(unittest.TestCase): # ------------------------------------------------------------------------------------------------------------------ def test_that_is_1D_reduction_by_default(self): state_gui_model = StateGuiModel({"test": [1]}) - self.assertEqual(state_gui_model.reduction_dimensionality, ReductionDimensionality.OneDim) + self.assertEqual(state_gui_model.reduction_dimensionality, ReductionDimensionality.ONE_DIM) def test_that_is_set_to_2D_reduction(self): state_gui_model = StateGuiModel({"test": [1]}) - state_gui_model.reduction_dimensionality = ReductionDimensionality.TwoDim - self.assertEqual(state_gui_model.reduction_dimensionality, ReductionDimensionality.TwoDim) + state_gui_model.reduction_dimensionality = ReductionDimensionality.TWO_DIM + self.assertEqual(state_gui_model.reduction_dimensionality, ReductionDimensionality.TWO_DIM) def test_that_raises_when_not_setting_with_reduction_dim_enum(self): def red_dim_wrapper(): @@ -106,10 +108,10 @@ class StateGuiModelTest(unittest.TestCase): self.assertRaises(ValueError, red_dim_wrapper) def test_that_can_update_reduction_dimensionality(self): - state_gui_model = StateGuiModel({OtherId.reduction_dimensionality: [ReductionDimensionality.OneDim]}) - self.assertEqual(state_gui_model.reduction_dimensionality, ReductionDimensionality.OneDim) - state_gui_model.reduction_dimensionality = ReductionDimensionality.TwoDim - self.assertEqual(state_gui_model.reduction_dimensionality, ReductionDimensionality.TwoDim) + state_gui_model = StateGuiModel({OtherId.REDUCTION_DIMENSIONALITY: [ReductionDimensionality.ONE_DIM]}) + self.assertEqual(state_gui_model.reduction_dimensionality, ReductionDimensionality.ONE_DIM) + state_gui_model.reduction_dimensionality = ReductionDimensionality.TWO_DIM + self.assertEqual(state_gui_model.reduction_dimensionality, ReductionDimensionality.TWO_DIM) # ------------------------------------------------------------------------------------------------------------------ # Event binning for compatibility mode @@ -128,12 +130,12 @@ class StateGuiModelTest(unittest.TestCase): # ------------------------------------------------------------------------------------------------------------------ def test_that_is_set_to_lab_by_default(self): state_gui_model = StateGuiModel({"test": [1]}) - self.assertEqual(state_gui_model.reduction_mode, ISISReductionMode.LAB) + self.assertEqual(state_gui_model.reduction_mode, ReductionMode.LAB) def test_that_can_be_set_to_something_else(self): state_gui_model = StateGuiModel({"test": [1]}) - state_gui_model.reduction_mode = ISISReductionMode.Merged - self.assertEqual(state_gui_model.reduction_mode, ISISReductionMode.Merged) + state_gui_model.reduction_mode = ReductionMode.MERGED + self.assertEqual(state_gui_model.reduction_mode, ReductionMode.MERGED) def test_that_raises_when_setting_with_wrong_input(self): def red_mode_wrapper(): @@ -142,10 +144,10 @@ class StateGuiModelTest(unittest.TestCase): self.assertRaises(ValueError, red_mode_wrapper) def test_that_can_update_reduction_mode(self): - state_gui_model = StateGuiModel({DetectorId.reduction_mode: [ISISReductionMode.HAB]}) - self.assertEqual(state_gui_model.reduction_mode, ISISReductionMode.HAB) - state_gui_model.reduction_mode = ISISReductionMode.All - self.assertEqual(state_gui_model.reduction_mode, ISISReductionMode.All) + state_gui_model = StateGuiModel({DetectorId.REDUCTION_MODE: [ReductionMode.HAB]}) + self.assertEqual(state_gui_model.reduction_mode, ReductionMode.HAB) + state_gui_model.reduction_mode = ReductionMode.ALL + self.assertEqual(state_gui_model.reduction_mode, ReductionMode.ALL) # ------------------------------------------------------------------------------------------------------------------ # Merge range @@ -173,7 +175,7 @@ class StateGuiModelTest(unittest.TestCase): self.assertEqual(state_gui_model.merge_min, 78.9) def test_that_merge_range_set_correctly(self): - state_gui_model = StateGuiModel({DetectorId.merge_range: [det_fit_range(use_fit=True, start=0.13, stop=0.15)]}) + state_gui_model = StateGuiModel({DetectorId.MERGE_RANGE: [det_fit_range(use_fit=True, start=0.13, stop=0.15)]}) self.assertEqual(state_gui_model.merge_min, 0.13) self.assertEqual(state_gui_model.merge_max, 0.15) self.assertTrue(state_gui_model.merge_mask) @@ -191,10 +193,10 @@ class StateGuiModelTest(unittest.TestCase): self.assertEqual(state_gui_model.merge_q_range_stop, "") def test_that_can_set_and_reset_merged_settings(self): - state_gui_model = StateGuiModel({DetectorId.shift_fit: [det_fit_range(start=1., stop=2., use_fit=True)], - DetectorId.rescale_fit: [det_fit_range(start=1.4, stop=7., use_fit=False)], - DetectorId.rescale: [12.], - DetectorId.shift: [234.]}) + state_gui_model = StateGuiModel({DetectorId.SHIFT_FIT: [det_fit_range(start=1., stop=2., use_fit=True)], + DetectorId.RESCALE_FIT: [det_fit_range(start=1.4, stop=7., use_fit=False)], + DetectorId.RESCALE: [12.], + DetectorId.SHIFT: [234.]}) self.assertEqual(state_gui_model.merge_scale, 12.) self.assertEqual(state_gui_model.merge_shift, 234.) self.assertFalse(state_gui_model.merge_scale_fit) @@ -227,19 +229,19 @@ class StateGuiModelTest(unittest.TestCase): def test_that_default_wavelength_step_type_is_linear(self): state_gui_model = StateGuiModel({"test": [1]}) - self.assertEqual(state_gui_model.wavelength_step_type, RangeStepType.Lin) + self.assertEqual(state_gui_model.wavelength_step_type, RangeStepType.LIN) def test_that_can_set_wavelength(self): state_gui_model = StateGuiModel({"test": [1]}) state_gui_model.wavelength_min = 1. state_gui_model.wavelength_max = 2. state_gui_model.wavelength_step = .5 - state_gui_model.wavelength_step_type = RangeStepType.Lin - state_gui_model.wavelength_step_type = RangeStepType.Log + state_gui_model.wavelength_step_type = RangeStepType.LIN + state_gui_model.wavelength_step_type = RangeStepType.LOG self.assertEqual(state_gui_model.wavelength_min, 1.) self.assertEqual(state_gui_model.wavelength_max, 2.) self.assertEqual(state_gui_model.wavelength_step, .5) - self.assertEqual(state_gui_model.wavelength_step_type, RangeStepType.Log) + self.assertEqual(state_gui_model.wavelength_step_type, RangeStepType.LOG) # ------------------------------------------------------------------------------------------------------------------ # Scale @@ -270,12 +272,12 @@ class StateGuiModelTest(unittest.TestCase): state_gui_model.sample_height = 1.6 state_gui_model.sample_thickness = 1.8 state_gui_model.z_offset = 1.78 - state_gui_model.sample_shape = SampleShape.FlatPlate + state_gui_model.sample_shape = SampleShape.FLAT_PLATE self.assertEqual(state_gui_model.sample_width, 1.2) self.assertEqual(state_gui_model.sample_height, 1.6) self.assertEqual(state_gui_model.sample_thickness, 1.8) self.assertEqual(state_gui_model.z_offset, 1.78) - self.assertEqual(state_gui_model.sample_shape, SampleShape.FlatPlate) + self.assertEqual(state_gui_model.sample_shape, SampleShape.FLAT_PLATE) # ================================================================================================================== # ================================================================================================================== @@ -301,14 +303,14 @@ class StateGuiModelTest(unittest.TestCase): state_gui_model.q_1d_rebin_string = "test" state_gui_model.q_xy_max = 1. state_gui_model.q_xy_step = 122. - state_gui_model.q_xy_step_type = RangeStepType.Log + state_gui_model.q_xy_step_type = RangeStepType.LOG state_gui_model.r_cut = 45. state_gui_model.w_cut = 890. self.assertEqual(state_gui_model.q_1d_rebin_string, "test") self.assertEqual(state_gui_model.q_xy_max, 1.) self.assertEqual(state_gui_model.q_xy_step, 122.) - self.assertEqual(state_gui_model.q_xy_step_type, RangeStepType.Log) + self.assertEqual(state_gui_model.q_xy_step_type, RangeStepType.LOG) self.assertEqual(state_gui_model.r_cut, 45.) self.assertEqual(state_gui_model.w_cut, 890.) diff --git a/scripts/test/SANS/gui_logic/summation_settings_model_test.py b/scripts/test/SANS/gui_logic/summation_settings_model_test.py index 58c929f89be3a2da351fa80fa104062af348680d..d17e813230e9e5cbd1ecb03c66347d00234e890d 100644 --- a/scripts/test/SANS/gui_logic/summation_settings_model_test.py +++ b/scripts/test/SANS/gui_logic/summation_settings_model_test.py @@ -5,7 +5,6 @@ # & Institut Laue - Langevin # SPDX - License - Identifier: GPL - 3.0 + import unittest -import sys from sans.common.enums import BinningType from sans.gui_logic.models.summation_settings import SummationSettings @@ -16,7 +15,7 @@ class SummationSettingsTestCase(unittest.TestCase): self.summation_settings = SummationSettings(initial_type) def setUp(self): - self.setUpWithInitialType(BinningType.Custom) + self.setUpWithInitialType(BinningType.CUSTOM) class SummationSettingsOverlayEventWorkspaceTestCase(unittest.TestCase): @@ -46,15 +45,15 @@ class SummationSettingsBinSettingsTest(SummationSettingsTestCase): self.assertEqual(bin_settings, self.summation_settings.bin_settings) def test_custom_binning_has_bin_settings(self): - self.setUpWithInitialType(BinningType.Custom) + self.setUpWithInitialType(BinningType.CUSTOM) self.assertHasBinSettings() def test_save_as_event_data_does_not_have_bin_settings(self): - self.setUpWithInitialType(BinningType.SaveAsEventData) + self.setUpWithInitialType(BinningType.SAVE_AS_EVENT_DATA) self.assertDoesNotHaveBinSettings() def test_from_monitors_does_not_have_bin_settings(self): - self.setUpWithInitialType(BinningType.FromMonitors) + self.setUpWithInitialType(BinningType.FROM_MONITORS) self.assertDoesNotHaveBinSettings() @@ -67,11 +66,11 @@ class SummationSettingsAdditionalTimeShiftsTest(SummationSettingsTestCase, \ self.assertFalse(self.summation_settings.has_additional_time_shifts()) def test_custom_binning_does_not_have_additional_time_shifts(self): - self.setUpWithInitialType(BinningType.Custom) + self.setUpWithInitialType(BinningType.CUSTOM) self.assertDoesNotHaveAdditionalTimeShifts() def test_save_as_event_data_has_additional_time_shifts_if_overlay_event_workspaces_enabled(self): - self.setUpWithInitialType(BinningType.SaveAsEventData) + self.setUpWithInitialType(BinningType.SAVE_AS_EVENT_DATA) self.assertDoesNotHaveAdditionalTimeShifts() self.summation_settings.enable_overlay_event_workspaces() self.assertHasAdditionalTimeShifts() @@ -79,11 +78,11 @@ class SummationSettingsAdditionalTimeShiftsTest(SummationSettingsTestCase, \ self.assertDoesNotHaveAdditionalTimeShifts() def test_from_monitors_does_not_have_additional_time_shifts(self): - self.setUpWithInitialType(BinningType.FromMonitors) + self.setUpWithInitialType(BinningType.FROM_MONITORS) self.assertDoesNotHaveAdditionalTimeShifts() def test_can_set_additional_time_shifts_when_available(self): - self.setUpWithInitialType(BinningType.SaveAsEventData) + self.setUpWithInitialType(BinningType.SAVE_AS_EVENT_DATA) self.summation_settings.enable_overlay_event_workspaces() additional_time_shifts = '1,24,545,23' self.summation_settings.additional_time_shifts = additional_time_shifts @@ -92,42 +91,42 @@ class SummationSettingsAdditionalTimeShiftsTest(SummationSettingsTestCase, \ def test_stores_additional_time_shifts_between_mode_switches(self): bin_settings = '232,2132,123' additional_time_shifts = '32,252,12' - self.setUpWithInitialType(BinningType.Custom) + self.setUpWithInitialType(BinningType.CUSTOM) self.summation_settings.bin_settings = bin_settings - self.summation_settings.set_histogram_binning_type(BinningType.SaveAsEventData) + self.summation_settings.set_histogram_binning_type(BinningType.SAVE_AS_EVENT_DATA) self.summation_settings.additional_time_shifts = additional_time_shifts - self.summation_settings.set_histogram_binning_type(BinningType.Custom) + self.summation_settings.set_histogram_binning_type(BinningType.CUSTOM) self.assertEqual(bin_settings, self.summation_settings.bin_settings) - self.summation_settings.set_histogram_binning_type(BinningType.SaveAsEventData) + self.summation_settings.set_histogram_binning_type(BinningType.SAVE_AS_EVENT_DATA) self.assertEqual(additional_time_shifts, self.summation_settings.additional_time_shifts) class SummationSettingsOverlayEventWorkspace(SummationSettingsTestCase, \ SummationSettingsOverlayEventWorkspaceTestCase): def test_custom_binning_does_not_have_overlay_event_workspaces(self): - self.setUpWithInitialType(BinningType.Custom) + self.setUpWithInitialType(BinningType.CUSTOM) self.assertDoesNotHaveOverlayEventWorkspaces() def test_save_as_event_data_has_overlay_event_workspaces(self): - self.setUpWithInitialType(BinningType.SaveAsEventData) + self.setUpWithInitialType(BinningType.SAVE_AS_EVENT_DATA) self.assertHasOverlayEventWorkspaces() def test_from_monitors_does_not_have_overlay_event_workspaces(self): - self.setUpWithInitialType(BinningType.FromMonitors) + self.setUpWithInitialType(BinningType.FROM_MONITORS) self.assertDoesNotHaveOverlayEventWorkspaces() def test_switching_to_save_as_event_data_enables_overlay_event_workspaces_option(self): - self.setUpWithInitialType(BinningType.FromMonitors) - self.summation_settings.set_histogram_binning_type(BinningType.SaveAsEventData) + self.setUpWithInitialType(BinningType.FROM_MONITORS) + self.summation_settings.set_histogram_binning_type(BinningType.SAVE_AS_EVENT_DATA) self.assertHasOverlayEventWorkspaces() def test_can_enable_overlay_event_workspaces_when_available(self): - self.setUpWithInitialType(BinningType.SaveAsEventData) + self.setUpWithInitialType(BinningType.SAVE_AS_EVENT_DATA) self.summation_settings.enable_overlay_event_workspaces() self.assertOverlayEventWorkspacesEnabled() def test_can_disable_overlay_event_workspaces_when_available(self): - self.setUpWithInitialType(BinningType.SaveAsEventData) + self.setUpWithInitialType(BinningType.SAVE_AS_EVENT_DATA) self.summation_settings.enable_overlay_event_workspaces() self.summation_settings.disable_overlay_event_workspaces() self.assertOverlayEventWorkspacesDisabled() diff --git a/scripts/test/SANS/gui_logic/summation_settings_presenter_test.py b/scripts/test/SANS/gui_logic/summation_settings_presenter_test.py index 1f849e1612e1c57be932b0d0c073f8e6ca74b47d..181afa58b307bb0e37d9cb047db83d7db0508de9 100644 --- a/scripts/test/SANS/gui_logic/summation_settings_presenter_test.py +++ b/scripts/test/SANS/gui_logic/summation_settings_presenter_test.py @@ -6,14 +6,13 @@ # SPDX - License - Identifier: GPL - 3.0 + import unittest +from assert_called import assert_called +from fake_signal import FakeSignal from mantid.py3compat import mock from sans.common.enums import BinningType -from sans.gui_logic.presenter.summation_settings_presenter import SummationSettingsPresenter from sans.gui_logic.models.summation_settings import SummationSettings +from sans.gui_logic.presenter.summation_settings_presenter import SummationSettingsPresenter from ui.sans_isis.summation_settings_widget import SummationSettingsWidget -from fake_signal import FakeSignal - -from assert_called import assert_called class SummationSettingsPresenterTest(unittest.TestCase): @@ -39,7 +38,7 @@ class SummationSettingsPresenterTest(unittest.TestCase): def test_sets_binning_type_when_changed(self): new_binning_type = 0 self.view.binningTypeChanged.emit(new_binning_type) - self.summation_settings.set_histogram_binning_type.assert_called_with(BinningType.Custom) + self.summation_settings.set_histogram_binning_type.assert_called_with(BinningType.CUSTOM) def test_retrieves_additional_time_shifts_when_changed(self): self.view.additionalTimeShiftsChanged.emit() diff --git a/scripts/test/SANS/gui_logic/table_model_test.py b/scripts/test/SANS/gui_logic/table_model_test.py index 6d1bfb7af700e6e4c0ee8acbf740b5b918cc182f..d0240662c8cddd6405317bb62f68c529146ce239 100644 --- a/scripts/test/SANS/gui_logic/table_model_test.py +++ b/scripts/test/SANS/gui_logic/table_model_test.py @@ -60,7 +60,7 @@ class TableModelTest(unittest.TestCase): sample_shape_enum = table_index_model.sample_shape sample_shape_text = table_index_model.sample_shape_string - self.assertEqual(sample_shape_enum, SampleShape.FlatPlate) + self.assertEqual(sample_shape_enum, SampleShape.FLAT_PLATE) self.assertEqual(sample_shape_text, "FlatPlate") def test_that_sample_shape_can_be_set_as_enum(self): @@ -68,21 +68,21 @@ class TableModelTest(unittest.TestCase): # So SampleShapeColumnModel must be able to parse this. table_index_model = TableIndexModel('0', "", "", "", "", "", "", "", "", "", "", "", "", "", "", - sample_shape=SampleShape.FlatPlate) + sample_shape=SampleShape.FLAT_PLATE) sample_shape_enum = table_index_model.sample_shape sample_shape_text = table_index_model.sample_shape_string - self.assertEqual(sample_shape_enum, SampleShape.FlatPlate) + self.assertEqual(sample_shape_enum, SampleShape.FLAT_PLATE) self.assertEqual(sample_shape_text, "FlatPlate") - def test_that_incorrect_sample_shape_reverts_to_previous_sampleshape(self): + def test_that_incorrect_sample_shape_turns_to_not_set(self): table_index_model = TableIndexModel('0', "", "", "", "", "", "", "", "", "", "", "", "", "", "", sample_shape="Disc") table_index_model.sample_shape = "not a sample shape" - self.assertEqual("Disc", table_index_model.sample_shape_string) + self.assertEqual("", table_index_model.sample_shape_string) - def test_that_empty_string_is_acceptable_sample_shape(self): + def test_that_empty_string_turns_to_not_set(self): table_index_model = TableIndexModel('0', "", "", "", "", "", "", "", "", "", "", "", "", "", "", sample_shape="Disc") @@ -91,8 +91,8 @@ class TableModelTest(unittest.TestCase): sample_shape_enum = table_index_model.sample_shape sample_shape_text = table_index_model.sample_shape_string - self.assertEqual(sample_shape_enum, "") - self.assertEqual(sample_shape_text, "") + self.assertEqual(SampleShape.NOT_SET, sample_shape_enum) + self.assertEqual('', sample_shape_text) def test_that_table_model_completes_partial_sample_shape(self): table_index_model = TableIndexModel('0', "", "", "", "", "", "", @@ -102,7 +102,7 @@ class TableModelTest(unittest.TestCase): sample_shape_enum = table_index_model.sample_shape sample_shape_text = table_index_model.sample_shape_string - self.assertEqual(sample_shape_enum, SampleShape.Cylinder) + self.assertEqual(sample_shape_enum, SampleShape.CYLINDER) self.assertEqual(sample_shape_text, "Cylinder") def test_that_querying_nonexistent_row_index_raises_IndexError_exception(self): @@ -311,7 +311,7 @@ class TableModelTest(unittest.TestCase): table_index_model = TableIndexModel(0, "", "", "", "", "", "", "", "", "", "", "", "") - self.assertEqual(table_index_model.row_state, RowState.Unprocessed) + self.assertEqual(table_index_model.row_state, RowState.UNPROCESSED) self.assertEqual(table_index_model.tool_tip, '') def test_that_set_processed_sets_state_to_processed(self): @@ -324,7 +324,7 @@ class TableModelTest(unittest.TestCase): table_model.set_row_to_processed(row, tool_tip) - self.assertEqual(table_index_model.row_state, RowState.Processed) + self.assertEqual(table_index_model.row_state, RowState.PROCESSED) self.assertEqual(table_index_model.tool_tip, tool_tip) def test_that_reset_row_state_sets_row_to_unproceesed_and_sets_tool_tip_to_empty(self): @@ -338,7 +338,7 @@ class TableModelTest(unittest.TestCase): table_model.reset_row_state(row) - self.assertEqual(table_index_model.row_state, RowState.Unprocessed) + self.assertEqual(table_index_model.row_state, RowState.UNPROCESSED) self.assertEqual(table_index_model.tool_tip, '') def test_that_set_row_to_error_sets_row_to_error_and_tool_tip(self): @@ -351,7 +351,7 @@ class TableModelTest(unittest.TestCase): table_model.set_row_to_error(row, tool_tip) - self.assertEqual(table_index_model.row_state, RowState.Error) + self.assertEqual(table_index_model.row_state, RowState.ERROR) self.assertEqual(table_index_model.tool_tip, tool_tip) def test_serialise_options_dict_correctly(self): diff --git a/scripts/test/SANS/state/adjustment_test.py b/scripts/test/SANS/state/adjustment_test.py index 5f9d1ff98bae5cb8cdacf166bbe70c411c21c15e..2423f9a4e18b6fdcf11568c2d28d71ee5c9a8d48 100644 --- a/scripts/test/SANS/state/adjustment_test.py +++ b/scripts/test/SANS/state/adjustment_test.py @@ -5,15 +5,15 @@ # & Institut Laue - Langevin # SPDX - License - Identifier: GPL - 3.0 + from __future__ import (absolute_import, division, print_function) + import unittest -import mantid +from sans.common.enums import (SANSFacility, SANSInstrument) from sans.state.adjustment import (StateAdjustment, get_adjustment_builder) -from sans.state.data import (get_data_builder) from sans.state.calculate_transmission import StateCalculateTransmission +from sans.state.data import (get_data_builder) from sans.state.normalize_to_monitor import StateNormalizeToMonitor from sans.state.wavelength_and_pixel_adjustment import StateWavelengthAndPixelAdjustment -from sans.common.enums import (SANSFacility, SANSInstrument, FitType) from sans.test_helper.file_information_mock import SANSFileInformationMock diff --git a/scripts/test/SANS/state/calculate_transmission_test.py b/scripts/test/SANS/state/calculate_transmission_test.py index 0c2fb38bdaf88ffe17a18ddefb8b7634efe9c397..e3a26dade8cd3b604bbc84b8cdd144b71bfe82ef 100644 --- a/scripts/test/SANS/state/calculate_transmission_test.py +++ b/scripts/test/SANS/state/calculate_transmission_test.py @@ -8,8 +8,6 @@ from __future__ import (absolute_import, division, print_function) import unittest -from state_test_helper import assert_validate_error, assert_raises_nothing - from sans.common.enums import (RebinType, RangeStepType, FitType, DataType, SANSFacility, SANSInstrument) from sans.state.calculate_transmission import (StateCalculateTransmission, StateCalculateTransmissionLOQ, get_calculate_transmission_builder) @@ -23,13 +21,12 @@ from sans.test_helper.file_information_mock import SANSFileInformationMock class StateCalculateTransmissionTest(unittest.TestCase): @staticmethod def _set_fit(state, default_settings, custom_settings, fit_key): - fit = state.fit[fit_key] for key, value in list(default_settings.items()): if key in custom_settings: value = custom_settings[key] + if value is not None: # If the value is None, then don't set it - setattr(fit, key, value) - state.fit[fit_key] = fit + setattr(state.fit[fit_key], key, value) @staticmethod def _get_calculate_transmission_state(trans_entries, fit_entries): @@ -40,8 +37,8 @@ class StateCalculateTransmissionTest(unittest.TestCase): "transmission_mask_files": ["test.xml"], "default_transmission_monitor": 3, "transmission_monitor": 4, "default_incident_monitor": 1, "incident_monitor": 2, "prompt_peak_correction_min": 123., "prompt_peak_correction_max": 1234., - "rebin_type": RebinType.Rebin, "wavelength_low": [1.], "wavelength_high": [2.7], - "wavelength_step": 0.5, "wavelength_step_type": RangeStepType.Lin, + "rebin_type": RebinType.REBIN, "wavelength_low": [1.], "wavelength_high": [2.7], + "wavelength_step": 0.5, "wavelength_step_type": RangeStepType.LIN, "use_full_wavelength_range": True, "wavelength_full_range_low": 12., "wavelength_full_range_high": 434., "background_TOF_general_start": 1.4, "background_TOF_general_stop": 24.5, "background_TOF_monitor_start": {"1": 123, "2": 123}, @@ -54,14 +51,12 @@ class StateCalculateTransmissionTest(unittest.TestCase): if value is not None: # If the value is None, then don't set it setattr(state, key, value) - fit_settings = {"fit_type": FitType.Polynomial, "polynomial_order": 1, "wavelength_low": 12., + fit_settings = {"fit_type": FitType.POLYNOMIAL, "polynomial_order": 1, "wavelength_low": 12., "wavelength_high": 232.} if fit_entries is None: fit_entries = {} - StateCalculateTransmissionTest._set_fit(state, fit_settings, fit_entries, - DataType.to_string(DataType.Sample)) - StateCalculateTransmissionTest._set_fit(state, fit_settings, fit_entries, - DataType.to_string(DataType.Can)) + StateCalculateTransmissionTest._set_fit(state, fit_settings, fit_entries, DataType.SAMPLE) + StateCalculateTransmissionTest._set_fit(state, fit_settings, fit_entries, DataType.CAN) return state @staticmethod @@ -74,11 +69,12 @@ class StateCalculateTransmissionTest(unittest.TestCase): def check_bad_and_good_values(self, bad_trans=None, bad_fit=None, good_trans=None, good_fit=None): # Bad values state = self._get_calculate_transmission_state(bad_trans, bad_fit) - assert_validate_error(self, ValueError, state) + with self.assertRaises(ValueError): + state.validate() # Good values state = self._get_calculate_transmission_state(good_trans, good_fit) - assert_raises_nothing(self, state) + self.assertIsNone(state.validate()) def test_that_is_sans_state_data_object(self): state = StateCalculateTransmissionLOQ() @@ -114,13 +110,13 @@ class StateCalculateTransmissionTest(unittest.TestCase): self.check_bad_and_good_values(bad_trans={"wavelength_low": [1.], "wavelength_high": [2.], "wavelength_step": 0.5, "wavelength_step_type": None}, good_trans={"wavelength_low": [1.], "wavelength_high": [2.], - "wavelength_step": 0.5, "wavelength_step_type": RangeStepType.Lin}) + "wavelength_step": 0.5, "wavelength_step_type": RangeStepType.LIN}) def test_that_raises_for_lower_bound_larger_than_upper_bound_for_wavelength(self): self.check_bad_and_good_values(bad_trans={"wavelength_low": [2.], "wavelength_high": [1.], - "wavelength_step": 0.5, "wavelength_step_type": RangeStepType.Lin}, + "wavelength_step": 0.5, "wavelength_step_type": RangeStepType.LIN}, good_trans={"wavelength_low": [1.], "wavelength_high": [2.], - "wavelength_step": 0.5, "wavelength_step_type": RangeStepType.Lin}) + "wavelength_step": 0.5, "wavelength_step_type": RangeStepType.LIN}) def test_that_raises_for_missing_full_wavelength_entry(self): self.check_bad_and_good_values(bad_trans={"use_full_wavelength_range": True, "wavelength_full_range_low": None, @@ -183,8 +179,8 @@ class StateCalculateTransmissionTest(unittest.TestCase): "background_TOF_monitor_stop": {"1": 2., "2": 2.}}) def test_that_polynomial_order_can_only_be_set_with_polynomial_setting(self): - self.check_bad_and_good_values(bad_fit={"fit_type": FitType.Polynomial, "polynomial_order": 0}, - good_fit={"fit_type": FitType.Polynomial, "polynomial_order": 4}) + self.check_bad_and_good_values(bad_fit={"fit_type": FitType.POLYNOMIAL, "polynomial_order": 0}, + good_fit={"fit_type": FitType.POLYNOMIAL, "polynomial_order": 4}) def test_that_raises_for_inconsistent_wavelength_in_fit(self): self.check_bad_and_good_values(bad_trans={"wavelength_low": None, "wavelength_high": [2.]}, @@ -222,11 +218,11 @@ class StateCalculateTransmissionBuilderTest(unittest.TestCase): builder.set_transmission_roi_files(["sdfs", "sddfsdf"]) builder.set_transmission_mask_files(["sdfs", "bbbbbb"]) - builder.set_rebin_type(RebinType.Rebin) + builder.set_rebin_type(RebinType.REBIN) builder.set_wavelength_low([1.5]) builder.set_wavelength_high([2.7]) builder.set_wavelength_step(0.5) - builder.set_wavelength_step_type(RangeStepType.Lin) + builder.set_wavelength_step_type(RangeStepType.LIN) builder.set_use_full_wavelength_range(True) builder.set_wavelength_full_range_low(12.) builder.set_wavelength_full_range_high(24.) @@ -238,15 +234,15 @@ class StateCalculateTransmissionBuilderTest(unittest.TestCase): builder.set_background_TOF_roi_start(1.4) builder.set_background_TOF_roi_stop(34.4) - builder.set_Sample_fit_type(FitType.Linear) - builder.set_Sample_polynomial_order(0) - builder.set_Sample_wavelength_low(10.0) - builder.set_Sample_wavelength_high(20.0) + builder.set_sample_fit_type(FitType.LINEAR) + builder.set_sample_polynomial_order(0) + builder.set_sample_wavelength_low(10.0) + builder.set_sample_wavelength_high(20.0) - builder.set_Can_fit_type(FitType.Polynomial) - builder.set_Can_polynomial_order(3) - builder.set_Can_wavelength_low(10.0) - builder.set_Can_wavelength_high(20.0) + builder.set_can_fit_type(FitType.POLYNOMIAL) + builder.set_can_polynomial_order(3) + builder.set_can_wavelength_low(10.0) + builder.set_can_wavelength_high(20.0) state = builder.build() @@ -262,11 +258,11 @@ class StateCalculateTransmissionBuilderTest(unittest.TestCase): self.assertEqual(state.transmission_roi_files, ["sdfs", "sddfsdf"]) self.assertEqual(state.transmission_mask_files, ["sdfs", "bbbbbb"]) - self.assertEqual(state.rebin_type, RebinType.Rebin) + self.assertEqual(state.rebin_type, RebinType.REBIN) self.assertEqual(state.wavelength_low, [1.5]) self.assertEqual(state.wavelength_high, [2.7]) self.assertEqual(state.wavelength_step, 0.5) - self.assertEqual(state.wavelength_step_type, RangeStepType.Lin) + self.assertEqual(state.wavelength_step_type, RangeStepType.LIN) self.assertEqual(state.use_full_wavelength_range, True) self.assertEqual(state.wavelength_full_range_low, 12.) self.assertEqual(state.wavelength_full_range_high, 24.) @@ -278,15 +274,15 @@ class StateCalculateTransmissionBuilderTest(unittest.TestCase): self.assertEqual(state.background_TOF_roi_start, 1.4) self.assertEqual(state.background_TOF_roi_stop, 34.4) - self.assertEqual(state.fit[DataType.to_string(DataType.Sample)].fit_type, FitType.Linear) - self.assertEqual(state.fit[DataType.to_string(DataType.Sample)].polynomial_order, 0) - self.assertEqual(state.fit[DataType.to_string(DataType.Sample)].wavelength_low, 10.) - self.assertEqual(state.fit[DataType.to_string(DataType.Sample)].wavelength_high, 20.) + self.assertEqual(state.fit[DataType.SAMPLE].fit_type, FitType.LINEAR) + self.assertEqual(state.fit[DataType.SAMPLE].polynomial_order, 0) + self.assertEqual(state.fit[DataType.SAMPLE].wavelength_low, 10.) + self.assertEqual(state.fit[DataType.SAMPLE].wavelength_high, 20.) - self.assertEqual(state.fit[DataType.to_string(DataType.Can)].fit_type, FitType.Polynomial) - self.assertEqual(state.fit[DataType.to_string(DataType.Can)].polynomial_order, 3) - self.assertEqual(state.fit[DataType.to_string(DataType.Can)].wavelength_low, 10.) - self.assertEqual(state.fit[DataType.to_string(DataType.Can)].wavelength_high, 20.) + self.assertEqual(state.fit[DataType.CAN].fit_type, FitType.POLYNOMIAL) + self.assertEqual(state.fit[DataType.CAN].polynomial_order, 3) + self.assertEqual(state.fit[DataType.CAN].wavelength_low, 10.) + self.assertEqual(state.fit[DataType.CAN].wavelength_high, 20.) if __name__ == '__main__': diff --git a/scripts/test/SANS/state/convert_to_q_test.py b/scripts/test/SANS/state/convert_to_q_test.py index 0d0d85a402ae14b53c48b07802d929a32ffc35f3..6b5d744849bfcef002c50b55462fc09f9c255a48 100644 --- a/scripts/test/SANS/state/convert_to_q_test.py +++ b/scripts/test/SANS/state/convert_to_q_test.py @@ -5,13 +5,12 @@ # & Institut Laue - Langevin # SPDX - License - Identifier: GPL - 3.0 + from __future__ import (absolute_import, division, print_function) + import unittest -import mantid +from sans.common.enums import (RangeStepType, ReductionDimensionality, SANSFacility, SANSInstrument) from sans.state.convert_to_q import (StateConvertToQ, get_convert_to_q_builder) from sans.state.data import get_data_builder -from sans.common.enums import (RangeStepType, ReductionDimensionality, SANSFacility, SANSInstrument) -from state_test_helper import (assert_validate_error, assert_raises_nothing) from sans.test_helper.file_information_mock import SANSFileInformationMock @@ -22,11 +21,11 @@ class StateConvertToQTest(unittest.TestCase): @staticmethod def _get_convert_to_q_state(convert_to_q_entries): state = StateConvertToQ() - default_entries = {"reduction_dimensionality": ReductionDimensionality.OneDim, "use_gravity": True, + default_entries = {"reduction_dimensionality": ReductionDimensionality.ONE_DIM, "use_gravity": True, "gravity_extra_length": 12., "radius_cutoff": 1.5, "wavelength_cutoff": 2.7, "q_min": 0.5, "q_max": 1., "q_1d_rebin_string": "0.5,0.2,1.", - "q_step2": 1., "q_step_type2": RangeStepType.Lin, "q_mid": 1., - "q_xy_max": 1.4, "q_xy_step": 24.5, "q_xy_step_type": RangeStepType.Lin, + "q_step2": 1., "q_step_type2": RangeStepType.LIN, "q_mid": 1., + "q_xy_max": 1.4, "q_xy_step": 24.5, "q_xy_step_type": RangeStepType.LIN, "use_q_resolution": True, "q_resolution_collimation_length": 12., "q_resolution_delta_r": 12., "moderator_file": "test.txt", "q_resolution_a1": 1., "q_resolution_a2": 2., "q_resolution_h1": 1., "q_resolution_h2": 2., "q_resolution_w1": 1., @@ -42,11 +41,12 @@ class StateConvertToQTest(unittest.TestCase): def check_bad_and_good_value(self, bad_convert_to_q, good_convert_to_q): # Bad values state = self._get_convert_to_q_state(bad_convert_to_q) - assert_validate_error(self, ValueError, state) + with self.assertRaises(ValueError): + state.validate() # Good values state = self._get_convert_to_q_state(good_convert_to_q) - assert_raises_nothing(self, state) + self.assertIsNone(state.validate()) def test_that_raises_with_inconsistent_1D_q_values(self): self.check_bad_and_good_value({"q_min": None, "q_max": 2.}, {"q_min": 1., "q_max": 2.}) @@ -56,9 +56,9 @@ class StateConvertToQTest(unittest.TestCase): def test_that_raises_when_no_q_bounds_are_set_for_explicit_1D_reduction(self): self.check_bad_and_good_value({"q_min": None, "q_max": None, - "reduction_dimensionality": ReductionDimensionality.OneDim}, + "reduction_dimensionality": ReductionDimensionality.ONE_DIM}, {"q_min": 1., "q_max": 2., - "reduction_dimensionality": ReductionDimensionality.OneDim}) + "reduction_dimensionality": ReductionDimensionality.ONE_DIM}) def test_that_raises_when_q_rebin_string_is_invalid(self): self.check_bad_and_good_value({"q_1d_rebin_string": ""}, {"q_1d_rebin_string": "1.0,2.0"}) @@ -68,9 +68,9 @@ class StateConvertToQTest(unittest.TestCase): def test_that_raises_when_no_q_bounds_are_set_for_explicit_2D_reduction(self): self.check_bad_and_good_value({"q_xy_max": None, "q_xy_step": None, - "reduction_dimensionality": ReductionDimensionality.TwoDim}, + "reduction_dimensionality": ReductionDimensionality.TWO_DIM}, {"q_xy_max": 1., "q_xy_step": 2., - "reduction_dimensionality": ReductionDimensionality.TwoDim}) + "reduction_dimensionality": ReductionDimensionality.TWO_DIM}) def test_that_raises_when_inconsistent_circular_values_for_q_resolution_are_specified(self): self.check_bad_and_good_value({"use_q_resolution": True, "q_resolution_a1": None, @@ -115,7 +115,7 @@ class StateConvertToQBuilderTest(unittest.TestCase): builder.set_q_min(12.0) builder.set_q_max(17.0) builder.set_q_1d_rebin_string("12.0,-1.2,17.0") - builder.set_reduction_dimensionality(ReductionDimensionality.OneDim) + builder.set_reduction_dimensionality(ReductionDimensionality.ONE_DIM) state = builder.build() @@ -123,7 +123,7 @@ class StateConvertToQBuilderTest(unittest.TestCase): self.assertEqual(state.q_min, 12.0) self.assertEqual(state.q_max, 17.0) self.assertEqual(state.q_1d_rebin_string, "12.0,-1.2,17.0") - self.assertEqual(state.reduction_dimensionality, ReductionDimensionality.OneDim) + self.assertEqual(state.reduction_dimensionality, ReductionDimensionality.ONE_DIM) if __name__ == '__main__': diff --git a/scripts/test/SANS/state/data_test.py b/scripts/test/SANS/state/data_test.py index 51ea9fda5fbe559ae8751ec9ff37b619264bbc12..449583d626365c6388be38f195afd7ccb75b84f1 100644 --- a/scripts/test/SANS/state/data_test.py +++ b/scripts/test/SANS/state/data_test.py @@ -8,8 +8,6 @@ from __future__ import (absolute_import, division, print_function) import unittest -from state_test_helper import (assert_validate_error, assert_raises_nothing) - from sans.common.enums import (SANSFacility, SANSInstrument) from sans.state.data import (StateData, get_data_builder) from sans.test_helper.file_information_mock import SANSFileInformationMock @@ -37,11 +35,12 @@ class StateDataTest(unittest.TestCase): data_entries_good): # Bad values state = StateDataTest._get_data_state(**data_entries_bad) - assert_validate_error(self, ValueError, state) + with self.assertRaises(ValueError): + state.validate() # Good values state = StateDataTest._get_data_state(**data_entries_good) - assert_raises_nothing(self, state) + self.assertIsNone(state.validate()) def test_that_raises_when_sample_scatter_is_missing(self): self.assert_raises_for_bad_value_and_raises_nothing_for_good_value({"sample_scatter": None}, diff --git a/scripts/test/SANS/state/mask_test.py b/scripts/test/SANS/state/mask_test.py index cf83a011225b5d2c083ab84626b8f53d518d6718..5df51fec9bbc58bb1d36919cf3f1f7507b148597 100644 --- a/scripts/test/SANS/state/mask_test.py +++ b/scripts/test/SANS/state/mask_test.py @@ -5,17 +5,14 @@ # & Institut Laue - Langevin # SPDX - License - Identifier: GPL - 3.0 + from __future__ import (absolute_import, division, print_function) + import unittest -import mantid -from sans.state.mask import (StateMaskSANS2D, get_mask_builder) +from sans.common.enums import (SANSFacility, DetectorType) from sans.state.data import get_data_builder -from sans.common.enums import (SANSFacility, SANSInstrument, DetectorType) -from state_test_helper import (assert_validate_error, assert_raises_nothing) +from sans.state.mask import (StateMaskSANS2D, get_mask_builder) from sans.test_helper.file_information_mock import SANSFileInformationMock - - # ---------------------------------------------------------------------------------------------------------------------- # State # ---------------------------------------------------------------------------------------------------------------------- @@ -64,9 +61,9 @@ class StateMaskTest(unittest.TestCase): "spectrum_range_start": [1, 5, 7], "spectrum_range_stop": [2, 6, 8]} StateMaskTest._set_detector(state, detector_settings, detector_entries, - DetectorType.to_string(DetectorType.LAB)) + DetectorType.LAB.value) StateMaskTest._set_detector(state, detector_settings, detector_entries, - DetectorType.to_string(DetectorType.HAB)) + DetectorType.HAB.value) return state @@ -87,13 +84,14 @@ class StateMaskTest(unittest.TestCase): bad_value_detector_dict = StateMaskTest._get_dict(entry_name, bad_value_detector) state = self._get_mask_state(bad_value_general_dict, bad_value_detector_dict) - assert_validate_error(self, ValueError, state) + with self.assertRaises(ValueError): + state.validate() # Good values good_value_general_dict = StateMaskTest._get_dict(entry_name, good_value_general) good_value_detector_dict = StateMaskTest._get_dict(entry_name, good_value_detector) state = self._get_mask_state(good_value_general_dict, good_value_detector_dict) - assert_raises_nothing(self, state) + self.assertIsNone(state.validate()) def test_that_raises_when_lower_radius_bound_larger_than_upper_bound(self): self.assert_raises_for_bad_value_and_raises_nothing_for_good_value("radius_min", 500., None, 12., None) @@ -226,7 +224,7 @@ class StateMaskBuilderTest(unittest.TestCase): self.assertEqual(state.bin_mask_general_stop[0], end_time[0]) self.assertEqual(state.bin_mask_general_stop[1], end_time[1]) - strip_mask = state.detectors[DetectorType.to_string(DetectorType.LAB)].single_vertical_strip_mask + strip_mask = state.detectors[DetectorType.LAB.value].single_vertical_strip_mask self.assertEqual(len(strip_mask), 3) self.assertEqual(strip_mask[2], 3) diff --git a/scripts/test/SANS/state/move_test.py b/scripts/test/SANS/state/move_test.py index 9dd86685a84bfe03b879d81fd75ecd336e3f3865..1b8e05176ed744a808197f91df11e45bc5321590 100644 --- a/scripts/test/SANS/state/move_test.py +++ b/scripts/test/SANS/state/move_test.py @@ -5,14 +5,13 @@ # & Institut Laue - Langevin # SPDX - License - Identifier: GPL - 3.0 + from __future__ import (absolute_import, division, print_function) + import unittest -import mantid +from sans.common.enums import (CanonicalCoordinates, SANSFacility, DetectorType, SANSInstrument) +from sans.state.data import get_data_builder from sans.state.move import (StateMoveLOQ, StateMoveSANS2D, StateMoveLARMOR, StateMoveZOOM, StateMove, StateMoveDetector, get_move_builder) -from sans.state.data import get_data_builder -from sans.common.enums import (CanonicalCoordinates, SANSFacility, DetectorType, SANSInstrument) -from state_test_helper import assert_validate_error, assert_raises_nothing from sans.test_helper.file_information_mock import SANSFileInformationMock @@ -22,43 +21,47 @@ from sans.test_helper.file_information_mock import SANSFileInformationMock class StateMoveWorkspaceTest(unittest.TestCase): def test_that_raises_if_the_detector_name_is_not_set_up(self): state = StateMove() - state.detectors = {DetectorType.to_string(DetectorType.LAB): StateMoveDetector(), - DetectorType.to_string(DetectorType.HAB): StateMoveDetector()} - state.detectors[DetectorType.to_string(DetectorType.LAB)].detector_name = "test" - state.detectors[DetectorType.to_string(DetectorType.HAB)].detector_name_short = "test" - state.detectors[DetectorType.to_string(DetectorType.LAB)].detector_name_short = "test" - assert_validate_error(self, ValueError, state) - state.detectors[DetectorType.to_string(DetectorType.HAB)].detector_name = "test" - assert_raises_nothing(self, state) + state.detectors = {DetectorType.LAB.value: StateMoveDetector(), + DetectorType.HAB.value: StateMoveDetector()} + state.detectors[DetectorType.LAB.value].detector_name = "test" + state.detectors[DetectorType.HAB.value].detector_name_short = "test" + state.detectors[DetectorType.LAB.value].detector_name_short = "test" + with self.assertRaises(ValueError): + state.validate() + + state.detectors[DetectorType.HAB.value].detector_name = "test" + self.assertIsNone(state.validate()) def test_that_raises_if_the_short_detector_name_is_not_set_up(self): state = StateMove() - state.detectors = {DetectorType.to_string(DetectorType.LAB): StateMoveDetector(), - DetectorType.to_string(DetectorType.HAB): StateMoveDetector()} - state.detectors[DetectorType.to_string(DetectorType.HAB)].detector_name = "test" - state.detectors[DetectorType.to_string(DetectorType.LAB)].detector_name = "test" - state.detectors[DetectorType.to_string(DetectorType.HAB)].detector_name_short = "test" - assert_validate_error(self, ValueError, state) - state.detectors[DetectorType.to_string(DetectorType.LAB)].detector_name_short = "test" - assert_raises_nothing(self, state) + state.detectors = {DetectorType.LAB.value: StateMoveDetector(), + DetectorType.HAB.value: StateMoveDetector()} + state.detectors[DetectorType.HAB.value].detector_name = "test" + state.detectors[DetectorType.LAB.value].detector_name = "test" + state.detectors[DetectorType.HAB.value].detector_name_short = "test" + with self.assertRaises(ValueError): + state.validate() + + state.detectors[DetectorType.LAB.value].detector_name_short = "test" + self.assertIsNone(state.validate()) def test_that_general_isis_default_values_are_set_up(self): state = StateMove() - state.detectors = {DetectorType.to_string(DetectorType.LAB): StateMoveDetector(), - DetectorType.to_string(DetectorType.HAB): StateMoveDetector()} + state.detectors = {DetectorType.LAB.value: StateMoveDetector(), + DetectorType.HAB.value: StateMoveDetector()} self.assertEqual(state.sample_offset, 0.0) self.assertEqual(state.sample_offset_direction, CanonicalCoordinates.Z) - self.assertEqual(state.detectors[DetectorType.to_string(DetectorType.HAB)].x_translation_correction, 0.0) - self.assertEqual(state.detectors[DetectorType.to_string(DetectorType.HAB)].y_translation_correction, 0.0) - self.assertEqual(state.detectors[DetectorType.to_string(DetectorType.HAB)].z_translation_correction, 0.0) - self.assertEqual(state.detectors[DetectorType.to_string(DetectorType.HAB)].rotation_correction, 0.0) - self.assertEqual(state.detectors[DetectorType.to_string(DetectorType.HAB)].side_correction, 0.0) - self.assertEqual(state.detectors[DetectorType.to_string(DetectorType.HAB)].radius_correction, 0.0) - self.assertEqual(state.detectors[DetectorType.to_string(DetectorType.HAB)].x_tilt_correction, 0.0) - self.assertEqual(state.detectors[DetectorType.to_string(DetectorType.HAB)].y_tilt_correction, 0.0) - self.assertEqual(state.detectors[DetectorType.to_string(DetectorType.HAB)].z_tilt_correction, 0.0) - self.assertEqual(state.detectors[DetectorType.to_string(DetectorType.HAB)].sample_centre_pos1, 0.0) - self.assertEqual(state.detectors[DetectorType.to_string(DetectorType.HAB)].sample_centre_pos2, 0.0) + self.assertEqual(state.detectors[DetectorType.HAB.value].x_translation_correction, 0.0) + self.assertEqual(state.detectors[DetectorType.HAB.value].y_translation_correction, 0.0) + self.assertEqual(state.detectors[DetectorType.HAB.value].z_translation_correction, 0.0) + self.assertEqual(state.detectors[DetectorType.HAB.value].rotation_correction, 0.0) + self.assertEqual(state.detectors[DetectorType.HAB.value].side_correction, 0.0) + self.assertEqual(state.detectors[DetectorType.HAB.value].radius_correction, 0.0) + self.assertEqual(state.detectors[DetectorType.HAB.value].x_tilt_correction, 0.0) + self.assertEqual(state.detectors[DetectorType.HAB.value].y_tilt_correction, 0.0) + self.assertEqual(state.detectors[DetectorType.HAB.value].z_tilt_correction, 0.0) + self.assertEqual(state.detectors[DetectorType.HAB.value].sample_centre_pos1, 0.0) + self.assertEqual(state.detectors[DetectorType.HAB.value].sample_centre_pos2, 0.0) class StateMoveWorkspaceLOQTest(unittest.TestCase): @@ -139,12 +142,12 @@ class StateMoveBuilderTest(unittest.TestCase): # Assert state = builder.build() self.assertEqual(state.center_position, value) - self.assertEqual(state.detectors[DetectorType.to_string(DetectorType.HAB)].x_translation_correction, value) - self.assertEqual(state.detectors[DetectorType.to_string(DetectorType.HAB)].detector_name_short, "HAB") - self.assertEqual(state.detectors[DetectorType.to_string(DetectorType.LAB)].detector_name, "main-detector-bank") + self.assertEqual(state.detectors[DetectorType.HAB.value].x_translation_correction, value) + self.assertEqual(state.detectors[DetectorType.HAB.value].detector_name_short, "HAB") + self.assertEqual(state.detectors[DetectorType.LAB.value].detector_name, "main-detector-bank") self.assertEqual(state.monitor_names[str(2)], "monitor2") self.assertEqual(len(state.monitor_names), 2) - self.assertEqual(state.detectors[DetectorType.to_string(DetectorType.LAB)].sample_centre_pos1, value) + self.assertEqual(state.detectors[DetectorType.LAB.value].sample_centre_pos1, value) def test_that_state_for_sans2d_can_be_built(self): # Arrange @@ -162,9 +165,9 @@ class StateMoveBuilderTest(unittest.TestCase): # Assert state = builder.build() - self.assertEqual(state.detectors[DetectorType.to_string(DetectorType.HAB)].x_translation_correction, value) - self.assertEqual(state.detectors[DetectorType.to_string(DetectorType.HAB)].detector_name_short, "front") - self.assertEqual(state.detectors[DetectorType.to_string(DetectorType.LAB)].detector_name, "rear-detector") + self.assertEqual(state.detectors[DetectorType.HAB.value].x_translation_correction, value) + self.assertEqual(state.detectors[DetectorType.HAB.value].detector_name_short, "front") + self.assertEqual(state.detectors[DetectorType.LAB.value].detector_name, "rear-detector") self.assertEqual(state.monitor_names[str(4)], "monitor4") self.assertEqual(len(state.monitor_names), 4) @@ -184,9 +187,9 @@ class StateMoveBuilderTest(unittest.TestCase): # Assert state = builder.build() - self.assertEqual(state.detectors[DetectorType.to_string(DetectorType.LAB)].x_translation_correction, value) - self.assertEqual(state.detectors[DetectorType.to_string(DetectorType.LAB)].detector_name, "DetectorBench") - self.assertTrue(DetectorType.to_string(DetectorType.HAB) not in state.detectors) + self.assertEqual(state.detectors[DetectorType.LAB.value].x_translation_correction, value) + self.assertEqual(state.detectors[DetectorType.LAB.value].detector_name, "DetectorBench") + self.assertTrue(DetectorType.HAB.value not in state.detectors) self.assertEqual(state.monitor_names[str(5)], "monitor5") self.assertEqual(len(state.monitor_names), 5) diff --git a/scripts/test/SANS/state/normalize_to_monitor_test.py b/scripts/test/SANS/state/normalize_to_monitor_test.py index c3480a89bc0e8d55f2d0ac710e803c272186e4e0..cf2fb86e9b09f1fd6b2688e23993ad757b540e26 100644 --- a/scripts/test/SANS/state/normalize_to_monitor_test.py +++ b/scripts/test/SANS/state/normalize_to_monitor_test.py @@ -5,14 +5,13 @@ # & Institut Laue - Langevin # SPDX - License - Identifier: GPL - 3.0 + from __future__ import (absolute_import, division, print_function) + import unittest -import mantid +from sans.common.enums import (RebinType, RangeStepType, SANSFacility, SANSInstrument) +from sans.state.data import get_data_builder from sans.state.normalize_to_monitor import (StateNormalizeToMonitor, StateNormalizeToMonitorLOQ, get_normalize_to_monitor_builder) -from sans.state.data import get_data_builder -from sans.common.enums import (RebinType, RangeStepType, SANSFacility, SANSInstrument) -from state_test_helper import assert_validate_error, assert_raises_nothing from sans.test_helper.file_information_mock import SANSFileInformationMock @@ -24,8 +23,8 @@ class StateNormalizeToMonitorTest(unittest.TestCase): def _get_normalize_to_monitor_state(**kwargs): state = StateNormalizeToMonitor() default_entries = {"prompt_peak_correction_min": 12., "prompt_peak_correction_max": 17., - "rebin_type": RebinType.Rebin, "wavelength_low": [1.5], "wavelength_high": [2.7], - "wavelength_step": 0.5, "incident_monitor": 1, "wavelength_step_type": RangeStepType.Lin, + "rebin_type": RebinType.REBIN, "wavelength_low": [1.5], "wavelength_high": [2.7], + "wavelength_step": 0.5, "incident_monitor": 1, "wavelength_step_type": RangeStepType.LIN, "background_TOF_general_start": 1.4, "background_TOF_general_stop": 24.5, "background_TOF_monitor_start": {"1": 123, "2": 123}, "background_TOF_monitor_stop": {"1": 234, "2": 2323}} @@ -40,9 +39,10 @@ class StateNormalizeToMonitorTest(unittest.TestCase): def assert_raises_for_bad_value_and_raises_nothing_for_good_value(self, entry_name, bad_value, good_value): kwargs = {entry_name: bad_value} state = self._get_normalize_to_monitor_state(**kwargs) - assert_validate_error(self, ValueError, state) + with self.assertRaises(ValueError): + state.validate() setattr(state, entry_name, good_value) - assert_raises_nothing(self, state) + self.assertIsNone(state.validate()) def test_that_is_sans_state_normalize_to_monitor_object(self): state = StateNormalizeToMonitorLOQ() @@ -105,11 +105,11 @@ class StateReductionBuilderTest(unittest.TestCase): builder.set_prompt_peak_correction_min(12.0) builder.set_prompt_peak_correction_max(17.0) - builder.set_rebin_type(RebinType.Rebin) + builder.set_rebin_type(RebinType.REBIN) builder.set_wavelength_low([1.5]) builder.set_wavelength_high([2.7]) builder.set_wavelength_step(0.5) - builder.set_wavelength_step_type(RangeStepType.Lin) + builder.set_wavelength_step_type(RangeStepType.LIN) builder.set_incident_monitor(1) builder.set_background_TOF_general_start(1.4) builder.set_background_TOF_general_stop(34.4) @@ -121,11 +121,11 @@ class StateReductionBuilderTest(unittest.TestCase): # Assert self.assertEqual(state.prompt_peak_correction_min, 12.0) self.assertEqual(state.prompt_peak_correction_max, 17.0) - self.assertEqual(state.rebin_type, RebinType.Rebin) + self.assertEqual(state.rebin_type, RebinType.REBIN) self.assertEqual(state.wavelength_low, [1.5]) self.assertEqual(state.wavelength_high, [2.7]) self.assertEqual(state.wavelength_step, 0.5) - self.assertEqual(state.wavelength_step_type, RangeStepType.Lin) + self.assertEqual(state.wavelength_step_type, RangeStepType.LIN) self.assertEqual(state.background_TOF_general_start, 1.4) self.assertEqual(state.background_TOF_general_stop, 34.4) self.assertEqual(len(set(state.background_TOF_monitor_start.items()) & set({"1": 123, "2": 123}.items())), 2) diff --git a/scripts/test/SANS/state/reduction_mode_test.py b/scripts/test/SANS/state/reduction_mode_test.py index bee457a42bd96f21a1f1181dcea29c6efbde0390..53e2e2fdadb6ddc108845036ffb494237eebea42 100644 --- a/scripts/test/SANS/state/reduction_mode_test.py +++ b/scripts/test/SANS/state/reduction_mode_test.py @@ -5,13 +5,13 @@ # & Institut Laue - Langevin # SPDX - License - Identifier: GPL - 3.0 + from __future__ import (absolute_import, division, print_function) + import unittest -import mantid -from sans.state.reduction_mode import (StateReductionMode, get_reduction_mode_builder) -from sans.state.data import get_data_builder -from sans.common.enums import (ISISReductionMode, ReductionDimensionality, FitModeForMerge, +from sans.common.enums import (ReductionMode, ReductionDimensionality, FitModeForMerge, SANSFacility, SANSInstrument, DetectorType) +from sans.state.data import get_data_builder +from sans.state.reduction_mode import (StateReductionMode, get_reduction_mode_builder) from sans.test_helper.file_information_mock import SANSFileInformationMock @@ -23,33 +23,32 @@ class StateReductionModeTest(unittest.TestCase): # Arrange state = StateReductionMode() - state.reduction_mode = ISISReductionMode.Merged - state.dimensionality = ReductionDimensionality.TwoDim + state.reduction_mode = ReductionMode.MERGED + state.dimensionality = ReductionDimensionality.TWO_DIM state.merge_shift = 12.65 state.merge_scale = 34.6 - state.merge_fit_mode = FitModeForMerge.ShiftOnly + state.merge_fit_mode = FitModeForMerge.SHIFT_ONLY - state.detector_names[DetectorType.to_string(DetectorType.LAB)] = "Test1" - state.detector_names[DetectorType.to_string(DetectorType.HAB)] = "Test2" + state.detector_names[DetectorType.LAB.value] = "Test1" + state.detector_names[DetectorType.HAB.value] = "Test2" state.merge_mask = True state.merge_min = 78.89 state.merge_max = 56.4 - # Assert merge_strategy = state.get_merge_strategy() - self.assertEqual(merge_strategy[0], ISISReductionMode.LAB) - self.assertEqual(merge_strategy[1], ISISReductionMode.HAB) + self.assertEqual(merge_strategy[0], ReductionMode.LAB) + self.assertEqual(merge_strategy[1], ReductionMode.HAB) all_reductions = state.get_all_reduction_modes() self.assertEqual(len(all_reductions), 2) - self.assertEqual(all_reductions[0], ISISReductionMode.LAB) - self.assertEqual(all_reductions[1], ISISReductionMode.HAB) + self.assertEqual(all_reductions[0], ReductionMode.LAB) + self.assertEqual(all_reductions[1], ReductionMode.HAB) - result_lab = state.get_detector_name_for_reduction_mode(ISISReductionMode.LAB) + result_lab = state.get_detector_name_for_reduction_mode(ReductionMode.LAB) self.assertEqual(result_lab, "Test1") - result_hab = state.get_detector_name_for_reduction_mode(ISISReductionMode.HAB) + result_hab = state.get_detector_name_for_reduction_mode(ReductionMode.HAB) self.assertEqual(result_hab, "Test2") self.assertRaises(RuntimeError, state.get_detector_name_for_reduction_mode, "non_sense") @@ -71,14 +70,14 @@ class StateReductionModeBuilderTest(unittest.TestCase): builder = get_reduction_mode_builder(data_info) self.assertTrue(builder) - mode = ISISReductionMode.Merged - dim = ReductionDimensionality.OneDim + mode = ReductionMode.MERGED + dim = ReductionDimensionality.ONE_DIM builder.set_reduction_mode(mode) builder.set_reduction_dimensionality(dim) merge_shift = 324.2 merge_scale = 3420.98 - fit_mode = FitModeForMerge.Both + fit_mode = FitModeForMerge.BOTH builder.set_merge_fit_mode(fit_mode) builder.set_merge_shift(merge_shift) builder.set_merge_scale(merge_scale) @@ -99,7 +98,7 @@ class StateReductionModeBuilderTest(unittest.TestCase): self.assertEqual(state.merge_shift, merge_shift) self.assertEqual(state.merge_scale, merge_scale) detector_names = state.detector_names - self.assertEqual(detector_names[DetectorType.to_string(DetectorType.LAB)], "main-detector-bank") + self.assertEqual(detector_names[DetectorType.LAB.value], "main-detector-bank") self.assertTrue(state.merge_mask) self.assertEqual(state.merge_min, merge_min) self.assertEqual(state.merge_max, merge_max) diff --git a/scripts/test/SANS/state/save_test.py b/scripts/test/SANS/state/save_test.py index 513489119323cd24304c366da8d97e3ccb7f5c25..d2d5c639daa6dfd77c0b3c1388a322ed45fb85f7 100644 --- a/scripts/test/SANS/state/save_test.py +++ b/scripts/test/SANS/state/save_test.py @@ -5,12 +5,12 @@ # & Institut Laue - Langevin # SPDX - License - Identifier: GPL - 3.0 + from __future__ import (absolute_import, division, print_function) + import unittest -import mantid -from sans.state.save import (get_save_builder) -from sans.state.data import (get_data_builder) from sans.common.enums import (SANSFacility, SaveType, SANSInstrument) +from sans.state.data import (get_data_builder) +from sans.state.save import (get_save_builder) from sans.test_helper.file_information_mock import SANSFileInformationMock @@ -38,7 +38,7 @@ class StateReductionBuilderTest(unittest.TestCase): user_specified_output_name = "test_file_name" zero_free_correction = True - file_format = [SaveType.Nexus, SaveType.CanSAS] + file_format = [SaveType.NEXUS, SaveType.CAN_SAS] builder.set_user_specified_output_name(user_specified_output_name) builder.set_zero_free_correction(zero_free_correction) diff --git a/scripts/test/SANS/state/scale_test.py b/scripts/test/SANS/state/scale_test.py index 28213eddbc1c7f79b74b3a48f113d2886c02063c..f27e9ac55ea89257dd8a8dbc969df15ccc18eae3 100644 --- a/scripts/test/SANS/state/scale_test.py +++ b/scripts/test/SANS/state/scale_test.py @@ -5,16 +5,15 @@ # & Institut Laue - Langevin # SPDX - License - Identifier: GPL - 3.0 + from __future__ import (absolute_import, division, print_function) + import unittest -import mantid -from sans.state.scale import get_scale_builder +from sans.common.enums import (SANSFacility, SampleShape) from sans.state.data import get_data_builder -from sans.common.enums import (SANSFacility, SANSInstrument, SampleShape) +from sans.state.scale import get_scale_builder from sans.test_helper.file_information_mock import SANSFileInformationMock - # ---------------------------------------------------------------------------------------------------------------------- # State # No tests required for the current states @@ -38,14 +37,14 @@ class StateSliceEventBuilderTest(unittest.TestCase): self.assertTrue(builder) builder.set_scale(1.0) - builder.set_shape(SampleShape.FlatPlate) + builder.set_shape(SampleShape.FLAT_PLATE) builder.set_thickness(3.6) builder.set_width(3.7) builder.set_height(5.8) # Assert state = builder.build() - self.assertEqual(state.shape, SampleShape.FlatPlate) + self.assertEqual(state.shape, SampleShape.FLAT_PLATE) self.assertEqual(state.scale, 1.0) self.assertEqual(state.thickness, 3.6) self.assertEqual(state.width, 3.7) diff --git a/scripts/test/SANS/state/slice_event_test.py b/scripts/test/SANS/state/slice_event_test.py index 5a9b8e5db8f68aad7ad4331ed27fd26345982c67..d6ba7e7e3deb3b3c80b1ab751f5ffeb05dcdf876 100644 --- a/scripts/test/SANS/state/slice_event_test.py +++ b/scripts/test/SANS/state/slice_event_test.py @@ -5,13 +5,12 @@ # & Institut Laue - Langevin # SPDX - License - Identifier: GPL - 3.0 + from __future__ import (absolute_import, division, print_function) + import unittest -import mantid -from sans.state.slice_event import (StateSliceEvent, get_slice_event_builder) -from sans.state.data import get_data_builder from sans.common.enums import (SANSFacility, SANSInstrument) -from state_test_helper import (assert_validate_error) +from sans.state.data import get_data_builder +from sans.state.slice_event import (StateSliceEvent, get_slice_event_builder) from sans.test_helper.file_information_mock import SANSFileInformationMock @@ -22,7 +21,8 @@ class StateSliceEventTest(unittest.TestCase): def test_that_raises_when_only_one_time_is_set(self): state = StateSliceEvent() state.start_time = [1.0, 2.0] - assert_validate_error(self, ValueError, state) + with self.assertRaises(ValueError): + state.validate() state.end_time = [2.0, 3.0] def test_validate_method_raises_value_error_for_mismatching_start_and_end_time_length(self): diff --git a/scripts/test/SANS/state/state_base_test.py b/scripts/test/SANS/state/state_base_test.py index a806cab32373f3de3357810f625ebe5a4793bc6b..bddcbcff2fcc1fd9201c22210cdf5f6b4d1f8bc3 100644 --- a/scripts/test/SANS/state/state_base_test.py +++ b/scripts/test/SANS/state/state_base_test.py @@ -7,23 +7,16 @@ from __future__ import (absolute_import, division, print_function) import unittest -import mantid -from mantid.kernel import (PropertyManagerProperty, PropertyManager) from mantid.api import Algorithm - +from mantid.kernel import (PropertyManagerProperty, PropertyManager) +from mantid.py3compat import Enum from sans.state.state_base import (StringParameter, BoolParameter, FloatParameter, PositiveFloatParameter, - PositiveIntegerParameter, DictParameter, ClassTypeParameter, - FloatWithNoneParameter, PositiveFloatWithNoneParameter, FloatListParameter, - StringListParameter, PositiveIntegerListParameter, ClassTypeListParameter, - StateBase, rename_descriptor_names, TypedParameter, validator_sub_state, + PositiveIntegerParameter, DictParameter, FloatWithNoneParameter, + PositiveFloatWithNoneParameter, FloatListParameter, + StringListParameter, PositiveIntegerListParameter, StateBase, + rename_descriptor_names, TypedParameter, validator_sub_state, create_deserialized_sans_state_from_property_manager) -from sans.common.enums import serializable_enum - - -@serializable_enum("TypeA", "TypeB") -class TestType(object): - pass # ---------------------------------------------------------------------------------------------------------------------- @@ -44,8 +37,6 @@ class StateBaseTestClass(StateBase): float_list_parameter = FloatListParameter() string_list_parameter = StringListParameter() positive_integer_list_parameter = PositiveIntegerListParameter() - class_type_parameter = ClassTypeParameter(TestType) - class_type_list_parameter = ClassTypeListParameter(TestType) def __init__(self): super(StateBaseTestClass, self).__init__() @@ -54,6 +45,20 @@ class StateBaseTestClass(StateBase): pass +class FakeEnumClass(Enum): + FOO = 1 + BAR = "2" + + +class ExampleWrapper(StateBase): + # This has to be at the top module level, else the module name finding will fail + _foo = FakeEnumClass.FOO + bar = FakeEnumClass.BAR + + def validate(self): + return True + + class TypedParameterTest(unittest.TestCase): def _check_that_raises(self, error_type, obj, descriptor_name, value): try: @@ -81,8 +86,6 @@ class TypedParameterTest(unittest.TestCase): test_class.float_list_parameter = [12., -123., 2355.] test_class.string_list_parameter = ["test", "test"] test_class.positive_integer_list_parameter = [1, 2, 4] - test_class.class_type_parameter = TestType.TypeA - test_class.class_type_list_parameter = [TestType.TypeA, TestType.TypeB] except ValueError: self.fail() @@ -100,8 +103,6 @@ class TypedParameterTest(unittest.TestCase): self._check_that_raises(TypeError, test_class, "float_list_parameter", [1.23, "test"]) self._check_that_raises(TypeError, test_class, "string_list_parameter", ["test", "test", 123.]) self._check_that_raises(TypeError, test_class, "positive_integer_list_parameter", [1, "test"]) - self._check_that_raises(TypeError, test_class, "class_type_parameter", "test") - self._check_that_raises(TypeError, test_class, "class_type_list_parameter", ["test", TestType.TypeA]) def test_that_will_raise_if_set_with_wrong_value(self): # Note that this check does not apply to all parameter, it checks the validator @@ -180,9 +181,6 @@ class SimpleState(StateBase): float_list_parameter = FloatListParameter() string_list_parameter = StringListParameter() positive_integer_list_parameter = PositiveIntegerListParameter() - class_type_parameter = ClassTypeParameter(TestType) - class_type_list_parameter = ClassTypeListParameter(TestType) - sub_state_very_simple = TypedParameter(VerySimpleState, validator_sub_state) def __init__(self): @@ -198,8 +196,6 @@ class SimpleState(StateBase): self.float_list_parameter = [123., 234.] self.string_list_parameter = ["test1", "test2"] self.positive_integer_list_parameter = [1, 2, 3] - self.class_type_parameter = TestType.TypeA - self.class_type_list_parameter = [TestType.TypeA, TestType.TypeB] self.sub_state_very_simple = VerySimpleState() def validate(self): @@ -225,6 +221,13 @@ class ComplexState(StateBase): class TestStateBase(unittest.TestCase): + class FakeAlgorithm(Algorithm): + def PyInit(self): + self.declareProperty(PropertyManagerProperty("Args")) + + def PyExec(self): + pass + def _assert_simple_state(self, state): self.assertEqual(state.string_parameter, "String_in_SimpleState") self.assertFalse(state.bool_parameter) @@ -249,27 +252,54 @@ class TestStateBase(unittest.TestCase): self.assertEqual(state.positive_integer_list_parameter[1], 2) self.assertEqual(state.positive_integer_list_parameter[2], 3) - self.assertEqual(state.class_type_parameter, TestType.TypeA) - self.assertEqual(len(state.class_type_list_parameter), 2) - self.assertEqual(state.class_type_list_parameter[0], TestType.TypeA) - self.assertEqual(state.class_type_list_parameter[1], TestType.TypeB) - self.assertEqual(state.sub_state_very_simple.string_parameter, "test_in_very_simple") - - def test_that_sans_state_can_be_serialized_and_deserialized_when_going_through_an_algorithm(self): - class FakeAlgorithm(Algorithm): - def PyInit(self): - self.declareProperty(PropertyManagerProperty("Args")) - def PyExec(self): - pass + def test_that_enum_can_be_serialized(self): + original_obj = ExampleWrapper() + + # Serializing test + serialized = original_obj.property_manager + self.assertTrue("bar" in serialized) + self.assertFalse("_foo" in serialized) + self.assertTrue(isinstance(serialized["bar"], str), "The type was not converted to a string") + + # Deserializing Test + fake = TestStateBase.FakeAlgorithm() + fake.initialize() + fake.setProperty("Args", serialized) + property_manager = fake.getProperty("Args").value + + new_obj = create_deserialized_sans_state_from_property_manager(property_manager) + self.assertEqual(FakeEnumClass.BAR, new_obj.bar) + self.assertEqual(FakeEnumClass.FOO, new_obj._foo) + + def test_that_enum_list_can_be_serialized(self): + original_obj = ExampleWrapper() + original_obj.bar = [FakeEnumClass.BAR, FakeEnumClass.BAR] + # Serializing test + serialized = original_obj.property_manager + self.assertTrue("bar" in serialized) + self.assertFalse("_foo" in serialized) + self.assertTrue(isinstance(serialized["bar"], list), "The type was not converted to a list of strings") + + # Deserializing Test + fake = TestStateBase.FakeAlgorithm() + fake.initialize() + fake.setProperty("Args", serialized) + property_manager = fake.getProperty("Args").value + + new_obj = create_deserialized_sans_state_from_property_manager(property_manager) + self.assertEqual(original_obj.bar, new_obj.bar) + self.assertEqual(original_obj._foo, new_obj._foo) + + def test_that_sans_state_can_be_serialized_and_deserialized_when_going_through_an_algorithm(self): # Arrange state = ComplexState() # Act serialized = state.property_manager - fake = FakeAlgorithm() + fake = TestStateBase.FakeAlgorithm() fake.initialize() fake.setProperty("Args", serialized) property_manager = fake.getProperty("Args").value diff --git a/scripts/test/SANS/state/state_functions_test.py b/scripts/test/SANS/state/state_functions_test.py index 6be7f1fff857ce6d82e767fa626a7b3e38efaa3d..68cf0deeb5c6624ff5d7c117ddc5b81acc1053e8 100644 --- a/scripts/test/SANS/state/state_functions_test.py +++ b/scripts/test/SANS/state/state_functions_test.py @@ -5,14 +5,14 @@ # & Institut Laue - Langevin # SPDX - License - Identifier: GPL - 3.0 + from __future__ import (absolute_import, division, print_function) + import unittest -import mantid +from sans.common.enums import (ReductionDimensionality) +from sans.state.data import StateData from sans.state.state_functions import (is_pure_none_or_not_none, one_is_none, validation_message, is_not_none_and_first_larger_than_second) from sans.test_helper.test_director import TestDirector -from sans.state.data import StateData -from sans.common.enums import (ReductionDimensionality, ISISReductionMode, OutputParts) class StateFunctionsTest(unittest.TestCase): @@ -24,7 +24,7 @@ class StateFunctionsTest(unittest.TestCase): state.data.sample_scatter_run_number = 12345 state.data.sample_scatter_period = StateData.ALL_PERIODS - state.reduction.dimensionality = ReductionDimensionality.OneDim + state.reduction.dimensionality = ReductionDimensionality.ONE_DIM state.wavelength.wavelength_low = 12.0 state.wavelength.wavelength_high = 34.0 diff --git a/scripts/test/SANS/state/state_test.py b/scripts/test/SANS/state/state_test.py index ca5c6fde3308990403d715ed3ea3df665c58110d..42104221171a24658ee921281745204ff79f31e9 100644 --- a/scripts/test/SANS/state/state_test.py +++ b/scripts/test/SANS/state/state_test.py @@ -5,26 +5,24 @@ # & Institut Laue - Langevin # SPDX - License - Identifier: GPL - 3.0 + from __future__ import (absolute_import, division, print_function) + import unittest -import mantid -from sans.state.state import (State) +from sans.common.enums import (SANSInstrument, SANSFacility) +from sans.state.adjustment import (StateAdjustment) +from sans.state.calculate_transmission import (StateCalculateTransmission) +from sans.state.convert_to_q import (StateConvertToQ) from sans.state.data import (StateData) +from sans.state.mask import (StateMask) from sans.state.move import (StateMove) +from sans.state.normalize_to_monitor import (StateNormalizeToMonitor) from sans.state.reduction_mode import (StateReductionMode) -from sans.state.slice_event import (StateSliceEvent) -from sans.state.mask import (StateMask) -from sans.state.wavelength import (StateWavelength) from sans.state.save import (StateSave) -from sans.state.normalize_to_monitor import (StateNormalizeToMonitor) from sans.state.scale import (StateScale) -from sans.state.calculate_transmission import (StateCalculateTransmission) +from sans.state.slice_event import (StateSliceEvent) +from sans.state.state import (State) +from sans.state.wavelength import (StateWavelength) from sans.state.wavelength_and_pixel_adjustment import (StateWavelengthAndPixelAdjustment) -from sans.state.adjustment import (StateAdjustment) -from sans.state.convert_to_q import (StateConvertToQ) - -from state_test_helper import assert_validate_error, assert_raises_nothing -from sans.common.enums import (SANSInstrument, SANSFacility) # ---------------------------------------------------------------------------------------------------------------------- @@ -125,11 +123,12 @@ class StateTest(unittest.TestCase): def check_bad_and_good_values(self, bad_state, good_state): # Bad values state = self._get_state(bad_state) - assert_validate_error(self, ValueError, state) + with self.assertRaises(ValueError): + state.validate() # Good values state = self._get_state(good_state) - assert_raises_nothing(self, state) + self.assertIsNone(state.validate()) def test_that_raises_when_move_has_not_been_set(self): self.check_bad_and_good_values({"move": None}, {"move": MockStateMove()}) diff --git a/scripts/test/SANS/state/state_test_helper.py b/scripts/test/SANS/state/state_test_helper.py deleted file mode 100644 index ead505e3f34aa4cf0680a5a90244c9754559698a..0000000000000000000000000000000000000000 --- a/scripts/test/SANS/state/state_test_helper.py +++ /dev/null @@ -1,25 +0,0 @@ -# Mantid Repository : https://github.com/mantidproject/mantid -# -# Copyright © 2018 ISIS Rutherford Appleton Laboratory UKRI, -# NScD Oak Ridge National Laboratory, European Spallation Source -# & Institut Laue - Langevin -# SPDX - License - Identifier: GPL - 3.0 + -from __future__ import (absolute_import, division, print_function) - - -def assert_validate_error(caller, error_type, obj): - try: - obj.validate() - raised_correct = False - except error_type: - raised_correct = True - except: # noqa - raised_correct = False - caller.assertTrue(raised_correct) - - -def assert_raises_nothing(caller, obj): - try: - obj.validate() - except: # noqa - caller.fail() diff --git a/scripts/test/SANS/state/wavelength_and_pixel_adjustment_test.py b/scripts/test/SANS/state/wavelength_and_pixel_adjustment_test.py index ab24a379ac5fd98206a4cdda5b43a6606a615ad9..dcb5218a8a150c661ce52dbf322b533b9095aaab 100644 --- a/scripts/test/SANS/state/wavelength_and_pixel_adjustment_test.py +++ b/scripts/test/SANS/state/wavelength_and_pixel_adjustment_test.py @@ -5,14 +5,13 @@ # & Institut Laue - Langevin # SPDX - License - Identifier: GPL - 3.0 + from __future__ import (absolute_import, division, print_function) + import unittest -import mantid +from sans.common.enums import (RangeStepType, DetectorType, SANSFacility, SANSInstrument) +from sans.state.data import get_data_builder from sans.state.wavelength_and_pixel_adjustment import (StateWavelengthAndPixelAdjustment, get_wavelength_and_pixel_adjustment_builder) -from sans.state.data import get_data_builder -from sans.common.enums import (RebinType, RangeStepType, DetectorType, SANSFacility, SANSInstrument) -from state_test_helper import assert_validate_error, assert_raises_nothing from sans.test_helper.file_information_mock import SANSFileInformationMock @@ -23,23 +22,32 @@ class StateWavelengthAndPixelAdjustmentTest(unittest.TestCase): def test_that_raises_when_wavelength_entry_is_missing(self): # Arrange state = StateWavelengthAndPixelAdjustment() - assert_validate_error(self, ValueError, state) + with self.assertRaises(ValueError): + state.validate() + state.wavelength_low = [1.] - assert_validate_error(self, ValueError, state) + with self.assertRaises(ValueError): + state.validate() + state.wavelength_high = [2.] - assert_validate_error(self, ValueError, state) + with self.assertRaises(ValueError): + state.validate() + state.wavelength_step = 2. - assert_validate_error(self, ValueError, state) - state.wavelength_step_type = RangeStepType.Lin - assert_raises_nothing(self, state) + with self.assertRaises(ValueError): + state.validate() + + state.wavelength_step_type = RangeStepType.LIN + self.assertIsNone(state.validate()) def test_that_raises_when_lower_wavelength_is_smaller_than_high_wavelength(self): state = StateWavelengthAndPixelAdjustment() state.wavelength_low = [2.] state.wavelength_high = [1.] state.wavelength_step = 2. - state.wavelength_step_type = RangeStepType.Lin - assert_validate_error(self, ValueError, state) + state.wavelength_step_type = RangeStepType.LIN + with self.assertRaises(ValueError): + state.validate() # ---------------------------------------------------------------------------------------------------------------------- @@ -63,19 +71,17 @@ class StateWavelengthAndPixelAdjustmentBuilderTest(unittest.TestCase): builder.set_wavelength_low([1.5]) builder.set_wavelength_high([2.7]) builder.set_wavelength_step(0.5) - builder.set_wavelength_step_type(RangeStepType.Lin) + builder.set_wavelength_step_type(RangeStepType.LIN) state = builder.build() # Assert - self.assertTrue(state.adjustment_files[DetectorType.to_string( - DetectorType.HAB)].pixel_adjustment_file == "test") - self.assertTrue(state.adjustment_files[DetectorType.to_string( - DetectorType.HAB)].wavelength_adjustment_file == "test2") + self.assertTrue(state.adjustment_files[DetectorType.HAB.value].pixel_adjustment_file == "test") + self.assertTrue(state.adjustment_files[DetectorType.HAB.value].wavelength_adjustment_file == "test2") self.assertEqual(state.wavelength_low, [1.5]) self.assertEqual(state.wavelength_high, [2.7]) self.assertEqual(state.wavelength_step, 0.5) - self.assertEqual(state.wavelength_step_type, RangeStepType.Lin) + self.assertEqual(state.wavelength_step_type, RangeStepType.LIN) if __name__ == '__main__': diff --git a/scripts/test/SANS/state/wavelength_test.py b/scripts/test/SANS/state/wavelength_test.py index 03c0e71ed6b03422f4f4b409779c4b8b2ef59bd7..8fe4210a56385616da84c3722a85a6218d93f73e 100644 --- a/scripts/test/SANS/state/wavelength_test.py +++ b/scripts/test/SANS/state/wavelength_test.py @@ -5,13 +5,12 @@ # & Institut Laue - Langevin # SPDX - License - Identifier: GPL - 3.0 + from __future__ import (absolute_import, division, print_function) + import unittest -import mantid -from sans.state.wavelength import (StateWavelength, get_wavelength_builder) -from sans.state.data import get_data_builder from sans.common.enums import (SANSFacility, SANSInstrument, RebinType, RangeStepType) -from state_test_helper import assert_validate_error, assert_raises_nothing +from sans.state.data import get_data_builder +from sans.state.wavelength import (StateWavelength, get_wavelength_builder) from sans.test_helper.file_information_mock import SANSFileInformationMock @@ -27,20 +26,27 @@ class StateWavelengthTest(unittest.TestCase): def test_that_raises_when_wavelength_entry_is_missing(self): # Arrange state = StateWavelength() - assert_validate_error(self, ValueError, state) + with self.assertRaises(ValueError): + state.validate() + state.wavelength_low = [1.] - assert_validate_error(self, ValueError, state) + with self.assertRaises(ValueError): + state.validate() + state.wavelength_high = [2.] - assert_validate_error(self, ValueError, state) + with self.assertRaises(ValueError): + state.validate() + state.wavelength_step = 2. - assert_raises_nothing(self, state) + self.assertIsNone(state.validate()) def test_that_raises_when_lower_wavelength_is_smaller_than_high_wavelength(self): state = StateWavelength() state.wavelength_low = [2.] state.wavelength_high = [1.] state.wavelength_step = 2. - assert_validate_error(self, ValueError, state) + with self.assertRaises(ValueError): + state.validate() # ---------------------------------------------------------------------------------------------------------------------- @@ -63,16 +69,16 @@ class StateSliceEventBuilderTest(unittest.TestCase): builder.set_wavelength_low([10.0]) builder.set_wavelength_high([20.0]) builder.set_wavelength_step(3.0) - builder.set_wavelength_step_type(RangeStepType.Lin) - builder.set_rebin_type(RebinType.Rebin) + builder.set_wavelength_step_type(RangeStepType.LIN) + builder.set_rebin_type(RebinType.REBIN) # Assert state = builder.build() self.assertEqual(state.wavelength_low, [10.0]) self.assertEqual(state.wavelength_high, [20.0]) - self.assertEqual(state.wavelength_step_type, RangeStepType.Lin) - self.assertEqual(state.rebin_type, RebinType.Rebin) + self.assertEqual(state.wavelength_step_type, RangeStepType.LIN) + self.assertEqual(state.rebin_type, RebinType.REBIN) if __name__ == '__main__': diff --git a/scripts/test/SANS/user_file/state_director_test.py b/scripts/test/SANS/user_file/state_director_test.py index 28ff3362fb973d6574b771b12bcb8e67568f1814..a512b3c5a882fb7da74e385d76bef9211ce5443a 100644 --- a/scripts/test/SANS/user_file/state_director_test.py +++ b/scripts/test/SANS/user_file/state_director_test.py @@ -5,19 +5,17 @@ # & Institut Laue - Langevin # SPDX - License - Identifier: GPL - 3.0 + from __future__ import (absolute_import, division, print_function) + import os import unittest -import mantid - -from sans.user_file.state_director import StateDirectorISIS -from sans.common.enums import (SANSFacility, ISISReductionMode, RangeStepType, RebinType, DataType, FitType, - DetectorType, SampleShape, SANSInstrument) from sans.common.configurations import Configurations +from sans.common.enums import (SANSFacility, ReductionMode, RangeStepType, RebinType, DataType, FitType, + DetectorType, SampleShape, SANSInstrument) from sans.state.data import get_data_builder - -from sans.test_helper.user_file_test_helper import create_user_file, sample_user_file from sans.test_helper.file_information_mock import SANSFileInformationMock +from sans.test_helper.user_file_test_helper import create_user_file, sample_user_file +from sans.user_file.state_director import StateDirectorISIS # ----------------------------------------------------------------- @@ -29,15 +27,14 @@ class UserFileStateDirectorISISTest(unittest.TestCase): self.assertEqual(data.calibration, "TUBE_SANS2D_BOTH_31681_25Sept15.nxs") self.assertEqual(data.user_file, "USER_SANS2D_154E_2p4_4m_M3_Xpress_8mm_SampleChanger_FRONT.txt") - def _assert_move(self, state): move = state.move # Check the elements which were set on move self.assertEqual(move.sample_offset, 53.0/1000.) # Detector specific - lab = move.detectors[DetectorType.to_string(DetectorType.LAB)] - hab = move.detectors[DetectorType.to_string(DetectorType.HAB)] + lab = move.detectors[DetectorType.LAB.value] + hab = move.detectors[DetectorType.HAB.value] self.assertEqual(lab.x_translation_correction, -16.0/1000.) self.assertEqual(lab.z_translation_correction, 47.0/1000.) self.assertEqual(hab.x_translation_correction, -44.0/1000.) @@ -54,22 +51,22 @@ class UserFileStateDirectorISISTest(unittest.TestCase): self.assertEqual(mask.radius_max, 15/1000.) self.assertEqual(mask.clear, True) self.assertEqual(mask.clear_time, True) - self.assertEqual(mask.detectors[DetectorType.to_string(DetectorType.LAB)].single_horizontal_strip_mask, [0]) - self.assertEqual(mask.detectors[DetectorType.to_string(DetectorType.LAB)].single_vertical_strip_mask, [0, 191]) - self.assertEqual(mask.detectors[DetectorType.to_string(DetectorType.HAB)].single_horizontal_strip_mask, [0]) - self.assertEqual(mask.detectors[DetectorType.to_string(DetectorType.HAB)].single_vertical_strip_mask, [0, 191]) - self.assertTrue(mask.detectors[DetectorType.to_string(DetectorType.LAB)].range_horizontal_strip_start + self.assertEqual(mask.detectors[DetectorType.LAB.value].single_horizontal_strip_mask, [0]) + self.assertEqual(mask.detectors[DetectorType.LAB.value].single_vertical_strip_mask, [0, 191]) + self.assertEqual(mask.detectors[DetectorType.HAB.value].single_horizontal_strip_mask, [0]) + self.assertEqual(mask.detectors[DetectorType.HAB.value].single_vertical_strip_mask, [0, 191]) + self.assertTrue(mask.detectors[DetectorType.LAB.value].range_horizontal_strip_start == [190, 167]) - self.assertTrue(mask.detectors[DetectorType.to_string(DetectorType.LAB)].range_horizontal_strip_stop + self.assertTrue(mask.detectors[DetectorType.LAB.value].range_horizontal_strip_stop == [191, 172]) - self.assertTrue(mask.detectors[DetectorType.to_string(DetectorType.HAB)].range_horizontal_strip_start + self.assertTrue(mask.detectors[DetectorType.HAB.value].range_horizontal_strip_start == [190, 156]) - self.assertTrue(mask.detectors[DetectorType.to_string(DetectorType.HAB)].range_horizontal_strip_stop + self.assertTrue(mask.detectors[DetectorType.HAB.value].range_horizontal_strip_stop == [191, 159]) def _assert_reduction(self, state): reduction = state.reduction - self.assertEqual(reduction.reduction_mode, ISISReductionMode.LAB) + self.assertEqual(reduction.reduction_mode, ReductionMode.LAB) self.assertFalse(reduction.merge_mask) self.assertEqual(reduction.merge_min, None) self.assertEqual(reduction.merge_max, None) @@ -83,7 +80,7 @@ class UserFileStateDirectorISISTest(unittest.TestCase): self.assertEqual(wavelength.wavelength_low, [1.5]) self.assertEqual(wavelength.wavelength_high, [12.5]) self.assertEqual(wavelength.wavelength_step, 0.125) - self.assertEqual(wavelength.wavelength_step_type, RangeStepType.Lin) + self.assertEqual(wavelength.wavelength_step_type, RangeStepType.LIN) def _assert_convert_to_q(self, state): convert_to_q = state.convert_to_q @@ -108,11 +105,11 @@ class UserFileStateDirectorISISTest(unittest.TestCase): normalize_to_monitor = adjustment.normalize_to_monitor self.assertEqual(normalize_to_monitor.prompt_peak_correction_min, 1000) self.assertEqual(normalize_to_monitor.prompt_peak_correction_max, 2000) - self.assertEqual(normalize_to_monitor.rebin_type, RebinType.InterpolatingRebin) + self.assertEqual(normalize_to_monitor.rebin_type, RebinType.INTERPOLATING_REBIN) self.assertEqual(normalize_to_monitor.wavelength_low, [1.5]) self.assertEqual(normalize_to_monitor.wavelength_high, [12.5]) self.assertEqual(normalize_to_monitor.wavelength_step, 0.125) - self.assertEqual(normalize_to_monitor.wavelength_step_type, RangeStepType.Lin) + self.assertEqual(normalize_to_monitor.wavelength_step_type, RangeStepType.LIN) self.assertEqual(normalize_to_monitor.background_TOF_general_start, 3500) self.assertEqual(normalize_to_monitor.background_TOF_general_stop, 4500) self.assertEqual(normalize_to_monitor.background_TOF_monitor_start["1"], 35000) @@ -132,11 +129,11 @@ class UserFileStateDirectorISISTest(unittest.TestCase): self.assertEqual(calculate_transmission.transmission_roi_files, ["test.xml", "test2.xml"]) self.assertEqual(calculate_transmission.transmission_mask_files, ["test3.xml", "test4.xml"]) self.assertEqual(calculate_transmission.transmission_monitor, 4) - self.assertEqual(calculate_transmission.rebin_type, RebinType.InterpolatingRebin) + self.assertEqual(calculate_transmission.rebin_type, RebinType.INTERPOLATING_REBIN) self.assertEqual(calculate_transmission.wavelength_low, [1.5]) self.assertEqual(calculate_transmission.wavelength_high, [12.5]) self.assertEqual(calculate_transmission.wavelength_step, 0.125) - self.assertEqual(calculate_transmission.wavelength_step_type, RangeStepType.Lin) + self.assertEqual(calculate_transmission.wavelength_step_type, RangeStepType.LIN) self.assertFalse(calculate_transmission.use_full_wavelength_range) self.assertEqual(calculate_transmission.wavelength_full_range_low, Configurations.SANS2D.wavelength_full_range_low) @@ -150,26 +147,26 @@ class UserFileStateDirectorISISTest(unittest.TestCase): self.assertEqual(calculate_transmission.background_TOF_monitor_stop["2"], 98000) self.assertEqual(calculate_transmission.background_TOF_roi_start, 123) self.assertEqual(calculate_transmission.background_TOF_roi_stop, 466) - self.assertEqual(calculate_transmission.fit[DataType.to_string(DataType.Sample)].fit_type, FitType.Logarithmic) - self.assertEqual(calculate_transmission.fit[DataType.to_string(DataType.Sample)].wavelength_low, 1.5) - self.assertEqual(calculate_transmission.fit[DataType.to_string(DataType.Sample)].wavelength_high, 12.5) - self.assertEqual(calculate_transmission.fit[DataType.to_string(DataType.Sample)].polynomial_order, 0) - self.assertEqual(calculate_transmission.fit[DataType.to_string(DataType.Can)].fit_type, FitType.Logarithmic) - self.assertEqual(calculate_transmission.fit[DataType.to_string(DataType.Can)].wavelength_low, 1.5) - self.assertEqual(calculate_transmission.fit[DataType.to_string(DataType.Can)].wavelength_high, 12.5) - self.assertEqual(calculate_transmission.fit[DataType.to_string(DataType.Can)].polynomial_order, 0) + self.assertEqual(calculate_transmission.fit[DataType.SAMPLE].fit_type, FitType.LOGARITHMIC) + self.assertEqual(calculate_transmission.fit[DataType.SAMPLE].wavelength_low, 1.5) + self.assertEqual(calculate_transmission.fit[DataType.SAMPLE].wavelength_high, 12.5) + self.assertEqual(calculate_transmission.fit[DataType.SAMPLE].polynomial_order, 0) + self.assertEqual(calculate_transmission.fit[DataType.CAN].fit_type, FitType.LOGARITHMIC) + self.assertEqual(calculate_transmission.fit[DataType.CAN].wavelength_low, 1.5) + self.assertEqual(calculate_transmission.fit[DataType.CAN].wavelength_high, 12.5) + self.assertEqual(calculate_transmission.fit[DataType.CAN].polynomial_order, 0) # Wavelength and Pixel Adjustment wavelength_and_pixel_adjustment = adjustment.wavelength_and_pixel_adjustment self.assertEqual(wavelength_and_pixel_adjustment.wavelength_low, [1.5]) self.assertEqual(wavelength_and_pixel_adjustment.wavelength_high, [12.5]) self.assertEqual(wavelength_and_pixel_adjustment.wavelength_step, 0.125) - self.assertEqual(wavelength_and_pixel_adjustment.wavelength_step_type, RangeStepType.Lin) + self.assertEqual(wavelength_and_pixel_adjustment.wavelength_step_type, RangeStepType.LIN) self.assertTrue(wavelength_and_pixel_adjustment.adjustment_files[ - DetectorType.to_string(DetectorType.LAB)].wavelength_adjustment_file == + DetectorType.LAB.value].wavelength_adjustment_file == "DIRECTM1_15785_12m_31Oct12_v12.dat") self.assertTrue(wavelength_and_pixel_adjustment.adjustment_files[ - DetectorType.to_string(DetectorType.HAB)].wavelength_adjustment_file == + DetectorType.HAB.value].wavelength_adjustment_file == "DIRECTM1_15785_12m_31Oct12_v12.dat") # Assert wide angle correction @@ -222,7 +219,7 @@ class UserFileStateDirectorISISTest(unittest.TestCase): director.set_scale_builder_width(1.) director.set_scale_builder_height(1.5) director.set_scale_builder_thickness(12.) - director.set_scale_builder_shape(SampleShape.FlatPlate) + director.set_scale_builder_shape(SampleShape.FLAT_PLATE) # Act state = director.construct() @@ -233,7 +230,7 @@ class UserFileStateDirectorISISTest(unittest.TestCase): self.assertEqual(state.scale.width, 1.) self.assertEqual(state.scale.height, 1.5) self.assertEqual(state.scale.thickness, 12.) - self.assertEqual(state.scale.shape, SampleShape.FlatPlate) + self.assertEqual(state.scale.shape, SampleShape.FLAT_PLATE) # clean up if os.path.exists(user_file_path): diff --git a/scripts/test/SANS/user_file/user_file_parser_test.py b/scripts/test/SANS/user_file/user_file_parser_test.py index 5b618fddb8c7edd94dc7d47df0e4b16579d4a9b1..333bac755f20a30261282d9bfbaa95cf6794bc73 100644 --- a/scripts/test/SANS/user_file/user_file_parser_test.py +++ b/scripts/test/SANS/user_file/user_file_parser_test.py @@ -5,21 +5,23 @@ # & Institut Laue - Langevin # SPDX - License - Identifier: GPL - 3.0 + from __future__ import (absolute_import, division, print_function) + import unittest -import mantid -from sans.common.enums import (ISISReductionMode, DetectorType, RangeStepType, FitType, DataType, SANSInstrument) -from sans.user_file.user_file_parser import (InstrParser, DetParser, LimitParser, MaskParser, SampleParser, SetParser, TransParser, - TubeCalibFileParser, QResolutionParser, FitParser, GravityParser, - MaskFileParser, MonParser, PrintParser, BackParser, SANS2DParser, LOQParser, - UserFileParser, LARMORParser, CompatibilityParser) +from sans.common.enums import (ReductionMode, DetectorType, RangeStepType, FitType, DataType, SANSInstrument) from sans.user_file.settings_tags import (DetectorId, BackId, range_entry, back_single_monitor_entry, single_entry_with_detector, mask_angle_entry, LimitsId, - simple_range, complex_range, MaskId, mask_block, mask_block_cross, + simple_range, MaskId, mask_block, mask_block_cross, mask_line, range_entry_with_detector, SampleId, SetId, set_scales_entry, position_entry, TransId, TubeCalibrationFileId, QResolutionId, FitId, fit_general, MonId, monitor_length, monitor_file, GravityId, OtherId, monitor_spectrum, PrintId, det_fit_range, q_rebin_values) +from sans.user_file.user_file_parser import (InstrParser, DetParser, LimitParser, MaskParser, SampleParser, SetParser, + TransParser, + TubeCalibFileParser, QResolutionParser, FitParser, GravityParser, + MaskFileParser, MonParser, PrintParser, BackParser, SANS2DParser, + LOQParser, + UserFileParser, LARMORParser, CompatibilityParser) # ----------------------------------------------------------------- @@ -70,10 +72,10 @@ class InstrParserTest(unittest.TestCase): self.assertFalse(InstrParser.get_type_pattern("SANS2D/something else")) def test_that_instruments_are_parsed_correctly(self): - valid_settings = {"SANS2D": {DetectorId.instrument: SANSInstrument.SANS2D}, - "LOQ": {DetectorId.instrument: SANSInstrument.LOQ}, - "ZOOM": {DetectorId.instrument: SANSInstrument.ZOOM}, - "LARMOR": {DetectorId.instrument: SANSInstrument.LARMOR}} + valid_settings = {"SANS2D": {DetectorId.INSTRUMENT: SANSInstrument.SANS2D}, + "LOQ": {DetectorId.INSTRUMENT: SANSInstrument.LOQ}, + "ZOOM": {DetectorId.INSTRUMENT: SANSInstrument.ZOOM}, + "LARMOR": {DetectorId.INSTRUMENT: SANSInstrument.LARMOR}} invalid_settings = {"NOINSTRUMENT": RuntimeError, "SANS2D/HAB": RuntimeError, @@ -89,13 +91,13 @@ class DetParserTest(unittest.TestCase): def test_that_reduction_mode_is_parsed_correctly(self): # The dict below has the string to parse as the key and the expected result as a value - valid_settings = {"DET/HAB": {DetectorId.reduction_mode: ISISReductionMode.HAB}, - "dEt/ frONT ": {DetectorId.reduction_mode: ISISReductionMode.HAB}, - "dET/REAR": {DetectorId.reduction_mode: ISISReductionMode.LAB}, - "dEt/MAIn ": {DetectorId.reduction_mode: ISISReductionMode.LAB}, - " dEt/ BOtH": {DetectorId.reduction_mode: ISISReductionMode.All}, - "DeT /merge ": {DetectorId.reduction_mode: ISISReductionMode.Merged}, - " DEt / MERGED": {DetectorId.reduction_mode: ISISReductionMode.Merged}} + valid_settings = {"DET/HAB": {DetectorId.REDUCTION_MODE: ReductionMode.HAB}, + "dEt/ frONT ": {DetectorId.REDUCTION_MODE: ReductionMode.HAB}, + "dET/REAR": {DetectorId.REDUCTION_MODE: ReductionMode.LAB}, + "dEt/MAIn ": {DetectorId.REDUCTION_MODE: ReductionMode.LAB}, + " dEt/ BOtH": {DetectorId.REDUCTION_MODE: ReductionMode.ALL}, + "DeT /merge ": {DetectorId.REDUCTION_MODE: ReductionMode.MERGED}, + " DEt / MERGED": {DetectorId.REDUCTION_MODE: ReductionMode.MERGED}} invalid_settings = {"DET/HUB": RuntimeError, "DET/HAB/": RuntimeError} @@ -103,11 +105,11 @@ class DetParserTest(unittest.TestCase): do_test(det_parser, valid_settings, invalid_settings, self.assertTrue, self.assertRaises) def test_that_merge_option_is_parsed_correctly(self): - valid_settings = {"DET/RESCALE 123": {DetectorId.rescale: 123}, - "dEt/ shiFt 48.5": {DetectorId.shift: 48.5}, - "dET/reSCale/FIT 23 34.6 ": {DetectorId.rescale_fit: det_fit_range(start=23, stop=34.6, + valid_settings = {"DET/RESCALE 123": {DetectorId.RESCALE: 123}, + "dEt/ shiFt 48.5": {DetectorId.SHIFT: 48.5}, + "dET/reSCale/FIT 23 34.6 ": {DetectorId.RESCALE_FIT: det_fit_range(start=23, stop=34.6, use_fit=True)}, - "dEt/SHIFT/FIT 235.2 341 ": {DetectorId.shift_fit: det_fit_range(start=235.2, stop=341, + "dEt/SHIFT/FIT 235.2 341 ": {DetectorId.SHIFT_FIT: det_fit_range(start=235.2, stop=341, use_fit=True)}} invalid_settings = {"DET/Ruscale": RuntimeError, @@ -120,39 +122,39 @@ class DetParserTest(unittest.TestCase): do_test(det_parser, valid_settings, invalid_settings, self.assertTrue, self.assertRaises) def test_that_detector_setting_is_parsed_correctly(self): - valid_settings = {"Det/CORR/REAR/X 123": {DetectorId.correction_x: single_entry_with_detector(entry=123, - detector_type=DetectorType.LAB)}, # noqa - "DEt/CORR/ frOnt/X +95.7": {DetectorId.correction_x: + valid_settings = {"Det/CORR/REAR/X 123": {DetectorId.CORRECTION_X: single_entry_with_detector(entry=123, + detector_type=DetectorType.LAB)}, # noqa + "DEt/CORR/ frOnt/X +95.7": {DetectorId.CORRECTION_X: single_entry_with_detector(entry=95.7, detector_type=DetectorType.HAB)}, - "DeT/ CORR / ReAR/ y 12.3": {DetectorId.correction_y: + "DeT/ CORR / ReAR/ y 12.3": {DetectorId.CORRECTION_Y: single_entry_with_detector(entry=12.3, detector_type=DetectorType.LAB)}, - " DET/CoRR/fROnt/Y -957": {DetectorId.correction_y: + " DET/CoRR/fROnt/Y -957": {DetectorId.CORRECTION_Y: single_entry_with_detector(entry=-957, detector_type=DetectorType.HAB)}, - "DeT/ CORR /reAR/Z 12.3": {DetectorId.correction_z: + "DeT/ CORR /reAR/Z 12.3": {DetectorId.CORRECTION_Z: single_entry_with_detector(entry=12.3, detector_type=DetectorType.LAB)}, - " DET/CoRR/FRONT/ Z -957": {DetectorId.correction_z: + " DET/CoRR/FRONT/ Z -957": {DetectorId.CORRECTION_Z: single_entry_with_detector(entry=-957, detector_type=DetectorType.HAB)}, - "DeT/ CORR /reAR/SIDE 12.3": {DetectorId.correction_translation: + "DeT/ CORR /reAR/SIDE 12.3": {DetectorId.CORRECTION_TRANSLATION: single_entry_with_detector(entry=12.3, detector_type=DetectorType.LAB)}, - " DET/CoRR/FRONT/ SidE -957": {DetectorId.correction_translation: + " DET/CoRR/FRONT/ SidE -957": {DetectorId.CORRECTION_TRANSLATION: single_entry_with_detector(entry=-957, detector_type=DetectorType.HAB)}, - "DeT/ CORR /reAR/ROt 12.3": {DetectorId.correction_rotation: + "DeT/ CORR /reAR/ROt 12.3": {DetectorId.CORRECTION_ROTATION: single_entry_with_detector(entry=12.3, detector_type=DetectorType.LAB)}, - " DET/CoRR/FRONT/ROT -957": {DetectorId.correction_rotation: + " DET/CoRR/FRONT/ROT -957": {DetectorId.CORRECTION_ROTATION: single_entry_with_detector(entry=-957, detector_type=DetectorType.HAB)}, - "DeT/ CORR /reAR/Radius 12.3": {DetectorId.correction_radius: + "DeT/ CORR /reAR/Radius 12.3": {DetectorId.CORRECTION_RADIUS: single_entry_with_detector(entry=12.3, detector_type=DetectorType.LAB)}, - " DET/CoRR/FRONT/RADIUS 957": {DetectorId.correction_radius: + " DET/CoRR/FRONT/RADIUS 957": {DetectorId.CORRECTION_RADIUS: single_entry_with_detector(entry=957, detector_type=DetectorType.HAB)}} @@ -169,8 +171,8 @@ class DetParserTest(unittest.TestCase): do_test(det_parser, valid_settings, invalid_settings, self.assertTrue, self.assertRaises) def test_that_DET_OVERLAP_option_is_parsed_correctly(self): - valid_settings = {"DET/OVERLAP 0.13 0.15": {DetectorId.merge_range: det_fit_range(start=0.13, stop=0.15, use_fit=True)}, - "DeT/OverLAP 0.13 0.15": {DetectorId.merge_range: det_fit_range(start=0.13, stop=0.15, use_fit=True)} + valid_settings = {"DET/OVERLAP 0.13 0.15": {DetectorId.MERGE_RANGE: det_fit_range(start=0.13, stop=0.15, use_fit=True)}, + "DeT/OverLAP 0.13 0.15": {DetectorId.MERGE_RANGE: det_fit_range(start=0.13, stop=0.15, use_fit=True)} } invalid_settings = {"DET/OVERLAP 0.13 0.15 0.17": RuntimeError, @@ -187,9 +189,9 @@ class LimitParserTest(unittest.TestCase): self.assertTrue(LimitParser.get_type(), "L") def test_that_angle_limit_is_parsed_correctly(self): - valid_settings = {"L/PhI 123 345.2": {LimitsId.angle: mask_angle_entry(min=123, max=345.2, + valid_settings = {"L/PhI 123 345.2": {LimitsId.ANGLE: mask_angle_entry(min=123, max=345.2, use_mirror=True)}, - "L/PHI / NOMIRROR 123 -345.2": {LimitsId.angle: mask_angle_entry(min=123, max=-345.2, + "L/PHI / NOMIRROR 123 -345.2": {LimitsId.ANGLE: mask_angle_entry(min=123, max=-345.2, use_mirror=False)}} invalid_settings = {"L/PHI/NMIRROR/ 23 454": RuntimeError, @@ -201,9 +203,9 @@ class LimitParserTest(unittest.TestCase): do_test(limit_parser, valid_settings, invalid_settings, self.assertTrue, self.assertRaises) def test_that_event_time_limit_is_parsed_correctly(self): - valid_settings = {"L / EVEnTStime 0,-10,32,434,34523,35": {LimitsId.events_binning: + valid_settings = {"L / EVEnTStime 0,-10,32,434,34523,35": {LimitsId.EVENTS_BINNING: "0.0,-10.0,32.0,434.0,34523.0,35.0"}, - "L / Eventstime 0 -10 32 434 34523 35": {LimitsId.events_binning: + "L / Eventstime 0 -10 32 434 34523 35": {LimitsId.EVENTS_BINNING: "0.0,-10.0,32.0,434.0,34523.0,35.0"}} invalid_settings = {"L / EEnTStime 0,-10,32,434,34523,35": RuntimeError, @@ -213,10 +215,10 @@ class LimitParserTest(unittest.TestCase): do_test(limit_parser, valid_settings, invalid_settings, self.assertTrue, self.assertRaises) def test_that_cut_limits_are_parsed_correctly(self): - valid_settings = {"L/Q/RCUT 234.4": {LimitsId.radius_cut: 234.4}, - "L /q / RcUT -234.34": {LimitsId.radius_cut: -234.34}, - "l/Q/WCUT 234.4": {LimitsId.wavelength_cut: 234.4}, - "L /q / wcUT -234.34": {LimitsId.wavelength_cut: -234.34}} + valid_settings = {"L/Q/RCUT 234.4": {LimitsId.RADIUS_CUT: 234.4}, + "L /q / RcUT -234.34": {LimitsId.RADIUS_CUT: -234.34}, + "l/Q/WCUT 234.4": {LimitsId.WAVELENGTH_CUT: 234.4}, + "L /q / wcUT -234.34": {LimitsId.WAVELENGTH_CUT: -234.34}} invalid_settings = {"L/Q/Rcu 123": RuntimeError, "L/Q/RCUT/ 2134": RuntimeError, @@ -228,9 +230,9 @@ class LimitParserTest(unittest.TestCase): do_test(limit_parser, valid_settings, invalid_settings, self.assertTrue, self.assertRaises) def test_that_radius_limits_are_parsed_correctly(self): - valid_settings = {"L/R 234 235": {LimitsId.radius: range_entry(start=234, stop=235)}, - "L / r -234 235": {LimitsId.radius: range_entry(start=-234, stop=235)}, - "L / r -234 235 454": {LimitsId.radius: range_entry(start=-234, stop=235)} + valid_settings = {"L/R 234 235": {LimitsId.RADIUS: range_entry(start=234, stop=235)}, + "L / r -234 235": {LimitsId.RADIUS: range_entry(start=-234, stop=235)}, + "L / r -234 235 454": {LimitsId.RADIUS: range_entry(start=-234, stop=235)} } invalid_settings = {"L/R/ 234 435": RuntimeError, "L/Rr 234 435": RuntimeError, @@ -242,31 +244,31 @@ class LimitParserTest(unittest.TestCase): do_test(limit_parser, valid_settings, invalid_settings, self.assertTrue, self.assertRaises) def test_that_q_limits_are_parsed_correctly(self): - valid_settings = {"L/Q 12 34": {LimitsId.q: q_rebin_values(min=12., max=34., rebin_string="12.0,34.0")}, - "L/Q 12 34 2.7": {LimitsId.q: q_rebin_values(min=12., max=34., rebin_string="12.0,2.7,34.0")}, - "L/Q -12 34.6 2.7/LOG": {LimitsId.q: q_rebin_values(min=-12., max=34.6, + valid_settings = {"L/Q 12 34": {LimitsId.Q: q_rebin_values(min=12., max=34., rebin_string="12.0,34.0")}, + "L/Q 12 34 2.7": {LimitsId.Q: q_rebin_values(min=12., max=34., rebin_string="12.0,2.7,34.0")}, + "L/Q -12 34.6 2.7/LOG": {LimitsId.Q: q_rebin_values(min=-12., max=34.6, rebin_string="-12.0,-2.7,34.6")}, - "L/q -12 3.6 2 /LIN": {LimitsId.q: q_rebin_values(min=-12., max=3.6, + "L/q -12 3.6 2 /LIN": {LimitsId.Q: q_rebin_values(min=-12., max=3.6, rebin_string="-12.0,2.0,3.6")}, - "L/q -12 0.41 23 -34.8 3.6": {LimitsId.q: q_rebin_values(min=-12., max=3.6, - rebin_string="-12.0,0.41,23.0,-34.8,3.6")}, # noqa - "L/q -12 0.42 23 -34.8 3.6 /LIn": {LimitsId.q: q_rebin_values(min=-12., max=3.6, - rebin_string="-12.0,0.42,23.0,34.8,3.6")}, - "L/q -12 0.43 23 -34.8 3.6": {LimitsId.q: q_rebin_values(min=-12., max=3.6, - rebin_string="-12.0,0.43,23.0,-34.8,3.6")}, - "L/q -12 0.44 23 ,34.8,3.6 /Log": {LimitsId.q: q_rebin_values(min=-12., max=3.6, - rebin_string="-12.0,-0.44,23.0,-34.8,3.6")}, - "L/q -12 , 0.45 , 23 ,34.8 ,3.6, .123, 5.6 /Log": {LimitsId.q: q_rebin_values(min=-12., - max=5.6, - rebin_string="-12.0,-0.45,23.0,-34.8,3.6," + "L/q -12 0.41 23 -34.8 3.6": {LimitsId.Q: q_rebin_values(min=-12., max=3.6, + rebin_string="-12.0,0.41,23.0,-34.8,3.6")}, # noqa + "L/q -12 0.42 23 -34.8 3.6 /LIn": {LimitsId.Q: q_rebin_values(min=-12., max=3.6, + rebin_string="-12.0,0.42,23.0,34.8,3.6")}, + "L/q -12 0.43 23 -34.8 3.6": {LimitsId.Q: q_rebin_values(min=-12., max=3.6, + rebin_string="-12.0,0.43,23.0,-34.8,3.6")}, + "L/q -12 0.44 23 ,34.8,3.6 /Log": {LimitsId.Q: q_rebin_values(min=-12., max=3.6, + rebin_string="-12.0,-0.44,23.0,-34.8,3.6")}, + "L/q -12 , 0.45 , 23 ,34.8 ,3.6, .123, 5.6 /Log": {LimitsId.Q: q_rebin_values(min=-12., + max=5.6, + rebin_string="-12.0,-0.45,23.0,-34.8,3.6," "-0.123,5.6")}, - "L/q -12 , 0.46 , 23 ,34.8 ,3.6, -.123, 5.6": {LimitsId.q: q_rebin_values(min=-12., - max=5.6, - rebin_string="-12.0,0.46,23.0,34.8,3.6," + "L/q -12 , 0.46 , 23 ,34.8 ,3.6, -.123, 5.6": {LimitsId.Q: q_rebin_values(min=-12., + max=5.6, + rebin_string="-12.0,0.46,23.0,34.8,3.6," "-0.123,5.6")}, - "L/q -12 0.47 23 34.8 3.6, -.123 5.6": {LimitsId.q: q_rebin_values(min=-12., - max=5.6, - rebin_string="-12.0,0.47,23.0,34.8,3.6," + "L/q -12 0.47 23 34.8 3.6, -.123 5.6": {LimitsId.Q: q_rebin_values(min=-12., + max=5.6, + rebin_string="-12.0,0.47,23.0,34.8,3.6," "-0.123,5.6")} } @@ -278,13 +280,13 @@ class LimitParserTest(unittest.TestCase): do_test(limit_parser, valid_settings, invalid_settings, self.assertTrue, self.assertRaises) def test_that_qxy_limits_are_parsed_correctly(self): - valid_settings = {"L/QXY 12 34": {LimitsId.qxy: simple_range(start=12, stop=34, step=None, step_type=None)}, - "L/QXY 12 34 2.7": {LimitsId.qxy: simple_range(start=12, stop=34, step=2.7, - step_type=RangeStepType.Lin)}, - "L/QXY -12 34.6 2.7/LOG": {LimitsId.qxy: simple_range(start=-12, stop=34.6, step=2.7, - step_type=RangeStepType.Log)}, - "L/qxY -12 3.6 2 /LIN": {LimitsId.qxy: simple_range(start=-12, stop=3.6, step=2, - step_type=RangeStepType.Lin)}} + valid_settings = {"L/QXY 12 34": {LimitsId.QXY: simple_range(start=12, stop=34, step=None, step_type=None)}, + "L/QXY 12 34 2.7": {LimitsId.QXY: simple_range(start=12, stop=34, step=2.7, + step_type=RangeStepType.LIN)}, + "L/QXY -12 34.6 2.7/LOG": {LimitsId.QXY: simple_range(start=-12, stop=34.6, step=2.7, + step_type=RangeStepType.LOG)}, + "L/qxY -12 3.6 2 /LIN": {LimitsId.QXY: simple_range(start=-12, stop=3.6, step=2, + step_type=RangeStepType.LIN)}} """ These tests should be added back to valid settings when SANS GUI can accept complex QXY strings. "L/qxy -12 , 0.4, 23, -3.48, 36": {LimitsId.qxy: complex_range(start=-12, step1=0.4, @@ -311,14 +313,14 @@ class LimitParserTest(unittest.TestCase): do_test(limit_parser, valid_settings, invalid_settings, self.assertTrue, self.assertRaises) def test_that_wavelength_limits_are_parsed_correctly(self): - valid_settings = {"L/WAV 12 34": {LimitsId.wavelength: simple_range(start=12, stop=34, step=None, + valid_settings = {"L/WAV 12 34": {LimitsId.WAVELENGTH: simple_range(start=12, stop=34, step=None, step_type=None)}, - "L/waV 12 34 2.7": {LimitsId.wavelength: simple_range(start=12, stop=34, step=2.7, - step_type=RangeStepType.Lin)}, - "L/wAv -12 34.6 2.7/LOG": {LimitsId.wavelength: simple_range(start=-12, stop=34.6, step=2.7, - step_type=RangeStepType.Log)}, - "L/WaV -12 3.6 2 /LIN": {LimitsId.wavelength: simple_range(start=-12, stop=3.6, step=2, - step_type=RangeStepType.Lin)}} + "L/waV 12 34 2.7": {LimitsId.WAVELENGTH: simple_range(start=12, stop=34, step=2.7, + step_type=RangeStepType.LIN)}, + "L/wAv -12 34.6 2.7/LOG": {LimitsId.WAVELENGTH: simple_range(start=-12, stop=34.6, step=2.7, + step_type=RangeStepType.LOG)}, + "L/WaV -12 3.6 2 /LIN": {LimitsId.WAVELENGTH: simple_range(start=-12, stop=3.6, step=2, + step_type=RangeStepType.LIN)}} invalid_settings = {"L/WAV 12 2 3 4": RuntimeError, "L/WAV 12 2 3 4 23 3": RuntimeError, @@ -336,8 +338,8 @@ class MaskParserTest(unittest.TestCase): self.assertTrue(MaskParser.get_type(), "MASK") def test_that_masked_line_is_parsed_correctly(self): - valid_settings = {"MASK/LiNE 12 23.6": {MaskId.line: mask_line(width=12, angle=23.6, x=None, y=None)}, - "MASK/LiNE 12 23.6 2 346": {MaskId.line: mask_line(width=12, angle=23.6, x=2, y=346)} + valid_settings = {"MASK/LiNE 12 23.6": {MaskId.LINE: mask_line(width=12, angle=23.6, x=None, y=None)}, + "MASK/LiNE 12 23.6 2 346": {MaskId.LINE: mask_line(width=12, angle=23.6, x=2, y=346)} } invalid_settings = {"MASK/LiN 12 4": RuntimeError, "MASK/LINE 12": RuntimeError, @@ -350,18 +352,18 @@ class MaskParserTest(unittest.TestCase): do_test(mask_parser, valid_settings, invalid_settings, self.assertTrue, self.assertRaises) def test_that_masked_time_is_parsed_correctly(self): - valid_settings = {"MASK/TIME 23 35": {MaskId.time: range_entry_with_detector(start=23, stop=35, + valid_settings = {"MASK/TIME 23 35": {MaskId.TIME: range_entry_with_detector(start=23, stop=35, detector_type=None)}, - "MASK/T 23 35": {MaskId.time: range_entry_with_detector(start=23, stop=35, + "MASK/T 23 35": {MaskId.TIME: range_entry_with_detector(start=23, stop=35, detector_type=None)}, - "MASK/REAR/T 13 35": {MaskId.time_detector: range_entry_with_detector(start=13, stop=35, - detector_type=DetectorType.LAB)}, - "MASK/FRONT/TIME 33 35": {MaskId.time_detector: range_entry_with_detector(start=33, stop=35, - detector_type=DetectorType.HAB)}, - "MASK/TIME/REAR 13 35": {MaskId.time_detector: range_entry_with_detector(start=13, stop=35, - detector_type=DetectorType.LAB)}, - "MASK/T/FRONT 33 35": {MaskId.time_detector: range_entry_with_detector(start=33, stop=35, - detector_type=DetectorType.HAB)} + "MASK/REAR/T 13 35": {MaskId.TIME_DETECTOR: range_entry_with_detector(start=13, stop=35, + detector_type=DetectorType.LAB)}, + "MASK/FRONT/TIME 33 35": {MaskId.TIME_DETECTOR: range_entry_with_detector(start=33, stop=35, + detector_type=DetectorType.HAB)}, + "MASK/TIME/REAR 13 35": {MaskId.TIME_DETECTOR: range_entry_with_detector(start=13, stop=35, + detector_type=DetectorType.LAB)}, + "MASK/T/FRONT 33 35": {MaskId.TIME_DETECTOR: range_entry_with_detector(start=33, stop=35, + detector_type=DetectorType.HAB)} } invalid_settings = {"MASK/TIME 12 34 4 ": RuntimeError, @@ -374,8 +376,8 @@ class MaskParserTest(unittest.TestCase): do_test(mask_parser, valid_settings, invalid_settings, self.assertTrue, self.assertRaises) def test_that_clear_mask_is_parsed_correctly(self): - valid_settings = {"MASK/CLEAR": {MaskId.clear_detector_mask: True}, - "MASK/CLeaR /TIMe": {MaskId.clear_time_mask: True}} + valid_settings = {"MASK/CLEAR": {MaskId.CLEAR_DETECTOR_MASK: True}, + "MASK/CLeaR /TIMe": {MaskId.CLEAR_TIME_MASK: True}} invalid_settings = {"MASK/CLEAR/TIME/test": RuntimeError, "MASK/CLEAR/TIIE": RuntimeError, @@ -385,8 +387,8 @@ class MaskParserTest(unittest.TestCase): do_test(mask_parser, valid_settings, invalid_settings, self.assertTrue, self.assertRaises) def test_that_single_spectrum_mask_is_parsed_correctly(self): - valid_settings = {"MASK S 12 ": {MaskId.single_spectrum_mask: 12}, - "MASK S234": {MaskId.single_spectrum_mask: 234}} + valid_settings = {"MASK S 12 ": {MaskId.SINGLE_SPECTRUM_MASK: 12}, + "MASK S234": {MaskId.SINGLE_SPECTRUM_MASK: 234}} invalid_settings = {"MASK B 12 ": RuntimeError, "MASK S 12 23 ": RuntimeError} @@ -394,8 +396,8 @@ class MaskParserTest(unittest.TestCase): do_test(mask_parser, valid_settings, invalid_settings, self.assertTrue, self.assertRaises) def test_that_single_spectrum_range_is_parsed_correctly(self): - valid_settings = {"MASK S 12 > S23 ": {MaskId.spectrum_range_mask: range_entry(start=12, stop=23)}, - "MASK S234>S1234": {MaskId.spectrum_range_mask: range_entry(start=234, stop=1234)}} + valid_settings = {"MASK S 12 > S23 ": {MaskId.SPECTRUM_RANGE_MASK: range_entry(start=12, stop=23)}, + "MASK S234>S1234": {MaskId.SPECTRUM_RANGE_MASK: range_entry(start=234, stop=1234)}} invalid_settings = {"MASK S 12> S123.5 ": RuntimeError, "MASK S 12> 23 ": RuntimeError} @@ -403,18 +405,18 @@ class MaskParserTest(unittest.TestCase): do_test(mask_parser, valid_settings, invalid_settings, self.assertTrue, self.assertRaises) def test_that_single_vertical_strip_mask_is_parsed_correctly(self): - valid_settings = {"MASK V 12 ": {MaskId.vertical_single_strip_mask: single_entry_with_detector(entry=12, - detector_type=DetectorType.LAB)}, - "MASK / Rear V 12 ": {MaskId.vertical_single_strip_mask: single_entry_with_detector( + valid_settings = {"MASK V 12 ": {MaskId.VERTICAL_SINGLE_STRIP_MASK: single_entry_with_detector(entry=12, + detector_type=DetectorType.LAB)}, + "MASK / Rear V 12 ": {MaskId.VERTICAL_SINGLE_STRIP_MASK: single_entry_with_detector( entry=12, detector_type=DetectorType.LAB)}, - "MASK/mAin V234": {MaskId.vertical_single_strip_mask: single_entry_with_detector(entry=234, - detector_type=DetectorType.LAB)}, - "MASK / LaB V 234": {MaskId.vertical_single_strip_mask: single_entry_with_detector(entry=234, - detector_type=DetectorType.LAB)}, - "MASK /frOnt V 12 ": {MaskId.vertical_single_strip_mask: single_entry_with_detector( + "MASK/mAin V234": {MaskId.VERTICAL_SINGLE_STRIP_MASK: single_entry_with_detector(entry=234, + detector_type=DetectorType.LAB)}, + "MASK / LaB V 234": {MaskId.VERTICAL_SINGLE_STRIP_MASK: single_entry_with_detector(entry=234, + detector_type=DetectorType.LAB)}, + "MASK /frOnt V 12 ": {MaskId.VERTICAL_SINGLE_STRIP_MASK: single_entry_with_detector( entry=12, detector_type=DetectorType.HAB)}, - "MASK/HAB V234": {MaskId.vertical_single_strip_mask: single_entry_with_detector(entry=234, - detector_type=DetectorType.HAB)}} + "MASK/HAB V234": {MaskId.VERTICAL_SINGLE_STRIP_MASK: single_entry_with_detector(entry=234, + detector_type=DetectorType.HAB)}} invalid_settings = {"MASK B 12 ": RuntimeError, "MASK V 12 23 ": RuntimeError, @@ -423,19 +425,19 @@ class MaskParserTest(unittest.TestCase): do_test(mask_parser, valid_settings, invalid_settings, self.assertTrue, self.assertRaises) def test_that_range_vertical_strip_mask_is_parsed_correctly(self): - valid_settings = {"MASK V 12 > V23 ": {MaskId.vertical_range_strip_mask: range_entry_with_detector(start=12, - stop=23, detector_type=DetectorType.LAB)}, - "MASK V123>V234": {MaskId.vertical_range_strip_mask: range_entry_with_detector(start=123, - stop=234, detector_type=DetectorType.LAB)}, - "MASK / Rear V123>V234": {MaskId.vertical_range_strip_mask: range_entry_with_detector( + valid_settings = {"MASK V 12 > V23 ": {MaskId.VERTICAL_RANGE_STRIP_MASK: range_entry_with_detector(start=12, + stop=23, detector_type=DetectorType.LAB)}, + "MASK V123>V234": {MaskId.VERTICAL_RANGE_STRIP_MASK: range_entry_with_detector(start=123, + stop=234, detector_type=DetectorType.LAB)}, + "MASK / Rear V123>V234": {MaskId.VERTICAL_RANGE_STRIP_MASK: range_entry_with_detector( start=123, stop=234, detector_type=DetectorType.LAB)}, - "MASK/mAin V123>V234": {MaskId.vertical_range_strip_mask: range_entry_with_detector( + "MASK/mAin V123>V234": {MaskId.VERTICAL_RANGE_STRIP_MASK: range_entry_with_detector( start=123, stop=234, detector_type=DetectorType.LAB)}, - "MASK / LaB V123>V234": {MaskId.vertical_range_strip_mask: range_entry_with_detector( + "MASK / LaB V123>V234": {MaskId.VERTICAL_RANGE_STRIP_MASK: range_entry_with_detector( start=123, stop=234, detector_type=DetectorType.LAB)}, - "MASK/frOnt V123>V234": {MaskId.vertical_range_strip_mask: range_entry_with_detector( + "MASK/frOnt V123>V234": {MaskId.VERTICAL_RANGE_STRIP_MASK: range_entry_with_detector( start=123, stop=234, detector_type=DetectorType.HAB)}, - "MASK/HAB V123>V234": {MaskId.vertical_range_strip_mask: range_entry_with_detector( + "MASK/HAB V123>V234": {MaskId.VERTICAL_RANGE_STRIP_MASK: range_entry_with_detector( start=123, stop=234, detector_type=DetectorType.HAB)}} invalid_settings = {"MASK V 12> V123.5 ": RuntimeError, @@ -445,18 +447,18 @@ class MaskParserTest(unittest.TestCase): do_test(mask_parser, valid_settings, invalid_settings, self.assertTrue, self.assertRaises) def test_that_single_horizontal_strip_mask_is_parsed_correctly(self): - valid_settings = {"MASK H 12 ": {MaskId.horizontal_single_strip_mask: single_entry_with_detector(entry=12, - detector_type=DetectorType.LAB)}, - "MASK / Rear H 12 ": {MaskId.horizontal_single_strip_mask: single_entry_with_detector( + valid_settings = {"MASK H 12 ": {MaskId.HORIZONTAL_SINGLE_STRIP_MASK: single_entry_with_detector(entry=12, + detector_type=DetectorType.LAB)}, + "MASK / Rear H 12 ": {MaskId.HORIZONTAL_SINGLE_STRIP_MASK: single_entry_with_detector( entry=12, detector_type=DetectorType.LAB)}, - "MASK/mAin H234": {MaskId.horizontal_single_strip_mask: single_entry_with_detector(entry=234, - detector_type=DetectorType.LAB)}, - "MASK / LaB H 234": {MaskId.horizontal_single_strip_mask: single_entry_with_detector( + "MASK/mAin H234": {MaskId.HORIZONTAL_SINGLE_STRIP_MASK: single_entry_with_detector(entry=234, + detector_type=DetectorType.LAB)}, + "MASK / LaB H 234": {MaskId.HORIZONTAL_SINGLE_STRIP_MASK: single_entry_with_detector( entry=234, detector_type=DetectorType.LAB)}, - "MASK /frOnt H 12 ": {MaskId.horizontal_single_strip_mask: single_entry_with_detector( + "MASK /frOnt H 12 ": {MaskId.HORIZONTAL_SINGLE_STRIP_MASK: single_entry_with_detector( entry=12, detector_type=DetectorType.HAB)}, - "MASK/HAB H234": {MaskId.horizontal_single_strip_mask: single_entry_with_detector(entry=234, - detector_type=DetectorType.HAB)}} + "MASK/HAB H234": {MaskId.HORIZONTAL_SINGLE_STRIP_MASK: single_entry_with_detector(entry=234, + detector_type=DetectorType.HAB)}} invalid_settings = {"MASK H/12 ": RuntimeError, "MASK H 12 23 ": RuntimeError, @@ -465,19 +467,19 @@ class MaskParserTest(unittest.TestCase): do_test(mask_parser, valid_settings, invalid_settings, self.assertTrue, self.assertRaises) def test_that_range_horizontal_strip_mask_is_parsed_correctly(self): - valid_settings = {"MASK H 12 > H23 ": {MaskId.horizontal_range_strip_mask: range_entry_with_detector( + valid_settings = {"MASK H 12 > H23 ": {MaskId.HORIZONTAL_RANGE_STRIP_MASK: range_entry_with_detector( start=12, stop=23, detector_type=DetectorType.LAB)}, - "MASK H123>H234": {MaskId.horizontal_range_strip_mask: range_entry_with_detector( + "MASK H123>H234": {MaskId.HORIZONTAL_RANGE_STRIP_MASK: range_entry_with_detector( start=123, stop=234, detector_type=DetectorType.LAB)}, - "MASK / Rear H123>H234": {MaskId.horizontal_range_strip_mask: range_entry_with_detector( + "MASK / Rear H123>H234": {MaskId.HORIZONTAL_RANGE_STRIP_MASK: range_entry_with_detector( start=123, stop=234, detector_type=DetectorType.LAB)}, - "MASK/mAin H123>H234": {MaskId.horizontal_range_strip_mask: range_entry_with_detector( + "MASK/mAin H123>H234": {MaskId.HORIZONTAL_RANGE_STRIP_MASK: range_entry_with_detector( start=123, stop=234, detector_type=DetectorType.LAB)}, - "MASK / LaB H123>H234": {MaskId.horizontal_range_strip_mask: range_entry_with_detector( + "MASK / LaB H123>H234": {MaskId.HORIZONTAL_RANGE_STRIP_MASK: range_entry_with_detector( start=123, stop=234, detector_type=DetectorType.LAB)}, - "MASK/frOnt H123>H234": {MaskId.horizontal_range_strip_mask: range_entry_with_detector( + "MASK/frOnt H123>H234": {MaskId.HORIZONTAL_RANGE_STRIP_MASK: range_entry_with_detector( start=123, stop=234, detector_type=DetectorType.HAB)}, - "MASK/HAB H123>H234": {MaskId.horizontal_range_strip_mask: range_entry_with_detector( + "MASK/HAB H123>H234": {MaskId.HORIZONTAL_RANGE_STRIP_MASK: range_entry_with_detector( start=123, stop=234, detector_type=DetectorType.HAB)}} invalid_settings = {"MASK H 12> H123.5 ": RuntimeError, @@ -487,18 +489,18 @@ class MaskParserTest(unittest.TestCase): do_test(mask_parser, valid_settings, invalid_settings, self.assertTrue, self.assertRaises) def test_that_block_mask_is_parsed_correctly(self): - valid_settings = {"MASK H12>H23 + V14>V15 ": {MaskId.block: mask_block(horizontal1=12, horizontal2=23, + valid_settings = {"MASK H12>H23 + V14>V15 ": {MaskId.BLOCK: mask_block(horizontal1=12, horizontal2=23, vertical1=14, vertical2=15, detector_type=DetectorType.LAB)}, - "MASK/ HAB H12>H23 + V14>V15 ": {MaskId.block: mask_block(horizontal1=12, horizontal2=23, + "MASK/ HAB H12>H23 + V14>V15 ": {MaskId.BLOCK: mask_block(horizontal1=12, horizontal2=23, vertical1=14, vertical2=15, detector_type=DetectorType.HAB)}, - "MASK/ HAB V12>V23 + H14>H15 ": {MaskId.block: mask_block(horizontal1=14, horizontal2=15, + "MASK/ HAB V12>V23 + H14>H15 ": {MaskId.BLOCK: mask_block(horizontal1=14, horizontal2=15, vertical1=12, vertical2=23, detector_type=DetectorType.HAB)}, - "MASK V12 + H 14": {MaskId.block_cross: mask_block_cross(horizontal=14, vertical=12, + "MASK V12 + H 14": {MaskId.BLOCK_CROSS: mask_block_cross(horizontal=14, vertical=12, detector_type=DetectorType.LAB)}, - "MASK/HAB H12 + V 14": {MaskId.block_cross: mask_block_cross(horizontal=12, vertical=14, + "MASK/HAB H12 + V 14": {MaskId.BLOCK_CROSS: mask_block_cross(horizontal=12, vertical=14, detector_type=DetectorType.HAB)}} invalid_settings = {"MASK H12>H23 + V14 + V15 ": RuntimeError, @@ -513,8 +515,8 @@ class SampleParserTest(unittest.TestCase): self.assertTrue(SampleParser.get_type(), "SAMPLE") def test_that_setting_sample_path_is_parsed_correctly(self): - valid_settings = {"SAMPLE /PATH/ON": {SampleId.path: True}, - "SAMPLE / PATH / OfF": {SampleId.path: False}} + valid_settings = {"SAMPLE /PATH/ON": {SampleId.PATH: True}, + "SAMPLE / PATH / OfF": {SampleId.PATH: False}} invalid_settings = {"SAMPLE/PATH ON": RuntimeError, "SAMPLE /pATh ": RuntimeError, @@ -524,8 +526,8 @@ class SampleParserTest(unittest.TestCase): do_test(sample_parser, valid_settings, invalid_settings, self.assertTrue, self.assertRaises) def test_that_setting_sample_offset_is_parsed_correctly(self): - valid_settings = {"SAMPLE /Offset 234.5": {SampleId.offset: 234.5}, - "SAMPLE / Offset 25": {SampleId.offset: 25}} + valid_settings = {"SAMPLE /Offset 234.5": {SampleId.OFFSET: 234.5}, + "SAMPLE / Offset 25": {SampleId.OFFSET: 25}} invalid_settings = {"SAMPL/offset fg": RuntimeError, "SAMPLE /Offset/ 23 ": RuntimeError, @@ -540,7 +542,7 @@ class SetParserTest(unittest.TestCase): self.assertTrue(SetParser.get_type(), "SET") def test_that_setting_scales_is_parsed_correctly(self): - valid_settings = {"SET scales 2 5 4 7 8": {SetId.scales: set_scales_entry(s=2, a=5, b=4, c=7, d=8)}} + valid_settings = {"SET scales 2 5 4 7 8": {SetId.SCALES: set_scales_entry(s=2, a=5, b=4, c=7, d=8)}} invalid_settings = {"SET scales 2 4 6 7 8 9": RuntimeError, "SET scales ": RuntimeError} @@ -549,18 +551,18 @@ class SetParserTest(unittest.TestCase): do_test(set_parser, valid_settings, invalid_settings, self.assertTrue, self.assertRaises) def test_that_centre_is_parsed_correctly(self): - valid_settings = {"SET centre 23 45": {SetId.centre: position_entry(pos1=23, pos2=45, + valid_settings = {"SET centre 23 45": {SetId.CENTRE: position_entry(pos1=23, pos2=45, detector_type=DetectorType.LAB)}, - "SET centre /main 23 45": {SetId.centre: position_entry(pos1=23, pos2=45, + "SET centre /main 23 45": {SetId.CENTRE: position_entry(pos1=23, pos2=45, detector_type=DetectorType.LAB)}, - "SET centre / lAb 23 45": {SetId.centre: position_entry(pos1=23, pos2=45, + "SET centre / lAb 23 45": {SetId.CENTRE: position_entry(pos1=23, pos2=45, detector_type=DetectorType.LAB)}, - "SET centre / hAb 23 45": {SetId.centre_HAB: position_entry(pos1=23, pos2=45, - detector_type=DetectorType.HAB)}, - "SET centre /FRONT 23 45": {SetId.centre_HAB: position_entry(pos1=23, pos2=45, - detector_type=DetectorType.HAB)}, - "SET centre /FRONT 23 45 55 67": {SetId.centre_HAB: position_entry(pos1=23, pos2=45, - detector_type=DetectorType.HAB)}, + "SET centre / hAb 23 45": {SetId.CENTRE_HAB: position_entry(pos1=23, pos2=45, + detector_type=DetectorType.HAB)}, + "SET centre /FRONT 23 45": {SetId.CENTRE_HAB: position_entry(pos1=23, pos2=45, + detector_type=DetectorType.HAB)}, + "SET centre /FRONT 23 45 55 67": {SetId.CENTRE_HAB: position_entry(pos1=23, pos2=45, + detector_type=DetectorType.HAB)}, } invalid_settings = {"SET centre 23": RuntimeError, @@ -577,8 +579,8 @@ class TransParserTest(unittest.TestCase): self.assertTrue(TransParser.get_type(), "TRANS") def test_that_trans_spec_is_parsed_correctly(self): - valid_settings = {"TRANS/TRANSPEC=23": {TransId.spec: 23}, - "TRANS / TransPEC = 23": {TransId.spec: 23}} + valid_settings = {"TRANS/TRANSPEC=23": {TransId.SPEC: 23}, + "TRANS / TransPEC = 23": {TransId.SPEC: 23}} invalid_settings = {"TRANS/TRANSPEC 23": RuntimeError, "TRANS/TRANSPEC/23": RuntimeError, @@ -590,20 +592,20 @@ class TransParserTest(unittest.TestCase): do_test(trans_parser, valid_settings, invalid_settings, self.assertTrue, self.assertRaises) def test_mon_shift_is_parsed_correctly(self): - valid_settings = {"TRANS/SHIFT = 4000 5": {TransId.spec_5_shift: 4000}, - "TRANS/SHIFT=4000 4" : {TransId.spec_4_shift: 4000}, - "TRANS/SHIFT=4000 5": {TransId.spec_5_shift: 4000}, - "TRANS/SHIFT=-100 4": {TransId.spec_4_shift: -100}, - "TRANS/ SHIFT=4000 5": {TransId.spec_5_shift: 4000}, - "TRANS /SHIFT=4000 5": {TransId.spec_5_shift: 4000}, - "TRANS/SHIFT=4000 5": {TransId.spec_5_shift: 4000}, + valid_settings = {"TRANS/SHIFT = 4000 5": {TransId.SPEC_5_SHIFT: 4000}, + "TRANS/SHIFT=4000 4" : {TransId.SPEC_4_SHIFT: 4000}, + "TRANS/SHIFT=4000 5": {TransId.SPEC_5_SHIFT: 4000}, + "TRANS/SHIFT=-100 4": {TransId.SPEC_4_SHIFT: -100}, + "TRANS/ SHIFT=4000 5": {TransId.SPEC_5_SHIFT: 4000}, + "TRANS /SHIFT=4000 5": {TransId.SPEC_5_SHIFT: 4000}, + "TRANS/SHIFT=4000 5": {TransId.SPEC_5_SHIFT: 4000}, # An unrecognised monitor position (i.e. not 5) should be considered as 4 # see source code for details - "TRANS/SHIFT=1000 12": {TransId.spec_4_shift: 1000}, - "TRANS/SHIFT=4000 =12": {TransId.spec_4_shift: 4000}, - "TRANS/SHIFT=4000 =1": {TransId.spec_4_shift: 4000}, - "TRANS/SHIFT4000 120": {TransId.spec_4_shift: 4000}, - "TRANS/SHIFT 4000 999": {TransId.spec_4_shift: 4000}, + "TRANS/SHIFT=1000 12": {TransId.SPEC_4_SHIFT: 1000}, + "TRANS/SHIFT=4000 =12": {TransId.SPEC_4_SHIFT: 4000}, + "TRANS/SHIFT=4000 =1": {TransId.SPEC_4_SHIFT: 4000}, + "TRANS/SHIFT4000 120": {TransId.SPEC_4_SHIFT: 4000}, + "TRANS/SHIFT 4000 999": {TransId.SPEC_4_SHIFT: 4000}, } invalid_settings = { @@ -623,9 +625,9 @@ class TransParserTest(unittest.TestCase): do_test(trans_parser, valid_settings, invalid_settings, self.assertTrue, self.assertRaises) def test_that_trans_spec_shift_is_parsed_correctly(self): - valid_settings = {"TRANS/TRANSPEC=4/SHIFT=23": {TransId.spec_4_shift: 23, TransId.spec: 4}, - "TRANS/TRANSPEC =4/ SHIFT = 23": {TransId.spec_4_shift: 23, TransId.spec: 4}, - "TRANS/TRANSPEC =900/ SHIFT = 23": {TransId.spec_4_shift: 23, TransId.spec: 900}, + valid_settings = {"TRANS/TRANSPEC=4/SHIFT=23": {TransId.SPEC_4_SHIFT: 23, TransId.SPEC: 4}, + "TRANS/TRANSPEC =4/ SHIFT = 23": {TransId.SPEC_4_SHIFT: 23, TransId.SPEC: 4}, + "TRANS/TRANSPEC =900/ SHIFT = 23": {TransId.SPEC_4_SHIFT: 23, TransId.SPEC: 900}, } @@ -639,17 +641,17 @@ class TransParserTest(unittest.TestCase): do_test(trans_parser, valid_settings, invalid_settings, self.assertTrue, self.assertRaises) def test_that_radius_is_parsed_correctly(self): - valid_settings = {"TRANS / radius =23": {TransId.radius: 23}, - "TRANS /RADIUS= 245.7": {TransId.radius: 245.7}} + valid_settings = {"TRANS / radius =23": {TransId.RADIUS: 23}, + "TRANS /RADIUS= 245.7": {TransId.RADIUS: 245.7}} invalid_settings = {} trans_parser = TransParser() do_test(trans_parser, valid_settings, invalid_settings, self.assertTrue, self.assertRaises) def test_that_roi_is_parsed_correctly(self): - valid_settings = {"TRANS/ROI =testFile.xml": {TransId.roi: ["testFile.xml"]}, + valid_settings = {"TRANS/ROI =testFile.xml": {TransId.ROI: ["testFile.xml"]}, "TRANS/ROI =testFile.xml, " - "TestFile2.XmL,testFile4.xml": {TransId.roi: ["testFile.xml", "TestFile2.XmL", + "TestFile2.XmL,testFile4.xml": {TransId.ROI: ["testFile.xml", "TestFile2.XmL", "testFile4.xml"]}} invalid_settings = {"TRANS/ROI =t estFile.xml": RuntimeError, "TRANS/ROI =testFile.txt": RuntimeError, @@ -660,9 +662,9 @@ class TransParserTest(unittest.TestCase): do_test(trans_parser, valid_settings, invalid_settings, self.assertTrue, self.assertRaises) def test_that_mask_is_parsed_correctly(self): - valid_settings = {"TRANS/Mask =testFile.xml": {TransId.mask: ["testFile.xml"]}, + valid_settings = {"TRANS/Mask =testFile.xml": {TransId.MASK: ["testFile.xml"]}, "TRANS/ MASK =testFile.xml, " - "TestFile2.XmL,testFile4.xml": {TransId.mask: ["testFile.xml", "TestFile2.XmL", + "TestFile2.XmL,testFile4.xml": {TransId.MASK: ["testFile.xml", "TestFile2.XmL", "testFile4.xml"]}} invalid_settings = {"TRANS/MASK =t estFile.xml": RuntimeError, "TRANS/ MASK =testFile.txt": RuntimeError, @@ -673,10 +675,10 @@ class TransParserTest(unittest.TestCase): do_test(trans_parser, valid_settings, invalid_settings, self.assertTrue, self.assertRaises) def test_that_workspaces_are_parsed_correctly(self): - valid_settings = {"TRANS/SampleWS =testworksaoe234Name": {TransId.sample_workspace: "testworksaoe234Name"}, - "TRANS/ SampleWS = testworksaoe234Name": {TransId.sample_workspace: "testworksaoe234Name"}, - "TRANS/ CanWS =testworksaoe234Name": {TransId.can_workspace: "testworksaoe234Name"}, - "TRANS/ CANWS = testworksaoe234Name": {TransId.can_workspace: "testworksaoe234Name"}} + valid_settings = {"TRANS/SampleWS =testworksaoe234Name": {TransId.SAMPLE_WORKSPACE: "testworksaoe234Name"}, + "TRANS/ SampleWS = testworksaoe234Name": {TransId.SAMPLE_WORKSPACE: "testworksaoe234Name"}, + "TRANS/ CanWS =testworksaoe234Name": {TransId.CAN_WORKSPACE: "testworksaoe234Name"}, + "TRANS/ CANWS = testworksaoe234Name": {TransId.CAN_WORKSPACE: "testworksaoe234Name"}} invalid_settings = {"TRANS/CANWS/ test": RuntimeError, "TRANS/SAMPLEWS =": RuntimeError} @@ -689,8 +691,8 @@ class TubeCalibFileParserTest(unittest.TestCase): self.assertTrue(TubeCalibFileParser.get_type(), "TRANS") def test_that_tube_calibration_file_is_parsed_correctly(self): - valid_settings = {"TUBECALIbfile= calib_file.nxs": {TubeCalibrationFileId.file: "calib_file.nxs"}, - " tUBECALIBfile= caAlib_file.Nxs": {TubeCalibrationFileId.file: "caAlib_file.Nxs"}} + valid_settings = {"TUBECALIbfile= calib_file.nxs": {TubeCalibrationFileId.FILE: "calib_file.nxs"}, + " tUBECALIBfile= caAlib_file.Nxs": {TubeCalibrationFileId.FILE: "caAlib_file.Nxs"}} invalid_settings = {"TUBECALIFILE file.nxs": RuntimeError, "TUBECALIBFILE=file.txt": RuntimeError, @@ -705,8 +707,8 @@ class QResolutionParserTest(unittest.TestCase): self.assertTrue(QResolutionParser.get_type(), "QRESOL") def test_that_q_resolution_on_off_is_parsed_correctly(self): - valid_settings = {"QRESOL/ON": {QResolutionId.on: True}, - "QREsoL / oFF": {QResolutionId.on: False}} + valid_settings = {"QRESOL/ON": {QResolutionId.ON: True}, + "QREsoL / oFF": {QResolutionId.ON: False}} invalid_settings = {"QRESOL= ON": RuntimeError} @@ -714,14 +716,14 @@ class QResolutionParserTest(unittest.TestCase): do_test(q_resolution_parser, valid_settings, invalid_settings, self.assertTrue, self.assertRaises) def test_that_q_resolution_float_values_are_parsed_correctly(self): - valid_settings = {"QRESOL/deltaR = 23.546": {QResolutionId.delta_r: 23.546}, - "QRESOL/ Lcollim = 23.546": {QResolutionId.collimation_length: 23.546}, - "QRESOL/ a1 = 23.546": {QResolutionId.a1: 23.546}, - "QRESOL/ a2 = 23": {QResolutionId.a2: 23}, - "QRESOL / H1 = 23.546 ": {QResolutionId.h1: 23.546}, - "QRESOL /h2 = 23.546 ": {QResolutionId.h2: 23.546}, - "QRESOL / W1 = 23.546 ": {QResolutionId.w1: 23.546}, - "QRESOL /W2 = 23.546 ": {QResolutionId.w2: 23.546} + valid_settings = {"QRESOL/deltaR = 23.546": {QResolutionId.DELTA_R: 23.546}, + "QRESOL/ Lcollim = 23.546": {QResolutionId.COLLIMATION_LENGTH: 23.546}, + "QRESOL/ a1 = 23.546": {QResolutionId.A1: 23.546}, + "QRESOL/ a2 = 23": {QResolutionId.A2: 23}, + "QRESOL / H1 = 23.546 ": {QResolutionId.H1: 23.546}, + "QRESOL /h2 = 23.546 ": {QResolutionId.H2: 23.546}, + "QRESOL / W1 = 23.546 ": {QResolutionId.W1: 23.546}, + "QRESOL /W2 = 23.546 ": {QResolutionId.W2: 23.546} } invalid_settings = {"QRESOL/DELTAR 23": RuntimeError, @@ -733,7 +735,7 @@ class QResolutionParserTest(unittest.TestCase): do_test(q_resolution_parser, valid_settings, invalid_settings, self.assertTrue, self.assertRaises) def test_that_moderator_is_parsed_correctly(self): - valid_settings = {"QRESOL/MODERATOR = test_file.txt": {QResolutionId.moderator: "test_file.txt"}} + valid_settings = {"QRESOL/MODERATOR = test_file.txt": {QResolutionId.MODERATOR: "test_file.txt"}} invalid_settings = {"QRESOL/MODERATOR = test_file.nxs": RuntimeError, "QRESOL/MODERATOR/test_file.txt": RuntimeError, @@ -748,38 +750,38 @@ class FitParserTest(unittest.TestCase): self.assertTrue(FitParser.get_type(), "FIT") def test_that_general_fit_is_parsed_correctly(self): - valid_settings = {"FIT/ trans / LIN 123 3556": {FitId.general: fit_general(start=123, stop=3556, - fit_type=FitType.Linear, data_type=None, polynomial_order=0)}, - "FIT/ tranS/linear 123 3556": {FitId.general: fit_general(start=123, stop=3556, - fit_type=FitType.Linear, data_type=None, polynomial_order=0)}, - "FIT/TRANS/Straight 123 3556": {FitId.general: fit_general(start=123, stop=3556, - fit_type=FitType.Linear, data_type=None, polynomial_order=0)}, - "FIT/ tranS/LoG 123 3556.6 ": {FitId.general: fit_general(start=123, stop=3556.6, - fit_type=FitType.Logarithmic, data_type=None, polynomial_order=0)}, # noqa - "FIT/TRANS/ YlOG 123 3556": {FitId.general: fit_general(start=123, stop=3556, - fit_type=FitType.Logarithmic, data_type=None, polynomial_order=0)}, # noqa - "FIT/Trans/Lin": {FitId.general: fit_general(start=None, stop=None, fit_type=FitType.Linear, + valid_settings = {"FIT/ trans / LIN 123 3556": {FitId.GENERAL: fit_general(start=123, stop=3556, + fit_type=FitType.LINEAR, data_type=None, polynomial_order=0)}, + "FIT/ tranS/linear 123 3556": {FitId.GENERAL: fit_general(start=123, stop=3556, + fit_type=FitType.LINEAR, data_type=None, polynomial_order=0)}, + "FIT/TRANS/Straight 123 3556": {FitId.GENERAL: fit_general(start=123, stop=3556, + fit_type=FitType.LINEAR, data_type=None, polynomial_order=0)}, + "FIT/ tranS/LoG 123 3556.6 ": {FitId.GENERAL: fit_general(start=123, stop=3556.6, + fit_type=FitType.LOGARITHMIC, data_type=None, polynomial_order=0)}, # noqa + "FIT/TRANS/ YlOG 123 3556": {FitId.GENERAL: fit_general(start=123, stop=3556, + fit_type=FitType.LOGARITHMIC, data_type=None, polynomial_order=0)}, # noqa + "FIT/Trans/Lin": {FitId.GENERAL: fit_general(start=None, stop=None, fit_type=FitType.LINEAR, data_type=None, polynomial_order=0)}, - "FIT/Trans/ Log": {FitId.general: fit_general(start=None, stop=None, fit_type=FitType.Logarithmic, # noqa + "FIT/Trans/ Log": {FitId.GENERAL: fit_general(start=None, stop=None, fit_type=FitType.LOGARITHMIC, # noqa data_type=None, polynomial_order=0)}, - "FIT/Trans/ polYnomial": {FitId.general: fit_general(start=None, stop=None, - fit_type=FitType.Polynomial, data_type=None, polynomial_order=2)}, - "FIT/Trans/ polYnomial 3": {FitId.general: fit_general(start=None, stop=None, - fit_type=FitType.Polynomial, + "FIT/Trans/ polYnomial": {FitId.GENERAL: fit_general(start=None, stop=None, + fit_type=FitType.POLYNOMIAL, data_type=None, polynomial_order=2)}, + "FIT/Trans/ polYnomial 3": {FitId.GENERAL: fit_general(start=None, stop=None, + fit_type=FitType.POLYNOMIAL, data_type=None, polynomial_order=3)}, - "FIT/Trans/Sample/Log 23.4 56.7": {FitId.general: fit_general(start=23.4, stop=56.7, - fit_type=FitType.Logarithmic, data_type=DataType.Sample, + "FIT/Trans/Sample/Log 23.4 56.7": {FitId.GENERAL: fit_general(start=23.4, stop=56.7, + fit_type=FitType.LOGARITHMIC, data_type=DataType.SAMPLE, polynomial_order=0)}, - "FIT/Trans/can/ lIn 23.4 56.7": {FitId.general: fit_general(start=23.4, stop=56.7, - fit_type=FitType.Linear, data_type=DataType.Can, + "FIT/Trans/can/ lIn 23.4 56.7": {FitId.GENERAL: fit_general(start=23.4, stop=56.7, + fit_type=FitType.LINEAR, data_type=DataType.CAN, polynomial_order=0)}, - "FIT/Trans / can/polynomiAL 5 23 45": {FitId.general: fit_general(start=23, stop=45, - fit_type=FitType.Polynomial, data_type=DataType.Can, + "FIT/Trans / can/polynomiAL 5 23 45": {FitId.GENERAL: fit_general(start=23, stop=45, + fit_type=FitType.POLYNOMIAL, data_type=DataType.CAN, polynomial_order=5)}, - "FIT/ trans / clear": {FitId.general: fit_general(start=None, stop=None, - fit_type=FitType.NoFit, data_type=None, polynomial_order=None)}, - "FIT/traNS /ofF": {FitId.general: fit_general(start=None, stop=None, - fit_type=FitType.NoFit, data_type=None, polynomial_order=None)} + "FIT/ trans / clear": {FitId.GENERAL: fit_general(start=None, stop=None, + fit_type=FitType.NO_FIT, data_type=None, polynomial_order=None)}, + "FIT/traNS /ofF": {FitId.GENERAL: fit_general(start=None, stop=None, + fit_type=FitType.NO_FIT, data_type=None, polynomial_order=None)} } invalid_settings = {"FIT/TRANS/ YlOG 123": RuntimeError, @@ -797,8 +799,8 @@ class FitParserTest(unittest.TestCase): do_test(fit_parser, valid_settings, invalid_settings, self.assertTrue, self.assertRaises) def test_that_monitor_times_are_parsed_correctly(self): - valid_settings = {"FIT/monitor 12 34.5": {FitId.monitor_times: range_entry(start=12, stop=34.5)}, - "Fit / Monitor 12.6 34.5": {FitId.monitor_times: range_entry(start=12.6, stop=34.5)}} + valid_settings = {"FIT/monitor 12 34.5": {FitId.MONITOR_TIMES: range_entry(start=12, stop=34.5)}, + "Fit / Monitor 12.6 34.5": {FitId.MONITOR_TIMES: range_entry(start=12.6, stop=34.5)}} invalid_settings = {"Fit / Monitor 12.6 34 34": RuntimeError, "Fit / Monitor": RuntimeError} @@ -812,8 +814,8 @@ class GravityParserTest(unittest.TestCase): self.assertTrue(GravityParser.get_type(), "GRAVITY") def test_that_gravity_on_off_is_parsed_correctly(self): - valid_settings = {"Gravity on ": {GravityId.on_off: True}, - "Gravity OFF ": {GravityId.on_off: False}} + valid_settings = {"Gravity on ": {GravityId.ON_OFF: True}, + "Gravity OFF ": {GravityId.ON_OFF: False}} invalid_settings = {"Gravity ": RuntimeError, "Gravity ONN": RuntimeError} @@ -822,9 +824,9 @@ class GravityParserTest(unittest.TestCase): do_test(gravity_parser, valid_settings, invalid_settings, self.assertTrue, self.assertRaises) def test_that_gravity_extra_length_is_parsed_correctly(self): - valid_settings = {"Gravity/LExtra =23.5": {GravityId.extra_length: 23.5}, - "Gravity / lExtra = 23.5": {GravityId.extra_length: 23.5}, - "Gravity / lExtra 23.5": {GravityId.extra_length: 23.5}} + valid_settings = {"Gravity/LExtra =23.5": {GravityId.EXTRA_LENGTH: 23.5}, + "Gravity / lExtra = 23.5": {GravityId.EXTRA_LENGTH: 23.5}, + "Gravity / lExtra 23.5": {GravityId.EXTRA_LENGTH: 23.5}} invalid_settings = {"Gravity/LExtra - 23.5": RuntimeError, "Gravity/LExtra =tw": RuntimeError} @@ -838,8 +840,8 @@ class CompatibilityParserTest(unittest.TestCase): self.assertTrue(CompatibilityParser.get_type(), "COMPATIBILITY") def test_that_compatibility_on_off_is_parsed_correctly(self): - valid_settings = {"COMPATIBILITY on ": {OtherId.use_compatibility_mode: True}, - "COMPATIBILITY OFF ": {OtherId.use_compatibility_mode: False}} + valid_settings = {"COMPATIBILITY on ": {OtherId.USE_COMPATIBILITY_MODE: True}, + "COMPATIBILITY OFF ": {OtherId.USE_COMPATIBILITY_MODE: False}} invalid_settings = {"COMPATIBILITY ": RuntimeError, "COMPATIBILITY ONN": RuntimeError} @@ -854,7 +856,7 @@ class MaskFileParserTest(unittest.TestCase): def test_that_gravity_on_off_is_parsed_correctly(self): valid_settings = {"MaskFile= test.xml, testKsdk2.xml,tesetlskd.xml": - {MaskId.file: ["test.xml", "testKsdk2.xml", "tesetlskd.xml"]}} + {MaskId.FILE: ["test.xml", "testKsdk2.xml", "tesetlskd.xml"]}} invalid_settings = {"MaskFile=": RuntimeError, "MaskFile=test.txt": RuntimeError, @@ -869,9 +871,9 @@ class MonParserTest(unittest.TestCase): self.assertTrue(MonParser.get_type(), "MON") def test_that_length_is_parsed_correctly(self): - valid_settings = {"MON/length= 23.5 34": {MonId.length: monitor_length(length=23.5, spectrum=34, + valid_settings = {"MON/length= 23.5 34": {MonId.LENGTH: monitor_length(length=23.5, spectrum=34, interpolate=False)}, - "MON/length= 23.5 34 / InterPolate": {MonId.length: monitor_length(length=23.5, spectrum=34, + "MON/length= 23.5 34 / InterPolate": {MonId.LENGTH: monitor_length(length=23.5, spectrum=34, interpolate=True)}} invalid_settings = {"MON/length= 23.5 34.7": RuntimeError, @@ -883,27 +885,27 @@ class MonParserTest(unittest.TestCase): do_test(mon_parser, valid_settings, invalid_settings, self.assertTrue, self.assertRaises) def test_that_direct_files_are_parsed_correctly(self): - valid_settings = {"MON/DIRECT= C:\path1\Path2\file.ext ": {MonId.direct: [monitor_file( + valid_settings = {"MON/DIRECT= C:\path1\Path2\file.ext ": {MonId.DIRECT: [monitor_file( file_path="C:/path1/Path2/file.ext", detector_type=DetectorType.HAB), monitor_file(file_path="C:/path1/Path2/file.ext", detector_type=DetectorType.LAB)]}, - "MON/ direct = filE.Ext ": {MonId.direct: [monitor_file(file_path="filE.Ext", - detector_type=DetectorType.HAB), monitor_file( + "MON/ direct = filE.Ext ": {MonId.DIRECT: [monitor_file(file_path="filE.Ext", + detector_type=DetectorType.HAB), monitor_file( file_path="filE.Ext", detector_type=DetectorType.LAB) - ]}, - "MON/DIRECT= \path1\Path2\file.ext ": {MonId.direct: [monitor_file( + ]}, + "MON/DIRECT= \path1\Path2\file.ext ": {MonId.DIRECT: [monitor_file( file_path="/path1/Path2/file.ext", detector_type=DetectorType.HAB), monitor_file(file_path="/path1/Path2/file.ext", detector_type=DetectorType.LAB)]}, - "MON/DIRECT= /path1/Path2/file.ext ": {MonId.direct: [monitor_file( + "MON/DIRECT= /path1/Path2/file.ext ": {MonId.DIRECT: [monitor_file( file_path="/path1/Path2/file.ext", detector_type=DetectorType.HAB), - monitor_file(file_path="/path1/Path2/file.ext", - detector_type=DetectorType.LAB)]}, - "MON/DIRECT/ rear= /path1/Path2/file.ext ": {MonId.direct: [monitor_file( + monitor_file(file_path="/path1/Path2/file.ext", + detector_type=DetectorType.LAB)]}, + "MON/DIRECT/ rear= /path1/Path2/file.ext ": {MonId.DIRECT: [monitor_file( file_path="/path1/Path2/file.ext", detector_type=DetectorType.LAB)]}, - "MON/DIRECT/ frONT= path1/Path2/file.ext ": {MonId.direct: [monitor_file( + "MON/DIRECT/ frONT= path1/Path2/file.ext ": {MonId.DIRECT: [monitor_file( file_path="path1/Path2/file.ext", detector_type=DetectorType.HAB)]} } @@ -916,21 +918,21 @@ class MonParserTest(unittest.TestCase): do_test(mon_parser, valid_settings, invalid_settings, self.assertTrue, self.assertRaises) def test_that_flat_files_are_parsed_correctly(self): - valid_settings = {"MON/FLat = C:\path1\Path2\file.ext ": {MonId.flat: monitor_file( + valid_settings = {"MON/FLat = C:\path1\Path2\file.ext ": {MonId.FLAT: monitor_file( file_path="C:/path1/Path2/file.ext", detector_type=DetectorType.LAB)}, - "MON/ flAt = filE.Ext ": {MonId.flat: monitor_file(file_path="filE.Ext", - detector_type=DetectorType.LAB)}, - "MON/flAT= \path1\Path2\file.ext ": {MonId.flat: monitor_file( + "MON/ flAt = filE.Ext ": {MonId.FLAT: monitor_file(file_path="filE.Ext", + detector_type=DetectorType.LAB)}, + "MON/flAT= \path1\Path2\file.ext ": {MonId.FLAT: monitor_file( file_path="/path1/Path2/file.ext", detector_type=DetectorType.LAB)}, - "MON/FLat= /path1/Path2/file.ext ": {MonId.flat: monitor_file( + "MON/FLat= /path1/Path2/file.ext ": {MonId.FLAT: monitor_file( file_path="/path1/Path2/file.ext", detector_type=DetectorType.LAB)}, - "MON/FLat/ rear= /path1/Path2/file.ext ": {MonId.flat: monitor_file( + "MON/FLat/ rear= /path1/Path2/file.ext ": {MonId.FLAT: monitor_file( file_path="/path1/Path2/file.ext", detector_type=DetectorType.LAB)}, - "MON/FLat/ frONT= path1/Path2/file.ext ": {MonId.flat: monitor_file( + "MON/FLat/ frONT= path1/Path2/file.ext ": {MonId.FLAT: monitor_file( file_path="path1/Path2/file.ext", detector_type=DetectorType.HAB)}} @@ -942,16 +944,16 @@ class MonParserTest(unittest.TestCase): do_test(mon_parser, valid_settings, invalid_settings, self.assertTrue, self.assertRaises) def test_that_hab_files_are_parsed_correctly(self): - valid_settings = {"MON/HAB = C:\path1\Path2\file.ext ": {MonId.direct: [monitor_file( + valid_settings = {"MON/HAB = C:\path1\Path2\file.ext ": {MonId.DIRECT: [monitor_file( file_path="C:/path1/Path2/file.ext", detector_type=DetectorType.HAB)]}, - "MON/ hAB = filE.Ext ": {MonId.direct: [monitor_file( + "MON/ hAB = filE.Ext ": {MonId.DIRECT: [monitor_file( file_path="filE.Ext", detector_type=DetectorType.HAB)]}, - "MON/HAb= \path1\Path2\file.ext ": {MonId.direct: [monitor_file( + "MON/HAb= \path1\Path2\file.ext ": {MonId.DIRECT: [monitor_file( file_path="/path1/Path2/file.ext", detector_type=DetectorType.HAB)]}, - "MON/hAB= /path1/Path2/file.ext ": {MonId.direct: [monitor_file( + "MON/hAB= /path1/Path2/file.ext ": {MonId.DIRECT: [monitor_file( file_path="/path1/Path2/file.ext", detector_type=DetectorType.HAB)]}} invalid_settings = {"MON/HAB= /path1/ Path2/file.ext ": RuntimeError, @@ -962,13 +964,13 @@ class MonParserTest(unittest.TestCase): do_test(mon_parser, valid_settings, invalid_settings, self.assertTrue, self.assertRaises) def test_that_hab_files_are_parsed_correctly2(self): - valid_settings = {"MON/Spectrum = 123 ": {MonId.spectrum: monitor_spectrum(spectrum=123, is_trans=False, + valid_settings = {"MON/Spectrum = 123 ": {MonId.SPECTRUM: monitor_spectrum(spectrum=123, is_trans=False, interpolate=False)}, - "MON/trans/Spectrum = 123 ": {MonId.spectrum: monitor_spectrum(spectrum=123, is_trans=True, + "MON/trans/Spectrum = 123 ": {MonId.SPECTRUM: monitor_spectrum(spectrum=123, is_trans=True, interpolate=False)}, - "MON/trans/Spectrum = 123 / interpolate": {MonId.spectrum: monitor_spectrum(spectrum=123, - is_trans=True, interpolate=True)}, - "MON/Spectrum = 123 / interpolate": {MonId.spectrum: monitor_spectrum(spectrum=123, + "MON/trans/Spectrum = 123 / interpolate": {MonId.SPECTRUM: monitor_spectrum(spectrum=123, + is_trans=True, interpolate=True)}, + "MON/Spectrum = 123 / interpolate": {MonId.SPECTRUM: monitor_spectrum(spectrum=123, is_trans=False, interpolate=True)}} invalid_settings = {} @@ -982,10 +984,10 @@ class PrintParserTest(unittest.TestCase): self.assertTrue(PrintParser.get_type(), "PRINT") def test_that_print_is_parsed_correctly(self): - valid_settings = {"PRINT OdlfP slsk 23lksdl2 34l": {PrintId.print_line: "OdlfP slsk 23lksdl2 34l"}, - "PRiNt OdlfP slsk 23lksdl2 34l": {PrintId.print_line: "OdlfP slsk 23lksdl2 34l"}, + valid_settings = {"PRINT OdlfP slsk 23lksdl2 34l": {PrintId.PRINT_LINE: "OdlfP slsk 23lksdl2 34l"}, + "PRiNt OdlfP slsk 23lksdl2 34l": {PrintId.PRINT_LINE: "OdlfP slsk 23lksdl2 34l"}, " PRINT Loaded: USER_LOQ_174J, 12/03/18, Xuzhi (Lu), 12mm, Sample Changer, Banjo cells": - {PrintId.print_line: "Loaded: USER_LOQ_174J, 12/03/18, Xuzhi (Lu), 12mm, Sample Changer, Banjo cells"} + {PrintId.PRINT_LINE: "Loaded: USER_LOQ_174J, 12/03/18, Xuzhi (Lu), 12mm, Sample Changer, Banjo cells"} } invalid_settings = {"j PRINT OdlfP slsk 23lksdl2 34l ": RuntimeError,} @@ -999,7 +1001,7 @@ class BackParserTest(unittest.TestCase): self.assertTrue(BackParser.get_type(), "BACK") def test_that_all_monitors_is_parsed_correctly(self): - valid_settings = {"BACK / MON /times 123 34": {BackId.all_monitors: range_entry(start=123, stop=34)}} + valid_settings = {"BACK / MON /times 123 34": {BackId.ALL_MONITORS: range_entry(start=123, stop=34)}} invalid_settings = {} @@ -1007,10 +1009,10 @@ class BackParserTest(unittest.TestCase): do_test(back_parser, valid_settings, invalid_settings, self.assertTrue, self.assertRaises) def test_that_single_monitors_is_parsed_correctly(self): - valid_settings = {"BACK / M3 /times 123 34": {BackId.single_monitors: back_single_monitor_entry(monitor=3, + valid_settings = {"BACK / M3 /times 123 34": {BackId.SINGLE_MONITORS: back_single_monitor_entry(monitor=3, start=123, stop=34)}, - "BACK / M3 123 34": {BackId.single_monitors: back_single_monitor_entry(monitor=3, + "BACK / M3 123 34": {BackId.SINGLE_MONITORS: back_single_monitor_entry(monitor=3, start=123, stop=34)}} @@ -1021,7 +1023,7 @@ class BackParserTest(unittest.TestCase): do_test(back_parser, valid_settings, invalid_settings, self.assertTrue, self.assertRaises) def test_that_off_is_parsed_correctly(self): - valid_settings = {"BACK / M3 /OFF": {BackId.monitor_off: 3}} + valid_settings = {"BACK / M3 /OFF": {BackId.MONITOR_OFF: 3}} invalid_settings = {"BACK / M /OFF": RuntimeError} @@ -1029,8 +1031,8 @@ class BackParserTest(unittest.TestCase): do_test(back_parser, valid_settings, invalid_settings, self.assertTrue, self.assertRaises) def test_that_trans_mon_is_parsed_correctly(self): - valid_settings = {"BACK / TRANS 123 344": {BackId.trans: range_entry(start=123, stop=344)}, - "BACK / tranS 123 34": {BackId.trans: range_entry(start=123, stop=34)}} + valid_settings = {"BACK / TRANS 123 344": {BackId.TRANS: range_entry(start=123, stop=344)}, + "BACK / tranS 123 34": {BackId.TRANS: range_entry(start=123, stop=34)}} invalid_settings = {"BACK / Trans / 123 34": RuntimeError, "BACK / trans 123": RuntimeError} @@ -1079,68 +1081,68 @@ class UserFileParserTest(unittest.TestCase): # DetParser result = user_file_parser.parse_line(" DET/CoRR/FRONT/ SidE -957") - assert_valid_result(result, {DetectorId.correction_translation: single_entry_with_detector(entry=-957, - detector_type=DetectorType.HAB)}, self.assertTrue) + assert_valid_result(result, {DetectorId.CORRECTION_TRANSLATION: single_entry_with_detector(entry=-957, + detector_type=DetectorType.HAB)}, self.assertTrue) # LimitParser result = user_file_parser.parse_line("l/Q/WCUT 234.4") - assert_valid_result(result, {LimitsId.wavelength_cut: 234.4}, self.assertTrue) + assert_valid_result(result, {LimitsId.WAVELENGTH_CUT: 234.4}, self.assertTrue) # MaskParser result = user_file_parser.parse_line("MASK S 12 ") - assert_valid_result(result, {MaskId.single_spectrum_mask: 12}, self.assertTrue) + assert_valid_result(result, {MaskId.SINGLE_SPECTRUM_MASK: 12}, self.assertTrue) # SampleParser result = user_file_parser.parse_line("SAMPLE /Offset 234.5") - assert_valid_result(result, {SampleId.offset: 234.5}, self.assertTrue) + assert_valid_result(result, {SampleId.OFFSET: 234.5}, self.assertTrue) # TransParser result = user_file_parser.parse_line("TRANS / radius =23") - assert_valid_result(result, {TransId.radius: 23}, self.assertTrue) + assert_valid_result(result, {TransId.RADIUS: 23}, self.assertTrue) # TubeCalibFileParser result = user_file_parser.parse_line("TUBECALIbfile= calib_file.nxs") - assert_valid_result(result, {TubeCalibrationFileId.file: "calib_file.nxs"}, self.assertTrue) + assert_valid_result(result, {TubeCalibrationFileId.FILE: "calib_file.nxs"}, self.assertTrue) # QResolutionParser result = user_file_parser.parse_line("QRESOL/ON") - assert_valid_result(result, {QResolutionId.on: True}, self.assertTrue) + assert_valid_result(result, {QResolutionId.ON: True}, self.assertTrue) # FitParser result = user_file_parser.parse_line("FIT/TRANS/Straight 123 3556") - assert_valid_result(result, {FitId.general: fit_general(start=123, stop=3556, fit_type=FitType.Linear, + assert_valid_result(result, {FitId.GENERAL: fit_general(start=123, stop=3556, fit_type=FitType.LINEAR, data_type=None, polynomial_order=0)}, self.assertTrue) # GravityParser result = user_file_parser.parse_line("Gravity/LExtra =23.5") - assert_valid_result(result, {GravityId.extra_length: 23.5}, self.assertTrue) + assert_valid_result(result, {GravityId.EXTRA_LENGTH: 23.5}, self.assertTrue) # MaskFileParser result = user_file_parser.parse_line("MaskFile= test.xml, testKsdk2.xml,tesetlskd.xml") - assert_valid_result(result, {MaskId.file: ["test.xml", "testKsdk2.xml", "tesetlskd.xml"]}, + assert_valid_result(result, {MaskId.FILE: ["test.xml", "testKsdk2.xml", "tesetlskd.xml"]}, self.assertTrue) # MonParser result = user_file_parser.parse_line("MON/length= 23.5 34") - assert_valid_result(result, {MonId.length: monitor_length(length=23.5, spectrum=34, interpolate=False)}, + assert_valid_result(result, {MonId.LENGTH: monitor_length(length=23.5, spectrum=34, interpolate=False)}, self.assertTrue) # PrintParser result = user_file_parser.parse_line("PRINT OdlfP slsk 23lksdl2 34l") - assert_valid_result(result, {PrintId.print_line: "OdlfP slsk 23lksdl2 34l"}, self.assertTrue) + assert_valid_result(result, {PrintId.PRINT_LINE: "OdlfP slsk 23lksdl2 34l"}, self.assertTrue) # BackParser result = user_file_parser.parse_line("BACK / M3 /OFF") - assert_valid_result(result, {BackId.monitor_off: 3}, self.assertTrue) + assert_valid_result(result, {BackId.MONITOR_OFF: 3}, self.assertTrue) # Instrument parser result = user_file_parser.parse_line("SANS2D") - assert_valid_result(result, {DetectorId.instrument: SANSInstrument.SANS2D}, self.assertTrue) + assert_valid_result(result, {DetectorId.INSTRUMENT: SANSInstrument.SANS2D}, self.assertTrue) # Instrument parser - whitespace result = user_file_parser.parse_line(" ZOOM ") - assert_valid_result(result, {DetectorId.instrument: SANSInstrument.ZOOM}, self.assertTrue) + assert_valid_result(result, {DetectorId.INSTRUMENT: SANSInstrument.ZOOM}, self.assertTrue) def test_that_non_existent_parser_throws(self): # Arrange diff --git a/scripts/test/SANS/user_file/user_file_reader_test.py b/scripts/test/SANS/user_file/user_file_reader_test.py index ab24655cac79f1a58ca55c26349be7fc97cad802..96dcd8ed88df576b13a47ed9f62c87587805694b 100644 --- a/scripts/test/SANS/user_file/user_file_reader_test.py +++ b/scripts/test/SANS/user_file/user_file_reader_test.py @@ -5,19 +5,19 @@ # & Institut Laue - Langevin # SPDX - License - Identifier: GPL - 3.0 + from __future__ import (absolute_import, division, print_function) -import unittest -import mantid + import os -from sans.common.enums import (ISISReductionMode, DetectorType, RangeStepType, FitType) -from sans.user_file.user_file_reader import UserFileReader +import unittest + +from sans.common.enums import (ReductionMode, DetectorType, RangeStepType, FitType) +from sans.test_helper.user_file_test_helper import create_user_file, sample_user_file from sans.user_file.settings_tags import (DetectorId, BackId, range_entry, back_single_monitor_entry, - single_entry_with_detector, mask_angle_entry, LimitsId, rebin_string_values, - simple_range, complex_range, MaskId, mask_block, mask_block_cross, - mask_line, range_entry_with_detector, SampleId, SetId, set_scales_entry, + single_entry_with_detector, LimitsId, simple_range, complex_range, MaskId, + range_entry_with_detector, SampleId, SetId, set_scales_entry, position_entry, TransId, TubeCalibrationFileId, QResolutionId, FitId, - fit_general, MonId, monitor_length, monitor_file, GravityId, + fit_general, MonId, monitor_file, GravityId, monitor_spectrum, PrintId, q_rebin_values) -from sans.test_helper.user_file_test_helper import create_user_file, sample_user_file +from sans.user_file.user_file_reader import UserFileReader # ----------------------------------------------------------------- @@ -33,63 +33,63 @@ class UserFileReaderTest(unittest.TestCase): output = reader.read_user_file() # Assert - expected_values = {LimitsId.wavelength: [simple_range(start=1.5, stop=12.5, step=0.125, - step_type=RangeStepType.Lin)], - LimitsId.q: [q_rebin_values(min=.001, max=.2, rebin_string="0.001,0.001,0.0126,-0.08,0.2")], - LimitsId.qxy: [simple_range(0, 0.05, 0.001, RangeStepType.Lin)], - BackId.single_monitors: [back_single_monitor_entry(1, 35000, 65000), + expected_values = {LimitsId.WAVELENGTH: [simple_range(start=1.5, stop=12.5, step=0.125, + step_type=RangeStepType.LIN)], + LimitsId.Q: [q_rebin_values(min=.001, max=.2, rebin_string="0.001,0.001,0.0126,-0.08,0.2")], + LimitsId.QXY: [simple_range(0, 0.05, 0.001, RangeStepType.LIN)], + BackId.SINGLE_MONITORS: [back_single_monitor_entry(1, 35000, 65000), back_single_monitor_entry(2, 85000, 98000)], - DetectorId.reduction_mode: [ISISReductionMode.LAB], - GravityId.on_off: [True], - FitId.general: [fit_general(start=1.5, stop=12.5, fit_type=FitType.Logarithmic, + DetectorId.REDUCTION_MODE: [ReductionMode.LAB], + GravityId.ON_OFF: [True], + FitId.GENERAL: [fit_general(start=1.5, stop=12.5, fit_type=FitType.LOGARITHMIC, data_type=None, polynomial_order=0)], - MaskId.vertical_single_strip_mask: [single_entry_with_detector(191, DetectorType.LAB), + MaskId.VERTICAL_SINGLE_STRIP_MASK: [single_entry_with_detector(191, DetectorType.LAB), single_entry_with_detector(191, DetectorType.HAB), single_entry_with_detector(0, DetectorType.LAB), single_entry_with_detector(0, DetectorType.HAB)], - MaskId.horizontal_single_strip_mask: [single_entry_with_detector(0, DetectorType.LAB), + MaskId.HORIZONTAL_SINGLE_STRIP_MASK: [single_entry_with_detector(0, DetectorType.LAB), single_entry_with_detector(0, DetectorType.HAB)], - MaskId.horizontal_range_strip_mask: [range_entry_with_detector(190, 191, DetectorType.LAB), + MaskId.HORIZONTAL_RANGE_STRIP_MASK: [range_entry_with_detector(190, 191, DetectorType.LAB), range_entry_with_detector(167, 172, DetectorType.LAB), range_entry_with_detector(190, 191, DetectorType.HAB), range_entry_with_detector(156, 159, DetectorType.HAB) ], - MaskId.time: [range_entry_with_detector(17500, 22000, None)], - MonId.direct: [monitor_file("DIRECTM1_15785_12m_31Oct12_v12.dat", DetectorType.LAB), + MaskId.TIME: [range_entry_with_detector(17500, 22000, None)], + MonId.DIRECT: [monitor_file("DIRECTM1_15785_12m_31Oct12_v12.dat", DetectorType.LAB), monitor_file("DIRECTM1_15785_12m_31Oct12_v12.dat", DetectorType.HAB)], - MonId.spectrum: [monitor_spectrum(1, True, True), monitor_spectrum(1, False, True)], - SetId.centre: [position_entry(155.45, -169.6, DetectorType.LAB)], - SetId.scales: [set_scales_entry(0.074, 1.0, 1.0, 1.0, 1.0)], - SampleId.offset: [53.0], - DetectorId.correction_x: [single_entry_with_detector(-16.0, DetectorType.LAB), + MonId.SPECTRUM: [monitor_spectrum(1, True, True), monitor_spectrum(1, False, True)], + SetId.CENTRE: [position_entry(155.45, -169.6, DetectorType.LAB)], + SetId.SCALES: [set_scales_entry(0.074, 1.0, 1.0, 1.0, 1.0)], + SampleId.OFFSET: [53.0], + DetectorId.CORRECTION_X: [single_entry_with_detector(-16.0, DetectorType.LAB), single_entry_with_detector(-44.0, DetectorType.HAB)], - DetectorId.correction_y: [single_entry_with_detector(-20.0, DetectorType.HAB)], - DetectorId.correction_z: [single_entry_with_detector(47.0, DetectorType.LAB), + DetectorId.CORRECTION_Y: [single_entry_with_detector(-20.0, DetectorType.HAB)], + DetectorId.CORRECTION_Z: [single_entry_with_detector(47.0, DetectorType.LAB), single_entry_with_detector(47.0, DetectorType.HAB)], - DetectorId.correction_rotation: [single_entry_with_detector(0.0, DetectorType.HAB)], - LimitsId.events_binning: ["7000.0,500.0,60000.0"], - MaskId.clear_detector_mask: [True], - MaskId.clear_time_mask: [True], - LimitsId.radius: [range_entry(12, 15)], - TransId.spec_4_shift: [-70.], - PrintId.print_line: ["for changer"], - BackId.all_monitors: [range_entry(start=3500, stop=4500)], - FitId.monitor_times: [range_entry(start=1000, stop=2000)], - TransId.spec: [4], - BackId.trans: [range_entry(start=123, stop=466)], - TransId.radius: [7.0], - TransId.roi: ["test.xml", "test2.xml"], - TransId.mask: ["test3.xml", "test4.xml"], - SampleId.path: [True], - LimitsId.radius_cut: [200.0], - LimitsId.wavelength_cut: [8.0], - QResolutionId.on: [True], - QResolutionId.delta_r: [11.], - QResolutionId.collimation_length: [12.], - QResolutionId.a1: [13.], - QResolutionId.a2: [14.], - QResolutionId.moderator: ["moderator_rkh_file.txt"], - TubeCalibrationFileId.file: ["TUBE_SANS2D_BOTH_31681_25Sept15.nxs"]} + DetectorId.CORRECTION_ROTATION: [single_entry_with_detector(0.0, DetectorType.HAB)], + LimitsId.EVENTS_BINNING: ["7000.0,500.0,60000.0"], + MaskId.CLEAR_DETECTOR_MASK: [True], + MaskId.CLEAR_TIME_MASK: [True], + LimitsId.RADIUS: [range_entry(12, 15)], + TransId.SPEC_4_SHIFT: [-70.], + PrintId.PRINT_LINE: ["for changer"], + BackId.ALL_MONITORS: [range_entry(start=3500, stop=4500)], + FitId.MONITOR_TIMES: [range_entry(start=1000, stop=2000)], + TransId.SPEC: [4], + BackId.TRANS: [range_entry(start=123, stop=466)], + TransId.RADIUS: [7.0], + TransId.ROI: ["test.xml", "test2.xml"], + TransId.MASK: ["test3.xml", "test4.xml"], + SampleId.PATH: [True], + LimitsId.RADIUS_CUT: [200.0], + LimitsId.WAVELENGTH_CUT: [8.0], + QResolutionId.ON: [True], + QResolutionId.DELTA_R: [11.], + QResolutionId.COLLIMATION_LENGTH: [12.], + QResolutionId.A1: [13.], + QResolutionId.A2: [14.], + QResolutionId.MODERATOR: ["moderator_rkh_file.txt"], + TubeCalibrationFileId.FILE: ["TUBE_SANS2D_BOTH_31681_25Sept15.nxs"]} self.assertEqual(len(expected_values), len(output)) for key, value in list(expected_values.items()): @@ -123,7 +123,7 @@ class UserFileReaderTest(unittest.TestCase): elif isinstance(elements[0], range_entry_with_detector): UserFileReaderTest._sort(elements, lambda x: x.start) elif isinstance(elements[0], monitor_file): - UserFileReaderTest._sort(elements, lambda x: (x.file_path, DetectorType.to_string(x.detector_type))) + UserFileReaderTest._sort(elements, lambda x: (x.file_path, x.detector_type.value)) elif isinstance(elements[0], monitor_spectrum): UserFileReaderTest._sort(elements, lambda x: x.spectrum) elif isinstance(elements[0], position_entry):