Commit a290283b authored by D'azevedo, Ed's avatar D'azevedo, Ed
Browse files

add max_nocheck option to magmablas vbatched call

parent 713db885
Loading
Loading
Loading
Loading
+242 −1
Original line number Diff line number Diff line
@@ -20,6 +20,229 @@

#ifdef __cplusplus

static 
void magmablas_Xgemm_vbatched_max_nocheck(
    magma_trans_t transA, magma_trans_t transB, 
    magma_int_t* m, magma_int_t* n, magma_int_t* k,
    double alpha,
    double const * const * dA_array, magma_int_t* ldda,
    double const * const * dB_array, magma_int_t* lddb,
    double beta,
    double **dC_array, magma_int_t* lddc, 
    magma_int_t batchCount, 
    magma_int_t max_m, magma_int_t max_n, magma_int_t max_k, 
    magma_queue_t queue ) {


 magmablas_dgemm_vbatched_max_nocheck(
    transA, transB, 
    m,  n,  k,
    alpha,
    dA_array, ldda,
    dB_array, lddb,
    beta,
    dC_array, lddc, 
    batchCount, 
    max_m, max_n, max_k, 
    queue );
}


static 
void magmablas_Xgemm_vbatched_max_nocheck(
    magma_trans_t transA, magma_trans_t transB, 
    magma_int_t* m, magma_int_t* n, magma_int_t* k,
    float alpha,
    float const * const * dA_array, magma_int_t* ldda,
    float const * const * dB_array, magma_int_t* lddb,
    float beta,
    float **dC_array, magma_int_t* lddc, 
    magma_int_t batchCount, 
    magma_int_t max_m, magma_int_t max_n, magma_int_t max_k, 
    magma_queue_t queue ) {


 magmablas_sgemm_vbatched_max_nocheck(
    transA, transB, 
    m,  n,  k,
    alpha,
    dA_array, ldda,
    dB_array, lddb,
    beta,
    dC_array, lddc, 
    batchCount, 
    max_m, max_n, max_k, 
    queue );
}


static 
void magmablas_Xgemm_vbatched_max_nocheck(
    magma_trans_t transA, magma_trans_t transB, 
    magma_int_t* m, magma_int_t* n, magma_int_t* k,
    magmaDoubleComplex alpha,
    magmaDoubleComplex const * const * dA_array, magma_int_t* ldda,
    magmaDoubleComplex const * const * dB_array, magma_int_t* lddb,
    magmaDoubleComplex beta,
    magmaDoubleComplex **dC_array, magma_int_t* lddc, 
    magma_int_t batchCount, 
    magma_int_t max_m, magma_int_t max_n, magma_int_t max_k, 
    magma_queue_t queue ) {


 magmablas_zgemm_vbatched_max_nocheck(
    transA, transB, 
    m,  n,  k,
    alpha,
    dA_array, ldda,
    dB_array, lddb,
    beta,
    dC_array, lddc, 
    batchCount, 
    max_m, max_n, max_k, 
    queue );
}


static 
void magmablas_Xgemm_vbatched_max_nocheck(
    magma_trans_t transA, magma_trans_t transB, 
    magma_int_t* m, magma_int_t* n, magma_int_t* k,
    magmaFloatComplex alpha,
    magmaFloatComplex const * const * dA_array, magma_int_t* ldda,
    magmaFloatComplex const * const * dB_array, magma_int_t* lddb,
    magmaFloatComplex beta,
    magmaFloatComplex **dC_array, magma_int_t* lddc, 
    magma_int_t batchCount, 
    magma_int_t max_m, magma_int_t max_n, magma_int_t max_k, 
    magma_queue_t queue ) {


 magmablas_cgemm_vbatched_max_nocheck(
    transA, transB, 
    m,  n,  k,
    alpha,
    dA_array, ldda,
    dB_array, lddb,
    beta,
    dC_array, lddc, 
    batchCount, 
    max_m, max_n, max_k, 
    queue );
}


