Commit ded55b5e authored by Yin, Junqi's avatar Yin, Junqi

minor fixes

parent a325efbc
......@@ -9,6 +9,7 @@ from pcode.utils.logging import (
display_training_stat,
display_test_stat,
dispaly_best_test_stat,
print_grad_norm
)
from pcode.utils.stat_tracker import RuntimeTracker
from pcode.utils.timer import Timer
......@@ -91,6 +92,8 @@ def train_and_validate(
with timer("backward_pass", epoch=scheduler.epoch_):
loss.backward()
print(conf.graph.rank, "finish backward", idx)
if conf.print_grad:
print_grad_norm(conf, model, scheduler)
with timer("sync_complete", epoch=scheduler.epoch_):
# `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
......@@ -179,15 +182,22 @@ def train_and_validate(
print(timer.summary())
# shuffle graph.
if conf.shuffle_graph_per_epoch:
if (
conf.shuffle_graph
and scheduler.local_index % conf.shuffle_graph_freq == 0
):
print("\nReshuffle the graph.")
with timer("reshuffle_graph", epoch=scheduler.epoch_):
np.random.seed(int(scheduler.epoch_))
shuffle_graph(conf.graph)
print_neighbors(conf)
# hybrid mode
if conf.hybrid and not conf.is_centralized:
# hybrid mode
if (
conf.hybrid
and not conf.is_centralized
and scheduler.local_index % conf.hybrid_freq == 0
):
print("\nHybrid mode on.")
with timer("hybrid_sync", epoch=scheduler.epoch_):
optimizer.world_aggregator.agg_model(model, op="avg")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment