Commit 0a12a4a4 authored by David Fairbrother's avatar David Fairbrother Committed by Gemma Guest
Browse files

Fix createChildAlg not accepting workspace types

Workspace types were not accepted, this is particularly noticable in Fit
functions. Using the already-existing Workspace registry we can
trivially add a conversion that covers all of these types
parent 9fc46cef
......@@ -5,6 +5,7 @@
// Institut Laue - Langevin & CSNS, Institute of High Energy Physics, CAS
// SPDX - License - Identifier: GPL - 3.0 +
#pragma once
#include "MantidAPI/Workspace.h"
#include "MantidKernel/Logger.h"
#include <boost/python/extract.hpp>
......@@ -24,8 +25,8 @@ Mantid::Kernel::Logger g_log("Python Type Extractor");
namespace Mantid::PythonInterface {
struct PyNativeTypeExtractor {
using PythonOutputT =
boost::make_recursive_variant<bool, long, double, std::string, std::vector<boost::recursive_variant_>>::type;
using PythonOutputT = boost::make_recursive_variant<bool, long, double, std::string, Mantid::API::Workspace_sptr,
std::vector<boost::recursive_variant_>>::type;
static PythonOutputT convert(const boost::python::object &obj) {
using namespace boost::python;
......@@ -42,6 +43,8 @@ struct PyNativeTypeExtractor {
out = extract<long>(obj);
} else if (PyUnicode_Check(rawptr)) {
out = extract<std::string>(obj);
} else if (auto extractor = extract<Mantid::API::Workspace_sptr>(obj); extractor.check()) {
out = extractor();
} else {
throw std::invalid_argument("Unrecognised Python type");
}
......@@ -79,6 +82,7 @@ public:
virtual void operator()(long value) const = 0;
virtual void operator()(double value) const = 0;
virtual void operator()(std::string) const = 0;
virtual void operator()(Mantid::API::Workspace_sptr) const = 0;
virtual void operator()(std::vector<bool>) const = 0;
virtual void operator()(std::vector<long>) const = 0;
......
......@@ -112,6 +112,7 @@ public:
void operator()(long value) const override { setProp(static_cast<int>(value)); }
void operator()(double value) const override { setProp(value); }
void operator()(std::string value) const override { m_alg->setPropertyValue(m_propName, value); }
void operator()(Mantid::API::Workspace_sptr ws) const override { m_alg->setProperty(m_propName, std::move(ws)); }
void operator()(std::vector<bool> value) const override { setProp(value); }
void operator()(std::vector<long> value) const override { setProp(value); }
......@@ -147,7 +148,6 @@ object createChildWithProps(tuple args, dict kwargs) {
if (!name.is_initialized()) {
throw std::invalid_argument("Please specify the algorithm name");
}
auto childAlg = parentAlg->createChildAlgorithm(name.value(), startProgress.value_or(-1), endProgress.value_or(-1),
enableLogging.value_or(true), version.value_or(-1));
......
......@@ -35,6 +35,7 @@ public:
void operator()(long value) const override { m_attr.setValue(static_cast<int>(value)); }
void operator()(double value) const override { m_attr.setValue(value); }
void operator()(std::string value) const override { m_attr.setValue(std::move(value)); }
void operator()(Mantid::API::Workspace_sptr) const override { throw std::invalid_argument(m_errorMsg); }
void operator()(std::vector<bool>) const override { throw std::invalid_argument(m_errorMsg); }
void operator()(std::vector<long> value) const override {
......
......@@ -9,7 +9,8 @@ import unittest
import json
from mantid.kernel import Direction, FloatArrayProperty, IntArrayProperty, StringArrayProperty, \
IntArrayMandatoryValidator, FloatArrayMandatoryValidator, StringArrayMandatoryValidator
from mantid.api import AlgorithmID, AlgorithmManager, AlgorithmFactory, FrameworkManagerImpl, PythonAlgorithm
from mantid.simpleapi import CreateSampleWorkspace
from mantid.api import AlgorithmID, AlgorithmManager, AlgorithmFactory, FrameworkManagerImpl, PythonAlgorithm, Workspace
from testhelpers import run_algorithm
......@@ -133,7 +134,7 @@ class AlgorithmTest(unittest.TestCase):
parent_alg = AlgorithmManager.createUnmanaged('Load')
try:
parent_alg.createChildAlgorithm(name='Rebin', version=1, startProgress=0.5,
endProgress=0.9, enableLogging=True)
endProgress=0.9, enableLogging=True)
except Exception as exc:
self.fail("Expected createChildAlgorithm not to throw but it did: %s" % (str(exc)))
......@@ -182,6 +183,24 @@ class AlgorithmTest(unittest.TestCase):
self.assertEqual("Wavelength", ws.getAxis(0).getUnit().unitID())
def test_with_workspace_types(self):
ws = CreateSampleWorkspace(Function="User Defined",
UserDefinedFunction="name=LinearBackground, A0=0.3;name=Gaussian, "
"PeakCentre=5, Height=10, Sigma=0.7",
NumBanks=1, BankPixelWidth=1, XMin=0, XMax=10, BinWidth=0.1)
# Setup the model, here a Gaussian, to fit to data
tryCentre = '4' # A start guess on peak centre
sigma = '1' # A start guess on peak width
height = '8' # A start guess on peak height
myFunc = 'name=Gaussian, Height=' + height + ', PeakCentre=' + tryCentre + ', Sigma=' + sigma
args = {"Function": myFunc, "InputWorkspace": ws, "Output": 'fit'}
parent_alg = AlgorithmManager.createUnmanaged('Load')
child_alg = parent_alg.createChildAlgorithm('Fit', 0, 0, True, version=1, **args)
child_alg.execute()
out_ws = child_alg.getProperty("OutputWorkspace").value
self.assertIsInstance(out_ws, Workspace)
def test_createChildAlgorithm_without_name(self):
parent_alg = AlgorithmManager.createUnmanaged('Load')
with self.assertRaisesRegex(ValueError, "algorithm name"):
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment