Commit dc292b93 authored by Yin, Junqi's avatar Yin, Junqi

add vgg for cifar100

parent 7eb85c6d
......@@ -28,16 +28,8 @@ class VGG(nn.Module):
# init models.
self.features = self._make_layers()
self.classifier = nn.Sequential(
nn.Dropout(),
nn.Linear(512, 512),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(512, 512),
nn.ReLU(True),
nn.Linear(512, self.num_classes),
)
self.classifier = nn.Linear(512, self.num_classes)
# weight initialization.
self._weight_initialization()
......@@ -66,11 +58,9 @@ class VGG(nn.Module):
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
else:
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
if self.use_bn:
layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
else:
layers += [conv2d, nn.ReLU(inplace=True)]
layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
in_channels = v
layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
return nn.Sequential(*layers)
def forward(self, x):
......@@ -81,16 +71,15 @@ class VGG(nn.Module):
def vgg(conf):
use_bn = 'bn' in conf.arch
dataset = conf.data
if '11' in conf.arch:
return VGG(nn_arch='A', dataset=dataset, use_bn=use_bn)
return VGG(nn_arch='A', dataset=dataset)
elif '13' in conf.arch:
return VGG(nn_arch='B', dataset=dataset, use_bn=use_bn)
return VGG(nn_arch='B', dataset=dataset)
elif '16' in conf.arch:
return VGG(nn_arch='D', dataset=dataset, use_bn=use_bn)
return VGG(nn_arch='D', dataset=dataset)
elif '19' in conf.arch:
return VGG(nn_arch='E', dataset=dataset, use_bn=use_bn)
return VGG(nn_arch='E', dataset=dataset)
else:
raise NotImplementedError
#!/bin/bash
python -u main.py \
--work_dir $(pwd) \
--remote_exec False \
--data cifar100 \
--data_dir ./data/ \
--use_lmdb_data False \
--partition_data random \
--pin_memory True \
--arch vgg19 \
--train_fast False \
--stop_criteria epoch \
--num_epochs 300 \
--num_iterations 32000 \
--avg_model True \
--reshuffle_per_epoch True \
--batch_size 64 \
--base_batch_size 64 \
--lr 0.1 \
--lr_scaleup True \
--lr_scaleup_type linear \
--lr_scaleup_factor graph \
--lr_warmup True \
--lr_warmup_epochs 5 \
--lr_schedule_scheme custom_multistep \
--lr_change_epochs 60,150,225 \
--optimizer sgd \
--graph_topology TODO_TOPOLOGY \
--evaluate_consensus False \
--momentum_factor 0.9 \
--use_nesterov True \
--weight_decay 0.0005 \
--drop_rate 0.0 \
--manual_seed 6 \
--evaluate False \
--eval_freq 1 \
--summary_freq 100 \
--timestamp TODO_TIMESTAMP \
--track_time True \
--track_detailed_time False \
--display_tracked_time True \
--evaluate_avg False \
--checkpoint ./data/checkpoint \
--save_all_models False \
--experiment test \
--backend mpi \
--use_ipc False \
--num_workers 0 \
--n_mpi_process TODO_NRANK \
--n_sub_process TODO_NSUB \
--world TODO_GPURANKS \
--on_cuda True \
--comm_device cuda \
--ddp TODO_DDP \
--shuffle_graph TODO_SHUFFLE_GRAPH \
--shuffle_graph_freq TODO_FREQ_SHUFFLE \
--hybrid TODO_HYBRID \
--hybrid_freq TODO_FREQ_HYBRID \
--print_grad TODO_PRINT_GRAD \
--resume TODO_RESUME_DIR
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment