Commit ebb1019a authored by dpugmire's avatar dpugmire
Browse files

Add CheckRequests method.

parent af14f583
...@@ -119,44 +119,64 @@ void Messenger::PostRecv(int tag, std::size_t sz, int src) ...@@ -119,44 +119,64 @@ void Messenger::PostRecv(int tag, std::size_t sz, int src)
} }
void Messenger::CheckPendingSendRequests() void Messenger::CheckPendingSendRequests()
{
std::vector<RequestTagPair> reqTags;
this->CheckRequests(this->SendBuffers, {}, false, reqTags);
//Cleanup any send buffers that have completed.
for (auto&& rt : reqTags)
{
auto entry = this->SendBuffers.find(rt);
if (entry != this->SendBuffers.end())
{
delete[] entry->second;
this->SendBuffers.erase(entry);
}
}
}
void Messenger::CheckRequests(const std::map<RequestTagPair, char*>& buffers,
const std::set<int>& tagsToCheck,
bool BlockAndWait,
std::vector<RequestTagPair>& reqTags)
{ {
std::vector<MPI_Request> req, copy; std::vector<MPI_Request> req, copy;
std::vector<int> tags; std::vector<int> tags;
for (auto&& it : this->SendBuffers) reqTags.resize(0);
//Check the buffers for the specified tags.
for (auto&& it : buffers)
{ {
req.push_back(it.first.first); if (tagsToCheck.empty() || tagsToCheck.find(it.first.second) != tagsToCheck.end())
copy.push_back(it.first.first); {
tags.push_back(it.first.second); req.push_back(it.first.first);
copy.push_back(it.first.first);
tags.push_back(it.first.second);
}
} }
//Nothing..
if (req.empty()) if (req.empty())
return; return;
//See if any sends are done. //Check the outstanding requests.
int num = 0, *indices = new int[req.size()]; std::vector<MPI_Status> status(req.size());
MPI_Status* status = new MPI_Status[req.size()]; std::vector<int> indices(req.size());
int err = MPI_Testsome(req.size(), &req[0], &num, indices, status); int num = 0;
int err;
if (BlockAndWait)
err = MPI_Waitsome(req.size(), req.data(), &num, indices.data(), status.data());
else
err = MPI_Testsome(req.size(), req.data(), &num, indices.data(), status.data());
if (err != MPI_SUCCESS) if (err != MPI_SUCCESS)
throw vtkm::cont::ErrorFilterExecution( throw vtkm::cont::ErrorFilterExecution("Error with MPI_Testsome in Messenger::RecvData");
"Error iwth MPI_Testsome in Messenger::CheckPendingSendRequests");
for (int i = 0; i < num; i++)
{
MPI_Request r = copy[indices[i]];
int tag = tags[indices[i]];
RequestTagPair k(r, tag); //Add the req/tag to the return vector.
auto entry = this->SendBuffers.find(k); for (int i = 0; i < num; i++)
if (entry != this->SendBuffers.end()) reqTags.push_back(RequestTagPair(copy[indices[i]], tags[indices[i]]));
{
delete[] entry->second;
this->SendBuffers.erase(entry);
}
}
delete[] indices;
delete[] status;
} }
bool Messenger::PacketCompare(const char* a, const char* b) bool Messenger::PacketCompare(const char* a, const char* b)
...@@ -259,56 +279,36 @@ bool Messenger::RecvData(int tag, std::vector<vtkmdiy::MemoryBuffer>& buffers, b ...@@ -259,56 +279,36 @@ bool Messenger::RecvData(int tag, std::vector<vtkmdiy::MemoryBuffer>& buffers, b
return false; return false;
} }
bool Messenger::RecvData(std::set<int>& tags, bool Messenger::RecvData(const std::set<int>& tags,
std::vector<std::pair<int, vtkmdiy::MemoryBuffer>>& buffers, std::vector<std::pair<int, vtkmdiy::MemoryBuffer>>& buffers,
bool blockAndWait) bool blockAndWait)
{ {
buffers.resize(0); buffers.resize(0);
//Find all recv of type tag. std::vector<RequestTagPair> reqTags;
std::vector<MPI_Request> req, copy; this->CheckRequests(this->RecvBuffers, tags, blockAndWait, reqTags);
std::vector<int> reqTags;
for (const auto& i : this->RecvBuffers)
{
if (tags.find(i.first.second) != tags.end())
{
req.push_back(i.first.first);
copy.push_back(i.first.first);
reqTags.push_back(i.first.second);
}
}
if (req.empty())
return false;
std::vector<MPI_Status> status(req.size()); //Nothing came in.
std::vector<int> indices(req.size()); if (reqTags.empty())
int num = 0;
if (blockAndWait)
MPI_Waitsome(req.size(), req.data(), &num, indices.data(), status.data());
else
MPI_Testsome(req.size(), req.data(), &num, indices.data(), status.data());
if (num == 0)
return false; return false;
std::vector<char*> incomingBuffers(num); std::vector<char*> incomingBuffers;
for (int i = 0; i < num; i++) incomingBuffers.reserve(reqTags.size());
for (auto&& rt : reqTags)
{ {
RequestTagPair entry(copy[indices[i]], reqTags[indices[i]]); auto it = this->RecvBuffers.find(rt);
auto it = this->RecvBuffers.find(entry);
if (it == this->RecvBuffers.end()) if (it == this->RecvBuffers.end())
throw vtkm::cont::ErrorFilterExecution("receive buffer not found"); throw vtkm::cont::ErrorFilterExecution("receive buffer not found");
incomingBuffers[i] = it->second; incomingBuffers.push_back(it->second);
this->RecvBuffers.erase(it); this->RecvBuffers.erase(it);
} }
this->ProcessReceivedBuffers(incomingBuffers, buffers); this->ProcessReceivedBuffers(incomingBuffers, buffers);
for (int i = 0; i < num; i++) //Re-post receives
PostRecv(reqTags[indices[i]]); for (auto&& rt : reqTags)
this->PostRecv(rt.second);
return !buffers.empty(); return !buffers.empty();
} }
......
...@@ -48,16 +48,24 @@ public: ...@@ -48,16 +48,24 @@ public:
VTKM_CONT void RegisterTag(int tag, std::size_t numRecvs, std::size_t size); VTKM_CONT void RegisterTag(int tag, std::size_t numRecvs, std::size_t size);
protected: protected:
static std::size_t CalcMessageBufferSize(std::size_t msgSz);
int GetRank() const { return this->Rank; }
int GetNumRanks() const { return this->NumRanks; }
void InitializeBuffers(); void InitializeBuffers();
void CleanupRequests(int tag = TAG_ANY);
void CheckPendingSendRequests(); void CheckPendingSendRequests();
void PostRecv(int tag); void CleanupRequests(int tag = TAG_ANY);
void PostRecv(int tag, std::size_t sz, int src = -1);
void SendData(int dst, int tag, const vtkmdiy::MemoryBuffer& buff); void SendData(int dst, int tag, const vtkmdiy::MemoryBuffer& buff);
bool RecvData(std::set<int>& tags, bool RecvData(const std::set<int>& tags,
std::vector<std::pair<int, vtkmdiy::MemoryBuffer>>& buffers, std::vector<std::pair<int, vtkmdiy::MemoryBuffer>>& buffers,
bool blockAndWait = false); bool blockAndWait = false);
private:
void PostRecv(int tag);
void PostRecv(int tag, std::size_t sz, int src = -1);
//Message headers. //Message headers.
typedef struct typedef struct
{ {
...@@ -87,12 +95,16 @@ protected: ...@@ -87,12 +95,16 @@ protected:
std::map<RankIdPair, std::list<char*>> RecvPackets; std::map<RankIdPair, std::list<char*>> RecvPackets;
std::map<RequestTagPair, char*> SendBuffers; std::map<RequestTagPair, char*> SendBuffers;
static constexpr int TAG_ANY = -1; static constexpr int TAG_ANY = -1;
void CheckRequests(const std::map<RequestTagPair, char*>& buffer,
const std::set<int>& tags,
bool BlockAndWait,
std::vector<RequestTagPair>& reqTags);
#else #else
protected:
static constexpr int NumRanks = 1; static constexpr int NumRanks = 1;
static constexpr int Rank = 0; static constexpr int Rank = 0;
#endif #endif
static std::size_t CalcMessageBufferSize(std::size_t msgSz);
}; };
} }
} }
......
...@@ -106,7 +106,7 @@ void ParticleMessenger::Exchange( ...@@ -106,7 +106,7 @@ void ParticleMessenger::Exchange(
numTerminateMessages = 0; numTerminateMessages = 0;
inDataBlockIDsMap.clear(); inDataBlockIDsMap.clear();
if (this->NumRanks == 1) if (this->GetNumRanks() == 1)
return this->SerialExchange( return this->SerialExchange(
outData, outBlockIDsMap, numLocalTerm, inData, inDataBlockIDsMap, blockAndWait); outData, outBlockIDsMap, numLocalTerm, inData, inDataBlockIDsMap, blockAndWait);
...@@ -160,7 +160,7 @@ void ParticleMessenger::RegisterMessages(int msgSz, int nParticles, int numBlock ...@@ -160,7 +160,7 @@ void ParticleMessenger::RegisterMessages(int msgSz, int nParticles, int numBlock
std::size_t messageBuffSz = CalcMessageBufferSize(msgSz + 1); std::size_t messageBuffSz = CalcMessageBufferSize(msgSz + 1);
std::size_t particleBuffSz = CalcParticleBufferSize(nParticles, numBlockIds); std::size_t particleBuffSz = CalcParticleBufferSize(nParticles, numBlockIds);
int numRecvs = std::min(64, this->NumRanks - 1); int numRecvs = std::min(64, this->GetNumRanks() - 1);
this->RegisterTag(ParticleMessenger::MESSAGE_TAG, numRecvs, messageBuffSz); this->RegisterTag(ParticleMessenger::MESSAGE_TAG, numRecvs, messageBuffSz);
this->RegisterTag(ParticleMessenger::PARTICLE_TAG, numRecvs, particleBuffSz); this->RegisterTag(ParticleMessenger::PARTICLE_TAG, numRecvs, particleBuffSz);
...@@ -174,7 +174,7 @@ void ParticleMessenger::SendMsg(int dst, const std::vector<int>& msg) ...@@ -174,7 +174,7 @@ void ParticleMessenger::SendMsg(int dst, const std::vector<int>& msg)
vtkmdiy::MemoryBuffer buff; vtkmdiy::MemoryBuffer buff;
//Write data. //Write data.
vtkmdiy::save(buff, this->Rank); vtkmdiy::save(buff, this->GetRank());
vtkmdiy::save(buff, msg); vtkmdiy::save(buff, msg);
this->SendData(dst, ParticleMessenger::MESSAGE_TAG, buff); this->SendData(dst, ParticleMessenger::MESSAGE_TAG, buff);
} }
...@@ -182,8 +182,8 @@ void ParticleMessenger::SendMsg(int dst, const std::vector<int>& msg) ...@@ -182,8 +182,8 @@ void ParticleMessenger::SendMsg(int dst, const std::vector<int>& msg)
VTKM_CONT VTKM_CONT
void ParticleMessenger::SendAllMsg(const std::vector<int>& msg) void ParticleMessenger::SendAllMsg(const std::vector<int>& msg)
{ {
for (int i = 0; i < this->NumRanks; i++) for (int i = 0; i < this->GetNumRanks(); i++)
if (i != this->Rank) if (i != this->GetRank())
this->SendMsg(i, msg); this->SendMsg(i, msg);
} }
......
...@@ -114,7 +114,7 @@ VTKM_CONT ...@@ -114,7 +114,7 @@ VTKM_CONT
template <typename P, template <typename, typename> class Container, typename Allocator> template <typename P, template <typename, typename> class Container, typename Allocator>
inline void ParticleMessenger::SendParticles(int dst, const Container<P, Allocator>& c) inline void ParticleMessenger::SendParticles(int dst, const Container<P, Allocator>& c)
{ {
if (dst == this->Rank) if (dst == this->GetRank())
{ {
VTKM_LOG_S(vtkm::cont::LogLevel::Error, "Error. Sending a particle to yourself."); VTKM_LOG_S(vtkm::cont::LogLevel::Error, "Error. Sending a particle to yourself.");
return; return;
...@@ -123,7 +123,7 @@ inline void ParticleMessenger::SendParticles(int dst, const Container<P, Allocat ...@@ -123,7 +123,7 @@ inline void ParticleMessenger::SendParticles(int dst, const Container<P, Allocat
return; return;
vtkmdiy::MemoryBuffer bb; vtkmdiy::MemoryBuffer bb;
vtkmdiy::save(bb, this->Rank); vtkmdiy::save(bb, this->GetRank());
vtkmdiy::save(bb, c); vtkmdiy::save(bb, c);
this->SendData(dst, ParticleMessenger::PARTICLE_TAG, bb); this->SendData(dst, ParticleMessenger::PARTICLE_TAG, bb);
} }
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment