Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
torch_single_1fc.py 3.54 KiB
from t_ids.encoders import TorchTransformerEncoder
from t_ids.models import GenericIDS
from t_ids.data_loaders import (
    NonOverlappingSlidingWindowDataset,
    RandomStartSlidingWindowDataLoader,
    load_syn_training_dataset,
    load_syn_testing_dataset,
    RandomStartSlidingWindowDataset,
)
from t_ids.transformers import SingleTransformerEncoderFcDecoder
from t_ids.helpers import default_device, generate_uuid_suffix, setup_logging
from t_ids.training import train_and_validate_augmented
from t_ids.evaluation import inference, test_augmented_batched
from torch.utils.data import DataLoader
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split

import os
import logging

setup_logging()

logging.info(f"Running script '{os.path.basename(__file__)}'")

SUFFIX = generate_uuid_suffix()

TRAIN_DATA_PATH = "./data/syncan/train/"
TEST_DATA_PATH = "./data/syncan/test/"
CHECKPOINT_FILE_PATH = "models/checkpoints-encoder-only-rec_" + SUFFIX

DEVICE = default_device()

HP = {
    "use_ids": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
    "iterations": 4096,
    "window_size": 1024,
    "batch_size": 32,
    "learning_rate": 1e-6,
    "d_model": 64,
    "num_encoder_heads": 8,
    "num_encoder_layers": 2,
    "dim_feedforward": 512,
    "embedding_dropout": 0.3,
    "encoder_dropout": 0.3,
    "num_decoder_layers": 1,
    "random_mask_p": 0.7,
    "c_in": 20,
    "c_out": 20,
}

base_encoder = TorchTransformerEncoder(
    d_model=HP["d_model"],
    num_encoder_heads=HP["num_encoder_heads"],
    num_encoder_layers=HP["num_encoder_layers"],
    dim_feedforward=HP["dim_feedforward"],
    encoder_dropout=HP["encoder_dropout"],
    device=DEVICE,
)

base_model = SingleTransformerEncoderFcDecoder(
    c_in=HP["c_in"],
    d_model=HP["d_model"],
    max_len=HP["window_size"],
    embedding_dropout=HP["embedding_dropout"],
    base_encoder=base_encoder,
    out_feats=HP["c_out"],
    random_mask_p=HP["random_mask_p"],
    num_decoder_layers=HP["num_decoder_layers"],
    decoder_activation=None,
    device=DEVICE,
)

methods = {
    "train_and_validate": train_and_validate_augmented,
    "inference": inference,
    "test": test_augmented_batched,
}

ids = GenericIDS(
    hyperparams=HP,
    model=base_model,
    checkpoint_file_path=CHECKPOINT_FILE_PATH,
    methods=methods,
    device=DEVICE,
)

train_df = load_syn_training_dataset(TRAIN_DATA_PATH)

le = LabelEncoder()
train_df["ID"] = le.fit_transform(train_df["ID"])

train_df, val_df = train_test_split(train_df, test_size=0.01, shuffle=False)

train_sw = RandomStartSlidingWindowDataset(
    data=train_df, window_size=HP["window_size"], num_batches=HP["batch_size"]
)
val_sw = RandomStartSlidingWindowDataset(
    data=val_df, window_size=HP["window_size"], num_batches=HP["batch_size"]
)

dl_train_sw = RandomStartSlidingWindowDataLoader(dataset=train_sw, shuffle=True)
dl_val_sw = RandomStartSlidingWindowDataLoader(dataset=val_sw, shuffle=False)

train_losses, val_losses = ids.train_and_validate(dl_train_sw, dl_val_sw)

for test_file in [
    "test_normal.csv",
    "test_flooding.csv",
    "test_plateau.csv",
    "test_playback.csv",
    "test_continuous.csv",
    "test_suppress.csv",
]:
    test_df = load_syn_testing_dataset(TEST_DATA_PATH, filename=test_file)

    test_df["ID"] = le.fit_transform(test_df["ID"])
    test_sw = NonOverlappingSlidingWindowDataset(
        data=test_df, window_size=HP["window_size"], need_labels=True
    )
    dl_test_sw = DataLoader(test_sw, shuffle=False, batch_size=HP["batch_size"])

    test_name = test_file.split(".")[0]
    ids.test(dl_test_sw, 0.01, test_name)