Commit 18090c9b authored by Zhang, Chen's avatar Zhang, Chen
Browse files

use np to check for nan

parent fbc5a0c4
Loading
Loading
Loading
Loading
+3 −2
Original line number Diff line number Diff line
@@ -4,6 +4,7 @@ import os
import torch
import mlflow
import logging
import numpy as np
from tqdm.auto import tqdm
from torch.utils.data import DataLoader, random_split
from tgreft.models.refl_gpt import REFL_GPT
@@ -178,7 +179,7 @@ def train(
                    logger.info("Done saving model.")

            # if loss becomes nan, break the loop
            if torch.isnan(test_loss):
            if np.isnan(test_loss):
                logger.info("Loss becomes nan, break the loop.")
                break
        # save model (final)
@@ -197,7 +198,7 @@ def train(
        )
        logger.info(f"Final loss: {final_loss}")
        # if loss is nan, return large number
        if torch.isnan(final_loss):
        if np.isnan(final_loss):
            final_loss = 1e6

    return final_loss