Loading .pre-commit-config.yaml 0 → 100644 +29 −0 Original line number Diff line number Diff line repos: - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.5.0 hooks: - id: check-added-large-files args: [--maxkb=8192] - id: check-merge-conflict - id: check-yaml exclude: | (?x)^( ^conda.recipe/meta.yaml )$ - id: end-of-file-fixer exclude: "tests/cis_tests/.*" - id: trailing-whitespace exclude: "tests/cis_tests/.*" - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.1.7 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] exclude: "tests/cis_tests/.*" - repo: https://github.com/psf/black rev: 23.11.0 hooks: - id: black args: ['--line-length=120'] exclude: "tests/cis_tests/.*" scripts/train_gpt.py +3 −3 Original line number Diff line number Diff line Loading @@ -15,13 +15,13 @@ if __name__ == "__main__": train_params = { "cuda_id": 0, "n_epochs": 50, # seems like 50 is the point where training and validation loss diverge "n_training": 1_500_000, "n_training": 50_000, "error": 0.07, "batch_size": 250, "learning_rate": 0.0057929655918116715, "weight_decay": 7.198921885462489e-07, "optimizer": "SGD", "loss": "huber", "loss": "composite", "cache_dir": "data", "experiment_name": "Train_REFL_GPT", "run_name": "gpt_d1024_h8_l4_newlog", Loading src/tgreft/nn/loss.py 0 → 100644 +50 −0 Original line number Diff line number Diff line #!/usr/bin/env python """Extended loss functions for tgreft.""" import torch import torch.nn as nn import numpy as np from tgreft.utils.data.data_loader import param_to_rcurve class CompositeLoss(nn.Module): """Composite loss function that consider both the difference in params and the r-curves.""" def __init__( self, lambda_param: float = 0.5, num_param: int = 17, ): super().__init__() # define the loss functions self.param_loss_huber = nn.SmoothL1Loss() self.rcurve_loss = nn.MSELoss() # weight for the param loss and the rcurve loss self.lambda_param = lambda_param self.lambda_curve = 1 - lambda_param # number of parameters self.num_param = num_param def forward(self, pred_params, true_params): loss_param = self.param_loss_huber(pred_params, true_params) # generate the r-curves from the predicted params # 1. reshape the predicted params pred_params = pred_params.view(-1, self.num_param) # 2. generate the r-curves pred_rcurve = torch.from_numpy( np.apply_along_axis(param_to_rcurve, 1, pred_params.detach().cpu().numpy()) ).float() # 3. reshape the reference params true_params = true_params.view(-1, self.num_param) # 4. generate the r-curves true_rcurve = torch.from_numpy( np.apply_along_axis(param_to_rcurve, 1, true_params.detach().cpu().numpy()) ).float() # calculate the loss loss_rcurve = self.rcurve_loss(pred_rcurve, true_rcurve) # combine the losses loss = self.lambda_param * loss_param + self.lambda_curve * loss_rcurve return loss if __name__ == "__main__": pass src/tgreft/train/generic.py +3 −0 Original line number Diff line number Diff line Loading @@ -5,6 +5,7 @@ import torch from tqdm.auto import tqdm from torch.utils.data import DataLoader from tgreft.utils.visualization import params_cmp_heatmap, plot_rcurve_cmp from tgreft.nn.loss import CompositeLoss def train_single_epoch( Loading Loading @@ -215,5 +216,7 @@ def get_loss(name: str) -> torch.nn.Module: return torch.nn.L1Loss() elif name.lower() == "huber": return torch.nn.SmoothL1Loss() elif name.lower() == "composite": return CompositeLoss() else: raise ValueError(f"Unknown loss function: {name}") Loading
.pre-commit-config.yaml 0 → 100644 +29 −0 Original line number Diff line number Diff line repos: - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.5.0 hooks: - id: check-added-large-files args: [--maxkb=8192] - id: check-merge-conflict - id: check-yaml exclude: | (?x)^( ^conda.recipe/meta.yaml )$ - id: end-of-file-fixer exclude: "tests/cis_tests/.*" - id: trailing-whitespace exclude: "tests/cis_tests/.*" - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.1.7 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] exclude: "tests/cis_tests/.*" - repo: https://github.com/psf/black rev: 23.11.0 hooks: - id: black args: ['--line-length=120'] exclude: "tests/cis_tests/.*"
scripts/train_gpt.py +3 −3 Original line number Diff line number Diff line Loading @@ -15,13 +15,13 @@ if __name__ == "__main__": train_params = { "cuda_id": 0, "n_epochs": 50, # seems like 50 is the point where training and validation loss diverge "n_training": 1_500_000, "n_training": 50_000, "error": 0.07, "batch_size": 250, "learning_rate": 0.0057929655918116715, "weight_decay": 7.198921885462489e-07, "optimizer": "SGD", "loss": "huber", "loss": "composite", "cache_dir": "data", "experiment_name": "Train_REFL_GPT", "run_name": "gpt_d1024_h8_l4_newlog", Loading
src/tgreft/nn/loss.py 0 → 100644 +50 −0 Original line number Diff line number Diff line #!/usr/bin/env python """Extended loss functions for tgreft.""" import torch import torch.nn as nn import numpy as np from tgreft.utils.data.data_loader import param_to_rcurve class CompositeLoss(nn.Module): """Composite loss function that consider both the difference in params and the r-curves.""" def __init__( self, lambda_param: float = 0.5, num_param: int = 17, ): super().__init__() # define the loss functions self.param_loss_huber = nn.SmoothL1Loss() self.rcurve_loss = nn.MSELoss() # weight for the param loss and the rcurve loss self.lambda_param = lambda_param self.lambda_curve = 1 - lambda_param # number of parameters self.num_param = num_param def forward(self, pred_params, true_params): loss_param = self.param_loss_huber(pred_params, true_params) # generate the r-curves from the predicted params # 1. reshape the predicted params pred_params = pred_params.view(-1, self.num_param) # 2. generate the r-curves pred_rcurve = torch.from_numpy( np.apply_along_axis(param_to_rcurve, 1, pred_params.detach().cpu().numpy()) ).float() # 3. reshape the reference params true_params = true_params.view(-1, self.num_param) # 4. generate the r-curves true_rcurve = torch.from_numpy( np.apply_along_axis(param_to_rcurve, 1, true_params.detach().cpu().numpy()) ).float() # calculate the loss loss_rcurve = self.rcurve_loss(pred_rcurve, true_rcurve) # combine the losses loss = self.lambda_param * loss_param + self.lambda_curve * loss_rcurve return loss if __name__ == "__main__": pass
src/tgreft/train/generic.py +3 −0 Original line number Diff line number Diff line Loading @@ -5,6 +5,7 @@ import torch from tqdm.auto import tqdm from torch.utils.data import DataLoader from tgreft.utils.visualization import params_cmp_heatmap, plot_rcurve_cmp from tgreft.nn.loss import CompositeLoss def train_single_epoch( Loading Loading @@ -215,5 +216,7 @@ def get_loss(name: str) -> torch.nn.Module: return torch.nn.L1Loss() elif name.lower() == "huber": return torch.nn.SmoothL1Loss() elif name.lower() == "composite": return CompositeLoss() else: raise ValueError(f"Unknown loss function: {name}")