Commit 1a2c0792 authored by gbalduzz's avatar gbalduzz
Browse files

fix move construct performance and conversion MagmaQueue->CudaStream

parent c9e8f792
Loading
Loading
Loading
Loading
+11 −10
Original line number Diff line number Diff line
@@ -32,19 +32,20 @@ public:
    cusparseCreate(&cusparse_handle_);
    int device;
    cudaGetDevice(&device);
    magma_queue_create_from_cuda(device, stream_, cublas_handle_,
                                 cusparse_handle_, &queue_);
    magma_queue_create_from_cuda(device, stream_, cublas_handle_, cusparse_handle_, &queue_);
  }

  MagmaQueue(const MagmaQueue& rhs) = delete;
  MagmaQueue& operator=(const MagmaQueue& rhs) = delete;

  MagmaQueue(MagmaQueue&& rhs) noexcept {
    swapMembers(rhs);
  MagmaQueue(MagmaQueue&& rhs) noexcept : queue_(std::move(rhs.queue_)) {
    std::swap(cublas_handle_, rhs.cublas_handle_);
    std::swap(cusparse_handle_, rhs.cusparse_handle_);
    std::swap(queue_, rhs.queue_);
  }

  MagmaQueue& operator=(MagmaQueue&& rhs) noexcept {
    swapMembers(rhs);
    swap(rhs);
    return *this;
  }

@@ -62,21 +63,21 @@ public:
  // take a MagmaQueue, this makes all this code less intelligible
  // but less verbose.  Consider this carefully.
  operator cudaStream_t() const {
    return stream_;
    return static_cast<cudaStream_t>(stream_);
  }

  const CudaStream& getStream() const {
    return stream_;
  }

private:
  void swapMembers(MagmaQueue& rhs) noexcept {
  void swap(MagmaQueue& rhs) noexcept {
    std::swap(stream_, rhs.stream_);
    std::swap(cublas_handle_, rhs.cublas_handle_);
    std::swap(cusparse_handle_, rhs.cusparse_handle_);
    std::swap(queue_, rhs.queue_);
  }

private:
  CudaStream stream_;
  magma_queue_t queue_ = nullptr;
  cublasHandle_t cublas_handle_ = nullptr;
+1 −1
Original line number Diff line number Diff line
@@ -100,7 +100,7 @@ private:
  MatrixConfiguration configuration_;
  int sign_ = 0;

  std::vector<linalg::util::CudaStream*> streams_;
  std::vector<const linalg::util::CudaStream*> streams_;
  linalg::util::CudaEvent event_;

  util::Accumulator<int> accumulated_sign_;
+2 −2
Original line number Diff line number Diff line
@@ -123,8 +123,8 @@ public:
    return 0;
  }

  linalg::util::CudaStream* get_stream() const {
    static dca::linalg::util::CudaStream mock_stream;
  const linalg::util::CudaStream* get_stream() const {
    static const dca::linalg::util::CudaStream mock_stream;
    return &mock_stream;
  }

+2 −2
Original line number Diff line number Diff line
@@ -87,8 +87,8 @@ public:
  // other_acc.
  void sumTo(this_type& other_acc);

  linalg::util::CudaStream* get_stream() {
    return &queues_[0];
  const linalg::util::CudaStream* get_stream() {
    return &queues_[0].getStream();
  }

  void synchronizeCopy() {