ReflectometryQuickAuxiliaryTest.py 5.39 KB
Newer Older
1
from __future__ import (absolute_import, division, print_function)
2
import unittest
3
import numpy
4
from mantid.simpleapi import *
5
from isis_reflectometry import quick
6
7

class ReflectometryQuickAuxiliaryTest(unittest.TestCase):
8

9
    __wsName = None
10
11


12
13
    def __init__(self, methodName='runTest'):
        self.__wsName = "TestWorkspace"
14
15
16
        super(ReflectometryQuickAuxiliaryTest, self).__init__(methodName)

    def setUp(self):
17
        LoadISISNexus(Filename='POLREF00004699', OutputWorkspace=self.__wsName)
18

19
    def tearDown(self):
20
        DeleteWorkspace(mtd[self.__wsName])
21

22
    def test_cleanup(self):
23
        numObjectsOriginal = len(mtd.getObjectNames())
24
25
        todump =CreateSingleValuedWorkspace(OutputWorkspace='_toremove', DataValue=1, ErrorValue=1)
        tokeep =CreateSingleValuedWorkspace(OutputWorkspace='tokeep', DataValue=1, ErrorValue=1)
26
        self.assertEqual(numObjectsOriginal+2, len(mtd.getObjectNames()))
27
28
29
        # Should remove workspaces starting with _
        quick.cleanup()
        cleaned_object_names = mtd.getObjectNames()
30
        self.assertEqual(numObjectsOriginal+1, len(cleaned_object_names))
31
        self.assertEqual(True, ('tokeep' in cleaned_object_names))
32

33
        DeleteWorkspace(tokeep)
34

35
    def test_groupGet_instrument(self):
36

37
        expectedInstrument = "POLREF"
38

39
        # Test with group workspace as input
40
        instrument = quick.groupGet(self.__wsName, 'inst')
41
        self.assertEqual(expectedInstrument, instrument.getName(), "Did not fetch the instrument from ws group")
42

43
        # Test with single workspace as input
44
        instrument = quick.groupGet(mtd[self.__wsName][0].name(), 'inst')
45
        self.assertEqual(expectedInstrument, instrument.getName(), "Did not fetch the instrument from ws")
46
47


48
    def test_groupGet_histogram_count(self):
49
        expectedNHistograms = mtd[self.__wsName][0].getNumberHistograms()
50

51
        # Test with group workspace as input
52
        nHistograms = quick.groupGet(self.__wsName, 'wksp')
53
        self.assertEqual(expectedNHistograms, nHistograms, "Did not fetch the n histograms from ws group")
54

55
        # Test with single workspace as input
56
        nHistograms = quick.groupGet(mtd[self.__wsName][0].name(), 'wksp')
57
        self.assertEqual(expectedNHistograms, nHistograms, "Did not fetch the n histograms from ws")
58
59


60
    def test_groupGet_log_single_value(self):
61

62
        expectedNPeriods = 2
63

64
        # Test with group workspace as input
65
        nPeriods = quick.groupGet(self.__wsName, 'samp', 'nperiods')
66
        self.assertEqual(expectedNPeriods, nPeriods, "Did not fetch the number of periods from ws group")
67

68
        # Test with single workspace as input
69
        nPeriods = quick.groupGet(mtd[self.__wsName][0].name(), 'samp', 'nperiods')
70
        self.assertEqual(expectedNPeriods, nPeriods, "Did not fetch the number of periods from ws")
71

72
    def test_groupGet_multi_value_log(self):
73

74
        # Expected start theta, taken from the last value of the time series log.
75
76
        expectedStartTheta = 0.4903

77
        # Test with group workspace as input
78
        stheta = quick.groupGet(self.__wsName, 'samp', 'stheta')
79
        self.assertEqual(expectedStartTheta, round(float(stheta), 4))
80

81
        # Test with single workspace as input
82
        stheta = quick.groupGet(mtd[self.__wsName][0].name(), 'samp', 'stheta')
83
        self.assertEqual(expectedStartTheta, round(float(stheta), 4))
84

85
86
87
    def test_groupGet_unknown_log_error_code(self):
        errorCode = 0
        # Test with group workspace as input
88
        self.assertEqual(errorCode, quick.groupGet(self.__wsName, 'samp','MADE-UP-LOG-NAME'))
89

90
        # Test with group workspace as input
91
        self.assertEqual(errorCode, quick.groupGet(mtd[self.__wsName][0].name(), 'samp','MADE-UP-LOG-NAME'))
92

93
94
    def test_exponential_correction_strategy(self):
        test_ws =  CreateWorkspace(UnitX="TOF", DataX=[0,1,2,3], DataY=[1,1,1], NSpec=1)
95

96
97
        correction = quick.ExponentialCorrectionStrategy(1, 0) # Should have no effect.
        self.assertTrue(isinstance(correction, quick.CorrectionStrategy), msg="Should be of type Correction")
98

99
        corrected = correction.apply(test_ws)
100

101
        self.assertTrue( all( test_ws.readY(0) == corrected.readY(0) ), msg="Input and outputs should be identical" )
102

103
104
        DeleteWorkspace(test_ws)
        DeleteWorkspace(corrected)
105

106
107
    def test_polynomial_correction_strategy(self):
        test_ws =  CreateWorkspace(UnitX="TOF", DataX=[0,1,2,3], DataY=[1,1,1], NSpec=1)
108

109
110
        correction = quick.PolynomialCorrectionStrategy("1, 0") # Should have no effect.
        self.assertTrue(isinstance(correction, quick.CorrectionStrategy), msg="Should be of type Correction")
111

112
        corrected = correction.apply(test_ws)
113

114
        self.assertTrue( all( test_ws.readY(0) == corrected.readY(0) ), msg="Input and outputs should be identical" )
115

116
117
        DeleteWorkspace(test_ws)
        DeleteWorkspace(corrected)
118

119
120
    def test_null_correction_strategy(self):
        test_ws = CreateWorkspace(UnitX="TOF", DataX=[0,1,2,3], DataY=[1,1,1], NSpec=1)
121

122
123
        correction = quick.NullCorrectionStrategy() # Should have no effect.
        self.assertTrue(isinstance(correction, quick.CorrectionStrategy), msg="Should be of type Correction")
124

125
        corrected = correction.apply(test_ws)
126

127
        self.assertTrue( all( test_ws.readY(0) == corrected.readY(0) ), msg="Input and outputs should be identical" )
128

129
130
        DeleteWorkspace(test_ws)
        DeleteWorkspace(corrected)
131
132


133
if __name__ == '__main__':
134
    unittest.main()