Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in
Toggle navigation
Menu
Open sidebar
mantidproject
mantid
Commits
7c738fc1
Unverified
Commit
7c738fc1
authored
Sep 22, 2021
by
Gigg, Martyn Anthony
Committed by
GitHub
Sep 22, 2021
Browse files
Merge pull request #32489 from gemmaguest/support_kwargs_in_createChildAlgorithm
Support kwargs in create child algorithm
parents
4730bd33
cce097cb
Changes
9
Hide whitespace changes
Inline
Side-by-side
Framework/PythonInterface/core/inc/MantidPythonInterface/core/Converters/PyNativeTypeExtractor.h
0 → 100644
View file @
7c738fc1
// Mantid Repository : https://github.com/mantidproject/mantid
//
// Copyright © 2021 ISIS Rutherford Appleton Laboratory UKRI,
// NScD Oak Ridge National Laboratory, European Spallation Source,
// Institut Laue - Langevin & CSNS, Institute of High Energy Physics, CAS
// SPDX - License - Identifier: GPL - 3.0 +
#pragma once
#include
"MantidAPI/Workspace.h"
#include
"MantidKernel/Logger.h"
#include
<boost/python/extract.hpp>
#include
<boost/python/list.hpp>
#include
<boost/python/object.hpp>
#include
<boost/variant.hpp>
#include
<exception>
#include
<string>
#include
<variant>
#include
<vector>
namespace
{
Mantid
::
Kernel
::
Logger
g_log
(
"Python Type Extractor"
);
}
namespace
Mantid
::
PythonInterface
{
struct
PyNativeTypeExtractor
{
using
PythonOutputT
=
boost
::
make_recursive_variant
<
bool
,
long
,
double
,
std
::
string
,
Mantid
::
API
::
Workspace_sptr
,
std
::
vector
<
boost
::
recursive_variant_
>>::
type
;
static
PythonOutputT
convert
(
const
boost
::
python
::
object
&
obj
)
{
using
namespace
boost
::
python
;
PyObject
*
rawptr
=
obj
.
ptr
();
PythonOutputT
out
;
if
(
PyList_Check
(
rawptr
))
{
out
=
handleList
(
obj
);
}
else
if
(
PyBool_Check
(
rawptr
))
{
out
=
extract
<
bool
>
(
obj
);
}
else
if
(
PyFloat_Check
(
rawptr
))
{
out
=
extract
<
double
>
(
obj
);
}
else
if
(
PyLong_Check
(
rawptr
))
{
out
=
extract
<
long
>
(
obj
);
}
else
if
(
PyUnicode_Check
(
rawptr
))
{
out
=
extract
<
std
::
string
>
(
obj
);
}
else
if
(
auto
extractor
=
extract
<
Mantid
::
API
::
Workspace_sptr
>
(
obj
);
extractor
.
check
())
{
out
=
extractor
();
}
else
{
throw
std
::
invalid_argument
(
"Unrecognised Python type"
);
}
return
out
;
}
private:
static
PythonOutputT
handleList
(
const
boost
::
python
::
object
&
obj
)
{
auto
rawptr
=
obj
.
ptr
();
auto
n
=
PyList_Size
(
rawptr
);
auto
vec
=
std
::
vector
<
PythonOutputT
>
();
vec
.
reserve
(
n
);
for
(
Py_ssize_t
i
=
0
;
i
<
n
;
++
i
)
{
vec
.
emplace_back
(
convert
(
obj
[
i
]));
}
return
vec
;
}
};
class
IPyTypeVisitor
:
public
boost
::
static_visitor
<>
{
public:
/**
* Dynamically dispatches to overloaded operator depending on the underlying type.
* This also handles cases such as std::vector<T>, or nested lists, which will be flattened
* by invoking the operator each time for each elem.
* This assumes all element in a list matches the first element type found in that list,
* where this is not true, the element will be cast to the first element type.
*
* Note: you will need to include
* using Mantid::PythonInterface::IPyTypeVisitor::operator();
*/
virtual
~
IPyTypeVisitor
()
=
default
;
virtual
void
operator
()(
bool
value
)
const
=
0
;
virtual
void
operator
()(
long
value
)
const
=
0
;
virtual
void
operator
()(
double
value
)
const
=
0
;
virtual
void
operator
()(
std
::
string
)
const
=
0
;
virtual
void
operator
()(
Mantid
::
API
::
Workspace_sptr
)
const
=
0
;
virtual
void
operator
()(
std
::
vector
<
bool
>
)
const
=
0
;
virtual
void
operator
()(
std
::
vector
<
long
>
)
const
=
0
;
virtual
void
operator
()(
std
::
vector
<
double
>
)
const
=
0
;
virtual
void
operator
()(
std
::
vector
<
std
::
string
>
)
const
=
0
;
void
operator
()(
std
::
vector
<
PyNativeTypeExtractor
::
PythonOutputT
>
const
&
values
)
const
{
if
(
values
.
size
()
==
0
)
return
;
const
auto
&
elemType
=
values
[
0
].
type
();
// We must manually dispatch for container types, as boost will try
// to recurse down to scalar values.
if
(
elemType
==
typeid
(
bool
))
{
applyVectorProp
<
bool
>
(
values
);
}
else
if
(
elemType
==
typeid
(
double
))
{
applyVectorProp
<
double
>
(
values
);
}
else
if
(
elemType
==
typeid
(
long
))
{
applyVectorProp
<
long
>
(
values
);
}
else
if
(
elemType
==
typeid
(
std
::
string
))
{
applyVectorProp
<
std
::
string
>
(
values
);
}
else
{
// Recurse down
for
(
const
auto
&
val
:
values
)
{
boost
::
apply_visitor
(
*
this
,
val
);
}
}
}
private:
template
<
typename
ScalarT
>
void
applyVectorProp
(
const
std
::
vector
<
Mantid
::
PythonInterface
::
PyNativeTypeExtractor
::
PythonOutputT
>
&
values
)
const
{
std
::
vector
<
ScalarT
>
propVals
;
propVals
.
reserve
(
values
.
size
());
// Explicitly copy so we don't have to think about Python lifetimes with refs
try
{
std
::
transform
(
values
.
cbegin
(),
values
.
cend
(),
std
::
back_inserter
(
propVals
),
[](
const
Mantid
::
PythonInterface
::
PyNativeTypeExtractor
::
PythonOutputT
&
varadicVal
)
{
return
boost
::
get
<
ScalarT
>
(
varadicVal
);
});
}
catch
(
boost
::
bad_get
&
e
)
{
std
::
string
err
{
"A list with mixed types is unsupported as precision loss can occur trying to determine a common type."
"
\n
Original exception: "
};
// Boost will convert bad_get into runtime_error anyway....
throw
std
::
runtime_error
(
err
+
e
.
what
());
}
this
->
operator
()(
std
::
move
(
propVals
));
}
};
}
// namespace Mantid::PythonInterface
Framework/PythonInterface/mantid/api/src/Exports/Algorithm.cpp
View file @
7c738fc1
...
...
@@ -17,14 +17,35 @@
#ifdef _MSC_VER
#pragma warning(default : 4250)
#endif
#include
"MantidPythonInterface/core/Converters/PyNativeTypeExtractor.h"
#include
"MantidPythonInterface/core/GetPointer.h"
#include
<boost/optional.hpp>
#include
<boost/python/bases.hpp>
#include
<boost/python/class.hpp>
#include
<boost/python/dict.hpp>
#include
<boost/python/exception_translator.hpp>
#include
<boost/python/overloads.hpp>
// As of boost 1.67 raw_function.hpp tries to pass
// through size_t types through to make_function
// which accepts int type, emitting a warning
#ifdef _MSC_VER
#pragma warning(push)
#pragma warning(disable : 4267)
#endif
#include
<boost/python/raw_function.hpp>
#ifdef _MSC_VER
#pragma warning(pop)
#endif
#include
<boost/python/register_ptr_to_python.hpp>
#include
<boost/python/scope.hpp>
#include
<boost/variant.hpp>
#include
<boost/variant/static_visitor.hpp>
#include
<cstddef>
#include
<string>
#include
<variant>
#include
<vector>
using
Mantid
::
API
::
Algorithm
;
using
Mantid
::
API
::
DistributedAlgorithm
;
...
...
@@ -56,6 +77,7 @@ using declarePropertyType3 = void (*)(boost::python::object &, const std::string
// declarePyAlgProperty(name, defaultValue, direction)
using
declarePropertyType4
=
void
(
*
)(
boost
::
python
::
object
&
,
const
std
::
string
&
,
const
boost
::
python
::
object
&
,
const
int
);
GNU_DIAG_OFF
(
"unused-local-typedef"
)
// Ignore -Wconversion warnings coming from boost::python
// Seen with GCC 7.1.1 and Boost 1.63.0
...
...
@@ -64,6 +86,7 @@ GNU_DIAG_OFF("conversion")
BOOST_PYTHON_FUNCTION_OVERLOADS
(
declarePropertyType1_Overload
,
PythonAlgorithm
::
declarePyAlgProperty
,
2
,
3
)
BOOST_PYTHON_FUNCTION_OVERLOADS
(
declarePropertyType2_Overload
,
PythonAlgorithm
::
declarePyAlgProperty
,
3
,
6
)
BOOST_PYTHON_FUNCTION_OVERLOADS
(
declarePropertyType3_Overload
,
PythonAlgorithm
::
declarePyAlgProperty
,
4
,
5
)
GNU_DIAG_ON
(
"conversion"
)
GNU_DIAG_ON
(
"unused-local-typedef"
)
...
...
@@ -76,6 +99,93 @@ void translateCancel(const Algorithm::CancelException &exc) {
UNUSED_ARG
(
exc
);
PyErr_SetString
(
PyExc_KeyboardInterrupt
,
""
);
}
template
<
typename
T
>
boost
::
optional
<
T
>
extractArg
(
ssize_t
index
,
const
tuple
&
args
)
{
if
(
index
<
len
(
args
))
{
return
boost
::
optional
<
T
>
(
extract
<
T
>
(
args
[
index
]));
}
return
boost
::
none
;
}
template
<
typename
T
>
void
extractKwargs
(
const
dict
&
kwargs
,
const
std
::
string
&
keyName
,
boost
::
optional
<
T
>
&
out
)
{
if
(
!
kwargs
.
has_key
(
keyName
))
{
return
;
}
if
(
out
!=
boost
::
none
)
{
throw
std
::
invalid_argument
(
"Parameter called '"
+
keyName
+
"' was specified twice."
" This must be either positional or a kwarg, but not both."
);
}
out
=
boost
::
optional
<
T
>
(
extract
<
T
>
(
kwargs
.
get
(
keyName
)));
}
class
SetPropertyVisitor
final
:
public
Mantid
::
PythonInterface
::
IPyTypeVisitor
{
public:
SetPropertyVisitor
(
Mantid
::
API
::
Algorithm_sptr
&
alg
,
std
::
string
const
&
propName
)
:
m_alg
(
alg
),
m_propName
(
propName
)
{}
void
operator
()(
bool
value
)
const
override
{
setProp
(
value
);
}
void
operator
()(
long
value
)
const
override
{
setProp
(
static_cast
<
int
>
(
value
));
}
void
operator
()(
double
value
)
const
override
{
setProp
(
value
);
}
void
operator
()(
std
::
string
value
)
const
override
{
m_alg
->
setPropertyValue
(
m_propName
,
value
);
}
void
operator
()(
Mantid
::
API
::
Workspace_sptr
ws
)
const
override
{
m_alg
->
setProperty
(
m_propName
,
std
::
move
(
ws
));
}
void
operator
()(
std
::
vector
<
bool
>
value
)
const
override
{
setProp
(
value
);
}
void
operator
()(
std
::
vector
<
long
>
value
)
const
override
{
setProp
(
value
);
}
void
operator
()(
std
::
vector
<
double
>
value
)
const
override
{
setProp
(
value
);
}
void
operator
()(
std
::
vector
<
std
::
string
>
value
)
const
override
{
setProp
(
value
);
}
using
Mantid
::
PythonInterface
::
IPyTypeVisitor
::
operator
();
private:
template
<
typename
T
>
void
setProp
(
const
T
&
val
)
const
{
m_alg
->
setProperty
(
m_propName
,
val
);
}
Mantid
::
API
::
Algorithm_sptr
&
m_alg
;
std
::
string
const
&
m_propName
;
};
// Signature createChildWithProps(self, name, startProgress, endProgress, enableLogging, version, **kwargs)
object
createChildWithProps
(
tuple
args
,
dict
kwargs
)
{
Mantid
::
API
::
Algorithm_sptr
parentAlg
=
extract
<
Mantid
::
API
::
Algorithm_sptr
>
(
args
[
0
]);
auto
name
=
extractArg
<
std
::
string
>
(
1
,
args
);
auto
startProgress
=
extractArg
<
double
>
(
2
,
args
);
auto
endProgress
=
extractArg
<
double
>
(
3
,
args
);
auto
enableLogging
=
extractArg
<
bool
>
(
4
,
args
);
auto
version
=
extractArg
<
int
>
(
5
,
args
);
const
std
::
array
<
std
::
string
,
5
>
reservedNames
=
{
"name"
,
"startProgress"
,
"endProgress"
,
"enableLogging"
,
"version"
};
extractKwargs
<
std
::
string
>
(
kwargs
,
reservedNames
[
0
],
name
);
extractKwargs
<
double
>
(
kwargs
,
reservedNames
[
1
],
startProgress
);
extractKwargs
<
double
>
(
kwargs
,
reservedNames
[
2
],
endProgress
);
extractKwargs
<
bool
>
(
kwargs
,
reservedNames
[
3
],
enableLogging
);
extractKwargs
<
int
>
(
kwargs
,
reservedNames
[
4
],
version
);
if
(
!
name
.
is_initialized
())
{
throw
std
::
invalid_argument
(
"Please specify the algorithm name"
);
}
auto
childAlg
=
parentAlg
->
createChildAlgorithm
(
name
.
value
(),
startProgress
.
value_or
(
-
1
),
endProgress
.
value_or
(
-
1
),
enableLogging
.
value_or
(
true
),
version
.
value_or
(
-
1
));
const
list
keys
=
kwargs
.
keys
();
for
(
int
i
=
0
;
i
<
len
(
keys
);
++
i
)
{
const
std
::
string
propName
=
extract
<
std
::
string
>
(
keys
[
i
]);
if
(
std
::
find
(
reservedNames
.
cbegin
(),
reservedNames
.
cend
(),
propName
)
!=
reservedNames
.
cend
())
continue
;
object
curArg
=
kwargs
[
keys
[
i
]];
if
(
!
curArg
)
continue
;
using
Mantid
::
PythonInterface
::
PyNativeTypeExtractor
;
auto
nativeObj
=
PyNativeTypeExtractor
::
convert
(
curArg
);
boost
::
apply_visitor
(
SetPropertyVisitor
(
childAlg
,
propName
),
nativeObj
);
}
return
object
(
childAlg
);
}
}
// namespace
void
export_leaf_classes
()
{
...
...
@@ -86,14 +196,11 @@ void export_leaf_classes() {
// std::shared_ptr<AlgorithmAdapter>
// See
// http://wiki.python.org/moin/boost.python/HowTo#ownership_of_C.2B-.2B-_object_extended_in_Python
class_
<
Algorithm
,
bases
<
Mantid
::
API
::
IAlgorithm
>
,
std
::
shared_ptr
<
PythonAlgorithm
>
,
boost
::
noncopyable
>
(
"Algorithm"
,
"Base class for all algorithms"
)
.
def
(
"fromString"
,
&
Algorithm
::
fromString
,
"Initialize the algorithm from a string representation"
)
.
staticmethod
(
"fromString"
)
.
def
(
"createChildAlgorithm"
,
&
Algorithm
::
createChildAlgorithm
,
(
arg
(
"self"
),
arg
(
"name"
),
arg
(
"startProgress"
)
=
-
1.0
,
arg
(
"endProgress"
)
=
-
1.0
,
arg
(
"enableLogging"
)
=
true
,
arg
(
"version"
)
=
-
1
),
.
def
(
"createChildAlgorithm"
,
raw_function
(
&
createChildWithProps
,
std
::
size_t
(
1
)),
"Creates and intializes a named child algorithm. Output workspaces "
"are given a dummy name."
)
.
def
(
"declareProperty"
,
(
declarePropertyType1
)
&
PythonAlgorithm
::
declarePyAlgProperty
,
...
...
Framework/PythonInterface/mantid/api/src/FitFunctions/IFunctionAdapter.cpp
View file @
7c738fc1
...
...
@@ -6,10 +6,12 @@
// SPDX - License - Identifier: GPL - 3.0 +
#include
"MantidPythonInterface/api/FitFunctions/IFunctionAdapter.h"
#include
"MantidPythonInterface/core/CallMethod.h"
#include
"MantidPythonInterface/core/Converters/PyNativeTypeExtractor.h"
#include
"MantidPythonInterface/core/Converters/WrapWithNDArray.h"
#include
<boost/python/class.hpp>
#include
<boost/python/list.hpp>
#include
<boost/variant/apply_visitor.hpp>
#include
<utility>
#define PY_ARRAY_UNIQUE_SYMBOL API_ARRAY_API
...
...
@@ -25,6 +27,34 @@ using namespace boost::python;
namespace
{
class
AttrVisitor
:
Mantid
::
PythonInterface
::
IPyTypeVisitor
{
public:
AttrVisitor
(
IFunction
::
Attribute
&
attrToUpdate
)
:
m_attr
(
attrToUpdate
)
{}
void
operator
()(
bool
value
)
const
override
{
m_attr
.
setValue
(
value
);
}
void
operator
()(
long
value
)
const
override
{
m_attr
.
setValue
(
static_cast
<
int
>
(
value
));
}
void
operator
()(
double
value
)
const
override
{
m_attr
.
setValue
(
value
);
}
void
operator
()(
std
::
string
value
)
const
override
{
m_attr
.
setValue
(
std
::
move
(
value
));
}
void
operator
()(
Mantid
::
API
::
Workspace_sptr
)
const
override
{
throw
std
::
invalid_argument
(
m_errorMsg
);
}
void
operator
()(
std
::
vector
<
bool
>
)
const
override
{
throw
std
::
invalid_argument
(
m_errorMsg
);
}
void
operator
()(
std
::
vector
<
long
>
value
)
const
override
{
// Previous existing code blindly converted any list type into a list of doubles.
// We now have to preserve this behaviour to maintain API compatibility as
// setValue only takes std::vector<double>.
std
::
vector
<
double
>
doubleVals
(
value
.
cbegin
(),
value
.
cend
());
m_attr
.
setValue
(
std
::
move
(
doubleVals
));
}
void
operator
()(
std
::
vector
<
double
>
value
)
const
override
{
m_attr
.
setValue
(
std
::
move
(
value
));
}
void
operator
()(
std
::
vector
<
std
::
string
>
)
const
override
{
throw
std
::
invalid_argument
(
m_errorMsg
);
}
using
Mantid
::
PythonInterface
::
IPyTypeVisitor
::
operator
();
private:
IFunction
::
Attribute
&
m_attr
;
const
std
::
string
m_errorMsg
=
"Invalid attribute. Allowed types=float,int,str,bool,list(float),list(int)"
;
};
/**
* Create an Attribute from a python value.
* @param value :: A value python object. Allowed python types:
...
...
@@ -33,38 +63,10 @@ namespace {
*/
IFunction
::
Attribute
createAttributeFromPythonValue
(
IFunction
::
Attribute
attrToUpdate
,
const
object
&
value
)
{
PyObject
*
rawptr
=
value
.
ptr
();
using
Mantid
::
PythonInterface
::
PyNativeTypeExtractor
;
auto
variantObj
=
PyNativeTypeExtractor
::
convert
(
value
);
if
(
PyBool_Check
(
rawptr
)
==
1
)
{
attrToUpdate
.
setValue
(
extract
<
bool
>
(
rawptr
)());
}
#if PY_MAJOR_VERSION >= 3
else
if
(
PyLong_Check
(
rawptr
)
==
1
)
{
#else
else
if
(
PyInt_Check
(
rawptr
)
==
1
)
{
#endif
attrToUpdate
.
setValue
(
extract
<
int
>
(
rawptr
)());
}
else
if
(
PyFloat_Check
(
rawptr
)
==
1
)
{
attrToUpdate
.
setValue
(
extract
<
double
>
(
rawptr
)());
}
#if PY_MAJOR_VERSION >= 3
else
if
(
PyUnicode_Check
(
rawptr
)
==
1
)
{
#else
else
if
(
PyBytes_Check
(
rawptr
)
==
1
)
{
#endif
attrToUpdate
.
setValue
(
extract
<
std
::
string
>
(
rawptr
)());
}
else
if
(
PyList_Check
(
rawptr
)
==
1
)
{
auto
n
=
PyList_Size
(
rawptr
);
std
::
vector
<
double
>
vec
;
for
(
Py_ssize_t
i
=
0
;
i
<
n
;
++
i
)
{
auto
v
=
extract
<
double
>
(
PyList_GetItem
(
rawptr
,
i
))();
vec
.
emplace_back
(
v
);
}
attrToUpdate
.
setValue
(
vec
);
}
else
{
throw
std
::
invalid_argument
(
"Invalid attribute type. Allowed "
"types=float,int,str,bool,list(float)"
);
}
boost
::
apply_visitor
(
AttrVisitor
(
attrToUpdate
),
variantObj
);
return
attrToUpdate
;
}
...
...
Framework/PythonInterface/test/python/mantid/api/AlgorithmTest.py
View file @
7c738fc1
...
...
@@ -7,16 +7,34 @@
import
unittest
import
json
from
mantid.api
import
AlgorithmID
,
AlgorithmManager
,
FrameworkManagerImpl
from
mantid.kernel
import
Direction
,
FloatArrayProperty
,
IntArrayProperty
,
StringArrayProperty
,
\
IntArrayMandatoryValidator
,
FloatArrayMandatoryValidator
,
StringArrayMandatoryValidator
from
mantid.simpleapi
import
CreateSampleWorkspace
from
mantid.api
import
AlgorithmID
,
AlgorithmManager
,
AlgorithmFactory
,
FrameworkManagerImpl
,
PythonAlgorithm
,
Workspace
from
testhelpers
import
run_algorithm
class
AlgorithmTest
(
unittest
.
TestCase
):
class
_ParamTester
(
PythonAlgorithm
):
def
category
(
self
):
return
"Examples"
def
PyInit
(
self
):
self
.
declareProperty
(
FloatArrayProperty
(
"FloatInput"
,
FloatArrayMandatoryValidator
()))
self
.
declareProperty
(
IntArrayProperty
(
"IntInput"
,
IntArrayMandatoryValidator
()))
self
.
declareProperty
(
StringArrayProperty
(
"StringInput"
,
StringArrayMandatoryValidator
()))
def
PyExec
(
self
):
pass
class
AlgorithmTest
(
unittest
.
TestCase
):
_load
=
None
def
setUp
(
self
):
FrameworkManagerImpl
.
Instance
()
self
.
_alg_factory
=
AlgorithmFactory
.
Instance
()
self
.
_alg_factory
.
subscribe
(
_ParamTester
)
if
self
.
_load
is
None
:
self
.
__class__
.
_load
=
AlgorithmManager
.
createUnmanaged
(
'Load'
)
self
.
_load
.
initialize
()
...
...
@@ -48,42 +66,41 @@ class AlgorithmTest(unittest.TestCase):
self
.
assertRaises
(
RuntimeError
,
alg
.
execute
)
def
test_execute_succeeds_with_valid_props
(
self
):
data
=
[
1.5
,
2.5
,
3.5
]
alg
=
run_algorithm
(
'CreateWorkspace'
,
DataX
=
data
,
DataY
=
data
,
NSpec
=
1
,
UnitX
=
'Wavelength'
,
child
=
True
)
data
=
[
1.5
,
2.5
,
3.5
]
alg
=
run_algorithm
(
'CreateWorkspace'
,
DataX
=
data
,
DataY
=
data
,
NSpec
=
1
,
UnitX
=
'Wavelength'
,
child
=
True
)
self
.
assertEqual
(
alg
.
isExecuted
(),
True
)
self
.
assertEqual
(
alg
.
isRunning
(),
False
)
self
.
assertEqual
(
alg
.
getProperty
(
'NSpec'
).
value
,
1
)
self
.
assertEqual
(
type
(
alg
.
getProperty
(
'NSpec'
).
value
),
int
)
self
.
assertEqual
(
alg
.
getProperty
(
'NSpec'
).
name
,
'NSpec'
)
ws
=
alg
.
getProperty
(
'OutputWorkspace'
).
value
self
.
assertTrue
(
ws
.
getMemorySize
()
>
0.0
)
self
.
assertTrue
(
ws
.
getMemorySize
()
>
0.0
)
as_str
=
str
(
alg
)
self
.
assertEqual
(
as_str
,
'{"name":"CreateWorkspace","properties":{"DataX":[1.5,2.5,3.5],"DataY":[1.5,2.5,3.5],'
'"OutputWorkspace":"UNUSED_NAME_FOR_CHILD","UnitX":"Wavelength"},"version":1}'
)
'"OutputWorkspace":"UNUSED_NAME_FOR_CHILD","UnitX":"Wavelength"},"version":1}'
)
def
test_execute_succeeds_with_unicode_props
(
self
):
data
=
[
1.5
,
2.5
,
3.5
]
kwargs
=
{
'child'
:
True
}
data
=
[
1.5
,
2.5
,
3.5
]
kwargs
=
{
'child'
:
True
}
unitx
=
'Wavelength'
kwargs
[
'UnitX'
]
=
unitx
alg
=
run_algorithm
(
'CreateWorkspace'
,
DataX
=
data
,
DataY
=
data
,
NSpec
=
1
,
**
kwargs
)
alg
=
run_algorithm
(
'CreateWorkspace'
,
DataX
=
data
,
DataY
=
data
,
NSpec
=
1
,
**
kwargs
)
self
.
assertEqual
(
alg
.
isExecuted
(),
True
)
self
.
assertEqual
(
alg
.
isRunning
(),
False
)
self
.
assertEqual
(
alg
.
getProperty
(
'NSpec'
).
value
,
1
)
self
.
assertEqual
(
type
(
alg
.
getProperty
(
'NSpec'
).
value
),
int
)
self
.
assertEqual
(
alg
.
getProperty
(
'NSpec'
).
name
,
'NSpec'
)
ws
=
alg
.
getProperty
(
'OutputWorkspace'
).
value
self
.
assertTrue
(
ws
.
getMemorySize
()
>
0.0
)
self
.
assertTrue
(
ws
.
getMemorySize
()
>
0.0
)
as_str
=
str
(
alg
)
self
.
assertEqual
(
as_str
,
'{"name":"CreateWorkspace","properties":{"DataX":[1.5,2.5,3.5],"DataY":[1.5,2.5,3.5],'
'"OutputWorkspace":"UNUSED_NAME_FOR_CHILD","UnitX":"Wavelength"},"version":1}'
)
'"OutputWorkspace":"UNUSED_NAME_FOR_CHILD","UnitX":"Wavelength"},"version":1}'
)
def
test_execute_succeeds_with_unicode_kwargs
(
self
):
props
=
json
.
loads
(
'{"DryRun":true}'
)
# this is always unicode
props
=
json
.
loads
(
'{"DryRun":true}'
)
# this is always unicode
alg
=
run_algorithm
(
'Segfault'
,
**
props
)
def
test_getAlgorithmID_returns_AlgorithmID_object
(
self
):
...
...
@@ -93,14 +110,14 @@ class AlgorithmTest(unittest.TestCase):
def
test_AlgorithmID_compares_by_value
(
self
):
alg
=
AlgorithmManager
.
createUnmanaged
(
'Load'
)
id
=
alg
.
getAlgorithmID
()
self
.
assertEqual
(
id
,
id
)
# equals itself
self
.
assertEqual
(
id
,
id
)
# equals itself
alg2
=
AlgorithmManager
.
createUnmanaged
(
'Load'
)
id2
=
alg2
.
getAlgorithmID
()
self
.
assertNotEqual
(
id2
,
id
)
def
test_cancel_does_nothing_to_executed_algorithm
(
self
):
data
=
[
1.0
]
alg
=
run_algorithm
(
'CreateWorkspace'
,
DataX
=
data
,
DataY
=
data
,
NSpec
=
1
,
UnitX
=
'Wavelength'
,
child
=
True
)
alg
=
run_algorithm
(
'CreateWorkspace'
,
DataX
=
data
,
DataY
=
data
,
NSpec
=
1
,
UnitX
=
'Wavelength'
,
child
=
True
)
self
.
assertEqual
(
alg
.
isExecuted
(),
True
)
self
.
assertEqual
(
alg
.
isRunning
(),
False
)
alg
.
cancel
()
...
...
@@ -116,14 +133,151 @@ class AlgorithmTest(unittest.TestCase):
def
test_createChildAlgorithm_respects_keyword_arguments
(
self
):
parent_alg
=
AlgorithmManager
.
createUnmanaged
(
'Load'
)
try
:
child_alg
=
parent_alg
.
createChildAlgorithm
(
name
=
'Rebin'
,
version
=
1
,
startProgress
=
0.5
,
endProgress
=
0.9
,
enableLogging
=
True
)
parent_alg
.
createChildAlgorithm
(
name
=
'Rebin'
,
version
=
1
,
startProgress
=
0.5
,
endProgress
=
0.9
,
enableLogging
=
True
)
except
Exception
as
exc
:
self
.
fail
(
"Expected createChildAlgorithm not to throw but it did: %s"
%
(
str
(
exc
)))
# Unknown keyword
self
.
assertRaises
(
Exception
,
parent_alg
.
createChildAlgorithm
,
name
=
'Rebin'
,
version
=
1
,
startProgress
=
0.5
,
endProgress
=
0.9
,
enableLogging
=
True
,
unknownKW
=
1
)
self
.
assertRaises
(
Exception
,
parent_alg
.
createChildAlgorithm
,
name
=
'Rebin'
,
version
=
1
,
startProgress
=
0.5
,
endProgress
=
0.9
,
enableLogging
=
True
,
unknownKW
=
1
)
def
test_createChildAlgorithm_with_kwargs
(
self
):
parent_alg
=
AlgorithmManager
.
createUnmanaged
(
'Load'
)
child_alg
=
parent_alg
.
createChildAlgorithm
(
'CreateSampleWorkspace'
,
**
{
"XUnit"
:
"Wavelength"
})
self
.
assertTrue
(
child_alg
.
isChild
())
child_alg
.
execute
()
ws
=
child_alg
.
getProperty
(
"OutputWorkspace"
).
value
self
.
assertEqual
(
"Wavelength"
,
ws
.
getAxis
(
0
).
getUnit
().
unitID
())
def
test_createChildAlgorithm_with_named_args
(
self
):
parent_alg
=
AlgorithmManager
.
createUnmanaged
(
'Load'
)
child_alg
=
parent_alg
.
createChildAlgorithm
(
'CreateSampleWorkspace'
,
XUnit
=
"Wavelength"
)
self
.
assertTrue
(
child_alg
.
isChild
())
child_alg
.
execute
()
ws
=
child_alg
.
getProperty
(
"OutputWorkspace"
).
value
self
.
assertEqual
(
"Wavelength"
,
ws
.
getAxis
(
0
).
getUnit
().
unitID
())
def
test_createChildAlgorithm_with_version_and_kwargs
(
self
):
parent_alg
=
AlgorithmManager
.
createUnmanaged
(
'Load'
)
child_alg
=
parent_alg
.
createChildAlgorithm
(
'CreateSampleWorkspace'
,
version
=
1
,
**
{
"XUnit"
:
"Wavelength"
})
self
.
assertTrue
(
child_alg
.
isChild
())
child_alg
.
execute
()
ws
=
child_alg
.
getProperty
(
"OutputWorkspace"
).
value
self
.
assertEqual
(
"Wavelength"
,
ws
.
getAxis
(
0
).
getUnit
().
unitID
())
def
test_createChildAlgorithm_with_all_args
(
self
):
parent_alg
=
AlgorithmManager
.
createUnmanaged
(
'Load'
)
child_alg
=
parent_alg
.
createChildAlgorithm
(
'CreateSampleWorkspace'
,
startProgress
=
0.0
,
endProgress
=
1.0
,
enableLogging
=
False
,
version
=
1
,
**
{
"XUnit"
:
"Wavelength"
})
self
.
assertTrue
(
child_alg
.
isChild
())
child_alg
.
execute
()
ws
=
child_alg
.
getProperty
(
"OutputWorkspace"
).
value
self
.
assertEqual
(
"Wavelength"
,
ws
.
getAxis
(
0
).
getUnit
().
unitID
())
def
test_with_workspace_types
(
self
):
ws
=
CreateSampleWorkspace
(
Function
=
"User Defined"
,
UserDefinedFunction
=
"name=LinearBackground, A0=0.3;name=Gaussian, "
"PeakCentre=5, Height=10, Sigma=0.7"
,
NumBanks
=
1
,
BankPixelWidth
=
1
,
XMin
=
0
,
XMax
=
10
,
BinWidth
=
0.1
)
# Setup the model, here a Gaussian, to fit to data
tryCentre
=
'4'
# A start guess on peak centre
sigma
=
'1'
# A start guess on peak width
height
=
'8'
# A start guess on peak height