static 
void magmablas_Xgemm_vbatched_max(
    magma_trans_t transA, magma_trans_t transB, 
    magma_int_t* m, magma_int_t* n, magma_int_t* k,
    double alpha,
    double const * const * dA_array, magma_int_t* ldda,
    double const * const * dB_array, magma_int_t* lddb,
    double beta,
    double **dC_array, magma_int_t* lddc, 
    magma_int_t batchCount, 
    magma_int_t max_m, magma_int_t max_n, magma_int_t max_k, 
    magma_queue_t queue ) {


 magmablas_dgemm_vbatched_max(
    transA, transB, 
    m,  n,  k,
    alpha,
    dA_array, ldda,
    dB_array, lddb,
    beta,
    dC_array, lddc, 
    batchCount, 
    max_m, max_n, max_k, 
    queue );
}


static 
void magmablas_Xgemm_vbatched_max(
    magma_trans_t transA, magma_trans_t transB, 
    magma_int_t* m, magma_int_t* n, magma_int_t* k,
    float alpha,
    float const * const * dA_array, magma_int_t* ldda,
    float const * const * dB_array, magma_int_t* lddb,
    float beta,
    float **dC_array, magma_int_t* lddc, 
    magma_int_t batchCount, 
    magma_int_t max_m, magma_int_t max_n, magma_int_t max_k, 
    magma_queue_t queue ) {


 magmablas_sgemm_vbatched_max(
    transA, transB, 
    m,  n,  k,
    alpha,
    dA_array, ldda,
    dB_array, lddb,
    beta,
    dC_array, lddc, 
    batchCount, 
    max_m, max_n, max_k, 
    queue );
}


static 
void magmablas_Xgemm_vbatched_max(
    magma_trans_t transA, magma_trans_t transB, 
    magma_int_t* m, magma_int_t* n, magma_int_t* k,
    magmaDoubleComplex alpha,
    magmaDoubleComplex const * const * dA_array, magma_int_t* ldda,
    magmaDoubleComplex const * const * dB_array, magma_int_t* lddb,
    magmaDoubleComplex beta,
    magmaDoubleComplex **dC_array, magma_int_t* lddc, 
    magma_int_t batchCount, 
    magma_int_t max_m, magma_int_t max_n, magma_int_t max_k, 
    magma_queue_t queue ) {


 magmablas_zgemm_vbatched_max(
    transA, transB, 
    m,  n,  k,
    alpha,
    dA_array, ldda,
    dB_array, lddb,
    beta,
    dC_array, lddc, 
    batchCount, 
    max_m, max_n, max_k, 
    queue );
}


static 
void magmablas_Xgemm_vbatched_max(
    magma_trans_t transA, magma_trans_t transB, 
    magma_int_t* m, magma_int_t* n, magma_int_t* k,
    magmaFloatComplex alpha,
    magmaFloatComplex const * const * dA_array, magma_int_t* ldda,
    magmaFloatComplex const * const * dB_array, magma_int_t* lddb,
    magmaFloatComplex beta,
    magmaFloatComplex **dC_array, magma_int_t* lddc, 
    magma_int_t batchCount, 
    magma_int_t max_m, magma_int_t max_n, magma_int_t max_k, 
    magma_queue_t queue ) {


 magmablas_cgemm_vbatched_max(
    transA, transB, 
    m,  n,  k,
    alpha,
    dA_array, ldda,
    dB_array, lddb,
    beta,
    dC_array, lddc, 
    batchCount, 
    max_m, max_n, max_k, 
    queue );
}

static
void magmablas_Xgemm_vbatched(
    magma_trans_t transA, magma_trans_t transB, 
@@ -29,7 +252,8 @@ void magmablas_Xgemm_vbatched(
    magmaFloatComplex const * const * dB_array, magma_int_t* lddb,
    magmaFloatComplex beta,
    magmaFloatComplex **dC_array, magma_int_t* lddc, 
    magma_int_t batchCount, magma_queue_t queue ) {
    magma_int_t batchCount, 
    magma_queue_t queue ) {

        magmablas_cgemm_vbatched(
            transA, transB,
@@ -43,6 +267,11 @@ void magmablas_Xgemm_vbatched(
            );
    }






static
void magmablas_Xgemm_vbatched(
    magma_trans_t transA, magma_trans_t transB, 
@@ -66,6 +295,10 @@ void magmablas_Xgemm_vbatched(
            );
    }





static
void magmablas_Xgemm_vbatched(
    magma_trans_t transA, magma_trans_t transB, 
@@ -326,6 +559,8 @@ void magma_Xsetvector( magma_int_t n,
    #define magma_Xgetvector magma_zgetvector
  
    #define magmablas_Xgemm_vbatched magmablas_zgemm_vbatched
    #define magmablas_Xgemm_vbatched_max_nocheck magmablas_zgemm_vbatched_max_nocheck
    #define magmablas_Xgemm_vbatched_max magmablas_zgemm_vbatched_max
  
  #elif defined(USE_COMPLEX_C)

@@ -336,6 +571,8 @@ void magma_Xsetvector( magma_int_t n,
    #define magma_Xgetvector magma_cgetvector
  
    #define magmablas_Xgemm_vbatched magmablas_cgemm_vbatched
    #define magmablas_Xgemm_vbatched_max_nocheck magmablas_cgemm_vbatched_max_nocheck
    #define magmablas_Xgemm_vbatched_max magmablas_cgemm_vbatched_max

  #elif defined(USE_FLOAT)
    #define magma_Xsetmatrix magma_ssetmatrix
@@ -345,6 +582,8 @@ void magma_Xsetvector( magma_int_t n,
    #define magma_Xgetvector magma_sgetvector
  
    #define magmablas_Xgemm_vbatched magmablas_sgemm_vbatched
    #define magmablas_Xgemm_vbatched_max_nocheck magmablas_sgemm_vbatched_max_nocheck
    #define magmablas_Xgemm_vbatched_max magmablas_sgemm_vbatched_max

  #else
  
@@ -355,6 +594,8 @@ void magma_Xsetvector( magma_int_t n,
    #define magma_Xgetvector magma_dgetvector
  
    #define magmablas_Xgemm_vbatched magmablas_dgemm_vbatched
    #define magmablas_Xgemm_vbatched_max_nocheck magmablas_dgemm_vbatched_max_nocheck
    #define magmablas_Xgemm_vbatched_max magmablas_dgemm_vbatched_max
  
  #endif

+42 −0
Original line number Diff line number Diff line
@@ -3,6 +3,10 @@
#include "dmrg_vbatch.h"
#include "dmrg_lapack.h"

#ifndef MAX
#define MAX(x,y)  (((x) > (y)) ? (x) : (y))
#endif

#define USE_MALLOC

#ifdef USE_INTEL_MKL
@@ -356,6 +360,43 @@ double elapsed_time = 0;


     
#ifdef USE_MAX_NOCHECK
   {
    int max_m = 0;
    int max_n = 0;
    int max_k = 0;

    int i = 0;
    for(i=0; i < group_count; i++) {
            const int m = m_array[i];
            const int n = n_array[i];
            const int k = k_array[i];

            max_m = MAX(m,max_m);
            max_n = MAX(n,max_n);
            max_k = MAX(k,max_k);

    };
    if (idebug >= 1) {
      printf("max_m=%d, max_n=%d, max_k=%d\n",
              max_m,    max_n,    max_k );
    };


    magmablas_Xgemm_vbatched_max_nocheck( transA, transB,
           m_vbatch, n_vbatch, k_vbatch,
           alpha,
           (FpType const * const *) a_vbatch, lda_vbatch, 
           (FpType const * const *) b_vbatch, ldb_vbatch,
           beta,
           c_vbatch, ldc_vbatch,
           batch_size, 
           max_m, max_n, max_k,
           queue );


   }
#else

   magmablas_Xgemm_vbatched( transA, transB,
           m_vbatch, n_vbatch, k_vbatch,
@@ -366,6 +407,7 @@ double elapsed_time = 0;
           c_vbatch, ldc_vbatch,
           batch_size, queue );

#endif


  };