Commit 9ca4fb75 authored by Alvarez, Gonzalo's avatar Alvarez, Gonzalo
Browse files

inverse for float and complex<float>

parent 53457b08
Loading
Loading
Loading
Loading
+47 −0
Original line number Diff line number Diff line
@@ -33,8 +33,12 @@ extern "C" void zgesv_(int*,int*,std::complex<double>*,int*,int*,std::complex<do
//MSS
extern "C" int  dgetrf_(int*, int*, double*, int*, int*, int*);

extern "C" int  sgetrf_(int*, int*, float*, int*, int*, int*);

extern "C" int  zgetrf_(int*, int*, std::complex<double>*, int*, int*, int*);

extern "C" int  cgetrf_(int*, int*, std::complex<float>*, int*, int*, int*);

extern "C" int  zgetri_(int*,
                        std::complex<double>*,
                        int*,
@@ -43,8 +47,18 @@ extern "C" int zgetri_(int*,
                        int*,
                        int*);

extern "C" int  cgetri_(int*,
                        std::complex<float>*,
                        int*,
                        int*,
                        std::complex<float>*,
                        int*,
                        int*);

extern "C" int  dgetri_(int*, double*, int*, int*,  double*, int*, int*);

extern "C" int  sgetri_(int*, float*, int*, int*,  float*, int*, int*);

extern "C" int  dgesdd_(char* jobz,
                        int* m,
                        int* n,
@@ -325,6 +339,16 @@ inline void GETRF(int ma,int na,std::complex<double>* a,int lda,int* pivot,int&
	zgetrf_(&ma,&na,a,&lda,pivot,&info);
}

inline void GETRF(int ma, int na, float* a, int lda, int* pivot, int& info)
{
	sgetrf_(&ma,&na,a,&lda,pivot,&info);
}

inline void GETRF(int ma,int na,std::complex<float>* a,int lda,int* pivot,int& info)
{
	cgetrf_(&ma,&na,a,&lda,pivot,&info);
}

inline void GETRI(int na,
                  double* a,
                  int lda,
@@ -347,6 +371,29 @@ inline void GETRI(int na,
	zgetri_(&na,a,&lda,pivot,work,&lwork,&info);
}

inline void GETRI(int na,
                  float* a,
                  int lda,
                  int* pivot,
                  float* work,
                  int lwork,
                  int& info)
{
	sgetri_(&na,a,&lda,pivot,work,&lwork,&info);
}

inline void GETRI(int na,
                  std::complex<float>* a,
                  int lda,
                  int* pivot,
                  std::complex<float>* work,
                  int lwork,
                  int& info)
{
	cgetri_(&na,a,&lda,pivot,work,&lwork,&info);
}


inline void GESDD(char* jobz,
                  int* m,
                  int* n,
+15 −11
Original line number Diff line number Diff line
@@ -210,7 +210,9 @@ void geev(char jobvl,
#endif
}

void inverse(Matrix<std::complex<double> > &m)
template<typename T>
typename std::enable_if<Loki::TypeTraits<T>::isArith, void>::type
inverse(Matrix<std::complex<T> > &m)
{
#ifdef NO_LAPACK
	throw RuntimeError("inverse: NO LAPACK!\n");
@@ -218,19 +220,21 @@ void inverse(Matrix<std::complex<double> > &m)
	int n = m.rows();
	int info = 0;
	Vector<int>::Type ipiv(n,0);
	psimag::LAPACK::zgetrf_(&n,&n,&(m(0,0)),&n,&(ipiv[0]),&info);
	psimag::LAPACK::GETRF(&n,&n,&(m(0,0)),&n,&(ipiv[0]),&info);
	int lwork = -1;
	Vector<std::complex<double> >::Type work(2);
	psimag::LAPACK::zgetri_(&n,&(m(0,0)),&n,&(ipiv[0]),&(work[0]),&lwork,&info);
	typename Vector<std::complex<T> >::Type work(2);
	psimag::LAPACK::GETRI(&n,&(m(0,0)),&n,&(ipiv[0]),&(work[0]),&lwork,&info);
	lwork = static_cast<int>(PsimagLite::real(work[0]));
	work.resize(lwork+2);
	psimag::LAPACK::zgetri_(&n,&(m(0,0)),&n,&(ipiv[0]),&(work[0]),&lwork,&info);
	String s = "zgetri_ failed\n";
	psimag::LAPACK::GETRI(&n,&(m(0,0)),&n,&(ipiv[0]),&(work[0]),&lwork,&info);
	String s = "[cz]getri_ failed\n";
	if (info!=0) throw RuntimeError(s.c_str());
#endif
}

void inverse(Matrix<double> &m)
template<typename T>
typename std::enable_if<Loki::TypeTraits<T>::isArith, void>::type
inverse(Matrix<T> &m)
{
#ifdef NO_LAPACK
	throw RuntimeError("inverse: NO LAPACK!\n");
@@ -238,13 +242,13 @@ void inverse(Matrix<double> &m)
	int n = m.rows();
	int info = 0;
	Vector<int>::Type ipiv(n,0);
	psimag::LAPACK::dgetrf_(&n,&n,&(m(0,0)),&n,&(ipiv[0]),&info);
	psimag::LAPACK::GETRF(&n,&n,&(m(0,0)),&n,&(ipiv[0]),&info);
	int lwork = -1;
	Vector<double>::Type work(2);
	psimag::LAPACK::dgetri_(&n,&(m(0,0)),&n,&(ipiv[0]),&(work[0]),&lwork,&info);
	typename Vector<T>::Type work(2);
	psimag::LAPACK::GETRI(&n,&(m(0,0)),&n,&(ipiv[0]),&(work[0]),&lwork,&info);
	lwork = static_cast<int>(work[0]);
	work.resize(lwork+2);
	psimag::LAPACK::dgetri_(&n,&(m(0,0)),&n,&(ipiv[0]),&(work[0]),&lwork,&info);
	psimag::LAPACK::GETRI(&n,&(m(0,0)),&n,&(ipiv[0]),&(work[0]),&lwork,&info);
	String s = "dgetri_ failed\n";
	if (info!=0) throw RuntimeError(s.c_str());
#endif
+4 −0
Original line number Diff line number Diff line
@@ -531,6 +531,10 @@ void diag(Matrix<std::complex<float> > &m,Vector<float> ::Type& eigs,char option
void inverse(Matrix<std::complex<double> >& m);

void inverse(Matrix<double>& m);

void inverse(Matrix<std::complex<float> >& m);

void inverse(Matrix<float>& m);
// end in Matrix.cpp

template<typename T>