Commit 330f5cc9 authored by Zhang, Chen's avatar Zhang, Chen
Browse files

update for new model name

parent 6cd8022a
Loading
Loading
Loading
Loading
+11 −11
Original line number Diff line number Diff line
#!/usr/bin/env python3
"""Train script for the transformer model."""
from tgreft.train.train_gpt import train
from tgreft.train.train_ntrace import train

if __name__ == "__main__":
    model_params = {
        "d_model": 1024,
        "n_head": 8,
        "num_encoder_layers": 4,
        "d_model": 2048,
        "n_head": 32,
        "num_encoder_layers": 6,
        "input_dim": 150,
        "output_dim": 17,
        "to_log": True,
@@ -15,16 +15,16 @@ if __name__ == "__main__":
    train_params = {
        "cuda_id": 0,
        "n_epochs": 50,  # seems like 50 is the point where training and validation loss diverge
        "n_training": 50_000,
        "n_training": 1_500_000,
        "error": 0.07,
        "batch_size": 250,
        "learning_rate": 0.0057929655918116715,
        "weight_decay": 7.198921885462489e-07,
        "optimizer": "SGD",
        "batch_size": 180,
        "learning_rate": 0.005,
        "weight_decay": 1e-6,
        "optimizer": "Adam",
        "loss": "composite",
        "cache_dir": "data",
        "experiment_name": "Train_REFL_GPT",
        "run_name": "gpt_d1024_h8_l4_newlog",
        "experiment_name": "Train_nTRACE",
        "run_name": "d2048_h32_l6_cmloss",
    }

    train(
+24 −9
Original line number Diff line number Diff line
@@ -119,18 +119,29 @@ def train(
    # log the number of parameters
    logger.info(f"Number of parameters: {n_params}")

    # prepare loss function
    logger.info("Preparing loss function...")
    loss_fn = get_loss(name=loss)

    # prepare optimizer
    logger.info("Preparing optimizer...")
    if loss.lower() == "composite":
        # use composite loss
        optimizer = get_optimizer(
            name=optimizer,
            model=model,
            learning_rate=learning_rate,
            weight_decay=weight_decay,
            loss_fn=loss_fn,
        )
    else:
        # use normal loss   
        optimizer = get_optimizer(
            name=optimizer,
            model=model,  # pass in the combined parameters if using composite loss
            learning_rate=learning_rate,
            weight_decay=weight_decay,
        )

    # prepare loss function
    logger.info("Preparing loss function...")
    loss_fn = get_loss(name=loss)

    # auto advance?
    if not auto_advance:
@@ -170,6 +181,10 @@ def train(
            )
            logger.info(f"Testing loss: {test_loss}")
            mlflow.log_metric("test_loss", test_loss, step=epoch)
            # capture the loss lambda if using composite loss
            if loss.lower() == "composite":
                lambda_param = loss_fn.lambda_param.data.item()
                mlflow.log_metric("lambda_param", lambda_param, step=epoch)
            # visualize every 5 epochs
            if epoch % 5 == 0:
                logger.info("Visualizing model...")