Commit 940db55d authored by Yin, Junqi's avatar Yin, Junqi

make it async

parent 871eca12
......@@ -241,6 +241,8 @@ class DecentralizedAggregation(Aggregation):
]
self.group_size = comm_group_size
self.world_size = graph.n_nodes
self.neighbor_size = len(self.neighbor_ranks)
self.group_rank = self.rank % self.group_size
# TODO: need map to graph
self.group_number = self.rank // self.group_size
self.group = dist.new_group(list(range(self.group_number*self.group_size,(self.group_number+1)*self.group_size)))
......@@ -255,26 +257,41 @@ class DecentralizedAggregation(Aggregation):
"""
# Create some tensors to host the values from group
data_buffer = [torch.empty_like(data)]*self.group_size
dist.all_gather(data_buffer, data, group=self.group, async_op=False)
reqs = []
reqs.append(dist.all_gather(data_buffer, data, group=self.group, async_op=True))
node_prev = self.rank - 1
node_next = self.rank + 1
if self.rank % self.group_size == 0:
is_head = self.group_rank == 0
is_tail = self.group_rank == self.group_size - 1
if is_head:
if node_prev < 0:
node_prev = node_prev + self.world_size
dist.send(tensor=data_buffer[0], dst=node_prev)
dist.recv(tensor=data_buffer[self.group_size-1], src=node_prev)
elif self.rank % self.group_size == self.group_size - 1:
if node_next >= self.world_size :
local_buffer = [torch.empty_like(data)]*self.neighbor_size
local_buffer[0] = data
reqs.append(dist.isend(tensor=local_buffer[0], dst=node_prev))
reqs.append(dist.irecv(tensor=local_buffer[1], src=node_prev))
elif is_tail:
if node_next >= self.world_size:
node_next = node_next - self.world_size
dist.recv(tensor=data_buffer[0], src=node_next)
dist.send(tensor=data_buffer[self.group_size-1], dst=node_next)
output = sum(
[ data_buffer[rank%self.group_size]
for rank in (node_prev, self.rank, node_next)
]
)/(len(self.neighbor_ranks)+1)
local_buffer = [torch.empty_like(data)]*self.neighbor_size
local_buffer[0] = data
reqs.append(dist.irecv(tensor=local_buffer[1], src=node_next))
reqs.append(dist.isend(tensor=local_buffer[0], dst=node_next))
self.complete_wait(reqs)
if is_head:
output = (
local_buffer[1]+data_buffer[self.group_rank]+data_buffer[self.group_rank+1]
)/(len(self.neighbor_ranks)+1)
elif is_tail:
output = (
data_buffer[self.group_rank-1]+data_buffer[self.group_rank]+local_buffer[1]
)/(len(self.neighbor_ranks)+1)
else:
output = (
data_buffer[self.group_rank-1]+data_buffer[self.group_rank]+data_buffer[self.group_rank+1]
)/(len(self.neighbor_ranks)+1)
return output
# wait until finish.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment