Commit c861cd6e authored by Roman Tolchenov's avatar Roman Tolchenov
Browse files

Order ties. Check for circular dependency. Re #22925

parent 90a6b027
......@@ -476,6 +476,8 @@ public:
virtual bool removeTie(size_t i);
/// Get the tie of i-th parameter
virtual ParameterTie *getTie(size_t i) const;
/// Put all ties in order in which they will be applied correctly.
void makeOrderedTies();
/// Write a parameter tie to a string
std::string writeTies() const;
//@}
......@@ -593,6 +595,8 @@ protected:
const API::IFunction::Attribute &value) const;
/// Add a new tie. Derived classes must provide storage for ties
virtual void addTie(std::unique_ptr<ParameterTie> tie);
bool hasOrderedTies() const;
void applyOrderedTies();
/// Writes itself into a string
virtual std::string
writeToString(const std::string &parentLocalAttributesStr = "") const;
......@@ -618,10 +622,12 @@ private:
boost::shared_ptr<Kernel::Matrix<double>> m_covar;
/// The chi-squared of the last fit
double m_chiSquared;
/// Holds parameter ties as <parameter index,tie pointer>
/// Holds parameter ties
std::vector<std::unique_ptr<ParameterTie>> m_ties;
/// Holds the constraints added to function
std::vector<std::unique_ptr<IConstraint>> m_constraints;
/// Ties ordered in order of correct application
std::vector<ParameterTie*> m_orderedTies;
};
/// shared pointer to the function base class
......
......@@ -63,6 +63,8 @@ public:
bool findParametersOf(const IFunction *fun) const;
/// Check if the tie is a constant
bool isConstant() const;
/// Get a list of parameters on the right-hand side of the equation
std::vector<ParameterReference> getRHSParameters() const;
protected:
mu::Parser *m_parser; ///< math parser
......
......@@ -614,10 +614,14 @@ std::string CompositeFunction::parameterLocalName(size_t i,
* Apply the ties.
*/
void CompositeFunction::applyTies() {
for (size_t i = 0; i < nFunctions(); i++) {
getFunction(i)->applyTies();
if (hasOrderedTies()) {
applyOrderedTies();
} else {
for (size_t i = 0; i < nFunctions(); i++) {
getFunction(i)->applyTies();
}
IFunction::applyTies();
}
IFunction::applyTies();
}
/**
......
......@@ -41,6 +41,20 @@ using namespace Geometry;
namespace {
/// static logger
Kernel::Logger g_log("IFunction");
/// Struct that helps sort ties in correct order of application.
struct TieNode {
// Iindex of the tied parameter
size_t left;
// Indices of parameters on the right-hand-side of the expression
std::vector<size_t> right;
// This tie must be applied before the other if the RHS of the other
// contains this (left) parameter.
bool operator<(TieNode const &other) {
return std::find(other.right.begin(), other.right.end(), left) !=
other.right.end();
}
};
}
/**
......@@ -252,12 +266,28 @@ void IFunction::addTie(std::unique_ptr<ParameterTie> tie) {
}
}
bool IFunction::hasOrderedTies() const
{
return !m_orderedTies.empty();
}
void IFunction::applyOrderedTies()
{
for (auto &&tie : m_orderedTies) {
tie->eval();
}
}
/**
* Apply the ties.
*/
void IFunction::applyTies() {
for (auto &m_tie : m_ties) {
m_tie->eval();
if (hasOrderedTies()) {
applyOrderedTies();
} else {
for (auto &tie : m_ties) {
tie->eval();
}
}
}
......@@ -1451,6 +1481,46 @@ std::vector<IFunction_sptr> IFunction::createEquivalentFunctions() const {
1, FunctionFactory::Instance().createInitialized(asString()));
}
/// Put all ties in order in which they will be applied correctly.
void IFunction::makeOrderedTies() {
m_orderedTies.clear();
std::list<TieNode> orderedTieNodes;
for (size_t i = 0; i < nParams(); ++i) {
auto tie = getTie(i);
if (!tie) {
continue;
}
std::vector<size_t> right;
auto rhsParameters = tie->getRHSParameters();
for (auto &&p : rhsParameters) {
right.push_back(this->getParameterIndex(p));
}
TieNode newNode{getParameterIndex(*tie), right};
bool before(false), after(false);
for (auto &&node : orderedTieNodes) {
if (newNode < node) {
before = true;
}
if (node < newNode) {
after = true;
}
}
if (before) {
if (after) {
throw std::runtime_error("Circular dependency in ties: " +
tie->asString(this));
}
orderedTieNodes.push_front(newNode);
} else {
orderedTieNodes.push_back(newNode);
}
}
for (auto &&node : orderedTieNodes) {
auto tie = getTie(node.left);
m_orderedTies.push_back(tie);
}
}
} // namespace API
} // namespace Mantid
......
......@@ -194,5 +194,15 @@ bool ParameterTie::findParametersOf(const IFunction *fun) const {
*/
bool ParameterTie::isConstant() const { return m_varMap.empty(); }
/** Get a list of parameters on the right-hand side of the equation
*/
std::vector<ParameterReference> ParameterTie::getRHSParameters() const {
std::vector<ParameterReference> out;
for (auto &&varPair : m_varMap) {
out.emplace_back(varPair.second);
}
return out;
}
} // namespace CurveFitting
} // namespace Mantid
......@@ -256,6 +256,44 @@ public:
TS_ASSERT(!mf->isFixed(3));
}
void test_circular_dependency() {
auto mf = makeFunction();
mf->tie("f0.a", "f3.f1.hi");
mf->tie("f0.b", "f2.sig + f0.a");
mf->tie("f2.sig", "f3.f1.hi");
mf->tie("f3.f1.hi", "f0.b");
TS_ASSERT_THROWS_EQUALS(mf->makeOrderedTies(), std::runtime_error & e,
std::string(e.what()),
"Circular dependency in ties: f3.f1.hi=f0.b");
}
void test_ties_order() {
auto mf = makeFunction();
mf->tie("f0.a", "f3.f1.hi");
mf->tie("f0.b", "f2.sig + f0.a");
mf->tie("f1.hi", "f1.cen*2");
mf->tie("f2.sig", "f3.f1.hi");
mf->tie("f3.f1.hi", "f1.sig");
mf->applyTies();
// Unordered ties applied wrongly
TS_ASSERT(fabs(mf->getParameter("f0.a") - mf->getParameter("f3.f1.hi")) > 1);
TS_ASSERT(fabs(mf->getParameter("f0.b") - (mf->getParameter("f2.sig") + mf->getParameter("f0.a"))) > 1);
TS_ASSERT(fabs(mf->getParameter("f1.hi") - mf->getParameter("f1.cen")*2.0) < 1e-5);
TS_ASSERT(fabs(mf->getParameter("f2.sig") - mf->getParameter("f3.f1.hi")) > 1);
TS_ASSERT(fabs(mf->getParameter("f3.f1.hi") - mf->getParameter("f2.sig")) > 1);
TS_ASSERT_THROWS_NOTHING(mf->makeOrderedTies());
mf->applyTies();
// After ordering apply correctly
TS_ASSERT_DELTA(mf->getParameter("f0.a"), mf->getParameter("f3.f1.hi"), 1e-5);
TS_ASSERT_DELTA(mf->getParameter("f0.b"), mf->getParameter("f2.sig") + mf->getParameter("f0.a"), 1e-5);
TS_ASSERT_DELTA(mf->getParameter("f1.hi"), mf->getParameter("f1.cen")*2.0, 1e-5);
TS_ASSERT_DELTA(mf->getParameter("f2.sig"), mf->getParameter("f3.f1.hi"), 1e-5);
TS_ASSERT_DELTA(mf->getParameter("f3.f1.hi"), mf->getParameter("f2.sig"), 1e-5);
}
private:
void mustThrow1(CompositeFunction *fun) { ParameterTie tie(fun, "sig", "0"); }
void mustThrow2(CompositeFunction *fun) {
......@@ -266,6 +304,27 @@ private:
}
void mustThrow4(IFunction *fun) { ParameterTie tie(fun, "f1.a", "0"); }
void mustThrow5(IFunction *fun) { ParameterTie tie(fun, "cen", "0"); }
IFunction_sptr makeFunction() {
CompositeFunction_sptr mf = CompositeFunction_sptr(new CompositeFunction);
IFunction_sptr bk = IFunction_sptr(new ParameterTieTest_Linear());
IFunction_sptr g1 = IFunction_sptr(new ParameterTieTest_Gauss());
IFunction_sptr g2 = IFunction_sptr(new ParameterTieTest_Gauss());
CompositeFunction_sptr cf = CompositeFunction_sptr(new CompositeFunction);
IFunction_sptr g3 = IFunction_sptr(new ParameterTieTest_Gauss());
IFunction_sptr g4 = IFunction_sptr(new ParameterTieTest_Gauss());
cf->addFunction(g3);
cf->addFunction(g4);
mf->addFunction(bk);
mf->addFunction(g1);
mf->addFunction(g2);
mf->addFunction(cf);
for(size_t i = 0; i < mf->nParams(); ++i) {
mf->setParameter(i, double(i + 1));
}
return mf;
}
};
#endif /*PARAMETERTIETEST_H_*/
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment