Commit e4a5f7e1 authored by Doak, Peter W.'s avatar Doak, Peter W.
Browse files

changes to remove unecessary inheritance

parent 77c72a0a
Loading
Loading
Loading
Loading
+16 −4
Original line number Diff line number Diff line
@@ -25,26 +25,25 @@ namespace linalg {
namespace util {
// dca::linalg::util::

class MagmaQueue : public CudaStream {
class MagmaQueue {
public:
  MagmaQueue() {
    cublasCreate(&cublas_handle_);
    cusparseCreate(&cusparse_handle_);
    int device;
    cudaGetDevice(&device);
    magma_queue_create_from_cuda(device, static_cast<cudaStream_t>(*this), cublas_handle_,
    magma_queue_create_from_cuda(device, stream_, cublas_handle_,
                                 cusparse_handle_, &queue_);
  }

  MagmaQueue(const MagmaQueue& rhs) = delete;
  MagmaQueue& operator=(const MagmaQueue& rhs) = delete;

  MagmaQueue(MagmaQueue&& rhs) noexcept : CudaStream(std::move(rhs)) {
  MagmaQueue(MagmaQueue&& rhs) noexcept {
    swapMembers(rhs);
  }

  MagmaQueue& operator=(MagmaQueue&& rhs) noexcept {
    CudaStream::operator=(std::move(rhs));
    swapMembers(rhs);
    return *this;
  }
@@ -59,13 +58,26 @@ public:
    return queue_;
  }

  // Allows a large number of calls that previously took a stream
  // take a MagmaQueue, this makes all this code less intelligible
  // but less verbose.  Consider this carefully.
  operator cudaStream_t() const {
    return stream_;
  }
  
  const CudaStream& getStream() const {
    return stream_;
  }
  
private:
  void swapMembers(MagmaQueue& rhs) noexcept {
    std::swap(stream_, rhs.stream_);
    std::swap(cublas_handle_, rhs.cublas_handle_);
    std::swap(cusparse_handle_, rhs.cusparse_handle_);
    std::swap(queue_, rhs.queue_);
  }

  CudaStream stream_;
  magma_queue_t queue_ = nullptr;
  cublasHandle_t cublas_handle_ = nullptr;
  cusparseHandle_t cusparse_handle_ = nullptr;
+2 −4
Original line number Diff line number Diff line
@@ -63,7 +63,7 @@ public:

  // Returns: the stream associated with the magma queue.
  cudaStream_t get_stream() const {
    return stream_;
    return queue_.getStream();
  }

  std::size_t deviceFingerprint() const {
@@ -87,7 +87,6 @@ private:
  const int nc_;

  const linalg::util::MagmaQueue& queue_;
  const linalg::util::CudaStream& stream_;

  std::shared_ptr<RMatrix> workspace_;

@@ -102,7 +101,6 @@ SpaceTransform2DGpu<RDmn, KDmn, Real>::SpaceTransform2DGpu(const int nw_pos,
      nw_(2 * nw_pos),
      nc_(RDmn::dmn_size()),
      queue_(queue),
      stream_(queue),
      plan1_(queue_),
      plan2_(queue_) {
  workspace_ = std::make_shared<RMatrix>();
@@ -159,7 +157,7 @@ void SpaceTransform2DGpu<RDmn, KDmn, Real>::phaseFactorsAndRearrange(const RMatr
      BaseClass::hasPhaseFactors() ? getPhaseFactors().ptr() : nullptr;
  details::phaseFactorsAndRearrange(in.ptr(), in.leadingDimension(), out.ptr(),
                                    out.leadingDimension(), n_bands_, nc_, nw_, phase_factors_ptr,
                                    stream_);
                                    queue_);
}

template <class RDmn, class KDmn, typename Real>
+1 −1
Original line number Diff line number Diff line
@@ -107,7 +107,7 @@ TYPED_TEST(SpaceTransform2DGpuTest, Execute) {
  dca::math::transform::SpaceTransform2DGpu<RDmn, KDmn, Real> transform_obj(nw, queue);
  transform_obj.execute(M_dev);

  queue.sync();
  queue.getStream().sync();

  constexpr Real tolerance = std::numeric_limits<Real>::epsilon() * 500;