From 4875f54001cd9bb769c92268278817ae7c07a3d8 Mon Sep 17 00:00:00 2001 From: =?utf8?q?J=C3=A9r=C3=B4me=20Benoit?= Date: Fri, 14 Feb 2025 19:23:38 +0100 Subject: [PATCH] fix(qav3): ensure optuna is doing live optimization MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit Signed-off-by: Jérôme Benoit --- .../freqaimodels/LightGBMRegressorQuickAdapterV35.py | 11 +++++++---- .../freqaimodels/XGBoostRegressorQuickAdapterV35.py | 11 +++++++---- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/quickadapter/user_data/freqaimodels/LightGBMRegressorQuickAdapterV35.py b/quickadapter/user_data/freqaimodels/LightGBMRegressorQuickAdapterV35.py index c1e0166..f4b0d46 100644 --- a/quickadapter/user_data/freqaimodels/LightGBMRegressorQuickAdapterV35.py +++ b/quickadapter/user_data/freqaimodels/LightGBMRegressorQuickAdapterV35.py @@ -184,9 +184,9 @@ class LightGBMRegressorQuickAdapterV35(BaseRegressionModel): dk.data["extra_returns_per_train"]["DI_cutoff"] = cutoff dk.data["extra_returns_per_train"]["label_period_candles"] = ( - self.__optuna_hp.get( - pair, {} - ).get("label_period_candles", self.ft_params["label_period_candles"]) + self.__optuna_hp.get(pair, {}).get( + "label_period_candles", self.ft_params["label_period_candles"] + ) ) dk.data["extra_returns_per_train"]["rmse"] = self.__optuna_hp.get(pair, {}).get( "rmse", 0 @@ -250,6 +250,10 @@ class LightGBMRegressorQuickAdapterV35(BaseRegressionModel): study_name = dk.pair storage = self.get_optuna_storage(dk) pruner = optuna.pruners.HyperbandPruner() + try: + optuna.delete_study(study_name=study_name, storage=storage) + except optuna.exceptions.StudyNotFound: + pass study = optuna.create_study( study_name=study_name, sampler=optuna.samplers.TPESampler( @@ -259,7 +263,6 @@ class LightGBMRegressorQuickAdapterV35(BaseRegressionModel): pruner=pruner, direction=optuna.study.StudyDirection.MINIMIZE, storage=storage, - load_if_exists=True, ) hyperopt_failed = False try: diff --git a/quickadapter/user_data/freqaimodels/XGBoostRegressorQuickAdapterV35.py b/quickadapter/user_data/freqaimodels/XGBoostRegressorQuickAdapterV35.py index 53cee6f..ce0b8a9 100644 --- a/quickadapter/user_data/freqaimodels/XGBoostRegressorQuickAdapterV35.py +++ b/quickadapter/user_data/freqaimodels/XGBoostRegressorQuickAdapterV35.py @@ -187,9 +187,9 @@ class XGBoostRegressorQuickAdapterV35(BaseRegressionModel): dk.data["extra_returns_per_train"]["DI_cutoff"] = cutoff dk.data["extra_returns_per_train"]["label_period_candles"] = ( - self.__optuna_hp.get( - pair, {} - ).get("label_period_candles", self.ft_params["label_period_candles"]) + self.__optuna_hp.get(pair, {}).get( + "label_period_candles", self.ft_params["label_period_candles"] + ) ) dk.data["extra_returns_per_train"]["rmse"] = self.__optuna_hp.get(pair, {}).get( "rmse", 0 @@ -253,6 +253,10 @@ class XGBoostRegressorQuickAdapterV35(BaseRegressionModel): study_name = dk.pair storage = self.get_optuna_storage(dk) pruner = optuna.pruners.HyperbandPruner() + try: + optuna.delete_study(study_name=study_name, storage=storage) + except optuna.exceptions.StudyNotFound: + pass study = optuna.create_study( study_name=study_name, sampler=optuna.samplers.TPESampler( @@ -262,7 +266,6 @@ class XGBoostRegressorQuickAdapterV35(BaseRegressionModel): pruner=pruner, direction=optuna.study.StudyDirection.MINIMIZE, storage=storage, - load_if_exists=True, ) hyperopt_failed = False try: -- 2.43.0