Commit 9c00d888 authored by Zhang, Chen's avatar Zhang, Chen
Browse files

remove hard coded constants

parent e33f260b
Loading
Loading
Loading
Loading
+12 −9
Original line number Diff line number Diff line
@@ -5,23 +5,26 @@ from tgreft.train.train_gpt import train
if __name__ == "__main__":
    train_params = {
        "cuda_id": 0,
        "n_epochs": 500,
        "n_training": 1_000_000,
        "n_epochs": 50,  # seems like 50 is the point where training and validation loss diverge
        "n_training": 1_500_000,
        "error": 0.07,
        "batch_size": 512,
        "learning_rate": 1e-3,
        "weight_decay": 1e-6,
        "optimizer": "adam",
        "batch_size": 200,
        "learning_rate": 0.0057929655918116715,
        "weight_decay": 7.198921885462489e-07,
        "optimizer": "SGD",
        "loss": "huber",
        "cache_dir": "data",
        "experiment_name": "Train_REFL_GPT",
        "run_name": "gpt_d1024_h4_l4",
    }

    model_params = {
        "d_model": 256,
        "n_head": 16,
        "d_model": 1024,
        "n_head": 4,
        "num_encoder_layers": 4,
        "input_dim": 150,
        "output_dim": 13,
        "output_dim": 17,
        "to_log": True,
    }

    train(
+7 −0
Original line number Diff line number Diff line
@@ -37,6 +37,13 @@ class REFL_GPT(nn.Module):

        self.to_log = to_log

        # record the model configuration
        self.d_model = d_model
        self.nhead = nhead
        self.num_encoder_layers = num_encoder_layers
        self.input_dim = input_dim
        self.output_dim = output_dim

    def forward(self, src):
        """Forward pass."""
        if self.to_log:
+2 −2
Original line number Diff line number Diff line
@@ -115,8 +115,8 @@ def visualize_single_epoch(
    preds = np.concatenate(preds, axis=0)
    refs = np.concatenate(refs, axis=0)
    # reshape
    preds = preds.reshape(-1, 13)
    refs = refs.reshape(-1, 13)
    preds = preds.reshape(-1, model.output_dim)
    refs = refs.reshape(-1, model.output_dim)
    #
    labels = [
        "electolyte_sld",