Commit 4900cdc8 authored by Yin, Junqi's avatar Yin, Junqi

add async

parent d4912969
......@@ -261,16 +261,18 @@ class DecentralizedAggregation(Aggregation):
node_next = self.rank + 1
if node_prev < 0: node_prev = node_prev + self.world_size
if node_next >= self.world_size: node_next = node_next - self.world_size
req = dist.irecv(tensor=local_data[node_prev], src=node_prev)
dist.send(tensor=local_data[self.rank], dst=node_next)
req.wait()
req = dist.irecv(tensor=local_data[node_next], src=node_next)
dist.send(tensor=local_data[self.rank], dst=node_prev)
req.wait()
reqs = []
reqs.append(dist.irecv(tensor=local_data[node_prev], src=node_prev))
reqs.append(dist.isend(tensor=local_data[self.rank], dst=node_next))
#req.wait()
reqs.append(dist.irecv(tensor=local_data[node_next], src=node_next))
reqs.append(dist.isend(tensor=local_data[self.rank], dst=node_prev))
#req.wait()
# wait until finish.
if force_wait:
#self.complete_wait(reqs)
self.complete_wait(reqs)
# Aggregate local_data
if op == "avg":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment