Commit 4c25372d authored by Owen Arnold's avatar Owen Arnold
Browse files

refs #8520. Swap in new functionality and add tests.

Should now have no more force IDF parameter reads. You can do everything via the quick_explicit.
parent d6af85dd
......@@ -10,9 +10,9 @@
#from ReflectometerCors import *
from l2q import *
from combineMulti import *
#from mantidsimple import * # Old API
from mantid.simpleapi import * # New API
from mantid.api import WorkspaceGroup
from mantid.kernel import logger
from convert_to_wavelength import ConvertToWavelength
import math
import re
......@@ -32,14 +32,24 @@ class ExponentialCorrectionStrategy(CorrectionStrategy):
self.__c1 = c1
def apply(self, to_correct):
return ExponentialCorrection(InputWorkspace=to_correct,C0=self.__c0, C1= self.__c1, Operation='Divide')
logger.information("Exponential Correction")
corrected = ExponentialCorrection(InputWorkspace=to_correct,C0=self.__c0, C1= self.__c1, Operation='Divide')
return corrected
class PolynomialCorrectionStrategy(CorrectionStrategy):
def __init__(self, poly_string):
self.__poly_string = poly_string
def apply(self, to_correct):
return PolynomialCorrection(InputWorkspace=to_correct, Coefficients=self.__poly_string, Operation='Divide')
logger.information("Polynomial Correction")
corrected = PolynomialCorrection(InputWorkspace=to_correct, Coefficients=self.__poly_string, Operation='Divide')
return corrected
class NullCorrectionStrategy(CorrectionStrategy):
def apply(self, to_correct):
logger.information("Null Correction")
out = to_correct.clone()
return out
def quick(run, theta=0, pointdet=True,roi=[0,0], db=[0,0], trans='', polcorr=0, usemon=-1,outputType='pd', debug=False):
......@@ -60,15 +70,17 @@ def quick(run, theta=0, pointdet=True,roi=[0,0], db=[0,0], trans='', polcorr=0,
background_max = idf_defaults['MonitorBackgroundMax']
int_min = idf_defaults['MonitorIntegralMin']
int_max = idf_defaults['MonitorIntegralMax']
correction_strategy = idf_defaults['AlgoritmicCorrection']
return quick_explicit(run=run, i0_monitor_index = i0_monitor_index, lambda_min = lambda_min, lambda_max = lambda_max,
point_detector_start = point_detector_start, point_detector_stop = point_detector_stop,
multi_detector_start = multi_detector_start, background_min = background_min, background_max = background_max,
int_min = int_min, int_max = int_max, theta = theta, pointdet = pointdet, roi = roi, db = db, trans = trans, debug = debug )
int_min = int_min, int_max = int_max, theta = theta, pointdet = pointdet, roi = roi, db = db, trans = trans,
debug = debug, correction_strategy = correction_strategy )
def quick_explicit(run, i0_monitor_index, lambda_min, lambda_max, background_min, background_max, int_min, int_max,
point_detector_start=0, point_detector_stop=0, multi_detector_start=0, theta=0, pointdet=True,roi=[0,0], db=[0,0], trans='', debug=False):
point_detector_start=0, point_detector_stop=0, multi_detector_start=0, theta=0, pointdet=True,roi=[0,0], db=[0,0], trans='', debug=False, correction_strategy=NullCorrectionStrategy):
'''
Version of quick where all parameters are explicitly provided.
'''
......@@ -125,21 +137,7 @@ def quick_explicit(run, i0_monitor_index, lambda_min, lambda_max, background_mi
RunNumber = groupGet(IvsLam.getName(),'samp','run_number')
if (trans==''):
print "No transmission file. Trying default exponential/polynomial correction..."
inst=groupGet(_detector_ws.getName(),'inst')
corrType=inst.getStringParameter('correction')[0]
if (corrType=='polynomial'):
pString=inst.getStringParameter('polystring')
print pString
if len(pString):
IvsLam = PolynomialCorrection(InputWorkspace=_detector_ws,Coefficients=pString[0],Operation='Divide')
else:
print "No polynomial coefficients in IDF. Using monitor spectrum with no corrections."
elif (corrType=='exponential'):
c0=inst.getNumberParameter('C0')
c1=inst.getNumberParameter('C1')
print "Exponential parameters: ", c0[0], c1[0]
if len(c0):
IvsLam = ExponentialCorrection(InputWorkspace=_detector_ws,C0=c0[0],C1=c1[0],Operation='Divide')
IvsLam = correction_strategy.apply(_detector_ws)
IvsLam = Divide(LHSWorkspace=IvsLam, RHSWorkspace=_I0P)
else: # we have a transmission run
_monInt = Integration(InputWorkspace=_I0P,RangeLower=int_min,RangeUpper=int_max)
......@@ -277,6 +275,17 @@ def get_defaults(run_ws):
defaults['MultiDetectorStart'] = int( instrument.getNumberParameter('MultiDetectorStart')[0] )
defaults['I0MonitorIndex'] = int( instrument.getNumberParameter('I0MonitorIndex')[0] )
correction = NullCorrectionStrategy()
corrType=instrument.getStringParameter('correction')[0]
if corrType == 'polynomial':
poly_string = instrument.getStringParameter('polystring')[0]
correction = PolynomialCorrectionStrategy(poly_string)
elif corrType == 'exponential':
c0=instrument.getNumberParameter('C0')[0]
c1=instrument.getNumberParameter('C1')[0]
correction = ExponentialCorrectionStrategy(c0, c1)
defaults['AlgoritmicCorrection'] = correction
return defaults
def groupGet(wksp,whattoget,field=''):
......
......@@ -88,6 +88,45 @@ class ReflectometryQuickAuxiliaryTest(unittest.TestCase):
# Test with group workspace as input
self.assertEquals(errorCode, quick.groupGet(mtd[self.__wsName][0].name(), 'samp','MADE-UP-LOG-NAME'))
def test_exponential_correction_strategy(self):
test_ws = CreateWorkspace(UnitX="TOF", DataX=[0,1,2,3], DataY=[1,1,1], NSpec=1)
correction = quick.ExponentialCorrectionStrategy(1, 0) # Should have no effect.
self.assertTrue(isinstance(correction, quick.CorrectionStrategy), msg="Should be of type Correction")
corrected = correction.apply(test_ws)
self.assertTrue( all( test_ws.readY(0) == corrected.readY(0) ), msg="Input and outputs should be identical" )
DeleteWorkspace(test_ws)
DeleteWorkspace(corrected)
def test_polynomial_correction_strategy(self):
test_ws = CreateWorkspace(UnitX="TOF", DataX=[0,1,2,3], DataY=[1,1,1], NSpec=1)
correction = quick.PolynomialCorrectionStrategy("1, 0") # Should have no effect.
self.assertTrue(isinstance(correction, quick.CorrectionStrategy), msg="Should be of type Correction")
corrected = correction.apply(test_ws)
self.assertTrue( all( test_ws.readY(0) == corrected.readY(0) ), msg="Input and outputs should be identical" )
DeleteWorkspace(test_ws)
DeleteWorkspace(corrected)
def test_null_correction_strategy(self):
test_ws = CreateWorkspace(UnitX="TOF", DataX=[0,1,2,3], DataY=[1,1,1], NSpec=1)
correction = quick.NullCorrectionStrategy() # Should have no effect.
self.assertTrue(isinstance(correction, quick.CorrectionStrategy), msg="Should be of type Correction")
corrected = correction.apply(test_ws)
self.assertTrue( all( test_ws.readY(0) == corrected.readY(0) ), msg="Input and outputs should be identical" )
DeleteWorkspace(test_ws)
DeleteWorkspace(corrected)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment