Loading src/tgreft/nn/loss.py +15 −5 Original line number Diff line number Diff line Loading @@ -3,9 +3,15 @@ import torch import torch.nn as nn import numpy as np from multiprocessing import Pool from tgreft.utils.data.data_loader import param_to_rcurve def parallel_param_to_rcurve(params): with Pool() as p: return p.map(param_to_rcurve, params) class CompositeLoss(nn.Module): """Composite loss function that consider both the difference in params and the r-curves.""" Loading @@ -18,9 +24,9 @@ class CompositeLoss(nn.Module): # define the loss functions self.param_loss_huber = nn.SmoothL1Loss() self.rcurve_loss = nn.MSELoss() # weight for the param loss and the rcurve loss self.lambda_param = lambda_param self.lambda_curve = 1 - lambda_param # use learnable weights self.lambda_param = nn.Parameter(torch.tensor(lambda_param)) self.lambda_curve = nn.Parameter(torch.tensor(1.0 - lambda_param)) # number of parameters self.num_param = num_param Loading @@ -31,13 +37,13 @@ class CompositeLoss(nn.Module): pred_params = pred_params.view(-1, self.num_param) # 2. generate the r-curves pred_rcurve = torch.from_numpy( np.apply_along_axis(param_to_rcurve, 1, pred_params.detach().cpu().numpy()) np.array(parallel_param_to_rcurve(pred_params.detach().cpu().numpy())) ).float() # 3. reshape the reference params true_params = true_params.view(-1, self.num_param) # 4. generate the r-curves true_rcurve = torch.from_numpy( np.apply_along_axis(param_to_rcurve, 1, true_params.detach().cpu().numpy()) np.array(parallel_param_to_rcurve(pred_params.detach().cpu().numpy())) ).float() # calculate the loss # cast to log first Loading @@ -45,6 +51,10 @@ class CompositeLoss(nn.Module): true_rcurve = torch.log(true_rcurve) loss_rcurve = self.rcurve_loss(pred_rcurve, true_rcurve) # combine the losses # - normalize the parameters self.lambda_param.data = self.lambda_param.data.abs() / (self.lambda_param.data.abs() + self.lambda_curve.data.abs()) self.lambda_curve.data = 1.0 - self.lambda_param.data # - calculate the combined loss loss = self.lambda_param * loss_param + self.lambda_curve * loss_rcurve return loss Loading src/tgreft/train/generic.py +12 −2 Original line number Diff line number Diff line Loading @@ -2,6 +2,7 @@ """Helper functions for modeling training and evaluation.""" import numpy as np import torch from typing import Optional from tqdm.auto import tqdm from torch.utils.data import DataLoader from tgreft.utils.visualization import params_cmp_heatmap, plot_rcurve_cmp, plot_sld_cmp Loading Loading @@ -170,6 +171,7 @@ def get_optimizer( model: torch.nn.Module, learning_rate: float, weight_decay: float, loss_fn: Optional[torch.nn.Module] = None, ) -> torch.optim.Optimizer: """Get the optimizer. Loading @@ -183,21 +185,29 @@ def get_optimizer( The learning rate. weight_decay : float The weight decay. loss_fn : Optional[torch.nn.Module], optional The loss function, by default None Returns ------- torch.optim.Optimizer The optimizer. """ params = model.parameters() if loss_fn is not None: params = [ {"params": model.parameters()}, {"params": loss_fn.parameters()}, ] if name.lower() == "adam": return torch.optim.Adam( params=model.parameters(), params=params, lr=learning_rate, weight_decay=weight_decay, ) elif name.lower() == "sgd": return torch.optim.SGD( params=model.parameters(), params=params, lr=learning_rate, weight_decay=weight_decay, ) Loading Loading
src/tgreft/nn/loss.py +15 −5 Original line number Diff line number Diff line Loading @@ -3,9 +3,15 @@ import torch import torch.nn as nn import numpy as np from multiprocessing import Pool from tgreft.utils.data.data_loader import param_to_rcurve def parallel_param_to_rcurve(params): with Pool() as p: return p.map(param_to_rcurve, params) class CompositeLoss(nn.Module): """Composite loss function that consider both the difference in params and the r-curves.""" Loading @@ -18,9 +24,9 @@ class CompositeLoss(nn.Module): # define the loss functions self.param_loss_huber = nn.SmoothL1Loss() self.rcurve_loss = nn.MSELoss() # weight for the param loss and the rcurve loss self.lambda_param = lambda_param self.lambda_curve = 1 - lambda_param # use learnable weights self.lambda_param = nn.Parameter(torch.tensor(lambda_param)) self.lambda_curve = nn.Parameter(torch.tensor(1.0 - lambda_param)) # number of parameters self.num_param = num_param Loading @@ -31,13 +37,13 @@ class CompositeLoss(nn.Module): pred_params = pred_params.view(-1, self.num_param) # 2. generate the r-curves pred_rcurve = torch.from_numpy( np.apply_along_axis(param_to_rcurve, 1, pred_params.detach().cpu().numpy()) np.array(parallel_param_to_rcurve(pred_params.detach().cpu().numpy())) ).float() # 3. reshape the reference params true_params = true_params.view(-1, self.num_param) # 4. generate the r-curves true_rcurve = torch.from_numpy( np.apply_along_axis(param_to_rcurve, 1, true_params.detach().cpu().numpy()) np.array(parallel_param_to_rcurve(pred_params.detach().cpu().numpy())) ).float() # calculate the loss # cast to log first Loading @@ -45,6 +51,10 @@ class CompositeLoss(nn.Module): true_rcurve = torch.log(true_rcurve) loss_rcurve = self.rcurve_loss(pred_rcurve, true_rcurve) # combine the losses # - normalize the parameters self.lambda_param.data = self.lambda_param.data.abs() / (self.lambda_param.data.abs() + self.lambda_curve.data.abs()) self.lambda_curve.data = 1.0 - self.lambda_param.data # - calculate the combined loss loss = self.lambda_param * loss_param + self.lambda_curve * loss_rcurve return loss Loading
src/tgreft/train/generic.py +12 −2 Original line number Diff line number Diff line Loading @@ -2,6 +2,7 @@ """Helper functions for modeling training and evaluation.""" import numpy as np import torch from typing import Optional from tqdm.auto import tqdm from torch.utils.data import DataLoader from tgreft.utils.visualization import params_cmp_heatmap, plot_rcurve_cmp, plot_sld_cmp Loading Loading @@ -170,6 +171,7 @@ def get_optimizer( model: torch.nn.Module, learning_rate: float, weight_decay: float, loss_fn: Optional[torch.nn.Module] = None, ) -> torch.optim.Optimizer: """Get the optimizer. Loading @@ -183,21 +185,29 @@ def get_optimizer( The learning rate. weight_decay : float The weight decay. loss_fn : Optional[torch.nn.Module], optional The loss function, by default None Returns ------- torch.optim.Optimizer The optimizer. """ params = model.parameters() if loss_fn is not None: params = [ {"params": model.parameters()}, {"params": loss_fn.parameters()}, ] if name.lower() == "adam": return torch.optim.Adam( params=model.parameters(), params=params, lr=learning_rate, weight_decay=weight_decay, ) elif name.lower() == "sgd": return torch.optim.SGD( params=model.parameters(), params=params, lr=learning_rate, weight_decay=weight_decay, ) Loading