Skip to content
Snippets Groups Projects
Commit 2c9ea775 authored by Zhou, Wenduo's avatar Zhou, Wenduo
Browse files

Modifed to be inherited. Refs #7476.

Modify CostFuncLeastSquare such that CostFuncRwp can inherit from it.
parent 651d041f
No related merge requests found
......@@ -94,7 +94,14 @@ protected:
API::FunctionValues_sptr values,
bool evalFunction = true, bool evalDeriv = true, bool evalHessian = true) const;
private:
/// Get weight (1/sigma)
virtual double getWeight(API::FunctionValues_sptr values, size_t i, double sqrtW=1.0) const;
/// Calcualte sqrt(W). Final cost function = sum_i [ (obs_i - cal_i) / (sigma * sqrt(W))]**2
virtual double calSqrtW(API::FunctionValues_sptr values) const;
/// Flag to include constraint in cost function value
const bool m_includePenalty;
mutable double m_value;
mutable GSLVector m_der;
......
......@@ -25,7 +25,10 @@ DECLARE_COSTFUNCTION(CostFuncLeastSquares,Least squares)
/**
* Constructor
*/
CostFuncLeastSquares::CostFuncLeastSquares() : CostFuncFitting(),m_value(0),m_pushed(false),
CostFuncLeastSquares::CostFuncLeastSquares() : CostFuncFitting(),
m_includePenalty(true),
m_value(0),
m_pushed(false),
m_log(Kernel::Logger::get("CostFuncLeastSquares")) {}
/** Calculate value of cost function
......@@ -56,13 +59,16 @@ double CostFuncLeastSquares::val() const
}
// add penalty
for(size_t i=0;i<m_function->nParams();++i)
if (m_includePenalty)
{
if ( !m_function->isActive(i) ) continue;
API::IConstraint* c = m_function->getConstraint(i);
if (c)
for(size_t i=0;i<m_function->nParams();++i)
{
m_value += c->check();
if ( !m_function->isActive(i) ) continue;
API::IConstraint* c = m_function->getConstraint(i);
if (c)
{
m_value += c->check();
}
}
}
......@@ -82,9 +88,12 @@ void CostFuncLeastSquares::addVal(API::FunctionDomain_sptr domain, API::Function
double retVal = 0.0;
double sqrtw = calSqrtW(values);
for (size_t i = 0; i < ny; i++)
{
double val = ( values->getCalculated(i) - values->getFitData(i) ) * values->getFitWeight(i);
// double val = ( values->getCalculated(i) - values->getFitData(i) ) * values->getFitWeight(i);
double val = ( values->getCalculated(i) - values->getFitData(i) ) * getWeight(values, i, sqrtw);
retVal += val * val;
}
......@@ -273,6 +282,7 @@ void CostFuncLeastSquares::addValDerivHessian(
std::cerr << std::endl;
}
}
double sqrtw = calSqrtW(values);
for(size_t ip = 0; ip < np; ++ip)
{
if ( !function->isActive(ip) ) continue;
......@@ -281,7 +291,8 @@ void CostFuncLeastSquares::addValDerivHessian(
{
double calc = values->getCalculated(i);
double obs = values->getFitData(i);
double w = values->getFitWeight(i);
// double w = values->getFitWeight(i);
double w = getWeight(values, i, sqrtw);
double y = ( calc - obs ) * w;
d += y * jacobian.get(i,ip) * w;
if (iActiveP == 0 && evalFunction)
......@@ -319,7 +330,8 @@ void CostFuncLeastSquares::addValDerivHessian(
double d = 0.0;
for(size_t k = 0; k < ny; ++k) // over fitting data
{
double w = values->getFitWeight(k);
// double w = values->getFitWeight(k);
double w = getWeight(values, k, sqrtw);
d += jacobian.get(k,i) * jacobian.get(k,j) * w * w;
}
PARALLEL_CRITICAL(hessian_set)
......@@ -497,5 +509,23 @@ void CostFuncLeastSquares::calActiveCovarianceMatrix(GSLMatrix& covar, double ep
}
//----------------------------------------------------------------------------------------------
/** Get weight of data point i(1/sigma)
*/
double CostFuncLeastSquares::getWeight(API::FunctionValues_sptr values, size_t i, double sqrtW) const
{
return (values->getFitWeight(i) / sqrtW);
}
//----------------------------------------------------------------------------------------------
/** Get square root of normalization weight (W)
*/
double CostFuncLeastSquares::calSqrtW(API::FunctionValues_sptr values) const
{
UNUSED_ARG(values);
return 1.0;
}
} // namespace CurveFitting
} // namespace Mantid
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment