Commit 14423c48 authored by gbalduzz's avatar gbalduzz
Browse files

fixup! better CPU and GPU code unification with stream wrappers. Better move...

fixup! better CPU and GPU code unification with stream wrappers. Better move assignment and constructor for stream and queue.
parent 30bb897e
Loading
Loading
Loading
Loading
+5 −25
Original line number Diff line number Diff line
@@ -199,22 +199,15 @@ public:
  template <DeviceType rhs_device_name>
  void set(const Matrix<ScalarType, rhs_device_name>& rhs, int thread_id, int stream_id);

#ifdef DCA_HAVE_CUDA
  // Asynchronous assignment.
  template <DeviceType rhs_device_name>
  void setAsync(const Matrix<ScalarType, rhs_device_name>& rhs, cudaStream_t stream);
  void setAsync(const Matrix<ScalarType, rhs_device_name>& rhs, const util::CudaStream& stream);

  // Asynchronous assignment (copy with stream = getStream(thread_id, stream_id))
  template <DeviceType rhs_device_name>
  void setAsync(const Matrix<ScalarType, rhs_device_name>& rhs, int thread_id, int stream_id);

  void setToZero(cudaStream_t stream);
#else
  // Synchronous assignment fallback for SetAsync.
  template <DeviceType rhs_device_name>
  void setAsync(const Matrix<ScalarType, rhs_device_name>& rhs, int thread_id, int stream_id);

#endif  // DCA_HAVE_CUDA
  void setToZero(const util::CudaStream& stream);

  // Prints the values of the matrix elements.
  void print() const;
@@ -427,12 +420,10 @@ void Matrix<ScalarType, device_name>::set(const Matrix<ScalarType, rhs_device_na
                   stream_id);
}

#ifdef DCA_HAVE_CUDA

template <typename ScalarType, DeviceType device_name>
template <DeviceType rhs_device_name>
void Matrix<ScalarType, device_name>::setAsync(const Matrix<ScalarType, rhs_device_name>& rhs,
                                               const cudaStream_t stream) {
                                               const util::CudaStream& stream) {
  resizeNoCopy(rhs.size_);
  util::memoryCopyAsync(data_, leadingDimension(), rhs.data_, rhs.leadingDimension(), size_, stream);
}
@@ -445,21 +436,10 @@ void Matrix<ScalarType, device_name>::setAsync(const Matrix<ScalarType, rhs_devi
}

template <typename ScalarType, DeviceType device_name>
void Matrix<ScalarType, device_name>::setToZero(cudaStream_t stream) {
  cudaMemsetAsync(data_, 0, leadingDimension() * nrCols() * sizeof(ScalarType), stream);
void Matrix<ScalarType, device_name>::setToZero(const util::CudaStream& stream) {
  util::Memory<device_name>::setToZeroAsync(data_, leadingDimension() * nrCols(), stream);
}

#else

template <typename ScalarType, DeviceType device_name>
template <DeviceType rhs_device_name>
void Matrix<ScalarType, device_name>::setAsync(const Matrix<ScalarType, rhs_device_name>& rhs,
                                               int /*thread_id*/, int /*stream_id*/) {
  set(rhs);
}

#endif  // DCA_HAVE_CUDA

template <typename ScalarType, DeviceType device_name>
void Matrix<ScalarType, device_name>::print() const {
  if (device_name == GPU)
+3 −5
Original line number Diff line number Diff line
@@ -220,14 +220,12 @@ TEST(MatrixCPUGPUTest, SetAsync) {

  auto el_value = [](int i, int j) { return 3 * i - 2 * j; };
  testing::setMatrixElements(mat, el_value);
  cudaStream_t stream;
  cudaStreamCreate(&stream);

  dca::linalg::util::CudaStream stream;

  mat_copy.setAsync(mat, stream);
  mat_copy_copy.setAsync(mat_copy, stream);
  cudaStreamSynchronize(stream);
  stream.sync();

  EXPECT_EQ(mat, mat_copy_copy);

  cudaStreamDestroy(stream);
}
+2 −5
Original line number Diff line number Diff line
@@ -664,16 +664,13 @@ TEST(MatrixGPUTest, setToZero) {
  auto func = [](int i, int j) { return 10 * i - j; };
  testing::setMatrixElements(mat, func);

  cudaStream_t stream;
  cudaStreamCreate(&stream);
  dca::linalg::util::CudaStream stream;
  mat.setToZero(stream);
  cudaStreamSynchronize(stream);
  stream.sync();

  dca::linalg::Matrix<long, dca::linalg::CPU> mat_copy(mat);

  for (int j = 0; j < mat_copy.nrCols(); ++j)
    for (int i = 0; i < mat_copy.nrRows(); ++i)
      EXPECT_EQ(0, mat_copy(i, j));

  cudaStreamDestroy(stream);
}