Commit 4dc60369 authored by Yin, Junqi's avatar Yin, Junqi

add generic graph reshuffle

parent d8fe2bef
......@@ -16,7 +16,7 @@ from pcode.utils.logging import (
from pcode.utils.stat_tracker import RuntimeTracker
import pcode.utils.error_handler as error_handler
import pcode.utils.auxiliary as auxiliary
from pcode.utils.topology import print_neighbors
from pcode.utils.topology import print_neighbors, shuffle_graph
# sys.excepthook = error_handler.global_except_hook
......@@ -159,7 +159,7 @@ def train_and_validate(
print("\nReshuffle the graph.")
with timer("reshuffle_graph", epoch=scheduler.epoch_):
np.random.seed(int(scheduler.epoch_))
conf.graph.shuffle_graph()
shuffle_graph(conf.graph)
print_neighbors(conf)
# hybrid mode
......
......@@ -15,7 +15,7 @@ from pcode.utils.timer import Timer
from pcode.utils.auxiliary import get_model_difference
import pcode.utils.error_handler as error_handler
from pcode.create_dataset import load_data_batch, define_nlp_dataset
from pcode.utils.topology import print_neighbors
from pcode.utils.topology import print_neighbors, shuffle_graph
# sys.excepthook = error_handler.global_except_hook
......@@ -183,7 +183,7 @@ def train_and_validate(
print("\nReshuffle the graph.")
with timer("reshuffle_graph", epoch=scheduler.epoch_):
np.random.seed(int(scheduler.epoch_))
conf.graph.shuffle_graph()
shuffle_graph(conf.graph)
print_neighbors(conf)
# hybrid mode
......
......@@ -23,7 +23,17 @@ def print_neighbors(conf, save=True):
conf.logger.log(f"node: {platform.node()}")
conf.logger.log(f"neighbors: {neighbor_ranks}")
conf.logger.log(f"matrix: {conf.graph.matrix}")
def shuffle_graph(graph):
size = graph.n_nodes
mixing_matrix = graph.matrix
permutation = np.random.permutation(size)
matrix = np.zeros((size, size))
for i in range(size):
for j in range(size):
matrix[permutation[i]][permutation[j]] = mixing_matrix[i][j]
graph.set_matrix(matrix)
return permutation
class UndirectedGraph(ABC):
@property
......@@ -314,21 +324,9 @@ class RingGraph(PhysicalLayout):
vals = row.data
return {int(c): v for c, v in zip(cols, vals)}
def shuffle_graph(self):
size = self._n_mpi_process
permutation = np.random.permutation(size)
m = np.zeros((size, size))
w = 1.0/3
for i in range(size):
cur = permutation[i]
m[cur][cur] = w
prev = permutation[(i-1)%size]
nxt = permutation[(i+1)%size]
m[cur][prev] = m[prev][cur] = w
m[cur][nxt] = m[nxt][cur] = w
def set_matrix(self, m):
self._mixing_matrix = sp.sparse.csr_matrix(m)
class TorusGraph(PhysicalLayout):
def __init__(self, n_mpi_process, n_sub_process, world, comm_device, on_cuda, rank):
super(TorusGraph, self).__init__(
......@@ -353,7 +351,7 @@ class TorusGraph(PhysicalLayout):
mixing_matrix = networkx.adjacency_matrix(graph).toarray()
for i in range(0, mixing_matrix.shape[0]):
mixing_matrix[i][i] = 1
mixing_matrix = mixing_matrix / 5
mixing_matrix = mixing_matrix / 4
return mixing_matrix
@property
......@@ -383,6 +381,9 @@ class TorusGraph(PhysicalLayout):
def get_neighborhood(self):
row = self._mixing_matrix[self._rank]
return {c: v for c, v in zip(range(len(row)), row) if v != 0}
def set_matrix(self, m):
self._mixing_matrix = m
class ExpanderGraph(PhysicalLayout):
def __init__(self, n_mpi_process, n_sub_process, world, comm_device, on_cuda, rank):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment