Code owners
Assign users and groups as approvers for specific file changes. Learn more.
torch_single_1fc.py 3.54 KiB
from t_ids.encoders import TorchTransformerEncoder
from t_ids.models import GenericIDS
from t_ids.data_loaders import (
NonOverlappingSlidingWindowDataset,
RandomStartSlidingWindowDataLoader,
load_syn_training_dataset,
load_syn_testing_dataset,
RandomStartSlidingWindowDataset,
)
from t_ids.transformers import SingleTransformerEncoderFcDecoder
from t_ids.helpers import default_device, generate_uuid_suffix, setup_logging
from t_ids.training import train_and_validate_augmented
from t_ids.evaluation import inference, test_augmented_batched
from torch.utils.data import DataLoader
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
import os
import logging
setup_logging()
logging.info(f"Running script '{os.path.basename(__file__)}'")
SUFFIX = generate_uuid_suffix()
TRAIN_DATA_PATH = "./data/syncan/train/"
TEST_DATA_PATH = "./data/syncan/test/"
CHECKPOINT_FILE_PATH = "models/checkpoints-encoder-only-rec_" + SUFFIX
DEVICE = default_device()
HP = {
"use_ids": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
"iterations": 4096,
"window_size": 1024,
"batch_size": 32,
"learning_rate": 1e-6,
"d_model": 64,
"num_encoder_heads": 8,
"num_encoder_layers": 2,
"dim_feedforward": 512,
"embedding_dropout": 0.3,
"encoder_dropout": 0.3,
"num_decoder_layers": 1,
"random_mask_p": 0.7,
"c_in": 20,
"c_out": 20,
}
base_encoder = TorchTransformerEncoder(
d_model=HP["d_model"],
num_encoder_heads=HP["num_encoder_heads"],
num_encoder_layers=HP["num_encoder_layers"],
dim_feedforward=HP["dim_feedforward"],
encoder_dropout=HP["encoder_dropout"],
device=DEVICE,
)
base_model = SingleTransformerEncoderFcDecoder(
c_in=HP["c_in"],
d_model=HP["d_model"],
max_len=HP["window_size"],
embedding_dropout=HP["embedding_dropout"],
base_encoder=base_encoder,
out_feats=HP["c_out"],
random_mask_p=HP["random_mask_p"],
num_decoder_layers=HP["num_decoder_layers"],
decoder_activation=None,
device=DEVICE,
)
methods = {
"train_and_validate": train_and_validate_augmented,
"inference": inference,
"test": test_augmented_batched,
}
ids = GenericIDS(
hyperparams=HP,
model=base_model,
checkpoint_file_path=CHECKPOINT_FILE_PATH,
methods=methods,
device=DEVICE,
)
train_df = load_syn_training_dataset(TRAIN_DATA_PATH)
le = LabelEncoder()
train_df["ID"] = le.fit_transform(train_df["ID"])
train_df, val_df = train_test_split(train_df, test_size=0.01, shuffle=False)
train_sw = RandomStartSlidingWindowDataset(
data=train_df, window_size=HP["window_size"], num_batches=HP["batch_size"]
)
val_sw = RandomStartSlidingWindowDataset(
data=val_df, window_size=HP["window_size"], num_batches=HP["batch_size"]
)
dl_train_sw = RandomStartSlidingWindowDataLoader(dataset=train_sw, shuffle=True)
dl_val_sw = RandomStartSlidingWindowDataLoader(dataset=val_sw, shuffle=False)
train_losses, val_losses = ids.train_and_validate(dl_train_sw, dl_val_sw)
for test_file in [
"test_normal.csv",
"test_flooding.csv",
"test_plateau.csv",
"test_playback.csv",
"test_continuous.csv",
"test_suppress.csv",
]:
test_df = load_syn_testing_dataset(TEST_DATA_PATH, filename=test_file)
test_df["ID"] = le.fit_transform(test_df["ID"])
test_sw = NonOverlappingSlidingWindowDataset(
data=test_df, window_size=HP["window_size"], need_labels=True
)
dl_test_sw = DataLoader(test_sw, shuffle=False, batch_size=HP["batch_size"])
test_name = test_file.split(".")[0]
ids.test(dl_test_sw, 0.01, test_name)