Commit 65abd5bd authored by gbalduzz's avatar gbalduzz
Browse files

cleaned up MaqmaQueue.

parent 399c6736
Loading
Loading
Loading
Loading
+7 −3
Original line number Diff line number Diff line
@@ -34,10 +34,10 @@ public:
  CudaStream& operator=(const CudaStream& other) = delete;

  CudaStream(CudaStream&& other) {
    std::swap(stream_, other.stream_);
    swap(other);
  }
  CudaStream& operator=(CudaStream&& other) {
    std::swap(stream_, other.stream_);
    swap(other);
    return *this;
  }

@@ -45,7 +45,7 @@ public:
    checkRC(cudaStreamSynchronize(stream_));
  }

  ~CudaStream() {
  virtual ~CudaStream() {
    if (stream_)
      cudaStreamDestroy(stream_);
  }
@@ -54,6 +54,10 @@ public:
    return stream_;
  }

  void swap(CudaStream& other) {
    std::swap(stream_, other.stream_);
  }

private:
  cudaStream_t stream_ = nullptr;
};
+5 −6
Original line number Diff line number Diff line
@@ -54,7 +54,6 @@ public:

private:
  const linalg::util::MagmaQueue& queue_;
  const linalg::util::CudaStream& stream_;
  CudaEvent copied_;

  linalg::util::HostVector<const ScalarType*> a_ptr_, b_ptr_;
@@ -66,7 +65,7 @@ private:

template <typename ScalarType>
MagmaBatchedGemm<ScalarType>::MagmaBatchedGemm(const linalg::util::MagmaQueue& queue)
    : queue_(queue), stream_(queue_) {}
    : queue_(queue) {}

template <typename ScalarType>
MagmaBatchedGemm<ScalarType>::MagmaBatchedGemm(const int size, magma_queue_t queue)
@@ -101,10 +100,10 @@ void MagmaBatchedGemm<ScalarType>::execute(const char transa, const char transb,
                                           const ScalarType beta, const int lda, const int ldb,
                                           const int ldc) {
  // TODO: store in a buffer if the performance gain is necessary.
  a_ptr_dev_.setAsync(a_ptr_, stream_);
  b_ptr_dev_.setAsync(b_ptr_, stream_);
  c_ptr_dev_.setAsync(c_ptr_, stream_);
  copied_.record(stream_);
  a_ptr_dev_.setAsync(a_ptr_, queue_);
  b_ptr_dev_.setAsync(b_ptr_, queue_);
  c_ptr_dev_.setAsync(c_ptr_, queue_);
  copied_.record(queue_);

  const int n_batched = a_ptr_.size();
  magma::magmablas_gemm_batched(transa, transb, m, n, k, alpha, a_ptr_dev_.ptr(), lda,
+12 −11
Original line number Diff line number Diff line
@@ -25,26 +25,26 @@ namespace linalg {
namespace util {
// dca::linalg::util::

class MagmaQueue : public linalg::util::CudaStream {
class MagmaQueue : public CudaStream {
public:
  MagmaQueue() {
    cublasCreate(&cublas_handle_);
    cusparseCreate(&cusparse_handle_);
    int device;
    cudaGetDevice(&device);
    magma_queue_create_from_cuda(device, *this, cublas_handle_, cusparse_handle_, &queue_);
    magma_queue_create_from_cuda(device, static_cast<cudaStream_t>(*this), cublas_handle_,
                                 cusparse_handle_, &queue_);
  }

  MagmaQueue(const MagmaQueue& rhs) = delete;
  MagmaQueue& operator=(const MagmaQueue& rhs) = delete;

  MagmaQueue(MagmaQueue&& rhs) : CudaStream(std::move(rhs)) {
    swapMembers(rhs);
  MagmaQueue(MagmaQueue&& rhs) {
    swap(rhs);
  }

  MagmaQueue& operator=(MagmaQueue&& rhs) {
    CudaStream::operator=(std::move(rhs));
    swapMembers(rhs);
    swap(rhs);
    return *this;
  }

@@ -58,13 +58,14 @@ public:
    return queue_;
  }

private:
  void swapMembers(MagmaQueue& rhs) {
    std::swap(cublas_handle_, rhs.cublas_handle_);
    std::swap(cusparse_handle_, rhs.cusparse_handle_);
    std::swap(queue_, rhs.queue_);
  void swap(MagmaQueue& other) {
    static_cast<CudaStream&>(*this).swap(static_cast<CudaStream&>(other));
    std::swap(cublas_handle_, other.cublas_handle_);
    std::swap(cusparse_handle_, other.cusparse_handle_);
    std::swap(queue_, other.queue_);
  }

private:
  magma_queue_t queue_ = nullptr;
  cublasHandle_t cublas_handle_ = nullptr;
  cusparseHandle_t cusparse_handle_ = nullptr;