Commit 993da23e authored by Yin, Junqi's avatar Yin, Junqi

add evaluate consensus for nlp

parent fba88ae8
......@@ -27,7 +27,7 @@ def define_dataset(conf, force_shuffle=False):
"""nlp related task."""
def define_nlp_dataset(conf, force_shuffle):
def define_nlp_dataset(conf, force_shuffle, repeat=True):
print("create {} dataset for rank {}".format(conf.data, conf.graph.rank))
# create dataset.
TEXT, train, valid, _ = get_dataset(conf, conf.data, conf.data_dir)
......@@ -51,7 +51,7 @@ def define_nlp_dataset(conf, force_shuffle):
batch_size=conf.batch_size * conf.graph.n_nodes,
bptt_len=conf.rnn_bptt_len,
device="cuda:{}".format(conf.graph.device[0]) if conf.graph.on_cuda else None,
repeat=True,
repeat=repeat,
shuffle=force_shuffle or conf.reshuffle_per_epoch,
)
_, val_loader = torchtext.data.BPTTIterator.splits(
......
# -*- coding: utf-8 -*-
from copy import deepcopy
import copy
import numpy as np
import torch
......@@ -14,7 +14,7 @@ from pcode.utils.stat_tracker import RuntimeTracker
from pcode.utils.timer import Timer
from pcode.utils.auxiliary import get_model_difference
import pcode.utils.error_handler as error_handler
from pcode.create_dataset import load_data_batch
from pcode.create_dataset import load_data_batch, define_nlp_dataset
# sys.excepthook = error_handler.global_except_hook
......@@ -115,6 +115,51 @@ def train_and_validate(
# refresh the logging cache at the begining of each epoch.
tracker_tr.reset()
# evaluate (and only inference) on the whole training loader.
if (
conf.evaluate_consensus or scheduler.is_stop()
) and not conf.train_fast:
# prepare the dataloader for the consensus evaluation.
_data_loader = {
"val_loader": define_nlp_dataset(
conf,
force_shuffle=True,
repeat=False
)["train_loader"]
}
# evaluate on the local model.
conf.logger.log("eval the local model on full training data.")
validate(
conf,
model,
optimizer,
criterion,
scheduler,
metrics,
data_loader=_data_loader,
label="eval_local_model_on_full_training_data"
)
# evaluate on the averaged model.
conf.logger.log("eval the averaged model on full training data.")
copied_model = copy.deepcopy(
model.module
if "DataParallel" == model.__class__.__name__
else model
)
optimizer.world_aggregator.agg_model(copied_model, op="avg")
validate(
conf,
copied_model,
optimizer,
criterion,
scheduler,
metrics,
data_loader=_data_loader,
label="eval_averaged_model_on_full_training_data"
)
# determine if the training is finished.
if scheduler.is_stop():
conf.logger.save_json()
......@@ -169,7 +214,8 @@ def do_validate(conf, model, optimizer, criterion, scheduler, metrics, data_load
print("Finished validation.")
def validate(conf, model, optimizer, criterion, scheduler, metrics, data_loader):
def validate(conf, model, optimizer, criterion, scheduler, metrics, data_loader,
label="local_model"):
"""A function for model evaluation."""
def _evaluate(_model, label):
......@@ -180,10 +226,14 @@ def validate(conf, model, optimizer, criterion, scheduler, metrics, data_loader)
_model.eval()
# define hidden state for RNN.
if "train" in label:
batch_size = conf.batch_size*conf.graph.n_nodes
else:
batch_size = conf.batch_size
_hidden = (
model.module.init_hidden(conf.batch_size)
model.module.init_hidden(batch_size)
if "DataParallel" == model.__class__.__name__
else model.init_hidden(conf.batch_size)
else model.init_hidden(batch_size)
)
for batch in data_loader["val_loader"]:
......@@ -238,5 +288,5 @@ def validate(conf, model, optimizer, criterion, scheduler, metrics, data_loader)
# )
# evaluate each local model on the validation dataset.
global_performance = _evaluate(model, label="local_model")
global_performance = _evaluate(model, label=label)
return global_performance
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment