Commit cb01a173 authored by Roman Tolchenov's avatar Roman Tolchenov
Browse files

Added FunctionFactory, a few changes to the Fit algorithm. re #965

parent e76e2e78
......@@ -220,6 +220,10 @@
RelativePath=".\src\FrameworkManager.cpp"
>
</File>
<File
RelativePath=".\src\FunctionFactory.cpp"
>
</File>
<File
RelativePath=".\src\IFunction.cpp"
>
......@@ -402,6 +406,10 @@
RelativePath=".\test\FrameworkManagerTest.h"
>
</File>
<File
RelativePath=".\inc\MantidAPI\FunctionFactory.h"
>
</File>
<File
RelativePath=".\test\GaussianErrorHelperTest.h"
>
......
#ifndef MANTID_API_FUNCTIONFACTORY_H_
#define MANTID_API_FUNCTIONFACTORY_H_
//----------------------------------------------------------------------
// Includes
//----------------------------------------------------------------------
#include <vector>
#include "MantidAPI/DllExport.h"
#include "MantidKernel/DynamicFactory.h"
#include "MantidKernel/SingletonHolder.h"
namespace Mantid
{
//----------------------------------------------------------------------
// Forward declarations
//----------------------------------------------------------------------
namespace Kernel
{
class Logger;
}
namespace API
{
class IFunction;
}
namespace API
{
/** @class FunctionFactoryImpl
The FunctionFactory class is in charge of the creation of concrete
instances of fitting functions. It inherits most of its implementation from
the Dynamic Factory base class.
It is implemented as a singleton class.
@author Roman Tolchenov, Tessella Support Services plc
@date 27/10/2009
Copyright &copy; 2007 STFC Rutherford Appleton Laboratories
This file is part of Mantid.
Mantid is free software; you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation; either version 3 of the License, or
(at your option) any later version.
Mantid is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
File change history is stored at: <https://svn.mantidproject.org/mantid/trunk/Code/Mantid>
*/
class EXPORT_OPT_MANTID_API FunctionFactoryImpl : public Kernel::DynamicFactory<IFunction>
{
public:
/**Creates an instance of a function
* @param type The function's type
*/
IFunction* createFunction(const std::string& type) const
{
return createUnwrapped(type);
}
private:
friend struct Mantid::Kernel::CreateUsingNew<FunctionFactoryImpl>;
/// Private Constructor for singleton class
FunctionFactoryImpl();
/// Private copy constructor - NO COPY ALLOWED
FunctionFactoryImpl(const FunctionFactoryImpl&);
/// Private assignment operator - NO ASSIGNMENT ALLOWED
FunctionFactoryImpl& operator = (const FunctionFactoryImpl&);
///Private Destructor
virtual ~FunctionFactoryImpl();
///static reference to the logger class
Kernel::Logger& g_log;
};
///Forward declaration of a specialisation of SingletonHolder for AlgorithmFactoryImpl (needed for dllexport/dllimport) and a typedef for it.
#ifdef _WIN32
// this breaks new namespace declaraion rules; need to find a better fix
template class EXPORT_OPT_MANTID_API Mantid::Kernel::SingletonHolder<FunctionFactoryImpl>;
#endif /* _WIN32 */
typedef EXPORT_OPT_MANTID_API Mantid::Kernel::SingletonHolder<FunctionFactoryImpl> FunctionFactory;
} // namespace API
} // namespace Mantid
#endif /*MANTID_API_FUNCTIONFACTORY_H_*/
......@@ -6,6 +6,7 @@
//----------------------------------------------------------------------
#include "MantidKernel/System.h"
#include "MantidKernel/Unit.h"
#include "MantidAPI/FunctionFactory.h"
#include "boost/shared_ptr.hpp"
#include <string>
#include <vector>
......@@ -132,4 +133,14 @@ protected:
} // namespace API
} // namespace Mantid
/**
* Macro for declaring a new type of function to be used with the FunctionFactory
*/
#define DECLARE_FUNCTION(classname) \
namespace { \
Mantid::Kernel::RegistrationHelper register_alg_##classname( \
((Mantid::API::FunctionFactory::Instance().subscribe<classname>(#classname)) \
, 0)); \
}
#endif /*MANTID_API_IFUNCTION_H_*/
......@@ -13,6 +13,8 @@ namespace Mantid
namespace API
{
DECLARE_FUNCTION(CompositeFunction)
/// Copy contructor
CompositeFunction::CompositeFunction(const CompositeFunction& f)
:m_nActive(f.m_nParams),m_nParams(f.m_nParams)
......
#include "MantidAPI/FunctionFactory.h"
#include "MantidAPI/IFunction.h"
#include "MantidKernel/Logger.h"
namespace Mantid
{
namespace API
{
FunctionFactoryImpl::FunctionFactoryImpl() : Kernel::DynamicFactory<IFunction>(), g_log(Kernel::Logger::get("FunctionFactory"))
{
}
FunctionFactoryImpl::~FunctionFactoryImpl()
{
}
} // namespace DataObjects
} // namespace Mantid
......@@ -80,12 +80,14 @@ double IFunction::parameter(int i)const
*/
double& IFunction::getParameter(const std::string& name)
{
std::string ucName(name);
//std::transform(name.begin(), name.end(), ucName.begin(), toupper);
std::vector<std::string>::const_iterator it =
std::find(m_parameterNames.begin(),m_parameterNames.end(),name);
std::find(m_parameterNames.begin(),m_parameterNames.end(),ucName);
if (it == m_parameterNames.end())
{
std::ostringstream msg;
msg << "Function parameter ("<<name<<") does not exist.";
msg << "Function parameter ("<<ucName<<") does not exist.";
throw std::invalid_argument(msg.str());
}
return m_parameters[it - m_parameterNames.begin()];
......@@ -97,12 +99,14 @@ double& IFunction::getParameter(const std::string& name)
*/
double IFunction::getParameter(const std::string& name)const
{
std::string ucName(name);
//std::transform(name.begin(), name.end(), ucName.begin(), toupper);
std::vector<std::string>::const_iterator it =
std::find(m_parameterNames.begin(),m_parameterNames.end(),name);
std::find(m_parameterNames.begin(),m_parameterNames.end(),ucName);
if (it == m_parameterNames.end())
{
std::ostringstream msg;
msg << "Function parameter ("<<name<<") does not exist.";
msg << "Function parameter ("<<ucName<<") does not exist.";
throw std::invalid_argument(msg.str());
}
return m_parameters[it - m_parameterNames.begin()];
......@@ -124,16 +128,18 @@ std::string IFunction::parameterName(int i)const
*/
void IFunction::declareParameter(const std::string& name,double initValue )
{
std::string ucName(name);
//std::transform(name.begin(), name.end(), ucName.begin(), toupper);
std::vector<std::string>::const_iterator it =
std::find(m_parameterNames.begin(),m_parameterNames.end(),name);
std::find(m_parameterNames.begin(),m_parameterNames.end(),ucName);
if (it != m_parameterNames.end())
{
std::ostringstream msg;
msg << "Function parameter ("<<name<<") already exists.";
msg << "Function parameter ("<<ucName<<") already exists.";
throw std::invalid_argument(msg.str());
}
m_parameterNames.push_back(name);
m_parameterNames.push_back(ucName);
m_parameters.push_back(initValue);
}
......
......@@ -95,11 +95,13 @@ namespace Mantid
*/
virtual void afterDataRangedDetermined(const int& m_minX, const int& m_maxX){};
void processParameters();
/// Holds a copy of the value of the parameters that are actually least-squared fitted.
std::vector<double> m_fittedParameter;
//std::vector<double> m_fittedParameter;
/// Number of parameters.
size_t nParams()const{return m_fittedParameter.size();}
size_t nParams()const{return m_function->nActive();}
/// Pointer to the fitting function
API::IFunction* m_function;
......
......@@ -8,6 +8,7 @@
#include <iomanip>
#include "MantidKernel/Exception.h"
#include "MantidAPI/TableRow.h"
#include "MantidAPI/CompositeFunction.h"
#include "MantidDataObjects/Workspace2D.h"
#include "MantidKernel/UnitFactory.h"
......@@ -191,6 +192,8 @@ namespace CurveFitting
"A value in, or on the high x boundary of, the last bin the fitting range\n"
"(default the highest value of x)" );
declareProperty("InputParameters","","The name of a TableWorkspace holding fit parameters" );
declareProperty("MaxIterations", 500, mustBePositive->clone(),
"Stop after this number of iterations if a good fit is not found" );
declareProperty("Output Status","", Direction::Output);
......@@ -210,6 +213,8 @@ namespace CurveFitting
void Fit::exec()
{
processParameters();
if (m_function == NULL)
throw std::runtime_error("Function has not been set.");
......@@ -224,7 +229,7 @@ namespace CurveFitting
boost::shared_ptr<gsl_matrix> M( gsl_matrix_alloc(nParams(),1) );
J.setJ(M.get());
// note nData set to zero (last argument) hence this should avoid further memory problems
functionDeriv(&(inTest.front()), &J, &xValuesTest, 0);
functionDeriv(NULL, &J, &xValuesTest, 0);
}
catch (Exception::NotImplementedError&)
{
......@@ -348,7 +353,8 @@ namespace CurveFitting
for (size_t i = 0; i < nParams(); i++)
{
gsl_vector_set(initFuncArg, i, m_fittedParameter[i]);
//gsl_vector_set(initFuncArg, i, m_fittedParameter[i]);
gsl_vector_set(initFuncArg, i, m_function->activeParameter(i));
}
......@@ -436,7 +442,7 @@ namespace CurveFitting
// put final converged fitting values back into m_fittedParameter
for (size_t i = 0; i < nParams(); i++)
m_fittedParameter[i] = gsl_vector_get(s->x,i);
m_function->setActiveParameter(i,gsl_vector_get(s->x,i));
}
else
{
......@@ -457,8 +463,9 @@ namespace CurveFitting
finalCostFuncVal = simplexMinimizer->fval / dof;
// put final converged fitting values back into m_fittedParameter
for (unsigned int i = 0; i < m_fittedParameter.size(); i++)
m_fittedParameter[i] = gsl_vector_get(simplexMinimizer->x,i);
for (unsigned int i = 0; i < nParams(); i++)
//m_fittedParameter[i] = gsl_vector_get(simplexMinimizer->x,i);
m_function->setActiveParameter(i, gsl_vector_get(simplexMinimizer->x,i));
}
// Output summary to log file
......@@ -468,8 +475,8 @@ namespace CurveFitting
g_log.information() << "Iteration = " << iter << "\n" <<
"Status = " << reportOfFit << "\n" <<
"Chi^2/DoF = " << finalCostFuncVal << "\n";
for (size_t i = 0; i < m_fittedParameter.size(); i++)
g_log.information() << m_function->nameOfActive(i) << " = " << m_fittedParameter[i] << " \n";
for (size_t i = 0; i < nParams(); i++)
g_log.information() << m_function->nameOfActive(i) << " = " << m_function->activeParameter(i) << " \n";
// also output summary to properties
......@@ -508,7 +515,7 @@ namespace CurveFitting
for(size_t i=0;i<nParams();i++)
{
Mantid::API::TableRow row = m_result->appendRow();
row << m_function->nameOfActive(i) << m_fittedParameter[i];
row << m_function->nameOfActive(i) << m_function->activeParameter(i);
}
setProperty("OutputParameters",m_result);
......@@ -540,7 +547,7 @@ namespace CurveFitting
double* lOut = new double[l_data.n]; // to capture output from call to function()
function(&m_fittedParameter[0], lOut, l_data.X, l_data.n);
function(NULL, lOut, l_data.X, l_data.n);
for(unsigned int i=0; i<l_data.n; i++)
{
......@@ -567,11 +574,6 @@ namespace CurveFitting
void Fit::setFunction(API::IFunction* fun)
{
m_function = fun;
m_fittedParameter.resize(fun->nActive());
for(int i=0;i<fun->nActive();i++)
{
m_fittedParameter[i] = fun->activeParameter(i);
}
}
/** Calculate the fitting function.
......@@ -584,7 +586,7 @@ namespace CurveFitting
void Fit::function(const double* in, double* out, const double* xValues, const int& nData)
{
m_function->updateActive(in);
if (in) m_function->updateActive(in);
m_function->function(out,xValues,nData);
}
......@@ -600,9 +602,65 @@ namespace CurveFitting
*/
void Fit::functionDeriv(const double* in, Jacobian* out, const double* xValues, const int& nData)
{
m_function->updateActive(in);
if (in) m_function->updateActive(in);
m_function->functionDeriv(out,xValues,nData);
}
void Fit::processParameters()
{
std::string input = getProperty("InputParameters");
if (input.empty()) return;
typedef Poco::StringTokenizer tokenizer;
tokenizer functions(input, ";", tokenizer::TOK_IGNORE_EMPTY | tokenizer::TOK_TRIM);
bool isComposite = functions.count() > 1;
API::IFunction* function;
if (isComposite)
{
function = API::FunctionFactory::Instance().createFunction("CompositeFunction");
setFunction(function);
}
for (tokenizer::Iterator ifun = functions.begin(); ifun != functions.end(); ++ifun)
{
tokenizer params(*ifun, ",", tokenizer::TOK_IGNORE_EMPTY | tokenizer::TOK_TRIM);
std::map<std::string,std::string> param;
for (tokenizer::Iterator par = params.begin(); par != params.end(); ++par)
{
tokenizer name_value(*par, "=", tokenizer::TOK_IGNORE_EMPTY | tokenizer::TOK_TRIM);
if (name_value.count() > 1)
{
std::string name = name_value[0];
//std::transform(name.begin(), name.end(), name.begin(), toupper);
param[name] = name_value[1];
}
}
std::string functionName = param["function"];
if (functionName.empty())
throw std::runtime_error("Function is not defined");
API::IFunction* fun = API::FunctionFactory::Instance().createFunction(functionName);
fun->init();
if (isComposite)
static_cast<API::CompositeFunction*>(function)->addFunction(fun);
else
setFunction(fun);
std::map<std::string,std::string>::const_iterator par = param.begin();
for(;par!=param.end();++par)
{
if (par->first != "function")
{
//fun->getParameter(par->first) = boost::lexical_cast<double>(par->second);
fun->getParameter(par->first) = atof(par->second.c_str());
}
}
}
}
} // namespace Algorithm
} // namespace Mantid
......@@ -12,6 +12,8 @@ namespace CurveFitting
using namespace Kernel;
using namespace API;
DECLARE_FUNCTION(Gaussian)
void Gaussian::init()
{
declareParameter("Height", 0.0);
......
......@@ -11,6 +11,8 @@ namespace CurveFitting
using namespace Kernel;
using namespace API;
DECLARE_FUNCTION(LinearBackground)
void LinearBackground::init()
{
declareParameter("A0", 0.0);
......
......@@ -10,9 +10,6 @@
#include "MantidDataObjects/TableWorkspace.h"
#include "MantidAPI/TableRow.h"
#include <iostream>
#include <fstream>
using namespace Mantid;
using namespace Mantid::API;
using namespace Mantid::DataObjects;
......@@ -163,7 +160,6 @@ public:
WS_type ws = mkWS(1,0,10,0.1);
addNoise(ws,0.1);
storeWS("mfun",ws);
plotWS("mfun.txt",ws);
Fit alg;
alg.initialize();
......@@ -322,21 +318,6 @@ private:
}
}
void plotWS(const std::string& fname,WS_type ws)
{
std::string fn = "C:\\Documents and Settings\\hqs74821\\Desktop\\tmp\\" + fname;
std::ofstream fil(fn.c_str());
char sep = '\t';
for(int i=0;i<ws->blocksize();i++)
{
fil << ws->readX(0)[i];
for(int j=0;j<ws->getNumberHistograms();j++)
fil << sep << ws->readY(j)[i] << sep << ws->readE(j)[i];
fil << '\n';
}
fil.close();
}
};
#endif /*COMPOSITEFUNCTIONTEST_H_*/
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment