Loading src/dmrg_magma.h +242 −1 Original line number Diff line number Diff line Loading @@ -20,6 +20,229 @@ #ifdef __cplusplus static void magmablas_Xgemm_vbatched_max_nocheck( magma_trans_t transA, magma_trans_t transB, magma_int_t* m, magma_int_t* n, magma_int_t* k, double alpha, double const * const * dA_array, magma_int_t* ldda, double const * const * dB_array, magma_int_t* lddb, double beta, double **dC_array, magma_int_t* lddc, magma_int_t batchCount, magma_int_t max_m, magma_int_t max_n, magma_int_t max_k, magma_queue_t queue ) { magmablas_dgemm_vbatched_max_nocheck( transA, transB, m, n, k, alpha, dA_array, ldda, dB_array, lddb, beta, dC_array, lddc, batchCount, max_m, max_n, max_k, queue ); } static void magmablas_Xgemm_vbatched_max_nocheck( magma_trans_t transA, magma_trans_t transB, magma_int_t* m, magma_int_t* n, magma_int_t* k, float alpha, float const * const * dA_array, magma_int_t* ldda, float const * const * dB_array, magma_int_t* lddb, float beta, float **dC_array, magma_int_t* lddc, magma_int_t batchCount, magma_int_t max_m, magma_int_t max_n, magma_int_t max_k, magma_queue_t queue ) { magmablas_sgemm_vbatched_max_nocheck( transA, transB, m, n, k, alpha, dA_array, ldda, dB_array, lddb, beta, dC_array, lddc, batchCount, max_m, max_n, max_k, queue ); } static void magmablas_Xgemm_vbatched_max_nocheck( magma_trans_t transA, magma_trans_t transB, magma_int_t* m, magma_int_t* n, magma_int_t* k, magmaDoubleComplex alpha, magmaDoubleComplex const * const * dA_array, magma_int_t* ldda, magmaDoubleComplex const * const * dB_array, magma_int_t* lddb, magmaDoubleComplex beta, magmaDoubleComplex **dC_array, magma_int_t* lddc, magma_int_t batchCount, magma_int_t max_m, magma_int_t max_n, magma_int_t max_k, magma_queue_t queue ) { magmablas_zgemm_vbatched_max_nocheck( transA, transB, m, n, k, alpha, dA_array, ldda, dB_array, lddb, beta, dC_array, lddc, batchCount, max_m, max_n, max_k, queue ); } static void magmablas_Xgemm_vbatched_max_nocheck( magma_trans_t transA, magma_trans_t transB, magma_int_t* m, magma_int_t* n, magma_int_t* k, magmaFloatComplex alpha, magmaFloatComplex const * const * dA_array, magma_int_t* ldda, magmaFloatComplex const * const * dB_array, magma_int_t* lddb, magmaFloatComplex beta, magmaFloatComplex **dC_array, magma_int_t* lddc, magma_int_t batchCount, magma_int_t max_m, magma_int_t max_n, magma_int_t max_k, magma_queue_t queue ) { magmablas_cgemm_vbatched_max_nocheck( transA, transB, m, n, k, alpha, dA_array, ldda, dB_array, lddb, beta, dC_array, lddc, batchCount, max_m, max_n, max_k, queue ); } static void magmablas_Xgemm_vbatched_max( magma_trans_t transA, magma_trans_t transB, magma_int_t* m, magma_int_t* n, magma_int_t* k, double alpha, double const * const * dA_array, magma_int_t* ldda, double const * const * dB_array, magma_int_t* lddb, double beta, double **dC_array, magma_int_t* lddc, magma_int_t batchCount, magma_int_t max_m, magma_int_t max_n, magma_int_t max_k, magma_queue_t queue ) { magmablas_dgemm_vbatched_max( transA, transB, m, n, k, alpha, dA_array, ldda, dB_array, lddb, beta, dC_array, lddc, batchCount, max_m, max_n, max_k, queue ); } static void magmablas_Xgemm_vbatched_max( magma_trans_t transA, magma_trans_t transB, magma_int_t* m, magma_int_t* n, magma_int_t* k, float alpha, float const * const * dA_array, magma_int_t* ldda, float const * const * dB_array, magma_int_t* lddb, float beta, float **dC_array, magma_int_t* lddc, magma_int_t batchCount, magma_int_t max_m, magma_int_t max_n, magma_int_t max_k, magma_queue_t queue ) { magmablas_sgemm_vbatched_max( transA, transB, m, n, k, alpha, dA_array, ldda, dB_array, lddb, beta, dC_array, lddc, batchCount, max_m, max_n, max_k, queue ); } static void magmablas_Xgemm_vbatched_max( magma_trans_t transA, magma_trans_t transB, magma_int_t* m, magma_int_t* n, magma_int_t* k, magmaDoubleComplex alpha, magmaDoubleComplex const * const * dA_array, magma_int_t* ldda, magmaDoubleComplex const * const * dB_array, magma_int_t* lddb, magmaDoubleComplex beta, magmaDoubleComplex **dC_array, magma_int_t* lddc, magma_int_t batchCount, magma_int_t max_m, magma_int_t max_n, magma_int_t max_k, magma_queue_t queue ) { magmablas_zgemm_vbatched_max( transA, transB, m, n, k, alpha, dA_array, ldda, dB_array, lddb, beta, dC_array, lddc, batchCount, max_m, max_n, max_k, queue ); } static void magmablas_Xgemm_vbatched_max( magma_trans_t transA, magma_trans_t transB, magma_int_t* m, magma_int_t* n, magma_int_t* k, magmaFloatComplex alpha, magmaFloatComplex const * const * dA_array, magma_int_t* ldda, magmaFloatComplex const * const * dB_array, magma_int_t* lddb, magmaFloatComplex beta, magmaFloatComplex **dC_array, magma_int_t* lddc, magma_int_t batchCount, magma_int_t max_m, magma_int_t max_n, magma_int_t max_k, magma_queue_t queue ) { magmablas_cgemm_vbatched_max( transA, transB, m, n, k, alpha, dA_array, ldda, dB_array, lddb, beta, dC_array, lddc, batchCount, max_m, max_n, max_k, queue ); } static void magmablas_Xgemm_vbatched( magma_trans_t transA, magma_trans_t transB, Loading @@ -29,7 +252,8 @@ void magmablas_Xgemm_vbatched( magmaFloatComplex const * const * dB_array, magma_int_t* lddb, magmaFloatComplex beta, magmaFloatComplex **dC_array, magma_int_t* lddc, magma_int_t batchCount, magma_queue_t queue ) { magma_int_t batchCount, magma_queue_t queue ) { magmablas_cgemm_vbatched( transA, transB, Loading @@ -43,6 +267,11 @@ void magmablas_Xgemm_vbatched( ); } static void magmablas_Xgemm_vbatched( magma_trans_t transA, magma_trans_t transB, Loading @@ -66,6 +295,10 @@ void magmablas_Xgemm_vbatched( ); } static void magmablas_Xgemm_vbatched( magma_trans_t transA, magma_trans_t transB, Loading Loading @@ -326,6 +559,8 @@ void magma_Xsetvector( magma_int_t n, #define magma_Xgetvector magma_zgetvector #define magmablas_Xgemm_vbatched magmablas_zgemm_vbatched #define magmablas_Xgemm_vbatched_max_nocheck magmablas_zgemm_vbatched_max_nocheck #define magmablas_Xgemm_vbatched_max magmablas_zgemm_vbatched_max #elif defined(USE_COMPLEX_C) Loading @@ -336,6 +571,8 @@ void magma_Xsetvector( magma_int_t n, #define magma_Xgetvector magma_cgetvector #define magmablas_Xgemm_vbatched magmablas_cgemm_vbatched #define magmablas_Xgemm_vbatched_max_nocheck magmablas_cgemm_vbatched_max_nocheck #define magmablas_Xgemm_vbatched_max magmablas_cgemm_vbatched_max #elif defined(USE_FLOAT) #define magma_Xsetmatrix magma_ssetmatrix Loading @@ -345,6 +582,8 @@ void magma_Xsetvector( magma_int_t n, #define magma_Xgetvector magma_sgetvector #define magmablas_Xgemm_vbatched magmablas_sgemm_vbatched #define magmablas_Xgemm_vbatched_max_nocheck magmablas_sgemm_vbatched_max_nocheck #define magmablas_Xgemm_vbatched_max magmablas_sgemm_vbatched_max #else Loading @@ -355,6 +594,8 @@ void magma_Xsetvector( magma_int_t n, #define magma_Xgetvector magma_dgetvector #define magmablas_Xgemm_vbatched magmablas_dgemm_vbatched #define magmablas_Xgemm_vbatched_max_nocheck magmablas_dgemm_vbatched_max_nocheck #define magmablas_Xgemm_vbatched_max magmablas_dgemm_vbatched_max #endif Loading src/dmrg_vbatch.c +42 −0 Original line number Diff line number Diff line Loading @@ -3,6 +3,10 @@ #include "dmrg_vbatch.h" #include "dmrg_lapack.h" #ifndef MAX #define MAX(x,y) (((x) > (y)) ? (x) : (y)) #endif #define USE_MALLOC #ifdef USE_INTEL_MKL Loading Loading @@ -356,6 +360,43 @@ double elapsed_time = 0; #ifdef USE_MAX_NOCHECK { int max_m = 0; int max_n = 0; int max_k = 0; int i = 0; for(i=0; i < group_count; i++) { const int m = m_array[i]; const int n = n_array[i]; const int k = k_array[i]; max_m = MAX(m,max_m); max_n = MAX(n,max_n); max_k = MAX(k,max_k); }; if (idebug >= 1) { printf("max_m=%d, max_n=%d, max_k=%d\n", max_m, max_n, max_k ); }; magmablas_Xgemm_vbatched_max_nocheck( transA, transB, m_vbatch, n_vbatch, k_vbatch, alpha, (FpType const * const *) a_vbatch, lda_vbatch, (FpType const * const *) b_vbatch, ldb_vbatch, beta, c_vbatch, ldc_vbatch, batch_size, max_m, max_n, max_k, queue ); } #else magmablas_Xgemm_vbatched( transA, transB, m_vbatch, n_vbatch, k_vbatch, Loading @@ -366,6 +407,7 @@ double elapsed_time = 0; c_vbatch, ldc_vbatch, batch_size, queue ); #endif }; Loading Loading
src/dmrg_magma.h +242 −1 Original line number Diff line number Diff line Loading @@ -20,6 +20,229 @@ #ifdef __cplusplus static void magmablas_Xgemm_vbatched_max_nocheck( magma_trans_t transA, magma_trans_t transB, magma_int_t* m, magma_int_t* n, magma_int_t* k, double alpha, double const * const * dA_array, magma_int_t* ldda, double const * const * dB_array, magma_int_t* lddb, double beta, double **dC_array, magma_int_t* lddc, magma_int_t batchCount, magma_int_t max_m, magma_int_t max_n, magma_int_t max_k, magma_queue_t queue ) { magmablas_dgemm_vbatched_max_nocheck( transA, transB, m, n, k, alpha, dA_array, ldda, dB_array, lddb, beta, dC_array, lddc, batchCount, max_m, max_n, max_k, queue ); } static void magmablas_Xgemm_vbatched_max_nocheck( magma_trans_t transA, magma_trans_t transB, magma_int_t* m, magma_int_t* n, magma_int_t* k, float alpha, float const * const * dA_array, magma_int_t* ldda, float const * const * dB_array, magma_int_t* lddb, float beta, float **dC_array, magma_int_t* lddc, magma_int_t batchCount, magma_int_t max_m, magma_int_t max_n, magma_int_t max_k, magma_queue_t queue ) { magmablas_sgemm_vbatched_max_nocheck( transA, transB, m, n, k, alpha, dA_array, ldda, dB_array, lddb, beta, dC_array, lddc, batchCount, max_m, max_n, max_k, queue ); } static void magmablas_Xgemm_vbatched_max_nocheck( magma_trans_t transA, magma_trans_t transB, magma_int_t* m, magma_int_t* n, magma_int_t* k, magmaDoubleComplex alpha, magmaDoubleComplex const * const * dA_array, magma_int_t* ldda, magmaDoubleComplex const * const * dB_array, magma_int_t* lddb, magmaDoubleComplex beta, magmaDoubleComplex **dC_array, magma_int_t* lddc, magma_int_t batchCount, magma_int_t max_m, magma_int_t max_n, magma_int_t max_k, magma_queue_t queue ) { magmablas_zgemm_vbatched_max_nocheck( transA, transB, m, n, k, alpha, dA_array, ldda, dB_array, lddb, beta, dC_array, lddc, batchCount, max_m, max_n, max_k, queue ); } static void magmablas_Xgemm_vbatched_max_nocheck( magma_trans_t transA, magma_trans_t transB, magma_int_t* m, magma_int_t* n, magma_int_t* k, magmaFloatComplex alpha, magmaFloatComplex const * const * dA_array, magma_int_t* ldda, magmaFloatComplex const * const * dB_array, magma_int_t* lddb, magmaFloatComplex beta, magmaFloatComplex **dC_array, magma_int_t* lddc, magma_int_t batchCount, magma_int_t max_m, magma_int_t max_n, magma_int_t max_k, magma_queue_t queue ) { magmablas_cgemm_vbatched_max_nocheck( transA, transB, m, n, k, alpha, dA_array, ldda, dB_array, lddb, beta, dC_array, lddc, batchCount, max_m, max_n, max_k, queue ); } static void magmablas_Xgemm_vbatched_max( magma_trans_t transA, magma_trans_t transB, magma_int_t* m, magma_int_t* n, magma_int_t* k, double alpha, double const * const * dA_array, magma_int_t* ldda, double const * const * dB_array, magma_int_t* lddb, double beta, double **dC_array, magma_int_t* lddc, magma_int_t batchCount, magma_int_t max_m, magma_int_t max_n, magma_int_t max_k, magma_queue_t queue ) { magmablas_dgemm_vbatched_max( transA, transB, m, n, k, alpha, dA_array, ldda, dB_array, lddb, beta, dC_array, lddc, batchCount, max_m, max_n, max_k, queue ); } static void magmablas_Xgemm_vbatched_max( magma_trans_t transA, magma_trans_t transB, magma_int_t* m, magma_int_t* n, magma_int_t* k, float alpha, float const * const * dA_array, magma_int_t* ldda, float const * const * dB_array, magma_int_t* lddb, float beta, float **dC_array, magma_int_t* lddc, magma_int_t batchCount, magma_int_t max_m, magma_int_t max_n, magma_int_t max_k, magma_queue_t queue ) { magmablas_sgemm_vbatched_max( transA, transB, m, n, k, alpha, dA_array, ldda, dB_array, lddb, beta, dC_array, lddc, batchCount, max_m, max_n, max_k, queue ); } static void magmablas_Xgemm_vbatched_max( magma_trans_t transA, magma_trans_t transB, magma_int_t* m, magma_int_t* n, magma_int_t* k, magmaDoubleComplex alpha, magmaDoubleComplex const * const * dA_array, magma_int_t* ldda, magmaDoubleComplex const * const * dB_array, magma_int_t* lddb, magmaDoubleComplex beta, magmaDoubleComplex **dC_array, magma_int_t* lddc, magma_int_t batchCount, magma_int_t max_m, magma_int_t max_n, magma_int_t max_k, magma_queue_t queue ) { magmablas_zgemm_vbatched_max( transA, transB, m, n, k, alpha, dA_array, ldda, dB_array, lddb, beta, dC_array, lddc, batchCount, max_m, max_n, max_k, queue ); } static void magmablas_Xgemm_vbatched_max( magma_trans_t transA, magma_trans_t transB, magma_int_t* m, magma_int_t* n, magma_int_t* k, magmaFloatComplex alpha, magmaFloatComplex const * const * dA_array, magma_int_t* ldda, magmaFloatComplex const * const * dB_array, magma_int_t* lddb, magmaFloatComplex beta, magmaFloatComplex **dC_array, magma_int_t* lddc, magma_int_t batchCount, magma_int_t max_m, magma_int_t max_n, magma_int_t max_k, magma_queue_t queue ) { magmablas_cgemm_vbatched_max( transA, transB, m, n, k, alpha, dA_array, ldda, dB_array, lddb, beta, dC_array, lddc, batchCount, max_m, max_n, max_k, queue ); } static void magmablas_Xgemm_vbatched( magma_trans_t transA, magma_trans_t transB, Loading @@ -29,7 +252,8 @@ void magmablas_Xgemm_vbatched( magmaFloatComplex const * const * dB_array, magma_int_t* lddb, magmaFloatComplex beta, magmaFloatComplex **dC_array, magma_int_t* lddc, magma_int_t batchCount, magma_queue_t queue ) { magma_int_t batchCount, magma_queue_t queue ) { magmablas_cgemm_vbatched( transA, transB, Loading @@ -43,6 +267,11 @@ void magmablas_Xgemm_vbatched( ); } static void magmablas_Xgemm_vbatched( magma_trans_t transA, magma_trans_t transB, Loading @@ -66,6 +295,10 @@ void magmablas_Xgemm_vbatched( ); } static void magmablas_Xgemm_vbatched( magma_trans_t transA, magma_trans_t transB, Loading Loading @@ -326,6 +559,8 @@ void magma_Xsetvector( magma_int_t n, #define magma_Xgetvector magma_zgetvector #define magmablas_Xgemm_vbatched magmablas_zgemm_vbatched #define magmablas_Xgemm_vbatched_max_nocheck magmablas_zgemm_vbatched_max_nocheck #define magmablas_Xgemm_vbatched_max magmablas_zgemm_vbatched_max #elif defined(USE_COMPLEX_C) Loading @@ -336,6 +571,8 @@ void magma_Xsetvector( magma_int_t n, #define magma_Xgetvector magma_cgetvector #define magmablas_Xgemm_vbatched magmablas_cgemm_vbatched #define magmablas_Xgemm_vbatched_max_nocheck magmablas_cgemm_vbatched_max_nocheck #define magmablas_Xgemm_vbatched_max magmablas_cgemm_vbatched_max #elif defined(USE_FLOAT) #define magma_Xsetmatrix magma_ssetmatrix Loading @@ -345,6 +582,8 @@ void magma_Xsetvector( magma_int_t n, #define magma_Xgetvector magma_sgetvector #define magmablas_Xgemm_vbatched magmablas_sgemm_vbatched #define magmablas_Xgemm_vbatched_max_nocheck magmablas_sgemm_vbatched_max_nocheck #define magmablas_Xgemm_vbatched_max magmablas_sgemm_vbatched_max #else Loading @@ -355,6 +594,8 @@ void magma_Xsetvector( magma_int_t n, #define magma_Xgetvector magma_dgetvector #define magmablas_Xgemm_vbatched magmablas_dgemm_vbatched #define magmablas_Xgemm_vbatched_max_nocheck magmablas_dgemm_vbatched_max_nocheck #define magmablas_Xgemm_vbatched_max magmablas_dgemm_vbatched_max #endif Loading
src/dmrg_vbatch.c +42 −0 Original line number Diff line number Diff line Loading @@ -3,6 +3,10 @@ #include "dmrg_vbatch.h" #include "dmrg_lapack.h" #ifndef MAX #define MAX(x,y) (((x) > (y)) ? (x) : (y)) #endif #define USE_MALLOC #ifdef USE_INTEL_MKL Loading Loading @@ -356,6 +360,43 @@ double elapsed_time = 0; #ifdef USE_MAX_NOCHECK { int max_m = 0; int max_n = 0; int max_k = 0; int i = 0; for(i=0; i < group_count; i++) { const int m = m_array[i]; const int n = n_array[i]; const int k = k_array[i]; max_m = MAX(m,max_m); max_n = MAX(n,max_n); max_k = MAX(k,max_k); }; if (idebug >= 1) { printf("max_m=%d, max_n=%d, max_k=%d\n", max_m, max_n, max_k ); }; magmablas_Xgemm_vbatched_max_nocheck( transA, transB, m_vbatch, n_vbatch, k_vbatch, alpha, (FpType const * const *) a_vbatch, lda_vbatch, (FpType const * const *) b_vbatch, ldb_vbatch, beta, c_vbatch, ldc_vbatch, batch_size, max_m, max_n, max_k, queue ); } #else magmablas_Xgemm_vbatched( transA, transB, m_vbatch, n_vbatch, k_vbatch, Loading @@ -366,6 +407,7 @@ double elapsed_time = 0; c_vbatch, ldc_vbatch, batch_size, queue ); #endif }; Loading