Commit d4ba2dae authored by Stephen's avatar Stephen
Browse files

Add tests for replacing and removing functions

parent cbdc577b
......@@ -90,7 +90,7 @@ public:
size_t nParams() const override;
// Total number of attributes, which includes global and local function
// attributes
size_t nAttributes() const noexcept override;
size_t nAttributes() const override;
// Total number of global attributes, defined at the composite function level
size_t nGlobalAttributes() const noexcept { return IFunction::nAttributes(); }
/// Returns the index of parameter name
......@@ -160,7 +160,7 @@ public:
/// Remove a function
void removeFunction(size_t i);
/// Replace a function
void replaceFunction(size_t i, const IFunction_sptr &f);
void replaceFunction(size_t functionIndex, const IFunction_sptr &f);
/// Replace a function
void replaceFunctionPtr(const IFunction_sptr &f_old,
const IFunction_sptr &f_new);
......
......@@ -28,6 +28,27 @@ namespace {
/// static logger
Kernel::Logger g_log("CompositeFunction");
constexpr char *ATTNUMDERIV = "NumDeriv";
template <typename T>
void replaceVariableIndexRange(std::vector<T> &variableFunctionIndexList,
size_t oldSize, size_t newSize,
const T functionIndex) {
auto itFun = std::find(variableFunctionIndexList.begin(),
variableFunctionIndexList.end(), functionIndex);
if (itFun != variableFunctionIndexList.end()) {
if (oldSize > newSize) {
variableFunctionIndexList.erase(itFun, itFun + oldSize - newSize);
} else if (oldSize < newSize) {
variableFunctionIndexList.insert(itFun, newSize - oldSize, functionIndex);
}
} else if (newSize > 0) {
using std::placeholders::_1;
itFun = std::find_if(variableFunctionIndexList.begin(),
variableFunctionIndexList.end(),
std::bind(std::greater<size_t>(), _1, functionIndex));
variableFunctionIndexList.insert(itFun, newSize, functionIndex);
}
}
} // namespace
using std::size_t;
......@@ -569,7 +590,7 @@ void CompositeFunction::replaceFunctionPtr(const IFunction_sptr &f_old,
}
/** Replace a function with a new one. The old function is deleted.
* @param i :: The index of the function to replace
* @param functionIndex :: The index of the function to replace
* @param f :: A pointer to the new function
*/
void CompositeFunction::replaceFunction(size_t functionIndex,
......@@ -587,33 +608,15 @@ void CompositeFunction::replaceFunction(size_t functionIndex,
size_t at_new = f->nAttributes();
// Modify function parameter and attribute indices:
// The new function may have different number of
{
auto itFun =
std::find(m_IFunction.begin(), m_IFunction.end(), functionIndex);
if (itFun != m_IFunction.end()) // functions must have at least 1 parameter
{
if (np_old > np_new) {
m_IFunction.erase(itFun, itFun + np_old - np_new);
} else if (np_old < np_new) {
m_IFunction.insert(itFun, np_new - np_old, functionIndex);
}
} else if (np_new > 0) // it could happen if the old function is an empty
// CompositeFunction
{
using std::placeholders::_1;
itFun =
std::find_if(m_IFunction.begin(), m_IFunction.end(),
std::bind(std::greater<size_t>(), _1, functionIndex));
m_IFunction.insert(itFun, np_new, functionIndex);
}
}
replaceVariableIndexRange(m_IFunction, np_old, np_new, functionIndex);
replaceVariableIndexRange(m_attributeIndex, at_old, at_new,functionIndex);
// Decrement attribute and parameter counts
size_t dnp = np_new - np_old;
size_t dna = at_new - at_old;
m_nParams += dnp;
m_nAttributes += dna;
// Shift the parameter offsets down by the total number of i-th function's
// params
for (size_t j = functionIndex + 1; j < nFunctions(); j++) {
......@@ -665,8 +668,7 @@ size_t CompositeFunction::attributeFunctionIndex(std::size_t i) const {
/**
* @param varName :: The variable name which may contain function index (
* [f<index.>]name )
* @param index :: Receives function index or throws std::invalid_argument
* @param name :: Receives the variable name
* @return pair containing the unprefixed variable name and functionIndex
*/
std::pair<std::string, size_t>
CompositeFunction::parseName(const std::string &varName) {
......
......@@ -1379,4 +1379,49 @@ public:
IFunction::Attribute("NewCubicAttribute")),
std::invalid_argument &);
}
void test_remove_function_correctly_shifts_down_attributes() {
auto mfun = std::make_unique<CompositeFunction>();
auto gauss = std::make_shared<Gauss<true>>();
auto background = std::make_shared<Linear<true>>();
auto cubic = std::make_shared<Cubic<true>>();
mfun->addFunction(gauss);
mfun->addFunction(background);
mfun->addFunction(cubic);
mfun->removeFunction(1);
TS_ASSERT_EQUALS(mfun->nAttributes(), 3);
TS_ASSERT_EQUALS(mfun->attributeName(0), "NumDeriv");
TS_ASSERT_EQUALS(mfun->attributeName(1), "f0.GaussAttribute");
TS_ASSERT_EQUALS(mfun->attributeName(2), "f1.CubicAttribute");
}
void test_replace_function_correctly_adds_attributes() {
auto mfun = std::make_unique<CompositeFunction>();
auto gauss = std::make_shared<Gauss<false>>();
auto background = std::make_shared<Linear<true>>();
auto cubic = std::make_shared<Cubic<true>>();
auto gaussWithAttributes = std::make_shared<Gauss<true>>();
mfun->addFunction(background);
mfun->addFunction(gauss);
mfun->addFunction(cubic);
TS_ASSERT_EQUALS(mfun->nAttributes(), 3);
TS_ASSERT_EQUALS(mfun->attributeName(0), "NumDeriv");
TS_ASSERT_EQUALS(mfun->attributeName(1), "f0.LinearAttribute");
TS_ASSERT_EQUALS(mfun->attributeName(2), "f2.CubicAttribute");
mfun->replaceFunction(1, gaussWithAttributes);
TS_ASSERT_EQUALS(mfun->nAttributes(), 4);
TS_ASSERT_EQUALS(mfun->attributeName(0), "NumDeriv");
TS_ASSERT_EQUALS(mfun->attributeName(1), "f0.LinearAttribute");
TS_ASSERT_EQUALS(mfun->attributeName(2), "f1.GaussAttribute");
TS_ASSERT_EQUALS(mfun->attributeName(3), "f2.CubicAttribute");
}
};
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment