Skip to content
Snippets Groups Projects

Subsampler and qr

Merged Lupo Pasini, Massimiliano requested to merge subsampler_and_QR into master
1 file
+ 13
8
Compare changes
  • Side-by-side
  • Inline
+ 13
8
@@ -10,18 +10,21 @@ from AccelerationModule import AccelerationModule
# Abstract class that provides basic guidelines to implement an acceleration
class Optimizer(object, metaclass=ABCMeta):
def __init__(self, data_loader: torch.utils.data.dataloader.DataLoader, learning_rate: float,
weight_decay: float = 0.0, verbose: bool = False):
def __init__(self, training_dataloader: torch.utils.data.dataloader.DataLoader, validation_dataloader: torch.utils.data.dataloader.DataLoader,
learning_rate: float, weight_decay: float = 0.0, verbose: bool = False):
"""
:type data_loader: torch.utils.data.dataloader.DataLoader
:type training_dataloader: torch.utils.data.dataloader.DataLoader
:type validation_dataloader: torch.utils.data.dataloader.DataLoader
:type learning_rate: float
:type weight_decay: float
"""
self.iteration_counter = 0
assert isinstance(data_loader, torch.utils.data.dataloader.DataLoader)
self.data_loader = data_loader
assert isinstance(training_dataloader, torch.utils.data.dataloader.DataLoader)
assert isinstance(validation_dataloader, torch.utils.data.dataloader.DataLoader)
self.training_dataloader = training_dataloader
self.validation_dataloader = validation_dataloader
assert isinstance(learning_rate, float)
self.lr = learning_rate
@@ -33,8 +36,10 @@ class Optimizer(object, metaclass=ABCMeta):
self.model = None
self.training_loss_history = []
self.validation_loss_history = []
self.criterion_specified = False
self.criterion = None
self.optimizer_str = None
self.optimizer_specified = False
self.optimizer = None
self.loss_name = None
@@ -65,8 +70,8 @@ class Optimizer(object, metaclass=ABCMeta):
if criterion_string.lower() == 'mse':
self.criterion = torch.nn.MSELoss()
self.criterion_specified = True
elif criterion_string.lower() == 'ce':
self.criterion = torch.nn.CrossEntropyLoss()
elif criterion_string.lower() == 'nll':
self.criterion = torch.nn.functional.nll_loss
self.criterion_specified = True
else:
raise ValueError("Loss function is not recognized: currently only MSE and CE are allowed")
Loading