Unverified Commit e9cd4f67 authored by gbalduzz's avatar gbalduzz Committed by GitHub
Browse files

Merge branch 'master' into fix_arm_environment

parents 22ab2042 18a1de8b
Loading
Loading
Loading
Loading
+5 −2
Original line number Diff line number Diff line
@@ -65,8 +65,11 @@ inline cublasHandle_t getHandle(const int thread_id, const int stream_id) {

#else

// Implement SFINAE.
inline void resizeHandleContainer(int /*max_threads*/) {}

inline void resizeHandleContainer(const std::size_t max_threads) {
  if (getStreamContainer().get_max_threads() < max_threads)
    resizeStreamContainer(max_threads);
}

#endif  // DCA_HAVE_CUDA

+9 −8
Original line number Diff line number Diff line
@@ -13,6 +13,7 @@
#ifndef DCA_LINALG_UTIL_STREAM_CONTAINER_HPP
#define DCA_LINALG_UTIL_STREAM_CONTAINER_HPP

#include <array>
#include <cassert>
#include <functional>
#include <vector>
@@ -26,18 +27,18 @@ namespace util {

class StreamContainer {
public:
  StreamContainer(int max_threads = 0) : streams_(max_threads * streams_per_thread_) {}
  StreamContainer(std::size_t max_threads = 0) : streams_(max_threads) {}

  int get_max_threads() const {
    return streams_.size() / streams_per_thread_;
  std::size_t get_max_threads() const {
    return streams_.size();
  }

  int get_streams_per_thread() const {
  std::size_t get_streams_per_thread() const {
    return streams_per_thread_;
  }

  void resize(const int max_threads) {
    streams_.resize(max_threads * streams_per_thread_);
    streams_.resize(max_threads);
  }

  StreamContainer(const StreamContainer&) = delete;
@@ -49,7 +50,7 @@ public:
  CudaStream& operator()(int thread_id, int stream_id) {
    assert(thread_id >= 0 && thread_id < get_max_threads());
    assert(stream_id >= 0 && stream_id < streams_per_thread_);
    return streams_[stream_id + streams_per_thread_ * thread_id];
    return streams_[thread_id][stream_id];
  }

  // Synchronizes the 'stream_id'-th stream associated with thread 'thread_id'.
@@ -60,8 +61,8 @@ public:
  }

private:
  constexpr static int streams_per_thread_ = 2;
  std::vector<CudaStream> streams_;
  constexpr static std::size_t streams_per_thread_ = 2;
  std::vector<std::array<CudaStream, streams_per_thread_>> streams_;
};

}  // namespace util