Commit 3e5a7c49 authored by Zhang, Chen's avatar Zhang, Chen
Browse files

handle transformer dim mismatch during tuning

parent 18090c9b
Loading
Loading
Loading
Loading
+7 −0
Original line number Diff line number Diff line
@@ -64,6 +64,13 @@ def objective(trial: optuna.Trial) -> float:
            raise optuna.TrialPruned(f"OOM error: {e}")
        else:
            raise e
    except AssertionError as e:
        # Cache dimension and n_head mismatch
        if "embed_dim must be divisible by num_heads" in str(e):
            logger.error("embed_dim must be divisible by num_heads")
            raise optuna.TrialPruned(f"embed_dim must be divisible by num_heads: {e}")
        else:
            raise e
    except Exception as e:
        loss = 1e6
        raise e