Commit c3420ba4 authored by Zhang, Chen's avatar Zhang, Chen
Browse files

allow lambda to be updated

parent 330f5cc9
Loading
Loading
Loading
Loading
+15 −5
Original line number Diff line number Diff line
@@ -3,9 +3,15 @@
import torch
import torch.nn as nn
import numpy as np
from multiprocessing import Pool
from tgreft.utils.data.data_loader import param_to_rcurve


def parallel_param_to_rcurve(params):
    with Pool() as p:
        return p.map(param_to_rcurve, params)


class CompositeLoss(nn.Module):
    """Composite loss function that consider both the difference in params and the r-curves."""

@@ -18,9 +24,9 @@ class CompositeLoss(nn.Module):
        # define the loss functions
        self.param_loss_huber = nn.SmoothL1Loss()
        self.rcurve_loss = nn.MSELoss()
        # weight for the param loss and the rcurve loss
        self.lambda_param = lambda_param
        self.lambda_curve = 1 - lambda_param
        # use learnable weights
        self.lambda_param = nn.Parameter(torch.tensor(lambda_param))
        self.lambda_curve = nn.Parameter(torch.tensor(1.0 - lambda_param))
        # number of parameters
        self.num_param = num_param

@@ -31,13 +37,13 @@ class CompositeLoss(nn.Module):
        pred_params = pred_params.view(-1, self.num_param)
        # 2. generate the r-curves
        pred_rcurve = torch.from_numpy(
            np.apply_along_axis(param_to_rcurve, 1, pred_params.detach().cpu().numpy())
            np.array(parallel_param_to_rcurve(pred_params.detach().cpu().numpy()))
        ).float()
        # 3. reshape the reference params
        true_params = true_params.view(-1, self.num_param)
        # 4. generate the r-curves
        true_rcurve = torch.from_numpy(
            np.apply_along_axis(param_to_rcurve, 1, true_params.detach().cpu().numpy())
            np.array(parallel_param_to_rcurve(pred_params.detach().cpu().numpy()))
        ).float()
        # calculate the loss
        # cast to log first
@@ -45,6 +51,10 @@ class CompositeLoss(nn.Module):
        true_rcurve = torch.log(true_rcurve)
        loss_rcurve = self.rcurve_loss(pred_rcurve, true_rcurve)
        # combine the losses
        # - normalize the parameters
        self.lambda_param.data = self.lambda_param.data.abs() / (self.lambda_param.data.abs() + self.lambda_curve.data.abs())
        self.lambda_curve.data = 1.0 - self.lambda_param.data
        # - calculate the combined loss
        loss = self.lambda_param * loss_param + self.lambda_curve * loss_rcurve
        return loss

+12 −2
Original line number Diff line number Diff line
@@ -2,6 +2,7 @@
"""Helper functions for modeling training and evaluation."""
import numpy as np
import torch
from typing import Optional
from tqdm.auto import tqdm
from torch.utils.data import DataLoader
from tgreft.utils.visualization import params_cmp_heatmap, plot_rcurve_cmp, plot_sld_cmp
@@ -170,6 +171,7 @@ def get_optimizer(
    model: torch.nn.Module,
    learning_rate: float,
    weight_decay: float,
    loss_fn: Optional[torch.nn.Module] = None,
) -> torch.optim.Optimizer:
    """Get the optimizer.

@@ -183,21 +185,29 @@ def get_optimizer(
        The learning rate.
    weight_decay : float
        The weight decay.
    loss_fn : Optional[torch.nn.Module], optional
        The loss function, by default None

    Returns
    -------
    torch.optim.Optimizer
        The optimizer.
    """
    params = model.parameters()
    if loss_fn is not None:
        params = [
            {"params": model.parameters()},
            {"params": loss_fn.parameters()},
        ]
    if name.lower() == "adam":
        return torch.optim.Adam(
            params=model.parameters(),
            params=params,
            lr=learning_rate,
            weight_decay=weight_decay,
        )
    elif name.lower() == "sgd":
        return torch.optim.SGD(
            params=model.parameters(),
            params=params,
            lr=learning_rate,
            weight_decay=weight_decay,
        )