Skip to content
Snippets Groups Projects

Subsampler and qr

Merged Lupo Pasini, Massimiliano requested to merge subsampler_and_QR into master
1 file
+ 13
8
Compare changes
  • Side-by-side
  • Inline
+ 175
40
@@ -10,18 +10,21 @@ from AccelerationModule import AccelerationModule
# Abstract class that provides basic guidelines to implement an acceleration
class Optimizer(object, metaclass=ABCMeta):
def __init__(self, data_loader: torch.utils.data.dataloader.DataLoader, learning_rate: float,
weight_decay: float = 0.0, verbose: bool = False):
def __init__(self, training_dataloader: torch.utils.data.dataloader.DataLoader, validation_dataloader: torch.utils.data.dataloader.DataLoader,
learning_rate: float, weight_decay: float = 0.0, verbose: bool = False):
"""
:type data_loader: torch.utils.data.dataloader.DataLoader
:type training_dataloader: torch.utils.data.dataloader.DataLoader
:type validation_dataloader: torch.utils.data.dataloader.DataLoader
:type learning_rate: float
:type weight_decay: float
"""
self.iteration_counter = 0
assert isinstance(data_loader, torch.utils.data.dataloader.DataLoader)
self.data_loader = data_loader
assert isinstance(training_dataloader, torch.utils.data.dataloader.DataLoader)
assert isinstance(validation_dataloader, torch.utils.data.dataloader.DataLoader)
self.training_dataloader = training_dataloader
self.validation_dataloader = validation_dataloader
assert isinstance(learning_rate, float)
self.lr = learning_rate
@@ -33,8 +36,10 @@ class Optimizer(object, metaclass=ABCMeta):
self.model = None
self.training_loss_history = []
self.validation_loss_history = []
self.criterion_specified = False
self.criterion = None
self.optimizer_str = None
self.optimizer_specified = False
self.optimizer = None
self.loss_name = None
@@ -65,8 +70,8 @@ class Optimizer(object, metaclass=ABCMeta):
if criterion_string.lower() == 'mse':
self.criterion = torch.nn.MSELoss()
self.criterion_specified = True
elif criterion_string.lower() == 'ce':
self.criterion = torch.nn.CrossEntropyLoss()
elif criterion_string.lower() == 'nll':
self.criterion = torch.nn.functional.nll_loss
self.criterion_specified = True
else:
raise ValueError("Loss function is not recognized: currently only MSE and CE are allowed")
@@ -85,14 +90,22 @@ class Optimizer(object, metaclass=ABCMeta):
if optimizer_string.lower() == 'sgd':
self.optimizer = torch.optim.SGD(self.model.get_model().parameters(), lr=self.lr,
weight_decay=self.weight_decay)
self.optimizer_str = optimizer_string.lower()
self.optimizer_specified = True
elif optimizer_string.lower() == 'rmsprop':
self.optimizer = torch.optim.RMSprop(self.model.get_model().parameters(), lr=self.lr, alpha=0.99,
weight_decay=self.weight_decay)
self.optimizer_str = optimizer_string.lower()
self.optimizer_specified = True
elif optimizer_string.lower() == 'adam':
self.optimizer = torch.optim.Adam(self.model.get_model().parameters(), lr=self.lr, betas=(0.9, 0.999),
weight_decay=self.weight_decay)
self.optimizer_str = optimizer_string.lower()
self.optimizer_specified = True
elif optimizer_string.lower() == 'lbfgs':
self.optimizer = torch.optim.LBFGS(self.model.get_model().parameters(), lr=self.lr, history_size=10,
max_iter=20, line_search_fn=True, batch_mode=True)
self.optimizer_str = optimizer_string.lower()
self.optimizer_specified = True
else:
raise ValueError("Optimizer is not recognized: currently only SGD, RMSProp and Adam are allowed")
@@ -107,54 +120,116 @@ class Optimizer(object, metaclass=ABCMeta):
class FixedPointIteration(Optimizer, ABC):
def __init__(self, data_loader: torch.utils.data.dataloader.DataLoader, learning_rate: float,
weight_decay: float = 0.0, verbose: bool = False):
def __init__(self,training_dataloader: torch.utils.data.dataloader.DataLoader, validation_dataloader: torch.utils.data.dataloader.DataLoader,
learning_rate: float, weight_decay: float = 0.0, verbose: bool = False):
"""
:type training_dataloader: torch.utils.data.dataloader.DataLoader
:type validation_dataloader: torch.utils.data.dataloader.DataLoader
:param learning_rate: :type: float
:param weight_decay: :type: float
"""
super(FixedPointIteration, self).__init__(data_loader, learning_rate, weight_decay, verbose)
super(FixedPointIteration, self).__init__(training_dataloader, validation_dataloader,learning_rate, weight_decay, verbose)
def train(self, num_epochs, threshold, batch_size):
self.model.get_model().train(True) # True indicates actual training
assert self.optimizer_specified
epoch_counter = 0
value_loss = float('Inf')
self.training_loss_history = []
self.validation_loss_history = []
while epoch_counter < num_epochs and value_loss > threshold:
for batch_idx, (data, target) in enumerate(self.data_loader):
data, target = data.to(self.model.get_device()), target.to(self.model.get_device())
self.model.get_model().train(True)
train_loss = 0.0
# Training
for batch_idx, (data, target) in enumerate(self.training_dataloader):
data, target = (data.to(self.model.get_device()),target.to(self.model.get_device()))
self.optimizer.zero_grad()
output = self.model.forward(data)
# print("Input_data: "+str(data.shape)+' - Output: '+str(output.shape)+' - Target: '+str(target.shape)+' - '+' - Value:'+str(output))
loss = self.criterion(output, target)
loss.backward()
self.optimizer.step()
if self.optimizer_str == 'lbfgs':
def closure():
if torch.is_grad_enabled():
self.optimizer.zero_grad()
output = self.model.forward(data)
loss = self.criterion(output, target)
if loss.requires_grad:
loss.backward()
return loss
self.optimizer.step(closure)
else:
self.optimizer.step()
self.print_verbose(
'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch_counter, batch_idx * len(data),
len(self.training_dataloader.dataset),100.0 * batch_idx / len(self.training_dataloader),loss.item())
)
train_loss = loss.item()
self.training_loss_history.append(train_loss)
# Validation
self.model.get_model().train(False)
val_loss = 0.0
count_val = 0
correct = 0
for batch_idx, (data, target) in enumerate(self.validation_dataloader):
count_val = count_val + 1
data, target = (data.to(self.model.get_device()),target.to(self.model.get_device()))
output = self.model.forward(data)
loss = self.criterion(output, target)
val_loss = val_loss + loss
"""
pred = output.argmax(
dim=1, keepdim=True
) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()
"""
val_loss = val_loss / count_val
self.validation_loss_history.append(val_loss)
"""
self.print_verbose(
'\n Epoch: '
+ str(epoch_counter)
+ ' - Training Loss: '
+ str(train_loss)
+ ' - Validation - Loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
val_loss,
correct,
len(self.validation_dataloader.dataset),
100.0 * correct / len(self.validation_dataloader.dataset),
)
)
self.print_verbose("###############################")
self.print_verbose('Epoch: ' + str(epoch_counter) + ' - Loss function: ' + str(loss.item()))
"""
value_loss = loss.item()
value_loss = val_loss
epoch_counter = epoch_counter + 1
self.training_loss_history.append(loss)
return self.training_loss_history
return self.training_loss_history, self.validation_loss_history
class DeterministicAcceleration(Optimizer, ABC):
def __init__(self, data_loader: torch.utils.data.dataloader.DataLoader, acceleration_type: str, learning_rate: float, relaxation:float,
weight_decay: float = 0.0, wait_iterations: int = 1, history_depth: int = 15, frequency: int = 1,
reg_acc: float = 0.0, store_each_nth: int = 1, verbose: bool = False):
def __init__(self,training_dataloader: torch.utils.data.dataloader.DataLoader,validation_dataloader: torch.utils.data.dataloader.DataLoader,
acceleration_type: str = 'anderson',learning_rate: float = 1e-3,relaxation: float = 0.1,weight_decay: float = 0.0,
wait_iterations: int = 1, history_depth: int = 15, frequency: int = 1, reg_acc: float = 0.0, store_each_nth: int = 1, verbose: bool = False):
"""
:type training_dataloader: torch.utils.data.dataloader.DataLoader
:type validation_dataloader: torch.utils.data.dataloader.DataLoader
:param learning_rate: :type: float
:param weight_decay: :type: float
"""
super(DeterministicAcceleration, self).__init__(data_loader, learning_rate, weight_decay, verbose)
super(DeterministicAcceleration, self).__init__(training_dataloader,validation_dataloader,learning_rate,weight_decay,verbose)
self.acceleration_type = acceleration_type.lower()
self.wait_iterations = wait_iterations
self.relaxation = relaxation
@@ -168,34 +243,94 @@ class DeterministicAcceleration(Optimizer, ABC):
assert self.model_imported
# Initialization of acceleration module
self.acc_mod = AccelerationModule(self.acceleration_type, self.model.get_model(), self.history_depth, self.reg_acc, self.store_each_nth)
self.acc_mod = AccelerationModule(self.acceleration_type,self.model.get_model(),self.history_depth,self.reg_acc,self.store_each_nth)
self.acc_mod.store(self.model.get_model())
self.model.get_model().train(True)
assert self.optimizer_specified
epoch_counter = 0
value_loss = float('Inf')
self.training_loss_history = []
self.validation_loss_history = []
while epoch_counter < num_epochs and value_loss > threshold:
for batch_idx, (data, target) in enumerate(self.data_loader):
data, target = data.to(self.model.get_device()), target.to(self.model.get_device())
self.model.get_model().train(True)
# Training
for batch_idx, (data, target) in enumerate(self.training_dataloader):
data, target = (data.to(self.model.get_device()),target.to(self.model.get_device()))
self.optimizer.zero_grad()
output = self.model.forward(data)
loss = self.criterion(output, target)
loss.backward()
self.optimizer.step()
self.print_verbose("###############################")
self.print_verbose('Epoch: ' + str(epoch_counter) + ' - Loss function: ' + str(loss.item()))
if self.optimizer_str == 'lbfgs':
def closure():
if torch.is_grad_enabled():
self.optimizer.zero_grad()
output = self.model.forward(data)
loss = self.criterion(output, target)
if loss.requires_grad:
loss.backward()
return loss
self.optimizer.step(closure)
else:
self.optimizer.step()
self.print_verbose(
'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch_counter, batch_idx * len(data), len(self.training_dataloader.dataset), 100.0 * batch_idx / len(self.training_dataloader),loss.item())
)
train_loss = loss.item()
self.training_loss_history.append(train_loss)
# Acceleration
self.acc_mod.store(self.model.get_model())
if (epoch_counter > self.wait_iterations) and (epoch_counter % self.frequency == 0):
self.acc_mod.accelerate(self.model.get_model(), self.relaxation)
value_loss = loss.item()
# Validation
self.model.get_model().train(False)
val_loss = 0.0
count_val = 0
correct = 0
for batch_idx, (data, target) in enumerate(self.validation_dataloader):
count_val = count_val + 1
data, target = (
data.to(self.model.get_device()),
target.to(self.model.get_device()),
)
output = self.model.forward(data)
loss = self.criterion(output, target)
val_loss = val_loss + loss
"""
pred = output.argmax(
dim=1, keepdim=True
) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()
"""
val_loss = val_loss / count_val
self.validation_loss_history.append(val_loss)
"""
self.print_verbose(
'\n Epoch: '
+ str(epoch_counter)
+ ' - Training Loss: '
+ str(train_loss)
+ ' - Validation - Loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
val_loss,
correct,
len(self.validation_dataloader.dataset),
100.0 * correct / len(self.validation_dataloader.dataset),
)
)
self.print_verbose("###############################")
"""
value_loss = val_loss
epoch_counter = epoch_counter + 1
self.training_loss_history.append(loss)
return self.training_loss_history
return self.training_loss_history, self.validation_loss_history
Loading