From b2b494e4de0ca621cef792da09ae286dc5cfb5a6 Mon Sep 17 00:00:00 2001 From: =?utf8?q?J=C3=A9r=C3=B4me=20Benoit?= Date: Sat, 15 Feb 2025 21:54:32 +0100 Subject: [PATCH] fix(qav3): handle more optuna loaded study corner cases MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit Signed-off-by: Jérôme Benoit --- .../LightGBMRegressorQuickAdapterV35.py | 18 ++++++++++++++++-- .../XGBoostRegressorQuickAdapterV35.py | 18 ++++++++++++++++-- 2 files changed, 32 insertions(+), 4 deletions(-) diff --git a/quickadapter/user_data/freqaimodels/LightGBMRegressorQuickAdapterV35.py b/quickadapter/user_data/freqaimodels/LightGBMRegressorQuickAdapterV35.py index 0718c84..8820b80 100644 --- a/quickadapter/user_data/freqaimodels/LightGBMRegressorQuickAdapterV35.py +++ b/quickadapter/user_data/freqaimodels/LightGBMRegressorQuickAdapterV35.py @@ -285,7 +285,7 @@ class LightGBMRegressorQuickAdapterV35(BaseRegressionModel): direction=optuna.study.StudyDirection.MINIMIZE, storage=storage, ) - if previous_study and hasattr(previous_study, "best_params"): + if self.optuna_study_has_best_params(previous_study): study.enqueue_trial(previous_study.best_params) start = time.time() try: @@ -342,7 +342,7 @@ class LightGBMRegressorQuickAdapterV35(BaseRegressionModel): direction=optuna.study.StudyDirection.MINIMIZE, storage=storage, ) - if previous_study and hasattr(previous_study, "best_params"): + if self.optuna_study_has_best_params(previous_study): study.enqueue_trial(previous_study.best_params) start = time.time() try: @@ -390,6 +390,20 @@ class LightGBMRegressorQuickAdapterV35(BaseRegressionModel): pass return previous_study + def optuna_study_has_best_params(self, study: optuna.study.Study | None) -> bool: + if not study: + return False + try: + # Check if there are completed trials + if len(study.trials) == 0: + return False + + # Check if best_params exists (raises ValueError if no trials succeeded) + _ = study.best_params + return True + except ValueError: + return False + def log_sum_exp_min_max_pred( pred_df: pd.DataFrame, fit_live_predictions_candles: int, label_period_candles: int diff --git a/quickadapter/user_data/freqaimodels/XGBoostRegressorQuickAdapterV35.py b/quickadapter/user_data/freqaimodels/XGBoostRegressorQuickAdapterV35.py index 3493740..fa3fedf 100644 --- a/quickadapter/user_data/freqaimodels/XGBoostRegressorQuickAdapterV35.py +++ b/quickadapter/user_data/freqaimodels/XGBoostRegressorQuickAdapterV35.py @@ -286,7 +286,7 @@ class XGBoostRegressorQuickAdapterV35(BaseRegressionModel): direction=optuna.study.StudyDirection.MINIMIZE, storage=storage, ) - if previous_study and hasattr(previous_study, "best_params"): + if self.optuna_study_has_best_params(previous_study): study.enqueue_trial(previous_study.best_params) start = time.time() try: @@ -343,7 +343,7 @@ class XGBoostRegressorQuickAdapterV35(BaseRegressionModel): direction=optuna.study.StudyDirection.MINIMIZE, storage=storage, ) - if previous_study and hasattr(previous_study, "best_params"): + if self.optuna_study_has_best_params(previous_study): study.enqueue_trial(previous_study.best_params) start = time.time() try: @@ -391,6 +391,20 @@ class XGBoostRegressorQuickAdapterV35(BaseRegressionModel): pass return previous_study + def optuna_study_has_best_params(self, study: optuna.study.Study | None) -> bool: + if not study: + return False + try: + # Check if there are completed trials + if len(study.trials) == 0: + return False + + # Check if best_params exists (raises ValueError if no trials succeeded) + _ = study.best_params + return True + except ValueError: + return False + def log_sum_exp_min_max_pred( pred_df: pd.DataFrame, fit_live_predictions_candles: int, label_period_candles: int -- 2.43.0