Commit 71f614a2 authored by gbalduzz's avatar gbalduzz
Browse files

Limit the size of messages in MPI sum.

parent fd4457ff
Loading
Loading
Loading
Loading
+28 −18
Original line number Diff line number Diff line
@@ -16,6 +16,7 @@
#ifndef DCA_PARALLEL_MPI_CONCURRENCY_MPI_COLLECTIVE_SUM_HPP
#define DCA_PARALLEL_MPI_CONCURRENCY_MPI_COLLECTIVE_SUM_HPP

#include <algorithm>  // std::min
#include <map>
#include <string>
#include <utility>  // std::move, std::swap
@@ -138,6 +139,10 @@ public:
  template <typename Scalar, class Domain>
  std::vector<Scalar> avgNormalizedMomenta(const func::function<Scalar, Domain>& f,
                                           const std::vector<int>& orders) const;

private:
  template <typename T>
  void sum(const T* in, T* out, std::size_t n) const;
};

template <typename scalar_type>
@@ -152,10 +157,9 @@ void MPICollectiveSum::sum(scalar_type& value) const {

template <typename scalar_type>
void MPICollectiveSum::sum(std::vector<scalar_type>& m) const {
  std::vector<scalar_type> result(m.size(), scalar_type(0));
  std::vector<scalar_type> result(m.size());

  MPI_Allreduce(&(m[0]), &(result[0]), MPITypeMap<scalar_type>::factor() * m.size(),
                MPITypeMap<scalar_type>::value(), MPI_SUM, MPIProcessorGrouping::get());
  sum(m.data(), result.data(), m.size());

  m = std::move(result);
}
@@ -183,8 +187,7 @@ template <typename scalar_type, class domain>
void MPICollectiveSum::sum(func::function<scalar_type, domain>& f) const {
  func::function<scalar_type, domain> f_sum;

  MPI_Allreduce(f.values(), f_sum.values(), MPITypeMap<scalar_type>::factor() * f.size(),
                MPITypeMap<scalar_type>::value(), MPI_SUM, MPIProcessorGrouping::get());
  sum(f, f_sum);

  f = std::move(f_sum);

@@ -200,8 +203,7 @@ void MPICollectiveSum::sum(func::function<scalar_type, domain>& f) const {
template <typename scalar_type, class domain>
void MPICollectiveSum::sum(const func::function<scalar_type, domain>& f_in,
                           func::function<scalar_type, domain>& f_out) const {
  MPI_Allreduce(f_in.values(), f_out.values(), MPITypeMap<scalar_type>::factor() * f_in.size(),
                MPITypeMap<scalar_type>::value(), MPI_SUM, MPIProcessorGrouping::get());
  sum(f_in.values(), f_out.values(), f_in.size());
}

template <typename scalar_type, class domain>
@@ -226,10 +228,9 @@ template <typename scalar_type>
void MPICollectiveSum::sum(linalg::Vector<scalar_type, linalg::CPU>& vec) const {
  linalg::Vector<scalar_type, linalg::CPU> vec_sum("vec_sum", vec.size());

  MPI_Allreduce(&vec[0], &vec_sum[0], MPITypeMap<scalar_type>::factor() * vec.size(),
                MPITypeMap<scalar_type>::value(), MPI_SUM, MPIProcessorGrouping::get());
  sum(vec.ptr(), vec_sum.ptr(), vec.size());

  vec = vec_sum;
  vec = std::move(vec_sum);

#ifndef NDEBUG
  for (int i = 0; i < vec.size(); ++i)
@@ -248,12 +249,9 @@ void MPICollectiveSum::sum(linalg::Matrix<scalar_type, linalg::CPU>& f) const {
  int Nr = f.capacity().first;
  int Nc = f.capacity().second;

  MPI_Allreduce(&f(0, 0), &F(0, 0), MPITypeMap<scalar_type>::factor() * Nr * Nc,
                MPITypeMap<scalar_type>::value(), MPI_SUM, MPIProcessorGrouping::get());
  sum(f.ptr(), F.ptr(), Nr * Nc);

  for (int j = 0; j < F.size().second; j++)
    for (int i = 0; i < F.size().first; i++)
      f(i, j) = F(i, j);
  f = std::move(F);
}

template <typename some_type>
@@ -508,7 +506,19 @@ std::vector<Scalar> MPICollectiveSum::avgNormalizedMomenta(const func::function<
  return momenta_avg;
}

}  // parallel
}  // dca
template <typename T>
void MPICollectiveSum::sum(const T* in, T* out, std::size_t n) const {
  // On summit large messages hangs even if the size is lower than 2^31-1.
  constexpr std::size_t max_size = std::numeric_limits<int>::max() / 10;

  for (std::size_t start = 0; start < n; start += max_size) {
    const int msg_size = std::min(n - start, max_size);
    MPI_Allreduce(in + start, out + start, MPITypeMap<T>::factor() * msg_size,
                  MPITypeMap<T>::value(), MPI_SUM, MPIProcessorGrouping::get());
  }
}

}  // namespace parallel
}  // namespace dca

#endif  // DCA_PARALLEL_MPI_CONCURRENCY_MPI_COLLECTIVE_SUM_HPP