Commit 18ac15a3 authored by David Fairbrother's avatar David Fairbrother Committed by Gemma Guest
Browse files

Add overload to createChildAlgorithm with properties

parent b2f45324
......@@ -19,13 +19,18 @@
#endif
#include "MantidPythonInterface/core/GetPointer.h"
#include <boost/optional.hpp>
#include <boost/python/bases.hpp>
#include <boost/python/class.hpp>
#include <boost/python/dict.hpp>
#include <boost/python/exception_translator.hpp>
#include <boost/python/overloads.hpp>
#include <boost/python/raw_function.hpp>
#include <boost/python/register_ptr_to_python.hpp>
#include <boost/python/scope.hpp>
#include <cstddef>
using Mantid::API::Algorithm;
using Mantid::API::DistributedAlgorithm;
using Mantid::API::ParallelAlgorithm;
......@@ -56,6 +61,7 @@ using declarePropertyType3 = void (*)(boost::python::object &, const std::string
// declarePyAlgProperty(name, defaultValue, direction)
using declarePropertyType4 = void (*)(boost::python::object &, const std::string &, const boost::python::object &,
const int);
GNU_DIAG_OFF("unused-local-typedef")
// Ignore -Wconversion warnings coming from boost::python
// Seen with GCC 7.1.1 and Boost 1.63.0
......@@ -64,6 +70,7 @@ GNU_DIAG_OFF("conversion")
BOOST_PYTHON_FUNCTION_OVERLOADS(declarePropertyType1_Overload, PythonAlgorithm::declarePyAlgProperty, 2, 3)
BOOST_PYTHON_FUNCTION_OVERLOADS(declarePropertyType2_Overload, PythonAlgorithm::declarePyAlgProperty, 3, 6)
BOOST_PYTHON_FUNCTION_OVERLOADS(declarePropertyType3_Overload, PythonAlgorithm::declarePyAlgProperty, 4, 5)
GNU_DIAG_ON("conversion")
GNU_DIAG_ON("unused-local-typedef")
......@@ -76,6 +83,64 @@ void translateCancel(const Algorithm::CancelException &exc) {
UNUSED_ARG(exc);
PyErr_SetString(PyExc_KeyboardInterrupt, "");
}
template <typename T> boost::optional<T> extractArg(ssize_t index, const tuple &args) {
if (index < len(args)) {
return boost::optional<T>(extract<T>(args[index]));
}
return boost::none;
}
template <typename T> void extractKwargs(const dict &kwargs, const std::string &keyName, boost::optional<T> &out) {
if (kwargs.has_key(keyName)) {
out = boost::optional<T>(extract<T>(kwargs.get(keyName)));
}
}
// Signature createChildWithProps(self, name, startProgress, endProgress, enableLogging, version, **kwargs)
object createChildWithProps(tuple args, dict kwargs) {
std::shared_ptr<Algorithm> parentAlg = extract<std::shared_ptr<Algorithm>>(args[0]);
auto name = extractArg<std::string>(1, args);
auto startProgress = extractArg<double>(2, args);
auto endProgress = extractArg<double>(3, args);
auto enableLogging = extractArg<bool>(4, args);
auto version = extractArg<int>(5, args);
const std::array<std::string, 5> reservedNames = {"name", "startProgress", "endProgress", "enableLogging", "version"};
extractKwargs<std::string>(kwargs, reservedNames[0], name);
extractKwargs<double>(kwargs, reservedNames[1], startProgress);
extractKwargs<double>(kwargs, reservedNames[2], endProgress);
extractKwargs<bool>(kwargs, reservedNames[3], enableLogging);
extractKwargs<int>(kwargs, reservedNames[4], version);
if (!name.is_initialized()) {
throw std::invalid_argument("Please specify the algorithm name");
}
auto childAlg = parentAlg->createChildAlgorithm(name.value(), startProgress.value_or(-1), endProgress.value_or(-1),
enableLogging.value_or(true), version.value_or(-1));
const list keys = kwargs.keys();
for (int i = 0; i < len(keys); ++i) {
const std::string propName = extract<std::string>(keys[i]);
if (std::find(reservedNames.cbegin(), reservedNames.cend(), propName) != reservedNames.cend())
continue;
object curArg = kwargs[keys[i]];
if (curArg) {
// Rather than trying to figure out which type were getting from Python
// we will "ab"use setPropertyValue to simply use strings. This currently
// doesn't handle lists, but this could be retrofitted in future work
std::string propValue = extract<std::string>(curArg);
childAlg->setPropertyValue(propName, propValue);
}
}
return object(childAlg);
}
} // namespace
void export_leaf_classes() {
......@@ -86,14 +151,11 @@ void export_leaf_classes() {
// std::shared_ptr<AlgorithmAdapter>
// See
// http://wiki.python.org/moin/boost.python/HowTo#ownership_of_C.2B-.2B-_object_extended_in_Python
class_<Algorithm, bases<Mantid::API::IAlgorithm>, std::shared_ptr<PythonAlgorithm>, boost::noncopyable>(
"Algorithm", "Base class for all algorithms")
.def("fromString", &Algorithm::fromString, "Initialize the algorithm from a string representation")
.staticmethod("fromString")
.def("createChildAlgorithm", &Algorithm::createChildAlgorithm,
(arg("self"), arg("name"), arg("startProgress") = -1.0, arg("endProgress") = -1.0,
arg("enableLogging") = true, arg("version") = -1),
.def("createChildAlgorithm", raw_function(&createChildWithProps, 1),
"Creates and intializes a named child algorithm. Output workspaces "
"are given a dummy name.")
.def("declareProperty", (declarePropertyType1)&PythonAlgorithm::declarePyAlgProperty,
......
......@@ -125,5 +125,74 @@ class AlgorithmTest(unittest.TestCase):
self.assertRaises(Exception, parent_alg.createChildAlgorithm, name='Rebin',version=1,startProgress=0.5,
endProgress=0.9,enableLogging=True, unknownKW=1)
def test_createChildAlgorithm_with_kwargs(self):
parent_alg = AlgorithmManager.createUnmanaged('Load')
child_alg = parent_alg.createChildAlgorithm('CreateSampleWorkspace', **{"XUnit": "Wavelength"})
self.assertTrue(child_alg.isChild())
child_alg.execute()
ws = child_alg.getProperty("OutputWorkspace").value
self.assertEqual("Wavelength", ws.getAxis(0).getUnit().unitID())
def test_createChildAlgorithm_with_named_args(self):
parent_alg = AlgorithmManager.createUnmanaged('Load')
child_alg = parent_alg.createChildAlgorithm('CreateSampleWorkspace', XUnit="Wavelength")
self.assertTrue(child_alg.isChild())
child_alg.execute()
ws = child_alg.getProperty("OutputWorkspace").value
self.assertEqual("Wavelength", ws.getAxis(0).getUnit().unitID())
def test_createChildAlgorithm_with_version_and_kwargs(self):
parent_alg = AlgorithmManager.createUnmanaged('Load')
child_alg = parent_alg.createChildAlgorithm('CreateSampleWorkspace', version=1, **{"XUnit": "Wavelength"})
self.assertTrue(child_alg.isChild())
child_alg.execute()
ws = child_alg.getProperty("OutputWorkspace").value
self.assertEqual("Wavelength", ws.getAxis(0).getUnit().unitID())
def test_createChildAlgorithm_with_all_args(self):
parent_alg = AlgorithmManager.createUnmanaged('Load')
child_alg = parent_alg.createChildAlgorithm('CreateSampleWorkspace', startProgress=0.0, endProgress=1.0,
enableLogging=False, version=1, **{"XUnit": "Wavelength"})
self.assertTrue(child_alg.isChild())
child_alg.execute()
ws = child_alg.getProperty("OutputWorkspace").value
self.assertEqual("Wavelength", ws.getAxis(0).getUnit().unitID())
def test_createChildAlgorithm_without_name(self):
parent_alg = AlgorithmManager.createUnmanaged('Load')
with self.assertRaisesRegex(ValueError, "algorithm name"):
parent_alg.createChildAlgorithm(startProgress=0.0, endProgress=1.0,
enableLogging=False, version=1, **{"XUnit": "Wavelength"})
def test_createChildAlgorithm_without_parameters(self):
parent_alg = AlgorithmManager.createUnmanaged('Load')
with self.assertRaisesRegex(ValueError, "algorithm name"):
parent_alg.createChildAlgorithm()
def test_createChildAlgorithm_with_incorrect_types(self):
parent_alg = AlgorithmManager.createUnmanaged('Load')
with self.assertRaises(TypeError):
parent_alg.createChildAlgorithm("CreateSampleWorkspace", startProgress="0.0", endProgress=1.0,
enableLogging=False, version=1, **{"XUnit": "Wavelength"})
def test_createChildAlgorithm_with_mixed_args_and_kwargs(self):
parent_alg = AlgorithmManager.createUnmanaged('Load')
child_alg = parent_alg.createChildAlgorithm("CreateSampleWorkspace", 0.0, 1.0, version=1,
enableLogging=False, **{"XUnit": "Wavelength"})
self.assertTrue(child_alg.isChild())
child_alg.execute()
ws = child_alg.getProperty("OutputWorkspace").value
self.assertEqual("Wavelength", ws.getAxis(0).getUnit().unitID())
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment