Commit d8fe2bef authored by Yin, Junqi's avatar Yin, Junqi

add more timers

parent 91e11344
......@@ -48,7 +48,7 @@ def define_nlp_dataset(conf, force_shuffle, repeat=True):
# Partition training data.
train_loader, _ = torchtext.data.BPTTIterator.splits(
(train, valid),
batch_size=conf.batch_size * conf.graph.n_nodes,
batch_size=conf.batch_size * conf.graph.n_nodes if repeat else conf.batch_size,
bptt_len=conf.rnn_bptt_len,
device="cuda:{}".format(conf.graph.device[0]) if conf.graph.on_cuda else None,
repeat=repeat,
......
......@@ -157,14 +157,16 @@ def train_and_validate(
# shuffle graph.
if conf.shuffle_graph_per_epoch:
print("\nReshuffle the graph.")
np.random.seed(int(scheduler.epoch_))
conf.graph.shuffle_graph()
with timer("reshuffle_graph", epoch=scheduler.epoch_):
np.random.seed(int(scheduler.epoch_))
conf.graph.shuffle_graph()
print_neighbors(conf)
# hybrid mode
if conf.hybrid and not conf.is_centralized:
print("\nHyrbid mode on.")
optimizer.world_aggregator.agg_model(model, op="avg")
print("\nHybrid mode on.")
with timer("hybrid_sync", epoch=scheduler.epoch_):
optimizer.world_aggregator.agg_model(model, op="avg")
def inference(model, criterion, metrics, _input, _target, tracker=None):
"""Inference on the given model and get loss and accuracy."""
......
......@@ -15,6 +15,7 @@ from pcode.utils.timer import Timer
from pcode.utils.auxiliary import get_model_difference
import pcode.utils.error_handler as error_handler
from pcode.create_dataset import load_data_batch, define_nlp_dataset
from pcode.utils.topology import print_neighbors
# sys.excepthook = error_handler.global_except_hook
......@@ -122,7 +123,7 @@ def train_and_validate(
# evaluate (and only inference) on the whole training loader.
if (
conf.evaluate_consensus or scheduler.is_stop()
) and not conf.train_fast:
) and not conf.train_fast and not conf.ddp:
# prepare the dataloader for the consensus evaluation.
_data_loader = {
"val_loader": define_nlp_dataset(
......@@ -169,14 +170,27 @@ def train_and_validate(
conf.logger.save_json()
return
# display tracking time.
if (
conf.graph.rank == 0
and conf.display_tracked_time
and scheduler.local_index % conf.summary_freq == 0
):
print(timer.summary())
# display tracking time.
if (
conf.graph.rank == 0
and conf.display_tracked_time
and scheduler.local_index % conf.summary_freq == 0
):
print(timer.summary())
# shuffle graph.
if conf.shuffle_graph_per_epoch:
print("\nReshuffle the graph.")
with timer("reshuffle_graph", epoch=scheduler.epoch_):
np.random.seed(int(scheduler.epoch_))
conf.graph.shuffle_graph()
print_neighbors(conf)
# hybrid mode
if conf.hybrid and not conf.is_centralized:
print("\nHybrid mode on.")
with timer("hybrid_sync", epoch=scheduler.epoch_):
optimizer.world_aggregator.agg_model(model, op="avg")
def inference(conf, model, criterion, metrics, _input, _target, _hidden, tracker=None):
"""Inference on the given model and get loss and accuracy."""
......@@ -230,10 +244,7 @@ def validate(conf, model, optimizer, criterion, scheduler, metrics, data_loader,
_model.eval()
# define hidden state for RNN.
if "train" in label:
batch_size = conf.batch_size*conf.graph.n_nodes
else:
batch_size = conf.batch_size
batch_size = conf.batch_size
_hidden = (
model.module.init_hidden(batch_size)
if "DataParallel" in model.__class__.__name__
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment