Commit 30bb897e authored by gbalduzz's avatar gbalduzz
Browse files

better CPU and GPU code unification with stream wrappers.

Better move assignment and constructor for stream and queue.
parent a779f570
Loading
Loading
Loading
Loading
+15 −15
Original line number Diff line number Diff line
@@ -124,7 +124,7 @@ inline magma_trans_t toMagmaTrans(const char x) {
inline void magmablas_gemm_vbatched(const char transa, const char transb, int* m, int* n, int* k,
                                    const float alpha, const float* const* a, int* lda,
                                    const float* const* b, int* ldb, const float beta, float** c,
                                    int* ldc, const int batch_count, magma_queue_t& queue) {
                                    int* ldc, const int batch_count, magma_queue_t queue) {
  magmablas_sgemm_vbatched(toMagmaTrans(transa), toMagmaTrans(transb), m, n, k, alpha, a, lda, b,
                           ldb, beta, c, ldc, batch_count, queue);
  checkErrorsCudaDebug();
@@ -132,7 +132,7 @@ inline void magmablas_gemm_vbatched(const char transa, const char transb, int* m
inline void magmablas_gemm_vbatched(const char transa, const char transb, int* m, int* n, int* k,
                                    const double alpha, const double* const* a, int* lda,
                                    const double* const* b, int* ldb, const double beta, double** c,
                                    int* ldc, const int batch_count, const magma_queue_t& queue) {
                                    int* ldc, const int batch_count, const magma_queue_t queue) {
  magmablas_dgemm_vbatched(toMagmaTrans(transa), toMagmaTrans(transb), m, n, k, alpha, a, lda, b,
                           ldb, beta, c, ldc, batch_count, queue);
  checkErrorsCudaDebug();
@@ -142,7 +142,7 @@ inline void magmablas_gemm_vbatched(const char transa, const char transb, int* m
                                    const std::complex<float>* const* a, int* lda,
                                    const std::complex<float>* const* b, int* ldb,
                                    const std::complex<float> beta, std::complex<float>** c,
                                    int* ldc, const int batch_count, const magma_queue_t& queue) {
                                    int* ldc, const int batch_count, const magma_queue_t queue) {
  using util::castCudaComplex;
  magmablas_cgemm_vbatched(toMagmaTrans(transa), toMagmaTrans(transb), m, n, k,
                           *castCudaComplex(alpha), castCudaComplex(a), lda, castCudaComplex(b),
@@ -154,7 +154,7 @@ inline void magmablas_gemm_vbatched(const char transa, const char transb, int* m
                                    const std::complex<double>* const* a, int* lda,
                                    const std::complex<double>* const* b, int* ldb,
                                    const std::complex<double> beta, std::complex<double>** c,
                                    int* ldc, const int batch_count, const magma_queue_t& queue) {
                                    int* ldc, const int batch_count, const magma_queue_t queue) {
  using util::castCudaComplex;
  magmablas_zgemm_vbatched(toMagmaTrans(transa), toMagmaTrans(transb), m, n, k,
                           *castCudaComplex(alpha), castCudaComplex(a), lda, castCudaComplex(b),
@@ -168,7 +168,7 @@ inline void magmablas_gemm_vbatched_max_nocheck(const char transa, const char tr
                                                const float* const* b, int* ldb, const float beta,
                                                float** c, int* ldc, const int batch_count,
                                                const int m_max, const int n_max, const int k_max,
                                                magma_queue_t& queue) {
                                                magma_queue_t queue) {
  magmablas_sgemm_vbatched_max_nocheck(toMagmaTrans(transa), toMagmaTrans(transb), m, n, k, alpha,
                                       a, lda, b, ldb, beta, c, ldc, batch_count, m_max, n_max,
                                       k_max, queue);
@@ -181,7 +181,7 @@ inline void magmablas_gemm_vbatched_max_nocheck(const char transa, const char tr
                                                const double* const* b, int* ldb, const double beta,
                                                double** c, int* ldc, const int batch_count,
                                                const int m_max, const int n_max, const int k_max,
                                                magma_queue_t& queue) {
                                                magma_queue_t queue) {
  magmablas_dgemm_vbatched_max_nocheck(toMagmaTrans(transa), toMagmaTrans(transb), m, n, k, alpha,
                                       a, lda, b, ldb, beta, c, ldc, batch_count, m_max, n_max,
                                       k_max, queue);
@@ -192,7 +192,7 @@ inline void magmablas_gemm_vbatched_max_nocheck(
    const char transa, const char transb, int* m, int* n, int* k, const std::complex<float> alpha,
    const std::complex<float>* const* a, int* lda, const std::complex<float>* const* b, int* ldb,
    const std::complex<float> beta, std::complex<float>** c, int* ldc, const int batch_count,
    const int m_max, const int n_max, const int k_max, magma_queue_t& queue) {
    const int m_max, const int n_max, const int k_max, magma_queue_t queue) {
  using util::castCudaComplex;
  magmablas_cgemm_vbatched_max_nocheck(
      toMagmaTrans(transa), toMagmaTrans(transb), m, n, k, *castCudaComplex(alpha),
@@ -205,7 +205,7 @@ inline void magmablas_gemm_vbatched_max_nocheck(
    const char transa, const char transb, int* m, int* n, int* k, const std::complex<double> alpha,
    const std::complex<double>* const* a, int* lda, const std::complex<double>* const* b, int* ldb,
    const std::complex<double> beta, std::complex<double>** c, int* ldc, const int batch_count,
    const int m_max, const int n_max, const int k_max, magma_queue_t& queue) {
    const int m_max, const int n_max, const int k_max, magma_queue_t queue) {
  using util::castCudaComplex;
  magmablas_zgemm_vbatched_max_nocheck(
      toMagmaTrans(transa), toMagmaTrans(transb), m, n, k, *castCudaComplex(alpha),
@@ -218,7 +218,7 @@ inline void magmablas_gemm_batched(const char transa, const char transb, const i
                                   const int k, const float alpha, const float* const* a,
                                   const int lda, const float* const* b, const int ldb,
                                   const float beta, float** c, const int ldc,
                                   const int batch_count, magma_queue_t& queue) {
                                   const int batch_count, magma_queue_t queue) {
  magmablas_sgemm_batched(toMagmaTrans(transa), toMagmaTrans(transb), m, n, k, alpha, a, lda, b,
                          ldb, beta, c, ldc, batch_count, queue);
  checkErrorsCudaDebug();
@@ -227,7 +227,7 @@ inline void magmablas_gemm_batched(const char transa, const char transb, const i
                                   const int k, const double alpha, const double* const* a,
                                   const int lda, const double* const* b, const int ldb,
                                   const double beta, double** c, const int ldc,
                                   const int batch_count, const magma_queue_t& queue) {
                                   const int batch_count, const magma_queue_t queue) {
  magmablas_dgemm_batched(toMagmaTrans(transa), toMagmaTrans(transb), m, n, k, alpha, a, lda, b,
                          ldb, beta, c, ldc, batch_count, queue);
  checkErrorsCudaDebug();
@@ -237,7 +237,7 @@ inline void magmablas_gemm_batched(const char transa, const char transb, const i
                                   const std::complex<float>* const* a, const int lda,
                                   const std::complex<float>* const* b, const int ldb,
                                   const std::complex<float> beta, std::complex<float>** c,
                                   const int ldc, const int batch_count, const magma_queue_t& queue) {
                                   const int ldc, const int batch_count, const magma_queue_t queue) {
  using util::castCudaComplex;
  magmablas_cgemm_batched(toMagmaTrans(transa), toMagmaTrans(transb), m, n, k,
                          *castCudaComplex(alpha), castCudaComplex(a), lda, castCudaComplex(b), ldb,
@@ -249,7 +249,7 @@ inline void magmablas_gemm_batched(const char transa, const char transb, const i
                                   const std::complex<double>* const* a, const int lda,
                                   const std::complex<double>* const* b, const int ldb,
                                   const std::complex<double> beta, std::complex<double>** c,
                                   const int ldc, const int batch_count, const magma_queue_t& queue) {
                                   const int ldc, const int batch_count, const magma_queue_t queue) {
  using util::castCudaComplex;
  magmablas_zgemm_batched(toMagmaTrans(transa), toMagmaTrans(transb), m, n, k,
                          *castCudaComplex(alpha), castCudaComplex(a), lda, castCudaComplex(b), ldb,
@@ -276,8 +276,8 @@ inline int get_getri_nb<std::complex<double>>(int n) {
  return magma_get_zgetri_nb(n);
}

}  // magma
}  // linalg
}  // dca
}  // namespace magma
}  // namespace linalg
}  // namespace dca

#endif  // DCA_LINALG_LAPACK_MAGMA_HPP
+16 −3
Original line number Diff line number Diff line
@@ -16,6 +16,7 @@
#include <complex>
#include <cstring>
#include "dca/linalg/device_type.hpp"
#include "cuda_stream.hpp"

#ifdef DCA_HAVE_CUDA
#include <cuda_runtime.h>
@@ -141,10 +142,22 @@ void memoryCopy(ScalarType* dest, int ld_dest, const ScalarType* src, int ld_src
  memoryCopyCpu(dest, ld_dest, src, ld_src, size);
}

// Synchronous 1D memory copy fallback.
template <typename ScalarType>
void memoryCopyAsync(ScalarType* dest, const ScalarType* src, size_t size,
                     const util::CudaStream& /*s*/) {
  memoryCopyCpu(dest, src, size);
}
template <typename ScalarType>
void memoryCopyAsync(ScalarType* dest, int ld_dest, const ScalarType* src, int ld_src,
                     std::pair<int, int> size, const util::CudaStream& /*s*/) {
  memoryCopyCpu(dest, ld_dest, src, ld_src, size);
}

#endif  // DCA_HAVE_CUDA

}  // util
}  // linalg
}  // dca
}  // namespace util
}  // namespace linalg
}  // namespace dca

#endif  // DCA_LINALG_UTIL_COPY_HPP
+5 −0
Original line number Diff line number Diff line
@@ -31,10 +31,15 @@ public:
  }

  CudaStream(const CudaStream& other) = delete;
  CudaStream& operator=(const CudaStream& other) = delete;

  CudaStream(CudaStream&& other) {
    std::swap(stream_, other.stream_);
  }
  CudaStream& operator=(CudaStream&& other) {
    std::swap(stream_, other.stream_);
    return *this;
  }

  void sync() const {
    checkRC(cudaStreamSynchronize(stream_));
+12 −10
Original line number Diff line number Diff line
@@ -19,6 +19,7 @@
#include "dca/linalg/lapack/magma.hpp"
#include "dca/linalg/util/allocators/vectors_typedefs.hpp"
#include "dca/linalg/util/cuda_event.hpp"
#include "dca/linalg/util/magma_queue.hpp"
#include "dca/linalg/vector.hpp"

namespace dca {
@@ -30,7 +31,7 @@ template <typename ScalarType>
class MagmaBatchedGemm {
public:
  // Creates a plan for a batched gemm.
  MagmaBatchedGemm(magma_queue_t queue);
  MagmaBatchedGemm(const linalg::util::MagmaQueue& queue);
  // Creates a plan for a batched gemm and allocates the memory for the arguments of `size`
  // multiplications.
  MagmaBatchedGemm(int size, magma_queue_t queue);
@@ -52,8 +53,8 @@ public:
  void synchronizeCopy();

private:
  magma_queue_t queue_;
  const cudaStream_t stream_;
  const linalg::util::MagmaQueue& queue_;
  const linalg::util::CudaStream& stream_;
  CudaEvent copied_;

  linalg::util::HostVector<const ScalarType*> a_ptr_, b_ptr_;
@@ -64,8 +65,8 @@ private:
};

template <typename ScalarType>
MagmaBatchedGemm<ScalarType>::MagmaBatchedGemm(magma_queue_t queue)
    : queue_(queue), stream_(magma_queue_get_cuda_stream(queue_)) {}
MagmaBatchedGemm<ScalarType>::MagmaBatchedGemm(const linalg::util::MagmaQueue& queue)
    : queue_(queue), stream_(queue_) {}

template <typename ScalarType>
MagmaBatchedGemm<ScalarType>::MagmaBatchedGemm(const int size, magma_queue_t queue)
@@ -99,6 +100,7 @@ void MagmaBatchedGemm<ScalarType>::execute(const char transa, const char transb,
                                           const int n, const int k, const ScalarType alpha,
                                           const ScalarType beta, const int lda, const int ldb,
                                           const int ldc) {
  // TODO: store in a buffer if the performance gain is necessary.
  a_ptr_dev_.setAsync(a_ptr_, stream_);
  b_ptr_dev_.setAsync(b_ptr_, stream_);
  c_ptr_dev_.setAsync(c_ptr_, stream_);
@@ -111,9 +113,9 @@ void MagmaBatchedGemm<ScalarType>::execute(const char transa, const char transb,
  assert(cudaPeekAtLastError() == cudaSuccess);
}

}  // util
}  // linalg
}  // dca
}  // namespace util
}  // namespace linalg
}  // namespace dca

#endif  // DCA_HAVE_CUDA
#endif  // DCA_LINALG_UTIL_MAGMA_BATCHED_GEMM_HPP
+37 −10
Original line number Diff line number Diff line
@@ -13,39 +13,66 @@
#define DCA_LINALG_UTIL_MAGMA_QUEUE_HPP
#ifdef DCA_HAVE_CUDA

#include <cublas_v2.h>
#include <cuda.h>
#include <magma.h>
#include <cusparse_v2.h>
#include <magma_v2.h>

#include "dca/linalg/util/cuda_stream.hpp"

namespace dca {
namespace linalg {
namespace util {
// dca::linalg::util::

class MagmaQueue {
class MagmaQueue : public linalg::util::CudaStream {
public:
  MagmaQueue() {
    magma_queue_create(&queue_);
    cublasCreate(&cublas_handle_);
    cusparseCreate(&cusparse_handle_);
    int device;
    cudaGetDevice(&device);
    magma_queue_create_from_cuda(device, *this, cublas_handle_, cusparse_handle_, &queue_);
  }

  MagmaQueue(const MagmaQueue& rhs) = delete;
  MagmaQueue& operator=(const MagmaQueue& rhs) = delete;

  MagmaQueue(MagmaQueue&& rhs) : CudaStream(std::move(rhs)) {
    swapMembers(rhs);
  }

  MagmaQueue& operator=(MagmaQueue&& rhs) {
    CudaStream::operator=(std::move(rhs));
    swapMembers(rhs);
    return *this;
  }

  ~MagmaQueue() {
    magma_queue_destroy(queue_);
    cublasDestroy(cublas_handle_);
    cusparseDestroy(cusparse_handle_);
  }

  inline operator magma_queue_t() {
  operator magma_queue_t() const {
    return queue_;
  }

  cudaStream_t getStream() const {
    return magma_queue_get_cuda_stream(queue_);
private:
  void swapMembers(MagmaQueue& rhs) {
    std::swap(cublas_handle_, rhs.cublas_handle_);
    std::swap(cusparse_handle_, rhs.cusparse_handle_);
    std::swap(queue_, rhs.queue_);
  }

private:
  magma_queue_t queue_ = nullptr;
  cublasHandle_t cublas_handle_ = nullptr;
  cusparseHandle_t cusparse_handle_ = nullptr;
};

}  // util
}  // linalg
}  // dca
}  // namespace util
}  // namespace linalg
}  // namespace dca

#endif  // DCA_HAVE_CUDA
#endif  // DCA_LINALG_UTIL_MAGMA_QUEUE_HPP
Loading