From 34ab139340d654201cc1c84c16e1f4fdba42e804 Mon Sep 17 00:00:00 2001 From: =?utf8?q?J=C3=A9r=C3=B4me=20Benoit?= Date: Wed, 2 Apr 2025 21:31:36 +0200 Subject: [PATCH] refactor(qav3): refine typing MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit Signed-off-by: Jérôme Benoit --- .../user_data/freqaimodels/QuickAdapterRegressorV3.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py b/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py index 1a991ef..f5948e7 100644 --- a/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py +++ b/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py @@ -548,7 +548,7 @@ class QuickAdapterRegressorV3(BaseRegressionModel): return min_pred[EXTREMA_COLUMN], max_pred[EXTREMA_COLUMN] -def get_callbacks(trial: optuna.Trial, regressor: str) -> list: +def get_callbacks(trial: optuna.Trial, regressor: str) -> list[Callable]: if regressor == "xgboost": callbacks = [ optuna.integration.XGBoostPruningCallback(trial, "validation_0-rmse") @@ -569,7 +569,7 @@ def train_regressor( eval_weights: Optional[list[np.ndarray]], model_training_parameters: dict, init_model: Any = None, - callbacks: list = None, + callbacks: list[Callable] = None, ) -> Any: if regressor == "xgboost": from xgboost import XGBRegressor -- 2.43.0