Commit 17d9ca1f authored by Zhang, Chen's avatar Zhang, Chen
Browse files

switch back to static lambda for stability

parent c3420ba4
Loading
Loading
Loading
Loading
+13 −12
Original line number Diff line number Diff line
@@ -5,7 +5,7 @@ import os
import logging
import torch
import optuna
from tgreft.train.train_gpt import train
from tgreft.train.train_ntrace import train

logger = logging.getLogger("optuna_gpt")
logger.setLevel(logging.INFO)
@@ -30,25 +30,25 @@ def objective(trial: optuna.Trial) -> float:
    # training parameters
    train_params = {
        "cuda_id": 0,
        "n_epochs": 30,  # should be sufficient for hyperparameter tuning
        "n_epochs": 50,  # should be sufficient for hyperparameter tuning
        "n_training": 1_500_000,
        "error": 0.07,
        "batch_size": trial.suggest_categorical("batch_size", [64, 128, 256, 512, 1024, 2048]),
        "batch_size": trial.suggest_categorical("batch_size", [64, 128, 256, 512, 1024, 2048, 4096]),
        "learning_rate": trial.suggest_float("learning_rate", 1e-5, 1e-2, log=True),
        "weight_decay": trial.suggest_float("weight_decay", 1e-8, 1e-4, log=True),
        "optimizer": trial.suggest_categorical("optimizer", ["Adam", "SGD"]),
        "loss": "Huber",
        "cache_dir": "data",
        "experiment_name": "optuna_gpt",
        "experiment_name": "optuna_gpt_output_dim-17",
    }
    # model parameters
    model_params = {
        "d_model": trial.suggest_categorical("d_model", [16, 32, 64, 128, 256, 512, 1024]),
        "n_head": trial.suggest_categorical("n_head", [2, 4, 8, 16, 32]),
        "num_encoder_layers": trial.suggest_categorical("num_encoder_layers", [2, 4, 8, 16, 32]),
        "d_model": trial.suggest_categorical("d_model", [16, 32, 64, 128, 256, 512, 1024, 2048, 4096]),
        "n_head": trial.suggest_categorical("n_head", [2, 4, 8, 16, 32, 64, 128, 256, 512]),
        "num_encoder_layers": trial.suggest_categorical("num_encoder_layers", [2, 4, 8, 16, 32, 64, 128, 256, 512]),
        "input_dim": 150,
        "output_dim": 13,
        "to_log": trial.suggest_categorical("to_log", [True, False]),
        "output_dim": 17,
        "to_log": True,
    }
    # train the model
    try:
@@ -83,11 +83,12 @@ def objective(trial: optuna.Trial) -> float:


if __name__ == "__main__":
    study_name = "gpt"
    storage_name = "sqlite:///optuna.db"
    study_name = "gpt_refl"
    db_file_name = "optuna_gpt_refl.db" 
    storage_name = f"sqlite:///{db_file_name}.db"

    # try to resume the study if it exists
    if os.path.exists("optuna.db"):
    if os.path.exists(db_file_name):
        logger.info("Found existing study, resuming...")
        study = optuna.load_study(
            study_name=study_name,
+15 −9
Original line number Diff line number Diff line
#!/usr/bin/env python
"""Extended loss functions for tgreft."""
import logging
import torch
import torch.nn as nn
import numpy as np
from multiprocessing import Pool
from tgreft.utils.data.data_loader import param_to_rcurve

logger = logging.getLogger("LOSS")
logger.setLevel(logging.DEBUG)
# create a file handler
handler = logging.FileHandler("loss.log")
logger.addHandler(handler)
formatter = logging.Formatter("%(asctime)s %(levelname)s %(message)s")
logger.handlers[0].setFormatter(formatter)


def parallel_param_to_rcurve(params):
    with Pool() as p:
@@ -23,10 +32,10 @@ class CompositeLoss(nn.Module):
        super().__init__()
        # define the loss functions
        self.param_loss_huber = nn.SmoothL1Loss()
        self.rcurve_loss = nn.MSELoss()
        # use learnable weights
        self.lambda_param = nn.Parameter(torch.tensor(lambda_param))
        self.lambda_curve = nn.Parameter(torch.tensor(1.0 - lambda_param))
        self.rcurve_loss = nn.SmoothL1Loss()
        # define the lambda parameters
        self.lambda_param = lambda_param
        self.lambda_curve = 1.0 - lambda_param
        # number of parameters
        self.num_param = num_param

@@ -43,18 +52,15 @@ class CompositeLoss(nn.Module):
        true_params = true_params.view(-1, self.num_param)
        # 4. generate the r-curves
        true_rcurve = torch.from_numpy(
            np.array(parallel_param_to_rcurve(pred_params.detach().cpu().numpy()))
            np.array(parallel_param_to_rcurve(true_params.detach().cpu().numpy()))
        ).float()
        # calculate the loss
        # cast to log first
        pred_rcurve = torch.log(pred_rcurve)
        true_rcurve = torch.log(true_rcurve)
        loss_rcurve = self.rcurve_loss(pred_rcurve, true_rcurve)
        # combine the losses
        # - normalize the parameters
        self.lambda_param.data = self.lambda_param.data.abs() / (self.lambda_param.data.abs() + self.lambda_curve.data.abs())
        self.lambda_curve.data = 1.0 - self.lambda_param.data
        # - calculate the combined loss
        logger.debug(f"loss_param: {loss_param};\tloss_rcurve: {loss_rcurve}")
        loss = self.lambda_param * loss_param + self.lambda_curve * loss_rcurve
        return loss

+8 −23
Original line number Diff line number Diff line
@@ -18,7 +18,7 @@ from tgreft.train.generic import (
from tgreft.utils.data.data_loader import get_dataset


logger = logging.getLogger("GPT_trainer")
logger = logging.getLogger("mTRACE_trainer")
logger.setLevel(logging.INFO)
logger.addHandler(logging.StreamHandler())
formatter = logging.Formatter("%(asctime)s %(levelname)s %(message)s")
@@ -125,22 +125,11 @@ def train(

    # prepare optimizer
    logger.info("Preparing optimizer...")
    if loss.lower() == "composite":
        # use composite loss
    optimizer = get_optimizer(
        name=optimizer,
        model=model,
        learning_rate=learning_rate,
        weight_decay=weight_decay,
            loss_fn=loss_fn,
        )
    else:
        # use normal loss   
        optimizer = get_optimizer(
            name=optimizer,
            model=model,  # pass in the combined parameters if using composite loss
            learning_rate=learning_rate,
            weight_decay=weight_decay,
    )

    # auto advance?
@@ -160,7 +149,7 @@ def train(
        mlflow.log_artifact(__file__, "training_script")

        # main loop
        for epoch in tqdm(range(n_epochs + 2)):
        for epoch in tqdm(range(n_epochs + 1)):
            logger.info(f"Epoch {epoch}")
            # train
            train_loss = train_single_epoch(
@@ -181,10 +170,6 @@ def train(
            )
            logger.info(f"Testing loss: {test_loss}")
            mlflow.log_metric("test_loss", test_loss, step=epoch)
            # capture the loss lambda if using composite loss
            if loss.lower() == "composite":
                lambda_param = loss_fn.lambda_param.data.item()
                mlflow.log_metric("lambda_param", lambda_param, step=epoch)
            # visualize every 5 epochs
            if epoch % 5 == 0:
                logger.info("Visualizing model...")