Skip to content
Snippets Groups Projects
Commit 49d5cedf authored by Massimiliano Lupo Pasini's avatar Massimiliano Lupo Pasini
Browse files

nll loss function used in place of cross entropy for classification problems

parent 21a88c0b
No related branches found
No related tags found
1 merge request!11Subsampler and qr
This commit is part of merge request !11. Comments created here will be created in the context of that merge request.
......@@ -10,18 +10,21 @@ from AccelerationModule import AccelerationModule
# Abstract class that provides basic guidelines to implement an acceleration
class Optimizer(object, metaclass=ABCMeta):
def __init__(self, data_loader: torch.utils.data.dataloader.DataLoader, learning_rate: float,
weight_decay: float = 0.0, verbose: bool = False):
def __init__(self, training_dataloader: torch.utils.data.dataloader.DataLoader, validation_dataloader: torch.utils.data.dataloader.DataLoader,
learning_rate: float, weight_decay: float = 0.0, verbose: bool = False):
"""
:type data_loader: torch.utils.data.dataloader.DataLoader
:type training_dataloader: torch.utils.data.dataloader.DataLoader
:type validation_dataloader: torch.utils.data.dataloader.DataLoader
:type learning_rate: float
:type weight_decay: float
"""
self.iteration_counter = 0
assert isinstance(data_loader, torch.utils.data.dataloader.DataLoader)
self.data_loader = data_loader
assert isinstance(training_dataloader, torch.utils.data.dataloader.DataLoader)
assert isinstance(validation_dataloader, torch.utils.data.dataloader.DataLoader)
self.training_dataloader = training_dataloader
self.validation_dataloader = validation_dataloader
assert isinstance(learning_rate, float)
self.lr = learning_rate
......@@ -33,8 +36,10 @@ class Optimizer(object, metaclass=ABCMeta):
self.model = None
self.training_loss_history = []
self.validation_loss_history = []
self.criterion_specified = False
self.criterion = None
self.optimizer_str = None
self.optimizer_specified = False
self.optimizer = None
self.loss_name = None
......@@ -65,8 +70,8 @@ class Optimizer(object, metaclass=ABCMeta):
if criterion_string.lower() == 'mse':
self.criterion = torch.nn.MSELoss()
self.criterion_specified = True
elif criterion_string.lower() == 'ce':
self.criterion = torch.nn.CrossEntropyLoss()
elif criterion_string.lower() == 'nll':
self.criterion = torch.nn.functional.nll_loss
self.criterion_specified = True
else:
raise ValueError("Loss function is not recognized: currently only MSE and CE are allowed")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment