Commit fa6f7b88 authored by Yin, Junqi's avatar Yin, Junqi

add more knobs

parent fb7ad3de
#!/bin/bash
#BSUB -P stf011
#BSUB -W 12:00
#BSUB -nnodes 8
#BSUB -W 8:00
#BSUB -nnodes 16
#BSUB -alloc_flags "nvme gpumps"
#BSUB -J chocosgd
#BSUB -o logs/chocosgd.o%J
......@@ -13,17 +13,19 @@ NNODES=$(cat ${LSB_DJOB_HOSTFILE} | sort | uniq | grep -v login | grep -v batch
source choco_env.sh
single_rank_per_node=false
#resnet20, lstm, resnet50, vgg19
#resnet20, lstm, resnet50, densenet100
EXPERIMENT=resnet50
PRINT_GRAD=False
# centralized: complete; decentralized: ring, torus, expander, margulis_expander, social
TOPOLOGY=ring
TOPOLOGY=ring
# ddp only supports complete topology
DDP=False
# shuffle graph
SHUFFLE_GRAPH=True
SHUFFLE_GRAPH=False
FREQ_SHUFFLE=10
# hybrid mode
HYBRID=False
HYBRID=True
FREQ_HYBRID=5
# stage data to nvme
if [ "$EXPERIMENT" == "resnet50" ]; then
......@@ -42,12 +44,16 @@ else
sed -i "s/TODO_NSUB/1/" run.sh
fi
NOW=$(date '+%Y%m%d%H%M%S')
TIMESTAMP=${NNODES}-nodes_${TOPOLOGY}_DDP-${DDP}_SHUFFLE_GRAPH-${SHUFFLE_GRAPH}_FREQ-${FREQ_SHUFFLE}_HYBRID-${HYBRID}_FREQ-${FREQ_HYBRID}_${LSB_JOBID}
sed -i "s/TODO_GPURANKS/$WORLD/" run.sh
sed -i "s/TODO_TOPOLOGY/$TOPOLOGY/" run.sh
sed -i "s/TODO_TIMESTAMP/$NOW/" run.sh
sed -i "s/TODO_TIMESTAMP/$TIMESTAMP/" run.sh
sed -i "s/TODO_DDP/$DDP/" run.sh
sed -i "s/TODO_SHUFFLE_GRAPH/$SHUFFLE_GRAPH/" run.sh
sed -i "s/TODO_HYBRID/$HYBRID/" run.sh
sed -i "s/TODO_FREQ_HYBRID/$FREQ_HYBRID/" run.sh
sed -i "s/TODO_FREQ_SHUFFLE/$FREQ_SHUFFLE/" run.sh
sed -i "s/TODO_PRINT_GRAD/$PRINT_GRAD/" run.sh
if [ "$single_rank_per_node" = true ]; then
jsrun -n${NNODES} -a1 -g6 -c42 -r1 --smpiargs "-gpu" --bind=rs --launch_distribution=packed ./run.sh
......
......@@ -197,8 +197,8 @@ def init_config(conf):
# configure cuda related.
if conf.graph.on_cuda:
assert torch.cuda.is_available()
torch.manual_seed(conf.manual_seed)
torch.cuda.manual_seed(conf.manual_seed)
torch.manual_seed(conf.manual_seed*cur_rank)
torch.cuda.manual_seed(conf.manual_seed*cur_rank)
torch.cuda.set_device(conf.graph.device[0])
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
......
......@@ -229,8 +229,15 @@ def get_args():
parser.add_argument("--local_rank", default=None, type=str)
parser.add_argument("--clean_python", default=False, type=str2bool)
parser.add_argument("--ddp", type=str2bool, default=False)
parser.add_argument("--shuffle_graph_per_epoch", type=str2bool, default=False)
parser.add_argument("--shuffle_graph", type=str2bool, default=False)
parser.add_argument("--hybrid", type=str2bool, default=False)
parser.add_argument(
"--hybrid_freq", default=10, type=int, help="# of steps per sync."
)
parser.add_argument(
"--shuffle_graph_freq", default=10, type=int, help="# of steps per shuffle."
)
parser.add_argument("--print_grad", type=str2bool, default=False)
# parse conf.
conf = parser.parse_args()
......
......@@ -57,6 +57,8 @@ class Scheduler(object):
else:
if self.conf.lr_scaleup_factor == "graph":
_scale = self.conf.graph.scaling
if self.conf.hybrid:
_scale = 3 + 1.0*(self.conf.graph.n_nodes - 3)/(10.*self.conf.hybrid_freq/self.conf.num_batches_train_per_device_per_epoch + 1 - 10./self.conf.num_batches_train_per_device_per_epoch)
elif self.conf.lr_scaleup_factor == "world":
_scale = self.conf.graph.n_nodes
else:
......
......@@ -12,6 +12,7 @@ from pcode.utils.logging import (
display_training_stat,
display_test_stat,
dispaly_best_test_stat,
print_grad_norm
)
from pcode.utils.stat_tracker import RuntimeTracker
import pcode.utils.error_handler as error_handler
......@@ -55,6 +56,8 @@ def train_and_validate(
with timer("backward_pass", epoch=scheduler.epoch_):
loss.backward()
if conf.print_grad:
print_grad_norm(conf, model, scheduler)
with timer("sync_complete", epoch=scheduler.epoch_):
if not conf.ddp:
......@@ -139,6 +142,27 @@ def train_and_validate(
error_handler.abort()
return
# shuffle graph.
if (
conf.shuffle_graph
and scheduler.local_index % conf.shuffle_graph_freq == 0
):
print("\nReshuffle the graph.")
with timer("reshuffle_graph", epoch=scheduler.epoch_):
np.random.seed(int(scheduler.epoch_))
shuffle_graph(conf.graph)
print_neighbors(conf)
# hybrid mode
if (
conf.hybrid
and not conf.is_centralized
and scheduler.local_index % conf.hybrid_freq == 0
):
print("\nHybrid mode on.")
with timer("hybrid_sync", epoch=scheduler.epoch_):
optimizer.world_aggregator.agg_model(model, op="avg")
# display tracking time.
if (
conf.graph.rank == 0
......@@ -154,20 +178,6 @@ def train_and_validate(
gc.collect()
data_loader = define_dataset(conf)
# shuffle graph.
if conf.shuffle_graph_per_epoch:
print("\nReshuffle the graph.")
with timer("reshuffle_graph", epoch=scheduler.epoch_):
np.random.seed(int(scheduler.epoch_))
shuffle_graph(conf.graph)
print_neighbors(conf)
# hybrid mode
if conf.hybrid and not conf.is_centralized:
print("\nHybrid mode on.")
with timer("hybrid_sync", epoch=scheduler.epoch_):
optimizer.world_aggregator.agg_model(model, op="avg")
def inference(model, criterion, metrics, _input, _target, tracker=None):
"""Inference on the given model and get loss and accuracy."""
output = model(_input)
......
......@@ -153,3 +153,11 @@ def dispaly_best_test_stat(conf, scheduler):
scheduler.best_tracker.best_perf,
)
)
def print_grad_norm(conf, model, scheduler):
conf.logger.log(f"epoch: {scheduler.epoch_} step: {scheduler.local_index}")
for name, param in model.named_parameters():
if param.requires_grad:
conf.logger.log(f"parameter: {name} grad_norm: {param.grad.data.norm().item()}")
#!/bin/bash
python -u main.py \
--work_dir $(pwd) \
--remote_exec False \
--data cifar10 \
--data_dir ./data/ \
--use_lmdb_data False \
--pin_memory True \
--arch densenet100 \
--train_fast False \
--stop_criteria epoch \
--num_epochs 300 \
--num_iterations 32000 \
--avg_model True \
--reshuffle_per_epoch True \
--batch_size 128 \
--base_batch_size 64 \
--lr 0.1 \
--lr_scaleup True \
--lr_scaleup_type linear \
--lr_scaleup_factor graph \
--lr_warmup True \
--lr_warmup_epochs 5 \
--lr_decay 0.1 \
--lr_onecycle_low 0.15 \
--lr_onecycle_high 3 \
--lr_onecycle_extra_low 0.0015 \
--lr_onecycle_num_epoch 46 \
--lr_schedule_scheme custom_one_cycle \
--optimizer sgd \
--graph_topology TODO_TOPOLOGY \
--evaluate_consensus False \
--momentum_factor 0.9 \
--use_nesterov True \
--weight_decay 0.0001 \
--drop_rate 0.0 \
--manual_seed 6 \
--evaluate False \
--eval_freq 1 \
--summary_freq 100 \
--timestamp TODO_TIMESTAMP \
--track_time True \
--track_detailed_time False \
--display_tracked_time True \
--evaluate_avg False \
--checkpoint ./data/checkpoint \
--save_all_models False \
--experiment test \
--backend mpi \
--use_ipc False \
--num_workers 0 \
--n_mpi_process TODO_NRANK \
--n_sub_process TODO_NSUB \
--world TODO_GPURANKS \
--on_cuda True \
--comm_device cuda \
--ddp TODO_DDP \
--shuffle_graph TODO_SHUFFLE_GRAPH \
--shuffle_graph_freq TODO_FREQ_SHUFFLE \
--hybrid TODO_HYBRID \
--hybrid_freq TODO_FREQ_HYBRID \
--print_grad TODO_PRINT_GRAD
......@@ -80,5 +80,8 @@ python -u main.py \
--on_cuda True \
--comm_device cuda \
--ddp TODO_DDP \
--shuffle_graph_per_epoch TODO_SHUFFLE_GRAPH \
--hybrid TODO_HYBRID
--shuffle_graph TODO_SHUFFLE_GRAPH \
--shuffle_graph_freq TODO_FREQ_SHUFFLE \
--hybrid TODO_HYBRID \
--hybrid_freq TODO_FREQ_HYBRID \
--print_grad TODO_PRINT_GRAD
......@@ -56,5 +56,8 @@ python -u main.py \
--on_cuda True \
--comm_device cuda \
--ddp TODO_DDP \
--shuffle_graph_per_epoch TODO_SHUFFLE_GRAPH \
--hybrid TODO_HYBRID
--shuffle_graph TODO_SHUFFLE_GRAPH \
--shuffle_graph_freq TODO_FREQ_SHUFFLE \
--hybrid TODO_HYBRID \
--hybrid_freq TODO_FREQ_HYBRID \
--print_grad TODO_PRINT_GRAD
......@@ -5,7 +5,6 @@ python -u main.py \
--data imagenet \
--data_dir /mnt/bb/$USER/data \
--use_lmdb_data False \
--partition_data random \
--pin_memory True \
--arch resnet50 \
--train_fast False \
......@@ -51,5 +50,8 @@ python -u main.py \
--on_cuda True \
--comm_device cuda \
--ddp TODO_DDP \
--shuffle_graph_per_epoch TODO_SHUFFLE_GRAPH \
--hybrid TODO_HYBRID
--shuffle_graph TODO_SHUFFLE_GRAPH \
--shuffle_graph_freq TODO_FREQ_SHUFFLE \
--hybrid TODO_HYBRID \
--hybrid_freq TODO_FREQ_HYBRID \
--print_grad TODO_PRINT_GRAD
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment