Commit b769826d authored by Yin, Junqi's avatar Yin, Junqi
Browse files

add some tuning

parent ed883d9f
...@@ -7,6 +7,9 @@ class Scheduler(object): ...@@ -7,6 +7,9 @@ class Scheduler(object):
# init # init
self.conf = conf self.conf = conf
self.local_index = 0 if "local_index" not in conf else conf.local_index self.local_index = 0 if "local_index" not in conf else conf.local_index
self.epoch_ = (
self.local_index / self.conf.num_batches_train_per_device_per_epoch
)
self.init_learning_rate() self.init_learning_rate()
self.init_lr_scheduler() self.init_lr_scheduler()
...@@ -14,7 +17,9 @@ class Scheduler(object): ...@@ -14,7 +17,9 @@ class Scheduler(object):
self.conf.local_index = checkpoint["local_index"] self.conf.local_index = checkpoint["local_index"]
self.local_index = checkpoint["local_index"] self.local_index = checkpoint["local_index"]
self.conf.best_perf = checkpoint["best_perf"] self.conf.best_perf = checkpoint["best_perf"]
self.epoch_ = (
self.local_index / self.conf.num_batches_train_per_device_per_epoch
)
def set_best_tracker(self, best_tracker): def set_best_tracker(self, best_tracker):
self.best_tracker = best_tracker self.best_tracker = best_tracker
...@@ -99,7 +104,12 @@ class Scheduler(object): ...@@ -99,7 +104,12 @@ class Scheduler(object):
self.lr_scheduler = AdaptiveLRScheduler(self.conf).get_lr_scheduler() self.lr_scheduler = AdaptiveLRScheduler(self.conf).get_lr_scheduler()
def get_lr(self, **kargs): def get_lr(self, **kargs):
return self.lr_scheduler(self.epoch_, **kargs) #return self.lr_scheduler(self.epoch_, **kargs)
lr = self.lr_scheduler(self.epoch_, **kargs)
init_ringlatt_degree = max(self.conf.n_mpi_process//6*2, 2)
current_ringlatt_degree = max(self.conf.n_mpi_process//6*2-2*(self.epoch_//2), 2)
scale = 1.0*current_ringlatt_degree/init_ringlatt_degree
return lr*scale
def step(self, optimizer, **kargs): def step(self, optimizer, **kargs):
self.update_training_progress() self.update_training_progress()
......
...@@ -38,6 +38,7 @@ def train_and_validate( ...@@ -38,6 +38,7 @@ def train_and_validate(
# break until finish expected full epoch training. # break until finish expected full epoch training.
print("=>>>> enter the training.\n") print("=>>>> enter the training.\n")
conf.logger.log(f"start epoch: {scheduler.epoch_} lr: {scheduler.get_lr()}")
while True: while True:
dist.barrier() dist.barrier()
......
...@@ -12,7 +12,7 @@ python -u main.py \ ...@@ -12,7 +12,7 @@ python -u main.py \
--num_epochs 90 \ --num_epochs 90 \
--avg_model True \ --avg_model True \
--reshuffle_per_epoch True \ --reshuffle_per_epoch True \
--batch_size 32 \ --batch_size 16 \
--base_batch_size 256 \ --base_batch_size 256 \
--lr 0.1 \ --lr 0.1 \
--lr_scaleup True \ --lr_scaleup True \
...@@ -21,7 +21,7 @@ python -u main.py \ ...@@ -21,7 +21,7 @@ python -u main.py \
--lr_warmup True \ --lr_warmup True \
--lr_warmup_epochs 5 \ --lr_warmup_epochs 5 \
--lr_schedule_scheme custom_multistep \ --lr_schedule_scheme custom_multistep \
--lr_change_epochs 30,60,80 \ --lr_change_epochs 30,40,60 \
--optimizer sgd \ --optimizer sgd \
--graph_topology TODO_TOPOLOGY \ --graph_topology TODO_TOPOLOGY \
--evaluate_consensus False \ --evaluate_consensus False \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment