Commit fbc5a0c4 authored by Zhang, Chen's avatar Zhang, Chen
Browse files

deal with gradient loss during training

parent d34a525f
Loading
Loading
Loading
Loading
+8 −0
Original line number Diff line number Diff line
@@ -176,6 +176,11 @@ def train(
                    torch.save(model.state_dict(), model_name)
                    mlflow.log_artifact(model_name, "models")
                    logger.info("Done saving model.")

            # if loss becomes nan, break the loop
            if torch.isnan(test_loss):
                logger.info("Loss becomes nan, break the loop.")
                break
        # save model (final)
        logger.info("Training complete, saving model...")
        model_name_final = "model_gpt_final.pt"
@@ -191,6 +196,9 @@ def train(
            device=device,
        )
        logger.info(f"Final loss: {final_loss}")
        # if loss is nan, return large number
        if torch.isnan(final_loss):
            final_loss = 1e6

    return final_loss