Skip to content
Snippets Groups Projects
Commit d89eba08 authored by Keith Brown's avatar Keith Brown
Browse files

Merge remote-tracking branch 'origin/feature/9440_stitch1d_many'

parents 9faa5b72 c02bb4aa
No related merge requests found
#from mantid.simpleapi import *
from mantid.simpleapi import * from mantid.simpleapi import *
from mantid.api import * from mantid.api import *
from mantid.kernel import * from mantid.kernel import *
import numpy as np import numpy as np
class Stitch1DMany(PythonAlgorithm): class Stitch1DMany(DataProcessorAlgorithm):
def category(self): def category(self):
return "Reflectometry\\ISIS;PythonAlgorithms" return "Reflectometry\\ISIS;"
def name(self): def name(self):
return "Stitch1D" return "Stitch1D"
...@@ -20,13 +18,13 @@ class Stitch1DMany(PythonAlgorithm): ...@@ -20,13 +18,13 @@ class Stitch1DMany(PythonAlgorithm):
input_validator = StringMandatoryValidator() input_validator = StringMandatoryValidator()
self.declareProperty(name="InputWorkspaces", defaultValue="", direction=Direction.Input, validator=input_validator, doc="Input workspaces") self.declareProperty(name="InputWorkspaces", defaultValue="", direction=Direction.Input, validator=input_validator, doc="Input workspaces")
self.declareProperty(WorkspaceProperty("OutputWorkspace", "", Direction.Output), "Output stitched workspace") self.declareProperty(WorkspaceProperty("OutputWorkspace", "", Direction.Output), "Output stitched workspace")
self.declareProperty(FloatArrayProperty(name="StartOverlaps", values=[]), doc="Overlap in Q.") self.declareProperty(FloatArrayProperty(name="StartOverlaps", values=[]), doc="Start overlaps for stitched workspaces.")
self.declareProperty(FloatArrayProperty(name="EndOverlaps", values=[]), doc="End overlap in Q.") self.declareProperty(FloatArrayProperty(name="EndOverlaps", values=[]), doc="End overlap for stitched workspaces.")
self.declareProperty(FloatArrayProperty(name="Params", validator=FloatArrayMandatoryValidator()), doc="Rebinning Parameters. See Rebin for format.") self.declareProperty(FloatArrayProperty(name="Params", validator=FloatArrayMandatoryValidator()), doc="Rebinning Parameters. See Rebin for format.")
self.declareProperty(name="ScaleRHSWorkspace", defaultValue=True, doc="Scaling either with respect to workspace 1 or workspace 2.") self.declareProperty(name="ScaleRHSWorkspace", defaultValue=True, doc="Scaling either with respect to workspace 1 or workspace 2.")
self.declareProperty(name="UseManualScaleFactor", defaultValue=False, doc="True to use a provided value for the scale factor.") self.declareProperty(name="UseManualScaleFactor", defaultValue=False, doc="True to use a provided value for the scale factor.")
self.declareProperty(name="ManualScaleFactor", defaultValue=1.0, doc="Provided value for the scale factor.") self.declareProperty(name="ManualScaleFactor", defaultValue=1.0, doc="Provided value for the scale factor.")
self.declareProperty(name="OutScaleFactor", defaultValue=-2.0, direction = Direction.Output, doc="The actual used value for the scaling factor.") self.declareProperty(FloatArrayProperty(name="OutScaleFactors", direction = Direction.Output), doc="The actual used values for the scaling factors at each stitch step.")
def __workspace_from_split_name(self, list_of_names, index): def __workspace_from_split_name(self, list_of_names, index):
return mtd[list_of_names[index].strip()] return mtd[list_of_names[index].strip()]
...@@ -115,7 +113,7 @@ class Stitch1DMany(PythonAlgorithm): ...@@ -115,7 +113,7 @@ class Stitch1DMany(PythonAlgorithm):
raise ValueError("Wrong number of StartOverlaps, should be %i not %i" % (numberOfWorkspaces - 1, startOverlaps)) raise ValueError("Wrong number of StartOverlaps, should be %i not %i" % (numberOfWorkspaces - 1, startOverlaps))
self.__check_workspaces_are_common(inputWorkspaces) self.__check_workspaces_are_common(inputWorkspaces)
scaleFactor = None scaleFactors = list()
comma_separator = "," comma_separator = ","
no_separator = str() no_separator = str()
...@@ -142,10 +140,13 @@ class Stitch1DMany(PythonAlgorithm): ...@@ -142,10 +140,13 @@ class Stitch1DMany(PythonAlgorithm):
startOverlaps = self.getProperty("StartOverlaps").value startOverlaps = self.getProperty("StartOverlaps").value
endOverlaps = self.getProperty("EndOverlaps").value endOverlaps = self.getProperty("EndOverlaps").value
stitched, scaleFactor = Stitch1DMany(InputWorkspaces=to_process, OutputWorkspace=out_name, StartOverlaps=startOverlaps, EndOverlaps=endOverlaps, stitched, scaleFactorsOfIndex = Stitch1DMany(InputWorkspaces=to_process, OutputWorkspace=out_name, StartOverlaps=startOverlaps, EndOverlaps=endOverlaps,
Params=params, ScaleRHSWorkspace=scaleRHSWorkspace, UseManualScaleFactor=useManualScaleFactor, Params=params, ScaleRHSWorkspace=scaleRHSWorkspace, UseManualScaleFactor=useManualScaleFactor,
ManualScaleFactor=manualScaleFactor) ManualScaleFactor=manualScaleFactor)
# Flatten out scale factors.
for sf in scaleFactorsOfIndex:
scaleFactors.append(sf)
out_group_workspaces += out_group_separator + out_name out_group_workspaces += out_group_separator + out_name
out_group_separator = comma_separator out_group_separator = comma_separator
...@@ -162,6 +163,7 @@ class Stitch1DMany(PythonAlgorithm): ...@@ -162,6 +163,7 @@ class Stitch1DMany(PythonAlgorithm):
for i in range(1, numberOfWorkspaces, 1): for i in range(1, numberOfWorkspaces, 1):
rhsWS = self.__workspace_from_split_name(inputWorkspaces, i) rhsWS = self.__workspace_from_split_name(inputWorkspaces, i)
lhsWS, scaleFactor = self.__do_stitch_workspace(lhsWS, rhsWS, startOverlaps[i-1], endOverlaps[i-1], params, scaleRHSWorkspace, useManualScaleFactor, manualScaleFactor) lhsWS, scaleFactor = self.__do_stitch_workspace(lhsWS, rhsWS, startOverlaps[i-1], endOverlaps[i-1], params, scaleRHSWorkspace, useManualScaleFactor, manualScaleFactor)
scaleFactors.append(scaleFactor)
self.setProperty('OutputWorkspace', lhsWS) self.setProperty('OutputWorkspace', lhsWS)
# Iterate backwards through the workspaces. # Iterate backwards through the workspaces.
...@@ -170,9 +172,10 @@ class Stitch1DMany(PythonAlgorithm): ...@@ -170,9 +172,10 @@ class Stitch1DMany(PythonAlgorithm):
for i in range(0, numberOfWorkspaces-1, 1): for i in range(0, numberOfWorkspaces-1, 1):
lhsWS = self.__workspace_from_split_name(inputWorkspaces, i) lhsWS = self.__workspace_from_split_name(inputWorkspaces, i)
rhsWS, scaleFactor = Stitch1D(LHSWorkspace=lhsWS, RHSWorkspace=rhsWS, StartOverlap=startOverlaps[i-1], EndOverlap=endOverlaps[i-1], Params=params, ScaleRHSWorkspace=scaleRHSWorkspace, UseManualScaleFactor=useManualScaleFactor, ManualScaleFactor=manualScaleFactor) rhsWS, scaleFactor = Stitch1D(LHSWorkspace=lhsWS, RHSWorkspace=rhsWS, StartOverlap=startOverlaps[i-1], EndOverlap=endOverlaps[i-1], Params=params, ScaleRHSWorkspace=scaleRHSWorkspace, UseManualScaleFactor=useManualScaleFactor, ManualScaleFactor=manualScaleFactor)
scaleFactors.append(scaleFactor)
self.setProperty('OutputWorkspace', rhsWS) self.setProperty('OutputWorkspace', rhsWS)
self.setProperty('OutScaleFactor', scaleFactor) self.setProperty('OutScaleFactors', scaleFactors)
return None return None
......
...@@ -96,8 +96,10 @@ class Stitch1DManyTest(unittest.TestCase): ...@@ -96,8 +96,10 @@ class Stitch1DManyTest(unittest.TestCase):
def test_stitches_two(self): def test_stitches_two(self):
stitchedViaStitchMany, scaleFactorMany = Stitch1DMany(InputWorkspaces='a, b', StartOverlaps=[-0.4], EndOverlaps=[0.4], Params=[0.2]) stitchedViaStitchMany, scaleFactorMany = Stitch1DMany(InputWorkspaces='a, b', StartOverlaps=[-0.4], EndOverlaps=[0.4], Params=[0.2])
stitchedViaStitchTwo, scaleFactorTwo = Stitch1D(LHSWorkspace=self.a, RHSWorkspace=self.b, StartOverlap=-0.4, EndOverlap=0.4, Params=[0.2]) stitchedViaStitchTwo, scaleFactorTwo = Stitch1D(LHSWorkspace=self.a, RHSWorkspace=self.b, StartOverlap=-0.4, EndOverlap=0.4, Params=[0.2])
self.assertEquals(scaleFactorTwo, scaleFactorMany) self.assertTrue(isinstance(scaleFactorMany, numpy.ndarray), "Should be returning a list of scale factors")
self.assertEqual(1, scaleFactorMany.size)
self.assertEquals(scaleFactorTwo, scaleFactorMany[0])
expectedYData = [0,0,0,3,3,3,3,0,0,0] expectedYData = [0,0,0,3,3,3,3,0,0,0]
self.do_check_ydata(expectedYData, stitchedViaStitchMany) self.do_check_ydata(expectedYData, stitchedViaStitchMany)
...@@ -113,11 +115,13 @@ class Stitch1DManyTest(unittest.TestCase): ...@@ -113,11 +115,13 @@ class Stitch1DManyTest(unittest.TestCase):
ws1 = CreateWorkspace(UnitX="1/q", DataX=self.x, DataY=[3.0, 3.0, 3.0, 3.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], NSpec=1, DataE=self.e) ws1 = CreateWorkspace(UnitX="1/q", DataX=self.x, DataY=[3.0, 3.0, 3.0, 3.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], NSpec=1, DataE=self.e)
ws2 = CreateWorkspace(UnitX="1/q", DataX=self.x, DataY=[0.0, 0.0, 0.0, 2.0, 2.0, 2.0, 2.0, 0.0, 0.0, 0.0], NSpec=1, DataE=self.e) ws2 = CreateWorkspace(UnitX="1/q", DataX=self.x, DataY=[0.0, 0.0, 0.0, 2.0, 2.0, 2.0, 2.0, 0.0, 0.0, 0.0], NSpec=1, DataE=self.e)
ws3 = CreateWorkspace(UnitX="1/q", DataX=self.x, DataY=[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0], NSpec=1, DataE=self.e) ws3 = CreateWorkspace(UnitX="1/q", DataX=self.x, DataY=[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0], NSpec=1, DataE=self.e)
stitchedViaStitchMany, sf = Stitch1DMany(InputWorkspaces='ws1, ws2, ws3', StartOverlaps=[-0.4,0.2], EndOverlaps=[-0.2,0.4], Params=0.2) stitchedViaStitchMany, sfs = Stitch1DMany(InputWorkspaces='ws1, ws2, ws3', StartOverlaps=[-0.4,0.2], EndOverlaps=[-0.2,0.4], Params=0.2)
expectedYData = [3,3,3,3,3,3,3,3,3,3] expectedYData = [3,3,3,3,3,3,3,3,3,3]
self.do_check_ydata(expectedYData, stitchedViaStitchMany) self.do_check_ydata(expectedYData, stitchedViaStitchMany)
self.assertEquals(3.0, round(sf, 6)) self.assertEqual(2, sfs.size)
self.assertEquals(1.5, round(sfs[0], 6))
self.assertEquals(3.0, round(sfs[1], 6))
DeleteWorkspace(ws1) DeleteWorkspace(ws1)
DeleteWorkspace(ws2) DeleteWorkspace(ws2)
...@@ -143,10 +147,11 @@ class Stitch1DManyTest(unittest.TestCase): ...@@ -143,10 +147,11 @@ class Stitch1DManyTest(unittest.TestCase):
ws3 = CreateWorkspace(UnitX="1/q", DataX=self.x, DataY=[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0], NSpec=1, DataE=self.e) ws3 = CreateWorkspace(UnitX="1/q", DataX=self.x, DataY=[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0], NSpec=1, DataE=self.e)
input_group_1 = GroupWorkspaces(InputWorkspaces="%s,%s,%s" % (ws1.name(), ws2.name(), ws3.name())) input_group_1 = GroupWorkspaces(InputWorkspaces="%s,%s,%s" % (ws1.name(), ws2.name(), ws3.name()))
input_group_2 = GroupWorkspaces(InputWorkspaces="%s,%s,%s" % (ws1.name(), ws2.name(), ws3.name())) input_group_2 = GroupWorkspaces(InputWorkspaces="%s,%s,%s" % (ws1.name(), ws2.name(), ws3.name()))
stitched, sf = Stitch1DMany(InputWorkspaces='%s,%s' % (input_group_1.name(), input_group_2.name()), Params=0.2) stitched, sfs = Stitch1DMany(InputWorkspaces='%s,%s' % (input_group_1.name(), input_group_2.name()), Params=0.2)
self.assertTrue(isinstance(stitched, WorkspaceGroup), "Output should be a group workspace") self.assertTrue(isinstance(stitched, WorkspaceGroup), "Output should be a group workspace")
self.assertEqual(stitched.size(), 3, "Output should contain 3 workspaces") self.assertEqual(stitched.size(), 3, "Output should contain 3 workspaces")
self.assertEqual(stitched.name(), "stitched", "Output not named correctly") self.assertEqual(stitched.name(), "stitched", "Output not named correctly")
self.assertEquals(input_group_1.size(), sfs.size)
DeleteWorkspace(input_group_1) DeleteWorkspace(input_group_1)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment