Loading src/tgreft/train/train_gpt.py +3 −2 Original line number Diff line number Diff line Loading @@ -4,6 +4,7 @@ import os import torch import mlflow import logging import numpy as np from tqdm.auto import tqdm from torch.utils.data import DataLoader, random_split from tgreft.models.refl_gpt import REFL_GPT Loading Loading @@ -178,7 +179,7 @@ def train( logger.info("Done saving model.") # if loss becomes nan, break the loop if torch.isnan(test_loss): if np.isnan(test_loss): logger.info("Loss becomes nan, break the loop.") break # save model (final) Loading @@ -197,7 +198,7 @@ def train( ) logger.info(f"Final loss: {final_loss}") # if loss is nan, return large number if torch.isnan(final_loss): if np.isnan(final_loss): final_loss = 1e6 return final_loss Loading Loading
src/tgreft/train/train_gpt.py +3 −2 Original line number Diff line number Diff line Loading @@ -4,6 +4,7 @@ import os import torch import mlflow import logging import numpy as np from tqdm.auto import tqdm from torch.utils.data import DataLoader, random_split from tgreft.models.refl_gpt import REFL_GPT Loading Loading @@ -178,7 +179,7 @@ def train( logger.info("Done saving model.") # if loss becomes nan, break the loop if torch.isnan(test_loss): if np.isnan(test_loss): logger.info("Loss becomes nan, break the loop.") break # save model (final) Loading @@ -197,7 +198,7 @@ def train( ) logger.info(f"Final loss: {final_loss}") # if loss is nan, return large number if torch.isnan(final_loss): if np.isnan(final_loss): final_loss = 1e6 return final_loss Loading