Loading src/tgreft/train/train_gpt.py +8 −0 Original line number Diff line number Diff line Loading @@ -176,6 +176,11 @@ def train( torch.save(model.state_dict(), model_name) mlflow.log_artifact(model_name, "models") logger.info("Done saving model.") # if loss becomes nan, break the loop if torch.isnan(test_loss): logger.info("Loss becomes nan, break the loop.") break # save model (final) logger.info("Training complete, saving model...") model_name_final = "model_gpt_final.pt" Loading @@ -191,6 +196,9 @@ def train( device=device, ) logger.info(f"Final loss: {final_loss}") # if loss is nan, return large number if torch.isnan(final_loss): final_loss = 1e6 return final_loss Loading Loading
src/tgreft/train/train_gpt.py +8 −0 Original line number Diff line number Diff line Loading @@ -176,6 +176,11 @@ def train( torch.save(model.state_dict(), model_name) mlflow.log_artifact(model_name, "models") logger.info("Done saving model.") # if loss becomes nan, break the loop if torch.isnan(test_loss): logger.info("Loss becomes nan, break the loop.") break # save model (final) logger.info("Training complete, saving model...") model_name_final = "model_gpt_final.pt" Loading @@ -191,6 +196,9 @@ def train( device=device, ) logger.info(f"Final loss: {final_loss}") # if loss is nan, return large number if torch.isnan(final_loss): final_loss = 1e6 return final_loss Loading