Commit ebb1019a authored by dpugmire's avatar dpugmire
Browse files

Add CheckRequests method.

parent af14f583
......@@ -119,44 +119,64 @@ void Messenger::PostRecv(int tag, std::size_t sz, int src)
}
void Messenger::CheckPendingSendRequests()
{
std::vector<RequestTagPair> reqTags;
this->CheckRequests(this->SendBuffers, {}, false, reqTags);
//Cleanup any send buffers that have completed.
for (auto&& rt : reqTags)
{
auto entry = this->SendBuffers.find(rt);
if (entry != this->SendBuffers.end())
{
delete[] entry->second;
this->SendBuffers.erase(entry);
}
}
}
void Messenger::CheckRequests(const std::map<RequestTagPair, char*>& buffers,
const std::set<int>& tagsToCheck,
bool BlockAndWait,
std::vector<RequestTagPair>& reqTags)
{
std::vector<MPI_Request> req, copy;
std::vector<int> tags;
for (auto&& it : this->SendBuffers)
reqTags.resize(0);
//Check the buffers for the specified tags.
for (auto&& it : buffers)
{
if (tagsToCheck.empty() || tagsToCheck.find(it.first.second) != tagsToCheck.end())
{
req.push_back(it.first.first);
copy.push_back(it.first.first);
tags.push_back(it.first.second);
}
}
//Nothing..
if (req.empty())
return;
//See if any sends are done.
int num = 0, *indices = new int[req.size()];
MPI_Status* status = new MPI_Status[req.size()];
int err = MPI_Testsome(req.size(), &req[0], &num, indices, status);
//Check the outstanding requests.
std::vector<MPI_Status> status(req.size());
std::vector<int> indices(req.size());
int num = 0;
int err;
if (BlockAndWait)
err = MPI_Waitsome(req.size(), req.data(), &num, indices.data(), status.data());
else
err = MPI_Testsome(req.size(), req.data(), &num, indices.data(), status.data());
if (err != MPI_SUCCESS)
throw vtkm::cont::ErrorFilterExecution(
"Error iwth MPI_Testsome in Messenger::CheckPendingSendRequests");
throw vtkm::cont::ErrorFilterExecution("Error with MPI_Testsome in Messenger::RecvData");
for (int i = 0; i < num; i++)
{
MPI_Request r = copy[indices[i]];
int tag = tags[indices[i]];
RequestTagPair k(r, tag);
auto entry = this->SendBuffers.find(k);
if (entry != this->SendBuffers.end())
{
delete[] entry->second;
this->SendBuffers.erase(entry);
}
}
delete[] indices;
delete[] status;
//Add the req/tag to the return vector.
for (int i = 0; i < num; i++)
reqTags.push_back(RequestTagPair(copy[indices[i]], tags[indices[i]]));
}
bool Messenger::PacketCompare(const char* a, const char* b)
......@@ -259,56 +279,36 @@ bool Messenger::RecvData(int tag, std::vector<vtkmdiy::MemoryBuffer>& buffers, b
return false;
}
bool Messenger::RecvData(std::set<int>& tags,
bool Messenger::RecvData(const std::set<int>& tags,
std::vector<std::pair<int, vtkmdiy::MemoryBuffer>>& buffers,
bool blockAndWait)
{
buffers.resize(0);
//Find all recv of type tag.
std::vector<MPI_Request> req, copy;
std::vector<int> reqTags;
for (const auto& i : this->RecvBuffers)
{
if (tags.find(i.first.second) != tags.end())
{
req.push_back(i.first.first);
copy.push_back(i.first.first);
reqTags.push_back(i.first.second);
}
}
if (req.empty())
return false;
std::vector<RequestTagPair> reqTags;
this->CheckRequests(this->RecvBuffers, tags, blockAndWait, reqTags);
std::vector<MPI_Status> status(req.size());
std::vector<int> indices(req.size());
int num = 0;
if (blockAndWait)
MPI_Waitsome(req.size(), req.data(), &num, indices.data(), status.data());
else
MPI_Testsome(req.size(), req.data(), &num, indices.data(), status.data());
if (num == 0)
//Nothing came in.
if (reqTags.empty())
return false;
std::vector<char*> incomingBuffers(num);
for (int i = 0; i < num; i++)
std::vector<char*> incomingBuffers;
incomingBuffers.reserve(reqTags.size());
for (auto&& rt : reqTags)
{
RequestTagPair entry(copy[indices[i]], reqTags[indices[i]]);
auto it = this->RecvBuffers.find(entry);
auto it = this->RecvBuffers.find(rt);
if (it == this->RecvBuffers.end())
throw vtkm::cont::ErrorFilterExecution("receive buffer not found");
incomingBuffers[i] = it->second;
incomingBuffers.push_back(it->second);
this->RecvBuffers.erase(it);
}
this->ProcessReceivedBuffers(incomingBuffers, buffers);
for (int i = 0; i < num; i++)
PostRecv(reqTags[indices[i]]);
//Re-post receives
for (auto&& rt : reqTags)
this->PostRecv(rt.second);
return !buffers.empty();
}
......
......@@ -48,16 +48,24 @@ public:
VTKM_CONT void RegisterTag(int tag, std::size_t numRecvs, std::size_t size);
protected:
static std::size_t CalcMessageBufferSize(std::size_t msgSz);
int GetRank() const { return this->Rank; }
int GetNumRanks() const { return this->NumRanks; }
void InitializeBuffers();
void CleanupRequests(int tag = TAG_ANY);
void CheckPendingSendRequests();
void PostRecv(int tag);
void PostRecv(int tag, std::size_t sz, int src = -1);
void CleanupRequests(int tag = TAG_ANY);
void SendData(int dst, int tag, const vtkmdiy::MemoryBuffer& buff);
bool RecvData(std::set<int>& tags,
bool RecvData(const std::set<int>& tags,
std::vector<std::pair<int, vtkmdiy::MemoryBuffer>>& buffers,
bool blockAndWait = false);
private:
void PostRecv(int tag);
void PostRecv(int tag, std::size_t sz, int src = -1);
//Message headers.
typedef struct
{
......@@ -87,12 +95,16 @@ protected:
std::map<RankIdPair, std::list<char*>> RecvPackets;
std::map<RequestTagPair, char*> SendBuffers;
static constexpr int TAG_ANY = -1;
void CheckRequests(const std::map<RequestTagPair, char*>& buffer,
const std::set<int>& tags,
bool BlockAndWait,
std::vector<RequestTagPair>& reqTags);
#else
protected:
static constexpr int NumRanks = 1;
static constexpr int Rank = 0;
#endif
static std::size_t CalcMessageBufferSize(std::size_t msgSz);
};
}
}
......
......@@ -106,7 +106,7 @@ void ParticleMessenger::Exchange(
numTerminateMessages = 0;
inDataBlockIDsMap.clear();
if (this->NumRanks == 1)
if (this->GetNumRanks() == 1)
return this->SerialExchange(
outData, outBlockIDsMap, numLocalTerm, inData, inDataBlockIDsMap, blockAndWait);
......@@ -160,7 +160,7 @@ void ParticleMessenger::RegisterMessages(int msgSz, int nParticles, int numBlock
std::size_t messageBuffSz = CalcMessageBufferSize(msgSz + 1);
std::size_t particleBuffSz = CalcParticleBufferSize(nParticles, numBlockIds);
int numRecvs = std::min(64, this->NumRanks - 1);
int numRecvs = std::min(64, this->GetNumRanks() - 1);
this->RegisterTag(ParticleMessenger::MESSAGE_TAG, numRecvs, messageBuffSz);
this->RegisterTag(ParticleMessenger::PARTICLE_TAG, numRecvs, particleBuffSz);
......@@ -174,7 +174,7 @@ void ParticleMessenger::SendMsg(int dst, const std::vector<int>& msg)
vtkmdiy::MemoryBuffer buff;
//Write data.
vtkmdiy::save(buff, this->Rank);
vtkmdiy::save(buff, this->GetRank());
vtkmdiy::save(buff, msg);
this->SendData(dst, ParticleMessenger::MESSAGE_TAG, buff);
}
......@@ -182,8 +182,8 @@ void ParticleMessenger::SendMsg(int dst, const std::vector<int>& msg)
VTKM_CONT
void ParticleMessenger::SendAllMsg(const std::vector<int>& msg)
{
for (int i = 0; i < this->NumRanks; i++)
if (i != this->Rank)
for (int i = 0; i < this->GetNumRanks(); i++)
if (i != this->GetRank())
this->SendMsg(i, msg);
}
......
......@@ -114,7 +114,7 @@ VTKM_CONT
template <typename P, template <typename, typename> class Container, typename Allocator>
inline void ParticleMessenger::SendParticles(int dst, const Container<P, Allocator>& c)
{
if (dst == this->Rank)
if (dst == this->GetRank())
{
VTKM_LOG_S(vtkm::cont::LogLevel::Error, "Error. Sending a particle to yourself.");
return;
......@@ -123,7 +123,7 @@ inline void ParticleMessenger::SendParticles(int dst, const Container<P, Allocat
return;
vtkmdiy::MemoryBuffer bb;
vtkmdiy::save(bb, this->Rank);
vtkmdiy::save(bb, this->GetRank());
vtkmdiy::save(bb, c);
this->SendData(dst, ParticleMessenger::PARTICLE_TAG, bb);
}
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment