Commit 3a518b5b authored by Zhang, Chen's avatar Zhang, Chen
Browse files

add composite loss fn

parent 7618b47a
Loading
Loading
Loading
Loading
+29 −0
Original line number Diff line number Diff line
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
  rev: v4.5.0
  hooks:
    - id: check-added-large-files
      args: [--maxkb=8192]
    - id: check-merge-conflict
    - id: check-yaml
      exclude: |
        (?x)^(
        ^conda.recipe/meta.yaml
        )$

    - id: end-of-file-fixer
      exclude: "tests/cis_tests/.*"
    - id: trailing-whitespace
      exclude: "tests/cis_tests/.*"
- repo: https://github.com/astral-sh/ruff-pre-commit
  rev: v0.1.7
  hooks:
    - id: ruff
      args: [--fix, --exit-non-zero-on-fix]
      exclude: "tests/cis_tests/.*"
- repo: https://github.com/psf/black
  rev: 23.11.0
  hooks:
    - id: black
      args: ['--line-length=120']
      exclude: "tests/cis_tests/.*"
+3 −3
Original line number Diff line number Diff line
@@ -15,13 +15,13 @@ if __name__ == "__main__":
    train_params = {
        "cuda_id": 0,
        "n_epochs": 50,  # seems like 50 is the point where training and validation loss diverge
        "n_training": 1_500_000,
        "n_training": 50_000,
        "error": 0.07,
        "batch_size": 250,
        "learning_rate": 0.0057929655918116715,
        "weight_decay": 7.198921885462489e-07,
        "optimizer": "SGD",
        "loss": "huber",
        "loss": "composite",
        "cache_dir": "data",
        "experiment_name": "Train_REFL_GPT",
        "run_name": "gpt_d1024_h8_l4_newlog",

src/tgreft/nn/loss.py

0 → 100644
+50 −0
Original line number Diff line number Diff line
#!/usr/bin/env python
"""Extended loss functions for tgreft."""
import torch
import torch.nn as nn
import numpy as np
from tgreft.utils.data.data_loader import param_to_rcurve


class CompositeLoss(nn.Module):
    """Composite loss function that consider both the difference in params and the r-curves."""

    def __init__(
        self,
        lambda_param: float = 0.5,
        num_param: int = 17,
    ):
        super().__init__()
        # define the loss functions
        self.param_loss_huber = nn.SmoothL1Loss()
        self.rcurve_loss = nn.MSELoss()
        # weight for the param loss and the rcurve loss
        self.lambda_param = lambda_param
        self.lambda_curve = 1 - lambda_param
        # number of parameters
        self.num_param = num_param

    def forward(self, pred_params, true_params):
        loss_param = self.param_loss_huber(pred_params, true_params)
        # generate the r-curves from the predicted params
        # 1. reshape the predicted params
        pred_params = pred_params.view(-1, self.num_param)
        # 2. generate the r-curves
        pred_rcurve = torch.from_numpy(
            np.apply_along_axis(param_to_rcurve, 1, pred_params.detach().cpu().numpy())
        ).float()
        # 3. reshape the reference params
        true_params = true_params.view(-1, self.num_param)
        # 4. generate the r-curves
        true_rcurve = torch.from_numpy(
            np.apply_along_axis(param_to_rcurve, 1, true_params.detach().cpu().numpy())
        ).float()
        # calculate the loss
        loss_rcurve = self.rcurve_loss(pred_rcurve, true_rcurve)
        # combine the losses
        loss = self.lambda_param * loss_param + self.lambda_curve * loss_rcurve
        return loss


if __name__ == "__main__":
    pass
+3 −0
Original line number Diff line number Diff line
@@ -5,6 +5,7 @@ import torch
from tqdm.auto import tqdm
from torch.utils.data import DataLoader
from tgreft.utils.visualization import params_cmp_heatmap, plot_rcurve_cmp
from tgreft.nn.loss import CompositeLoss


def train_single_epoch(
@@ -215,5 +216,7 @@ def get_loss(name: str) -> torch.nn.Module:
        return torch.nn.L1Loss()
    elif name.lower() == "huber":
        return torch.nn.SmoothL1Loss()
    elif name.lower() == "composite":
        return CompositeLoss()
    else:
        raise ValueError(f"Unknown loss function: {name}")