Skip to content
Snippets Groups Projects

Torch anderson

Merged Reshniak, Viktor requested to merge torch_Anderson into master
1 file
+ 135
207
Compare changes
  • Side-by-side
  • Inline
+ 135
207
@@ -5,17 +5,23 @@ from torch import Tensor
from torch import autograd
from abc import ABCMeta, abstractmethod, ABC
import math
from AccelerationModule import AccelerationModule
from collections import deque
from torch.nn.utils import parameters_to_vector, vector_to_parameters
# Abstract class that provides basic guidelines to implement an acceleration
class Optimizer(object, metaclass=ABCMeta):
import sys
sys.path.append("../utils")
import rna_acceleration as rna
import anderson_acceleration as anderson
class FixedPointIteration(object):
def __init__(self, training_dataloader: torch.utils.data.dataloader.DataLoader, validation_dataloader: torch.utils.data.dataloader.DataLoader,
learning_rate: float, weight_decay: float = 0.0, verbose: bool = False):
"""
:type training_dataloader: torch.utils.data.dataloader.DataLoader
:type validation_dataloader: torch.utils.data.dataloader.DataLoader
:type validation_dataloader: torch.utils.data.dataloader.DataLoader
:type learning_rate: float
:type weight_decay: float
"""
@@ -60,11 +66,99 @@ class Optimizer(object, metaclass=ABCMeta):
assert self.model_imported
return self.model
@abstractmethod
def train(self, input_data: torch.Tensor, target: torch.Tensor, num_iterations: int, threshold: float,
batch_size: int):
def accelerate(self):
pass
def train(self, num_epochs, threshold, batch_size):
assert self.model_imported
assert self.optimizer_specified
epoch_counter = 0
value_loss = float('Inf')
self.training_loss_history = []
self.validation_loss_history = []
while epoch_counter < num_epochs and value_loss > threshold:
self.model.get_model().train(True)
# Training
for batch_idx, (data, target) in enumerate(self.training_dataloader):
self.accelerate()
data, target = (data.to(self.model.get_device()),target.to(self.model.get_device()))
self.optimizer.zero_grad()
output = self.model.forward(data)
loss = self.criterion(output, target)
loss.backward()
if self.optimizer_str == 'lbfgs':
def closure():
if torch.is_grad_enabled():
self.optimizer.zero_grad()
output = self.model.forward(data)
loss = self.criterion(output, target)
if loss.requires_grad:
loss.backward()
return loss
self.optimizer.step(closure)
else:
self.optimizer.step()
self.print_verbose(
'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch_counter, batch_idx * len(data), len(self.training_dataloader.dataset), 100.0 * batch_idx / len(self.training_dataloader),loss.item())
)
train_loss = loss.item()
self.training_loss_history.append(train_loss)
# Validation
with torch.no_grad():
self.model.get_model().train(False)
val_loss = 0.0
count_val = 0
correct = 0
for batch_idx, (data, target) in enumerate(self.validation_dataloader):
count_val = count_val + 1
data, target = (
data.to(self.model.get_device()),
target.to(self.model.get_device()),
)
output = self.model.forward(data)
loss = self.criterion(output, target)
val_loss = val_loss + loss
"""
pred = output.argmax(
dim=1, keepdim=True
) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()
"""
val_loss = val_loss / count_val
self.validation_loss_history.append(val_loss)
"""
self.print_verbose(
'\n Epoch: '
+ str(epoch_counter)
+ ' - Training Loss: '
+ str(train_loss)
+ ' - Validation - Loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
val_loss,
correct,
len(self.validation_dataloader.dataset),
100.0 * correct / len(self.validation_dataloader.dataset),
)
)
self.print_verbose("###############################")
"""
value_loss = val_loss
epoch_counter = epoch_counter + 1
return self.training_loss_history, self.validation_loss_history
def set_loss_function(self, criterion_string):
if criterion_string.lower() == 'mse':
@@ -103,8 +197,8 @@ class Optimizer(object, metaclass=ABCMeta):
self.optimizer_str = optimizer_string.lower()
self.optimizer_specified = True
elif optimizer_string.lower() == 'lbfgs':
self.optimizer = torch.optim.LBFGS(self.model.get_model().parameters(), lr=self.lr, history_size=10,
max_iter=20, line_search_fn=True, batch_mode=True)
self.optimizer = torch.optim.LBFGS(self.model.get_model().parameters(), lr=self.lr, history_size=10,
max_iter=20, line_search_fn=True, batch_mode=True)
self.optimizer_str = optimizer_string.lower()
self.optimizer_specified = True
else:
@@ -119,113 +213,15 @@ class Optimizer(object, metaclass=ABCMeta):
print(*args, **kwargs)
class FixedPointIteration(Optimizer, ABC):
def __init__(self,training_dataloader: torch.utils.data.dataloader.DataLoader, validation_dataloader: torch.utils.data.dataloader.DataLoader,
learning_rate: float, weight_decay: float = 0.0, verbose: bool = False):
"""
:type training_dataloader: torch.utils.data.dataloader.DataLoader
:type validation_dataloader: torch.utils.data.dataloader.DataLoader
:param learning_rate: :type: float
:param weight_decay: :type: float
"""
super(FixedPointIteration, self).__init__(training_dataloader, validation_dataloader,learning_rate, weight_decay, verbose)
def train(self, num_epochs, threshold, batch_size):
assert self.optimizer_specified
epoch_counter = 0
value_loss = float('Inf')
self.training_loss_history = []
self.validation_loss_history = []
while epoch_counter < num_epochs and value_loss > threshold:
self.model.get_model().train(True)
train_loss = 0.0
# Training
for batch_idx, (data, target) in enumerate(self.training_dataloader):
data, target = (data.to(self.model.get_device()),target.to(self.model.get_device()))
self.optimizer.zero_grad()
output = self.model.forward(data)
loss = self.criterion(output, target)
loss.backward()
if self.optimizer_str == 'lbfgs':
def closure():
if torch.is_grad_enabled():
self.optimizer.zero_grad()
output = self.model.forward(data)
loss = self.criterion(output, target)
if loss.requires_grad:
loss.backward()
return loss
self.optimizer.step(closure)
else:
self.optimizer.step()
self.print_verbose(
'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch_counter, batch_idx * len(data),
len(self.training_dataloader.dataset),100.0 * batch_idx / len(self.training_dataloader),loss.item())
)
train_loss = loss.item()
self.training_loss_history.append(train_loss)
# Validation
self.model.get_model().train(False)
val_loss = 0.0
count_val = 0
correct = 0
for batch_idx, (data, target) in enumerate(self.validation_dataloader):
count_val = count_val + 1
data, target = (data.to(self.model.get_device()),target.to(self.model.get_device()))
output = self.model.forward(data)
loss = self.criterion(output, target)
val_loss = val_loss + loss
"""
pred = output.argmax(
dim=1, keepdim=True
) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()
"""
val_loss = val_loss / count_val
self.validation_loss_history.append(val_loss)
"""
self.print_verbose(
'\n Epoch: '
+ str(epoch_counter)
+ ' - Training Loss: '
+ str(train_loss)
+ ' - Validation - Loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
val_loss,
correct,
len(self.validation_dataloader.dataset),
100.0 * correct / len(self.validation_dataloader.dataset),
)
)
self.print_verbose("###############################")
"""
value_loss = val_loss
epoch_counter = epoch_counter + 1
return self.training_loss_history, self.validation_loss_history
class DeterministicAcceleration(Optimizer, ABC):
class DeterministicAcceleration(FixedPointIteration):
def __init__(self,training_dataloader: torch.utils.data.dataloader.DataLoader,validation_dataloader: torch.utils.data.dataloader.DataLoader,
acceleration_type: str = 'anderson',learning_rate: float = 1e-3,relaxation: float = 0.1,weight_decay: float = 0.0,
wait_iterations: int = 1, history_depth: int = 15, frequency: int = 1, reg_acc: float = 0.0, store_each_nth: int = 1, verbose: bool = False):
"""
:type training_dataloader: torch.utils.data.dataloader.DataLoader
:type validation_dataloader: torch.utils.data.dataloader.DataLoader
:type validation_dataloader: torch.utils.data.dataloader.DataLoader
:param learning_rate: :type: float
:param weight_decay: :type: float
"""
@@ -238,99 +234,31 @@ class DeterministicAcceleration(Optimizer, ABC):
self.frequency = frequency
self.reg_acc = reg_acc
def train(self, num_epochs, threshold, batch_size):
assert self.model_imported
# Initialization of acceleration module
self.acc_mod = AccelerationModule(self.acceleration_type,self.model.get_model(),self.history_depth,self.reg_acc,self.store_each_nth)
self.acc_mod.store(self.model.get_model())
assert self.optimizer_specified
epoch_counter = 0
value_loss = float('Inf')
self.training_loss_history = []
self.validation_loss_history = []
while epoch_counter < num_epochs and value_loss > threshold:
self.model.get_model().train(True)
# Training
for batch_idx, (data, target) in enumerate(self.training_dataloader):
data, target = (data.to(self.model.get_device()),target.to(self.model.get_device()))
self.optimizer.zero_grad()
output = self.model.forward(data)
loss = self.criterion(output, target)
loss.backward()
if self.optimizer_str == 'lbfgs':
def closure():
if torch.is_grad_enabled():
self.optimizer.zero_grad()
output = self.model.forward(data)
loss = self.criterion(output, target)
if loss.requires_grad:
loss.backward()
return loss
self.optimizer.step(closure)
else:
self.optimizer.step()
self.print_verbose(
'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch_counter, batch_idx * len(data), len(self.training_dataloader.dataset), 100.0 * batch_idx / len(self.training_dataloader),loss.item())
)
train_loss = loss.item()
self.training_loss_history.append(train_loss)
# Acceleration
self.acc_mod.store(self.model.get_model())
if (epoch_counter > self.wait_iterations) and (epoch_counter % self.frequency == 0):
self.acc_mod.accelerate(self.model.get_model(), self.relaxation)
# Validation
self.model.get_model().train(False)
val_loss = 0.0
count_val = 0
correct = 0
for batch_idx, (data, target) in enumerate(self.validation_dataloader):
count_val = count_val + 1
data, target = (
data.to(self.model.get_device()),
target.to(self.model.get_device()),
)
output = self.model.forward(data)
loss = self.criterion(output, target)
val_loss = val_loss + loss
"""
pred = output.argmax(
dim=1, keepdim=True
) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()
"""
val_loss = val_loss / count_val
self.validation_loss_history.append(val_loss)
"""
self.print_verbose(
'\n Epoch: '
+ str(epoch_counter)
+ ' - Training Loss: '
+ str(train_loss)
+ ' - Validation - Loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
val_loss,
correct,
len(self.validation_dataloader.dataset),
100.0 * correct / len(self.validation_dataloader.dataset),
)
)
self.print_verbose("###############################")
"""
value_loss = val_loss
epoch_counter = epoch_counter + 1
return self.training_loss_history, self.validation_loss_history
self.store_counter = 0
self.call_counter = 0
self.x_hist = deque([], maxlen=history_depth)
def accelerate(self):
# update history of model parameters
self.store_counter += 1
if self.store_counter >= self.store_each_nth:
self.store_counter = 0 # reset and continue
self.x_hist.append(parameters_to_vector(self.model.get_model().parameters()).detach())
# perform acceleration
self.call_counter += 1
if len(self.x_hist) >= 3 and (self.call_counter > self.wait_iterations) and (self.call_counter % self.frequency == 0):
# make matrix of updates from the history list
X = torch.stack(list(self.x_hist), dim=1)
# compute acceleration
if self.acceleration_type == 'anderson':
x_acc = anderson.anderson(X, self.relaxation)
elif self.acceleration_type == 'rna':
x_acc, c = rna.rna(X, self.reg_acc)
# load acceleration back into model and update history
vector_to_parameters(x_acc, self.model.get_model().parameters())
self.x_hist.pop()
self.x_hist.append(x_acc)
Loading