diff --git a/Framework/Parallel/inc/MantidParallel/Collectives.h b/Framework/Parallel/inc/MantidParallel/Collectives.h index 05f3d1f5f073417f1dfbdbeef5c762936b58a829..e50f1ddfebd1906e2f610df926c6652a2e4b83a7 100644 --- a/Framework/Parallel/inc/MantidParallel/Collectives.h +++ b/Framework/Parallel/inc/MantidParallel/Collectives.h @@ -70,6 +70,12 @@ void gather(const Communicator &comm, const T &in_value, int root) { } } +template <typename... T> +void all_gather(const Communicator &comm, T &&... args) { + for (int root = 0; root < comm.size(); ++root) + gather(comm, std::forward<T>(args)..., root); +} + template <typename T> void all_to_all(const Communicator &comm, const std::vector<T> &in_values, std::vector<T> &out_values) { @@ -92,6 +98,15 @@ template <typename... T> void gather(const Communicator &comm, T &&... args) { detail::gather(comm, std::forward<T>(args)...); } +template <typename... T> +void all_gather(const Communicator &comm, T &&... args) { +#ifdef MPI_EXPERIMENTAL + if (!comm.hasBackend()) + return boost::mpi::all_gather(comm, std::forward<T>(args)...); +#endif + detail::all_gather(comm, std::forward<T>(args)...); +} + template <typename... T> void all_to_all(const Communicator &comm, T &&... args) { #ifdef MPI_EXPERIMENTAL diff --git a/Framework/Parallel/test/CollectivesTest.h b/Framework/Parallel/test/CollectivesTest.h index 3bbe4c061b35a7b9ee45e6bd6dd94bcad3f4ca82..a7333d1d400c4cfe1eca95eb0db000917ff541f6 100644 --- a/Framework/Parallel/test/CollectivesTest.h +++ b/Framework/Parallel/test/CollectivesTest.h @@ -40,6 +40,16 @@ void run_gather_short_version(const Communicator &comm) { } } +void run_all_gather(const Communicator &comm) { + int value = 123 * comm.rank(); + std::vector<int> result; + TS_ASSERT_THROWS_NOTHING(Parallel::all_gather(comm, value, result)); + TS_ASSERT_EQUALS(result.size(), comm.size()); + for (int i = 0; i < comm.size(); ++i) { + TS_ASSERT_EQUALS(result[i], 123 * i); + } +} + void run_all_to_all(const Communicator &comm) { std::vector<int> data; for (int rank = 0; rank < comm.size(); ++rank) @@ -66,6 +76,8 @@ public: ParallelTestHelpers::runParallel(run_gather_short_version); } + void test_all_gather() { ParallelTestHelpers::runParallel(run_all_gather); } + void test_all_to_all() { ParallelTestHelpers::runParallel(run_all_to_all); } };