Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
A
accelerated_deeplearning_training
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Iterations
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Locked files
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package Registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Code review analytics
Issue analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Accelerated_training
accelerated_deeplearning_training
Commits
61a92656
Commit
61a92656
authored
4 years ago
by
Reshniak, Viktor
Browse files
Options
Downloads
Patches
Plain Diff
update optimizers.py
parent
77a6cc68
No related branches found
Branches containing commit
No related tags found
1 merge request
!12
Torch anderson
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
modules/optimizers.py
+135
-207
135 additions, 207 deletions
modules/optimizers.py
with
135 additions
and
207 deletions
modules/optimizers.py
+
135
−
207
View file @
61a92656
...
...
@@ -5,17 +5,23 @@ from torch import Tensor
from
torch
import
autograd
from
abc
import
ABCMeta
,
abstractmethod
,
ABC
import
math
from
AccelerationModule
import
AccelerationModule
from
collections
import
deque
from
torch.nn.utils
import
parameters_to_vector
,
vector_to_parameters
# Abstract class that provides basic guidelines to implement an acceleration
class
Optimizer
(
object
,
metaclass
=
ABCMeta
):
import
sys
sys
.
path
.
append
(
"
../utils
"
)
import
rna_acceleration
as
rna
import
anderson_acceleration
as
anderson
class
FixedPointIteration
(
object
):
def
__init__
(
self
,
training_dataloader
:
torch
.
utils
.
data
.
dataloader
.
DataLoader
,
validation_dataloader
:
torch
.
utils
.
data
.
dataloader
.
DataLoader
,
learning_rate
:
float
,
weight_decay
:
float
=
0.0
,
verbose
:
bool
=
False
):
"""
:type training_dataloader: torch.utils.data.dataloader.DataLoader
:type validation_dataloader: torch.utils.data.dataloader.DataLoader
:type validation_dataloader: torch.utils.data.dataloader.DataLoader
:type learning_rate: float
:type weight_decay: float
"""
...
...
@@ -60,11 +66,99 @@ class Optimizer(object, metaclass=ABCMeta):
assert
self
.
model_imported
return
self
.
model
@abstractmethod
def
train
(
self
,
input_data
:
torch
.
Tensor
,
target
:
torch
.
Tensor
,
num_iterations
:
int
,
threshold
:
float
,
batch_size
:
int
):
def
accelerate
(
self
):
pass
def
train
(
self
,
num_epochs
,
threshold
,
batch_size
):
assert
self
.
model_imported
assert
self
.
optimizer_specified
epoch_counter
=
0
value_loss
=
float
(
'
Inf
'
)
self
.
training_loss_history
=
[]
self
.
validation_loss_history
=
[]
while
epoch_counter
<
num_epochs
and
value_loss
>
threshold
:
self
.
model
.
get_model
().
train
(
True
)
# Training
for
batch_idx
,
(
data
,
target
)
in
enumerate
(
self
.
training_dataloader
):
self
.
accelerate
()
data
,
target
=
(
data
.
to
(
self
.
model
.
get_device
()),
target
.
to
(
self
.
model
.
get_device
()))
self
.
optimizer
.
zero_grad
()
output
=
self
.
model
.
forward
(
data
)
loss
=
self
.
criterion
(
output
,
target
)
loss
.
backward
()
if
self
.
optimizer_str
==
'
lbfgs
'
:
def
closure
():
if
torch
.
is_grad_enabled
():
self
.
optimizer
.
zero_grad
()
output
=
self
.
model
.
forward
(
data
)
loss
=
self
.
criterion
(
output
,
target
)
if
loss
.
requires_grad
:
loss
.
backward
()
return
loss
self
.
optimizer
.
step
(
closure
)
else
:
self
.
optimizer
.
step
()
self
.
print_verbose
(
'
Train Epoch: {} [{}/{} ({:.0f}%)]
\t
Loss: {:.6f}
'
.
format
(
epoch_counter
,
batch_idx
*
len
(
data
),
len
(
self
.
training_dataloader
.
dataset
),
100.0
*
batch_idx
/
len
(
self
.
training_dataloader
),
loss
.
item
())
)
train_loss
=
loss
.
item
()
self
.
training_loss_history
.
append
(
train_loss
)
# Validation
with
torch
.
no_grad
():
self
.
model
.
get_model
().
train
(
False
)
val_loss
=
0.0
count_val
=
0
correct
=
0
for
batch_idx
,
(
data
,
target
)
in
enumerate
(
self
.
validation_dataloader
):
count_val
=
count_val
+
1
data
,
target
=
(
data
.
to
(
self
.
model
.
get_device
()),
target
.
to
(
self
.
model
.
get_device
()),
)
output
=
self
.
model
.
forward
(
data
)
loss
=
self
.
criterion
(
output
,
target
)
val_loss
=
val_loss
+
loss
"""
pred = output.argmax(
dim=1, keepdim=True
) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()
"""
val_loss
=
val_loss
/
count_val
self
.
validation_loss_history
.
append
(
val_loss
)
"""
self.print_verbose(
'
\n
Epoch:
'
+ str(epoch_counter)
+
'
- Training Loss:
'
+ str(train_loss)
+
'
- Validation - Loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)
\n
'
.format(
val_loss,
correct,
len(self.validation_dataloader.dataset),
100.0 * correct / len(self.validation_dataloader.dataset),
)
)
self.print_verbose(
"
###############################
"
)
"""
value_loss
=
val_loss
epoch_counter
=
epoch_counter
+
1
return
self
.
training_loss_history
,
self
.
validation_loss_history
def
set_loss_function
(
self
,
criterion_string
):
if
criterion_string
.
lower
()
==
'
mse
'
:
...
...
@@ -103,8 +197,8 @@ class Optimizer(object, metaclass=ABCMeta):
self
.
optimizer_str
=
optimizer_string
.
lower
()
self
.
optimizer_specified
=
True
elif
optimizer_string
.
lower
()
==
'
lbfgs
'
:
self
.
optimizer
=
torch
.
optim
.
LBFGS
(
self
.
model
.
get_model
().
parameters
(),
lr
=
self
.
lr
,
history_size
=
10
,
max_iter
=
20
,
line_search_fn
=
True
,
batch_mode
=
True
)
self
.
optimizer
=
torch
.
optim
.
LBFGS
(
self
.
model
.
get_model
().
parameters
(),
lr
=
self
.
lr
,
history_size
=
10
,
max_iter
=
20
,
line_search_fn
=
True
,
batch_mode
=
True
)
self
.
optimizer_str
=
optimizer_string
.
lower
()
self
.
optimizer_specified
=
True
else
:
...
...
@@ -119,113 +213,15 @@ class Optimizer(object, metaclass=ABCMeta):
print
(
*
args
,
**
kwargs
)
class
FixedPointIteration
(
Optimizer
,
ABC
):
def
__init__
(
self
,
training_dataloader
:
torch
.
utils
.
data
.
dataloader
.
DataLoader
,
validation_dataloader
:
torch
.
utils
.
data
.
dataloader
.
DataLoader
,
learning_rate
:
float
,
weight_decay
:
float
=
0.0
,
verbose
:
bool
=
False
):
"""
:type training_dataloader: torch.utils.data.dataloader.DataLoader
:type validation_dataloader: torch.utils.data.dataloader.DataLoader
:param learning_rate: :type: float
:param weight_decay: :type: float
"""
super
(
FixedPointIteration
,
self
).
__init__
(
training_dataloader
,
validation_dataloader
,
learning_rate
,
weight_decay
,
verbose
)
def
train
(
self
,
num_epochs
,
threshold
,
batch_size
):
assert
self
.
optimizer_specified
epoch_counter
=
0
value_loss
=
float
(
'
Inf
'
)
self
.
training_loss_history
=
[]
self
.
validation_loss_history
=
[]
while
epoch_counter
<
num_epochs
and
value_loss
>
threshold
:
self
.
model
.
get_model
().
train
(
True
)
train_loss
=
0.0
# Training
for
batch_idx
,
(
data
,
target
)
in
enumerate
(
self
.
training_dataloader
):
data
,
target
=
(
data
.
to
(
self
.
model
.
get_device
()),
target
.
to
(
self
.
model
.
get_device
()))
self
.
optimizer
.
zero_grad
()
output
=
self
.
model
.
forward
(
data
)
loss
=
self
.
criterion
(
output
,
target
)
loss
.
backward
()
if
self
.
optimizer_str
==
'
lbfgs
'
:
def
closure
():
if
torch
.
is_grad_enabled
():
self
.
optimizer
.
zero_grad
()
output
=
self
.
model
.
forward
(
data
)
loss
=
self
.
criterion
(
output
,
target
)
if
loss
.
requires_grad
:
loss
.
backward
()
return
loss
self
.
optimizer
.
step
(
closure
)
else
:
self
.
optimizer
.
step
()
self
.
print_verbose
(
'
Train Epoch: {} [{}/{} ({:.0f}%)]
\t
Loss: {:.6f}
'
.
format
(
epoch_counter
,
batch_idx
*
len
(
data
),
len
(
self
.
training_dataloader
.
dataset
),
100.0
*
batch_idx
/
len
(
self
.
training_dataloader
),
loss
.
item
())
)
train_loss
=
loss
.
item
()
self
.
training_loss_history
.
append
(
train_loss
)
# Validation
self
.
model
.
get_model
().
train
(
False
)
val_loss
=
0.0
count_val
=
0
correct
=
0
for
batch_idx
,
(
data
,
target
)
in
enumerate
(
self
.
validation_dataloader
):
count_val
=
count_val
+
1
data
,
target
=
(
data
.
to
(
self
.
model
.
get_device
()),
target
.
to
(
self
.
model
.
get_device
()))
output
=
self
.
model
.
forward
(
data
)
loss
=
self
.
criterion
(
output
,
target
)
val_loss
=
val_loss
+
loss
"""
pred = output.argmax(
dim=1, keepdim=True
) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()
"""
val_loss
=
val_loss
/
count_val
self
.
validation_loss_history
.
append
(
val_loss
)
"""
self.print_verbose(
'
\n
Epoch:
'
+ str(epoch_counter)
+
'
- Training Loss:
'
+ str(train_loss)
+
'
- Validation - Loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)
\n
'
.format(
val_loss,
correct,
len(self.validation_dataloader.dataset),
100.0 * correct / len(self.validation_dataloader.dataset),
)
)
self.print_verbose(
"
###############################
"
)
"""
value_loss
=
val_loss
epoch_counter
=
epoch_counter
+
1
return
self
.
training_loss_history
,
self
.
validation_loss_history
class
DeterministicAcceleration
(
Optimizer
,
ABC
):
class
DeterministicAcceleration
(
FixedPointIteration
):
def
__init__
(
self
,
training_dataloader
:
torch
.
utils
.
data
.
dataloader
.
DataLoader
,
validation_dataloader
:
torch
.
utils
.
data
.
dataloader
.
DataLoader
,
acceleration_type
:
str
=
'
anderson
'
,
learning_rate
:
float
=
1e-3
,
relaxation
:
float
=
0.1
,
weight_decay
:
float
=
0.0
,
wait_iterations
:
int
=
1
,
history_depth
:
int
=
15
,
frequency
:
int
=
1
,
reg_acc
:
float
=
0.0
,
store_each_nth
:
int
=
1
,
verbose
:
bool
=
False
):
"""
:type training_dataloader: torch.utils.data.dataloader.DataLoader
:type validation_dataloader: torch.utils.data.dataloader.DataLoader
:type validation_dataloader: torch.utils.data.dataloader.DataLoader
:param learning_rate: :type: float
:param weight_decay: :type: float
"""
...
...
@@ -238,99 +234,31 @@ class DeterministicAcceleration(Optimizer, ABC):
self
.
frequency
=
frequency
self
.
reg_acc
=
reg_acc
def
train
(
self
,
num_epochs
,
threshold
,
batch_size
):
assert
self
.
model_imported
# Initialization of acceleration module
self
.
acc_mod
=
AccelerationModule
(
self
.
acceleration_type
,
self
.
model
.
get_model
(),
self
.
history_depth
,
self
.
reg_acc
,
self
.
store_each_nth
)
self
.
acc_mod
.
store
(
self
.
model
.
get_model
())
assert
self
.
optimizer_specified
epoch_counter
=
0
value_loss
=
float
(
'
Inf
'
)
self
.
training_loss_history
=
[]
self
.
validation_loss_history
=
[]
while
epoch_counter
<
num_epochs
and
value_loss
>
threshold
:
self
.
model
.
get_model
().
train
(
True
)
# Training
for
batch_idx
,
(
data
,
target
)
in
enumerate
(
self
.
training_dataloader
):
data
,
target
=
(
data
.
to
(
self
.
model
.
get_device
()),
target
.
to
(
self
.
model
.
get_device
()))
self
.
optimizer
.
zero_grad
()
output
=
self
.
model
.
forward
(
data
)
loss
=
self
.
criterion
(
output
,
target
)
loss
.
backward
()
if
self
.
optimizer_str
==
'
lbfgs
'
:
def
closure
():
if
torch
.
is_grad_enabled
():
self
.
optimizer
.
zero_grad
()
output
=
self
.
model
.
forward
(
data
)
loss
=
self
.
criterion
(
output
,
target
)
if
loss
.
requires_grad
:
loss
.
backward
()
return
loss
self
.
optimizer
.
step
(
closure
)
else
:
self
.
optimizer
.
step
()
self
.
print_verbose
(
'
Train Epoch: {} [{}/{} ({:.0f}%)]
\t
Loss: {:.6f}
'
.
format
(
epoch_counter
,
batch_idx
*
len
(
data
),
len
(
self
.
training_dataloader
.
dataset
),
100.0
*
batch_idx
/
len
(
self
.
training_dataloader
),
loss
.
item
())
)
train_loss
=
loss
.
item
()
self
.
training_loss_history
.
append
(
train_loss
)
# Acceleration
self
.
acc_mod
.
store
(
self
.
model
.
get_model
())
if
(
epoch_counter
>
self
.
wait_iterations
)
and
(
epoch_counter
%
self
.
frequency
==
0
):
self
.
acc_mod
.
accelerate
(
self
.
model
.
get_model
(),
self
.
relaxation
)
# Validation
self
.
model
.
get_model
().
train
(
False
)
val_loss
=
0.0
count_val
=
0
correct
=
0
for
batch_idx
,
(
data
,
target
)
in
enumerate
(
self
.
validation_dataloader
):
count_val
=
count_val
+
1
data
,
target
=
(
data
.
to
(
self
.
model
.
get_device
()),
target
.
to
(
self
.
model
.
get_device
()),
)
output
=
self
.
model
.
forward
(
data
)
loss
=
self
.
criterion
(
output
,
target
)
val_loss
=
val_loss
+
loss
"""
pred = output.argmax(
dim=1, keepdim=True
) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()
"""
val_loss
=
val_loss
/
count_val
self
.
validation_loss_history
.
append
(
val_loss
)
"""
self.print_verbose(
'
\n
Epoch:
'
+ str(epoch_counter)
+
'
- Training Loss:
'
+ str(train_loss)
+
'
- Validation - Loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)
\n
'
.format(
val_loss,
correct,
len(self.validation_dataloader.dataset),
100.0 * correct / len(self.validation_dataloader.dataset),
)
)
self.print_verbose(
"
###############################
"
)
"""
value_loss
=
val_loss
epoch_counter
=
epoch_counter
+
1
return
self
.
training_loss_history
,
self
.
validation_loss_history
self
.
store_counter
=
0
self
.
call_counter
=
0
self
.
x_hist
=
deque
([],
maxlen
=
history_depth
)
def
accelerate
(
self
):
# update history of model parameters
self
.
store_counter
+=
1
if
self
.
store_counter
>=
self
.
store_each_nth
:
self
.
store_counter
=
0
# reset and continue
self
.
x_hist
.
append
(
parameters_to_vector
(
self
.
model
.
get_model
().
parameters
()).
detach
())
# perform acceleration
self
.
call_counter
+=
1
if
len
(
self
.
x_hist
)
>=
3
and
(
self
.
call_counter
>
self
.
wait_iterations
)
and
(
self
.
call_counter
%
self
.
frequency
==
0
):
# make matrix of updates from the history list
X
=
torch
.
stack
(
list
(
self
.
x_hist
),
dim
=
1
)
# compute acceleration
if
self
.
acceleration_type
==
'
anderson
'
:
x_acc
=
anderson
.
anderson
(
X
,
self
.
relaxation
)
elif
self
.
acceleration_type
==
'
rna
'
:
x_acc
,
c
=
rna
.
rna
(
X
,
self
.
reg_acc
)
# load acceleration back into model and update history
vector_to_parameters
(
x_acc
,
self
.
model
.
get_model
().
parameters
())
self
.
x_hist
.
pop
()
self
.
x_hist
.
append
(
x_acc
)
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment