Commit 9ca4fb75 authored by Alvarez, Gonzalo's avatar Alvarez, Gonzalo
Browse files

inverse for float and complex<float>

parent 53457b08
......@@ -33,8 +33,12 @@ extern "C" void zgesv_(int*,int*,std::complex<double>*,int*,int*,std::complex<do
//MSS
extern "C" int dgetrf_(int*, int*, double*, int*, int*, int*);
extern "C" int sgetrf_(int*, int*, float*, int*, int*, int*);
extern "C" int zgetrf_(int*, int*, std::complex<double>*, int*, int*, int*);
extern "C" int cgetrf_(int*, int*, std::complex<float>*, int*, int*, int*);
extern "C" int zgetri_(int*,
std::complex<double>*,
int*,
......@@ -43,8 +47,18 @@ extern "C" int zgetri_(int*,
int*,
int*);
extern "C" int cgetri_(int*,
std::complex<float>*,
int*,
int*,
std::complex<float>*,
int*,
int*);
extern "C" int dgetri_(int*, double*, int*, int*, double*, int*, int*);
extern "C" int sgetri_(int*, float*, int*, int*, float*, int*, int*);
extern "C" int dgesdd_(char* jobz,
int* m,
int* n,
......@@ -325,6 +339,16 @@ inline void GETRF(int ma,int na,std::complex<double>* a,int lda,int* pivot,int&
zgetrf_(&ma,&na,a,&lda,pivot,&info);
}
inline void GETRF(int ma, int na, float* a, int lda, int* pivot, int& info)
{
sgetrf_(&ma,&na,a,&lda,pivot,&info);
}
inline void GETRF(int ma,int na,std::complex<float>* a,int lda,int* pivot,int& info)
{
cgetrf_(&ma,&na,a,&lda,pivot,&info);
}
inline void GETRI(int na,
double* a,
int lda,
......@@ -347,6 +371,29 @@ inline void GETRI(int na,
zgetri_(&na,a,&lda,pivot,work,&lwork,&info);
}
inline void GETRI(int na,
float* a,
int lda,
int* pivot,
float* work,
int lwork,
int& info)
{
sgetri_(&na,a,&lda,pivot,work,&lwork,&info);
}
inline void GETRI(int na,
std::complex<float>* a,
int lda,
int* pivot,
std::complex<float>* work,
int lwork,
int& info)
{
cgetri_(&na,a,&lda,pivot,work,&lwork,&info);
}
inline void GESDD(char* jobz,
int* m,
int* n,
......
......@@ -210,7 +210,9 @@ void geev(char jobvl,
#endif
}
void inverse(Matrix<std::complex<double> > &m)
template<typename T>
typename std::enable_if<Loki::TypeTraits<T>::isArith, void>::type
inverse(Matrix<std::complex<T> > &m)
{
#ifdef NO_LAPACK
throw RuntimeError("inverse: NO LAPACK!\n");
......@@ -218,19 +220,21 @@ void inverse(Matrix<std::complex<double> > &m)
int n = m.rows();
int info = 0;
Vector<int>::Type ipiv(n,0);
psimag::LAPACK::zgetrf_(&n,&n,&(m(0,0)),&n,&(ipiv[0]),&info);
psimag::LAPACK::GETRF(&n,&n,&(m(0,0)),&n,&(ipiv[0]),&info);
int lwork = -1;
Vector<std::complex<double> >::Type work(2);
psimag::LAPACK::zgetri_(&n,&(m(0,0)),&n,&(ipiv[0]),&(work[0]),&lwork,&info);
typename Vector<std::complex<T> >::Type work(2);
psimag::LAPACK::GETRI(&n,&(m(0,0)),&n,&(ipiv[0]),&(work[0]),&lwork,&info);
lwork = static_cast<int>(PsimagLite::real(work[0]));
work.resize(lwork+2);
psimag::LAPACK::zgetri_(&n,&(m(0,0)),&n,&(ipiv[0]),&(work[0]),&lwork,&info);
String s = "zgetri_ failed\n";
psimag::LAPACK::GETRI(&n,&(m(0,0)),&n,&(ipiv[0]),&(work[0]),&lwork,&info);
String s = "[cz]getri_ failed\n";
if (info!=0) throw RuntimeError(s.c_str());
#endif
}
void inverse(Matrix<double> &m)
template<typename T>
typename std::enable_if<Loki::TypeTraits<T>::isArith, void>::type
inverse(Matrix<T> &m)
{
#ifdef NO_LAPACK
throw RuntimeError("inverse: NO LAPACK!\n");
......@@ -238,13 +242,13 @@ void inverse(Matrix<double> &m)
int n = m.rows();
int info = 0;
Vector<int>::Type ipiv(n,0);
psimag::LAPACK::dgetrf_(&n,&n,&(m(0,0)),&n,&(ipiv[0]),&info);
psimag::LAPACK::GETRF(&n,&n,&(m(0,0)),&n,&(ipiv[0]),&info);
int lwork = -1;
Vector<double>::Type work(2);
psimag::LAPACK::dgetri_(&n,&(m(0,0)),&n,&(ipiv[0]),&(work[0]),&lwork,&info);
typename Vector<T>::Type work(2);
psimag::LAPACK::GETRI(&n,&(m(0,0)),&n,&(ipiv[0]),&(work[0]),&lwork,&info);
lwork = static_cast<int>(work[0]);
work.resize(lwork+2);
psimag::LAPACK::dgetri_(&n,&(m(0,0)),&n,&(ipiv[0]),&(work[0]),&lwork,&info);
psimag::LAPACK::GETRI(&n,&(m(0,0)),&n,&(ipiv[0]),&(work[0]),&lwork,&info);
String s = "dgetri_ failed\n";
if (info!=0) throw RuntimeError(s.c_str());
#endif
......
......@@ -531,6 +531,10 @@ void diag(Matrix<std::complex<float> > &m,Vector<float> ::Type& eigs,char option
void inverse(Matrix<std::complex<double> >& m);
void inverse(Matrix<double>& m);
void inverse(Matrix<std::complex<float> >& m);
void inverse(Matrix<float>& m);
// end in Matrix.cpp
template<typename T>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment