Commit 9d3ed392 authored by Yin, Junqi's avatar Yin, Junqi

fixes checkpointing

parent f132b4f2
......@@ -72,7 +72,7 @@ def train_and_validate(
display_training_stat(conf, scheduler, tracker_tr, n_bits_to_transmit)
# finish one epoch training and to decide if we want to val our model.
if scheduler.epoch_ % conf.eval_freq == 0:
if scheduler.epoch_ % 1 == 0:
if tracker_tr.stat["loss"].avg > 1e3 or np.isnan(
tracker_tr.stat["loss"].avg
):
......@@ -194,14 +194,16 @@ def inference(model, criterion, metrics, _input, _target, tracker=None):
def do_validate(conf, model, optimizer, criterion, scheduler, metrics, data_loader):
"""Evaluate the model on the test dataset and save to the checkpoint."""
# wait until the whole group enters this function, and then evaluate.
print("Enter validation phase.")
performance = validate(
conf, model, optimizer, criterion, scheduler, metrics, data_loader
)
if scheduler.epoch_ % conf.eval_freq == 0:
print("Enter validation phase.")
performance = validate(
conf, model, optimizer, criterion, scheduler, metrics, data_loader
)
# remember best performance and display the val info.
scheduler.best_tracker.update(performance[0], scheduler.epoch_)
dispaly_best_test_stat(conf, scheduler)
# remember best performance and display the val info.
scheduler.best_tracker.update(performance[0], scheduler.epoch_)
dispaly_best_test_stat(conf, scheduler)
print("Finished validation.")
# save to the checkpoint.
if not conf.train_fast:
......@@ -220,8 +222,6 @@ def do_validate(conf, model, optimizer, criterion, scheduler, metrics, data_load
filename="checkpoint.pth.tar",
save_all=conf.save_all_models,
)
print("Finished validation.")
def validate(
conf,
......
......@@ -131,10 +131,10 @@ def maybe_resume_from_checkpoint(conf, model, optimizer, scheduler):
scheduler.update_from_checkpoint(checkpoint)
# reset path for log.
try:
remove_folder(conf.checkpoint_root)
except RuntimeError as e:
print(f"ignore the error={e}")
#try:
# remove_folder(conf.checkpoint_root)
#except RuntimeError as e:
# print(f"ignore the error={e}")
conf.checkpoint_root = conf.resume
conf.checkpoint_dir = join(conf.resume, str(conf.graph.rank))
# restore model.
......
......@@ -21,6 +21,7 @@ class Logger:
self.file_json = os.path.join(file_folder, "log-1.json")
self.file_txt = os.path.join(file_folder, "log.txt")
self.values = []
self.redirect_new_json()
def log_metric(self, name, values, tags, display=False):
"""
......
......@@ -114,7 +114,7 @@ class BestPerf(object):
self.cur_perf = None
self.best_perf_locs = []
self.larger_is_better = larger_is_better
self.is_best = False
# define meter
self._define_meter()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment