Commit a02b0748 authored by Roman Tolchenov's avatar Roman Tolchenov
Browse files

Re #4158. Made MultiDomainFunction initialization in FunctionFactory.

parent c542f000
......@@ -6,7 +6,9 @@
//----------------------------------------------------------------------
#include "MantidAPI/IFunction.h"
#include "MantidAPI/Jacobian.h"
#include <boost/shared_array.hpp>
#include <map>
namespace Mantid
{
......@@ -132,7 +134,6 @@ public:
/// Remove a constraint
void removeConstraint(const std::string& parName);
/* CompositeFunction own methods */
/// Add a function at the back of the internal function list
......@@ -158,6 +159,27 @@ public:
/// Check the function.
void checkFunction();
/// Returns the number of attributes associated with the function
virtual size_t nLocalAttributes()const {return 0;}
/// Returns a list of attribute names
virtual std::vector<std::string> getLocalAttributeNames()const {return std::vector<std::string>();}
/// Return a value of attribute attName
virtual Attribute getLocalAttribute(size_t i, const std::string& attName)const
{
(void)i;
throw std::invalid_argument("Attribute "+attName+" not found in function "+this->name());
}
/// Set a value to attribute attName
virtual void setLocalAttribute(size_t i, const std::string& attName,const Attribute& )
{
(void)i;
throw std::invalid_argument("Attribute "+attName+" not found in function "+this->name());
}
/// Check if attribute attName exists
virtual bool hasLocalAttribute(const std::string&)const {return false;}
template<typename T>
void setLocalAttributeValue(size_t i, const std::string& attName,const T& value){setLocalAttribute(i,attName,Attribute(value));}
protected:
/// Function initialization. Declare function parameters in this method.
void init();
......@@ -184,7 +206,6 @@ private:
size_t m_nParams;
/// Function counter to be used in nextConstraint
mutable size_t m_iConstraintFunction;
};
///shared pointer to the composite function base class
......
......@@ -98,9 +98,15 @@ namespace API
using Kernel::DynamicFactory<IFunction>::createUnwrapped;
/// Create a simple function
boost::shared_ptr<IFunction> createSimple(const Expression& expr)const;
boost::shared_ptr<IFunction> createSimple(
const Expression& expr,
std::map<std::string,std::string>& parentAttributes
)const;
/// Create a composite function
boost::shared_ptr<CompositeFunction> createComposite(const Expression& expr)const;
boost::shared_ptr<CompositeFunction> createComposite(
const Expression& expr,
std::map<std::string,std::string>& parentAttributes
)const;
///Creates an instance of a function
boost::shared_ptr<IFunction> createFitFunction(const Expression& expr) const;
......
......@@ -178,6 +178,8 @@ public:
class MANTID_API_DLL Attribute
{
public:
/// Create empty string attribute
explicit Attribute():m_data(""), m_quoteValue(false) {}
/// Create string attribute
explicit Attribute(const std::string& str, bool quoteValue=false):m_data(str), m_quoteValue(quoteValue) {}
/// Create int attribute
......
......@@ -46,6 +46,8 @@ public:
/// Constructor
MultiDomainFunction():m_nDomains(0),m_maxIndex(0){}
/// Returns the function's name
virtual std::string name()const {return "MultiDomainFunction";}
/// Function you want to fit to.
/// @param domain :: The buffer for writing the calculated values. Must be big enough to accept dataSize() values
virtual void function(const FunctionDomain& domain, FunctionValues& values)const;
......@@ -60,11 +62,25 @@ public:
/// Clear all domain indices
void clearDomainIndices();
/// Returns the number of attributes associated with the function
virtual size_t nLocalAttributes()const {return 1;}
/// Returns a list of attribute names
virtual std::vector<std::string> getLocalAttributeNames()const {return std::vector<std::string>(1,"domains");}
/// Return a value of attribute attName
virtual Attribute getLocalAttribute(size_t i, const std::string& attName)const;
/// Set a value to attribute attName
virtual void setLocalAttribute(size_t i, const std::string& attName,const Attribute& );
/// Check if attribute attName exists
virtual bool hasLocalAttribute(const std::string& attName)const {return attName == "domains";}
protected:
/// Counts number of the domains
void countNumberOfDomains();
void countValueOffsets(const CompositeDomain& domain)const;
void getFunctionDomains(size_t i, const CompositeDomain& cd, std::vector<size_t>& domains)const;
/// Domain index map: finction -> domain
std::map<size_t, std::vector<size_t> > m_domains;
/// Number of domains this MultiDomainFunction operates on. == number of different values in m_domains
......
......@@ -837,6 +837,5 @@ IFunction_sptr CompositeFunction::getContainingFunction(const ParameterReference
return IFunction_sptr();
}
} // namespace API
} // namespace Mantid
#include "MantidAPI/FunctionFactory.h"
#include "MantidAPI/IFunction.h"
#include "MantidAPI/CompositeFunction.h"
#include "MantidAPI/MultiDomainFunction.h"
#include "MantidAPI/Expression.h"
#include "MantidAPI/ConstraintFactory.h"
#include "MantidAPI/IConstraint.h"
......@@ -59,15 +60,15 @@ namespace Mantid
}
const Expression& e = expr.bracketsRemoved();
std::map<std::string,std::string> parentAttributes;
if (e.name() == ";")
{
IFunction_sptr fun = createComposite(e);
IFunction_sptr fun = createComposite(e,parentAttributes);
if (!fun) inputError();
return fun;
}
return createSimple(e);
return createSimple(e,parentAttributes);
}
......@@ -76,7 +77,7 @@ namespace Mantid
* @param expr :: The input expression
* @return A pointer to the created function
*/
IFunction_sptr FunctionFactoryImpl::createSimple(const Expression& expr)const
IFunction_sptr FunctionFactoryImpl::createSimple(const Expression& expr, std::map<std::string,std::string>& parentAttributes)const
{
if (expr.name() == "=" && expr.size() > 1)
{
......@@ -122,6 +123,11 @@ namespace Mantid
{
addTies(fun,(*term)[1]);
}
else if (!parName.empty() && parName[0] == '$')
{
parName.erase(0,1);
parentAttributes[parName] = parValue;
}
else
{// set initial parameter value
fun->setParameter(parName,atof(parValue.c_str()));
......@@ -136,7 +142,7 @@ namespace Mantid
* @param expr :: The input expression
* @return A pointer to the created function
*/
CompositeFunction_sptr FunctionFactoryImpl::createComposite(const Expression& expr)const
CompositeFunction_sptr FunctionFactoryImpl::createComposite(const Expression& expr, std::map<std::string,std::string>& parentAttributes)const
{
if (expr.name() != ";") inputError(expr.str());
......@@ -175,7 +181,7 @@ namespace Mantid
{
if (firstTerm->terms()[0].name() == "composite")
{
cfun = boost::dynamic_pointer_cast<CompositeFunction>(createSimple(term));
cfun = boost::dynamic_pointer_cast<CompositeFunction>(createSimple(term,parentAttributes));
if (!cfun) inputError(expr.str());
++it;
}
......@@ -204,9 +210,10 @@ namespace Mantid
{
const Expression& term = it->bracketsRemoved();
IFunction_sptr fun;
std::map<std::string,std::string> pAttributes;
if (term.name() == ";")
{
fun = createComposite(term);
fun = createComposite(term,pAttributes);
if (!fun) continue;
}
else
......@@ -225,10 +232,15 @@ namespace Mantid
}
else
{
fun = createSimple(term);
fun = createSimple(term,pAttributes);
}
}
cfun->addFunction(fun);
size_t i = cfun->nFunctions() - 1;
for(auto att = pAttributes.begin(); att != pAttributes.end(); ++att)
{
cfun->setLocalAttributeValue(i,att->first,att->second);
}
}
return cfun;
......
......@@ -3,6 +3,8 @@
//----------------------------------------------------------------------
#include "MantidAPI/MultiDomainFunction.h"
#include "MantidAPI/CompositeDomain.h"
#include "MantidAPI/FunctionFactory.h"
#include "MantidAPI/Expression.h"
#include <boost/lexical_cast.hpp>
#include <set>
......@@ -12,6 +14,8 @@ namespace Mantid
namespace API
{
DECLARE_FUNCTION(MultiDomainFunction)
/**
* Associate a member function and a domain. The function will only be applied
* to this domain.
......@@ -74,6 +78,24 @@ namespace API
countNumberOfDomains();
}
/**
* Populates a vector with domain indices assigned to function i.
*/
void MultiDomainFunction::getFunctionDomains(size_t i, const CompositeDomain& cd, std::vector<size_t>& domains)const
{
auto it = m_domains.find(i);
if (it == m_domains.end())
{// apply to all domains
domains.resize(cd.getNParts());
size_t i = 0;
std::generate(domains.begin(),domains.end(),[&i]()->size_t{return i++;});
}
else
{// apply to selected domains
domains.assign(it->second.begin(),it->second.end());
}
}
/// Function you want to fit to.
/// @param domain :: The buffer for writing the calculated values. Must be big enough to accept dataSize() values
void MultiDomainFunction::function(const FunctionDomain& domain, FunctionValues& values)const
......@@ -97,25 +119,15 @@ namespace API
throw std::invalid_argument("MultiDomainFunction: domain and values have different sizes.");
}
countValueOffsets(cd);
// evaluate member functions
values.zeroCalculated();
for(size_t iFun = 0; iFun < nFunctions(); ++iFun)
{
// find the domains member function must be applied to
std::vector<size_t> domains;
auto it = m_domains.find(iFun);
if (it == m_domains.end())
{// apply to all domains
domains.resize(cd.getNParts());
size_t i = 0;
std::generate(domains.begin(),domains.end(),[&i]()->size_t{return i++;});
}
else
{// apply to selected domains
domains.assign(it->second.begin(),it->second.end());
}
getFunctionDomains(iFun, cd, domains);
countValueOffsets(cd);
for(auto i = domains.begin(); i != domains.end(); ++i)
{
const FunctionDomain& d = cd.getDomain(*i);
......@@ -143,24 +155,14 @@ namespace API
") for MultiDomainFunction (max index " + boost::lexical_cast<std::string>(m_maxIndex) + ").");
}
countValueOffsets(cd);
// evaluate member functions derivatives
for(size_t iFun = 0; iFun < nFunctions(); ++iFun)
{
// find the domains member function must be applied to
std::vector<size_t> domains;
auto it = m_domains.find(iFun);
if (it == m_domains.end())
{// apply to all domains
domains.resize(cd.getNParts());
size_t i = 0;
std::generate(domains.begin(),domains.end(),[&i]()->size_t{return i++;});
}
else
{// apply to selected domains
domains.assign(it->second.begin(),it->second.end());
}
getFunctionDomains(iFun, cd, domains);
countValueOffsets(cd);
for(auto i = domains.begin(); i != domains.end(); ++i)
{
const FunctionDomain& d = cd.getDomain(*i);
......@@ -170,5 +172,83 @@ namespace API
}
}
/// Return a value of attribute attName
IFunction::Attribute MultiDomainFunction::getLocalAttribute(size_t i, const std::string& attName)const
{
if (attName != "domains")
{
throw std::invalid_argument("MultiDomainFunction does not have attribute " + attName);
}
if (i >= nFunctions())
{
throw std::out_of_range("Function index is out of range.");
}
auto it = m_domains.find(i);
if (it == m_domains.end())
{
return IFunction::Attribute("All");
}
else if (it->second.size() == 1 && it->second.front() == i)
{
return IFunction::Attribute("i");
}
else if ( !it->second.empty() )
{
std::string out(boost::lexical_cast<std::string>(it->second.front()));
for(auto i = it->second.begin() + 1; i != it->second.end(); ++it)
{
out += "," + boost::lexical_cast<std::string>(*i);
}
return IFunction::Attribute(out);
}
return IFunction::Attribute("");
}
/**
* Set a value to attribute attName
*/
void MultiDomainFunction::setLocalAttribute(size_t i, const std::string& attName,const IFunction::Attribute& att)
{
if (attName != "domains")
{
throw std::invalid_argument("MultiDomainFunction does not have attribute " + attName);
}
if (i >= nFunctions())
{
throw std::out_of_range("Function index is out of range.");
}
std::string value = att.asString();
auto it = m_domains.find(i);
if (value == "All")
{// fit to all domains
if (it != m_domains.end())
{
m_domains.erase(it);
}
return;
}
else if (value == "i")
{// fit to domain with the same index as the function
setDomainIndex(i,i);
return;
}
else if (value.empty())
{// do not fit to any domain
setDomainIndices(i,std::vector<size_t>());
}
// fit to a selection of domains
std::vector<size_t> indx;
Expression list;
list.parse(value);
list.toList();
for(size_t k = 0; k < list.size(); ++k)
{
indx.push_back(boost::lexical_cast<size_t>(list[k].name()));
}
setDomainIndices(i,indx);
}
} // namespace API
} // namespace Mantid
......@@ -1206,7 +1206,6 @@ public:
delete mfun;
}
};
#endif /*COMPOSITEFUNCTIONTEST_H_*/
......@@ -402,6 +402,22 @@ public:
// TS_ASSERT_EQUALS(gauss->getParameter("Sigma"),0.33);
//}
void test_MultiDomainFunction_creation()
{
std::string fnString = "composite=MultiDomainFunction;"
"name=FunctionFactoryTest_FunctA;"
"name=FunctionFactoryTest_FunctB";
IFunction_sptr fun = FunctionFactory::Instance().createInitialized(fnString);
if (!fun)
{
std::cerr << "\nFailed to create MultiDomainFunction\n";
}
else
{
std::cerr << "\n" << fun->asString() << std::endl;
}
}
};
......
......@@ -7,6 +7,7 @@
#include "MantidAPI/JointDomain.h"
#include "MantidAPI/IFunction1D.h"
#include "MantidAPI/ParamFunction.h"
#include "MantidAPI/FunctionFactory.h"
#include <cxxtest/TestSuite.h>
#include <boost/make_shared.hpp>
......@@ -48,6 +49,8 @@ protected:
}
};
DECLARE_FUNCTION(MultiDomainFunctionTest_Function);
class MultiDomainFunctionTest : public CxxTest::TestSuite
{
public:
......@@ -237,6 +240,97 @@ public:
}
void test_attribute()
{
multi.clearDomainIndices();
multi.setLocalAttributeValue(0,"domains","i");
multi.setLocalAttributeValue(1,"domains","0,1");
multi.setLocalAttributeValue(2,"domains","0,2");
FunctionValues values(domain);
multi.function(domain,values);
double A = multi.getFunction(0)->getParameter("A") +
multi.getFunction(1)->getParameter("A") +
multi.getFunction(2)->getParameter("A");
double B = multi.getFunction(0)->getParameter("B") +
multi.getFunction(1)->getParameter("B") +
multi.getFunction(2)->getParameter("B");
auto d0 = static_cast<const FunctionDomain1D&>(domain.getDomain(0));
for(size_t i = 0; i < 9; ++i)
{
TS_ASSERT_EQUALS(values.getCalculated(i), A + B * d0[i]);
}
A = multi.getFunction(1)->getParameter("A");
B = multi.getFunction(1)->getParameter("B");
auto d1 = static_cast<const FunctionDomain1D&>(domain.getDomain(1));
for(size_t i = 9; i < 19; ++i)
{
TS_ASSERT_EQUALS(values.getCalculated(i), A + B * d1[i-9]);
}
A = multi.getFunction(2)->getParameter("A");
B = multi.getFunction(2)->getParameter("B");
auto d2 = static_cast<const FunctionDomain1D&>(domain.getDomain(2));
for(size_t i = 19; i < 30; ++i)
{
TS_ASSERT_EQUALS(values.getCalculated(i), A + B * d2[i-19]);
}
}
void test_attribute_in_FunctionFactory()
{
std::string ini = "composite=MultiDomainFunction;"
"name=MultiDomainFunctionTest_Function,A=0,B=1,$domains=i;"
"name=MultiDomainFunctionTest_Function,A=1,B=2,$domains=(0,1);"
"name=MultiDomainFunctionTest_Function,A=2,B=3,$domains=(0,2)"
;
auto mfun = boost::dynamic_pointer_cast<CompositeFunction>(FunctionFactory::Instance().createInitialized(ini));
FunctionValues values(domain);
mfun->function(domain,values);
double A = mfun->getFunction(0)->getParameter("A") +
mfun->getFunction(1)->getParameter("A") +
mfun->getFunction(2)->getParameter("A");
double B = mfun->getFunction(0)->getParameter("B") +
mfun->getFunction(1)->getParameter("B") +
mfun->getFunction(2)->getParameter("B");
auto d0 = static_cast<const FunctionDomain1D&>(domain.getDomain(0));
double checksum = 0;
for(size_t i = 0; i < 9; ++i)
{
TS_ASSERT_EQUALS(values.getCalculated(i), A + B * d0[i]);
checksum += values.getCalculated(i);
}
TS_ASSERT_DIFFERS(checksum,0);
checksum = 0;
A = mfun->getFunction(1)->getParameter("A");
B = mfun->getFunction(1)->getParameter("B");
auto d1 = static_cast<const FunctionDomain1D&>(domain.getDomain(1));
for(size_t i = 9; i < 19; ++i)
{
TS_ASSERT_EQUALS(values.getCalculated(i), A + B * d1[i-9]);
checksum += values.getCalculated(i);
}
TS_ASSERT_DIFFERS(checksum,0);
checksum = 0;
A = mfun->getFunction(2)->getParameter("A");
B = mfun->getFunction(2)->getParameter("B");
auto d2 = static_cast<const FunctionDomain1D&>(domain.getDomain(2));
for(size_t i = 19; i < 30; ++i)
{
TS_ASSERT_EQUALS(values.getCalculated(i), A + B * d2[i-19]);
checksum += values.getCalculated(i);
}
TS_ASSERT_DIFFERS(checksum,0);
}
private:
MultiDomainFunction multi;
JointDomain domain;
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment