From fbea717305f66a20f1924d05628873bd7940f895 Mon Sep 17 00:00:00 2001 From: =?utf8?q?J=C3=A9r=C3=B4me=20Benoit?= Date: Sun, 16 Feb 2025 12:11:20 +0100 Subject: [PATCH] refactor(qav3): add a runtime fallback for optuna warm start MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit Signed-off-by: Jérôme Benoit --- .../LightGBMRegressorQuickAdapterV35.py | 18 ++++++++++++++---- .../XGBoostRegressorQuickAdapterV35.py | 18 ++++++++++++++---- 2 files changed, 28 insertions(+), 8 deletions(-) diff --git a/quickadapter/user_data/freqaimodels/LightGBMRegressorQuickAdapterV35.py b/quickadapter/user_data/freqaimodels/LightGBMRegressorQuickAdapterV35.py index 602c870..5438566 100644 --- a/quickadapter/user_data/freqaimodels/LightGBMRegressorQuickAdapterV35.py +++ b/quickadapter/user_data/freqaimodels/LightGBMRegressorQuickAdapterV35.py @@ -50,6 +50,7 @@ class LightGBMRegressorQuickAdapterV35(BaseRegressionModel): and self.__optuna_config.get("enabled", False) and self.data_split_parameters.get("test_size", TEST_SIZE) > 0 ) + self.__optuna_hp_params: dict[str, dict] = {} self.__optuna_period_params: dict[str, dict] = {} def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any: @@ -77,9 +78,12 @@ class LightGBMRegressorQuickAdapterV35(BaseRegressionModel): dk, X, y, train_weights, X_test, y_test, test_weights ) if optuna_hp_params: + if dk.pair not in self.__optuna_hp_params: + self.__optuna_hp_params[dk.pair] = {} + self.__optuna_hp_params[dk.pair] = optuna_hp_params model_training_parameters = { **model_training_parameters, - **optuna_hp_params, + **self.__optuna_hp_params[dk.pair], } optuna_period_params = self.optuna_period_optimize( @@ -277,7 +281,7 @@ class LightGBMRegressorQuickAdapterV35(BaseRegressionModel): study_name = f"hp-{dk.pair}" storage = self.optuna_storage(dk) pruner = optuna.pruners.HyperbandPruner() - previous_study = self.optuna_previous_study(study_name, storage) + previous_study = self.optuna_study_load_and_cleanup(study_name, storage) study = optuna.create_study( study_name=study_name, sampler=optuna.samplers.TPESampler( @@ -290,6 +294,8 @@ class LightGBMRegressorQuickAdapterV35(BaseRegressionModel): ) if self.optuna_study_has_best_params(previous_study): study.enqueue_trial(previous_study.best_params) + elif self.__optuna_hp_params.get(dk.pair): + study.enqueue_trial(self.__optuna_hp_params[dk.pair]) start = time.time() try: study.optimize( @@ -334,7 +340,7 @@ class LightGBMRegressorQuickAdapterV35(BaseRegressionModel): study_name = f"period-{dk.pair}" storage = self.optuna_storage(dk) pruner = optuna.pruners.HyperbandPruner() - previous_study = self.optuna_previous_study(study_name, storage) + previous_study = self.optuna_study_load_and_cleanup(study_name, storage) study = optuna.create_study( study_name=study_name, sampler=optuna.samplers.TPESampler( @@ -347,6 +353,10 @@ class LightGBMRegressorQuickAdapterV35(BaseRegressionModel): ) if self.optuna_study_has_best_params(previous_study): study.enqueue_trial(previous_study.best_params) + elif self.__optuna_period_params.get(dk.pair): + previous_best_params = self.__optuna_period_params[dk.pair].copy() + del previous_best_params["rmse"] + study.enqueue_trial(previous_best_params) start = time.time() try: study.optimize( @@ -380,7 +390,7 @@ class LightGBMRegressorQuickAdapterV35(BaseRegressionModel): logger.info(f"Optuna period hyperopt | {key:>20s} : {value}") return params - def optuna_previous_study( + def optuna_study_load_and_cleanup( self, study_name: str, storage ) -> optuna.study.Study | None: try: diff --git a/quickadapter/user_data/freqaimodels/XGBoostRegressorQuickAdapterV35.py b/quickadapter/user_data/freqaimodels/XGBoostRegressorQuickAdapterV35.py index a43d915..7038dff 100644 --- a/quickadapter/user_data/freqaimodels/XGBoostRegressorQuickAdapterV35.py +++ b/quickadapter/user_data/freqaimodels/XGBoostRegressorQuickAdapterV35.py @@ -50,6 +50,7 @@ class XGBoostRegressorQuickAdapterV35(BaseRegressionModel): and self.__optuna_config.get("enabled", False) and self.data_split_parameters.get("test_size", TEST_SIZE) > 0 ) + self.__optuna_hp_params: dict[str, dict] = {} self.__optuna_period_params: dict[str, dict] = {} def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any: @@ -77,9 +78,12 @@ class XGBoostRegressorQuickAdapterV35(BaseRegressionModel): dk, X, y, train_weights, X_test, y_test, test_weights ) if optuna_hp_params: + if dk.pair not in self.__optuna_hp_params: + self.__optuna_hp_params[dk.pair] = {} + self.__optuna_hp_params[dk.pair] = optuna_hp_params model_training_parameters = { **model_training_parameters, - **optuna_hp_params, + **self.__optuna_hp_params[dk.pair], } optuna_period_params = self.optuna_period_optimize( @@ -278,7 +282,7 @@ class XGBoostRegressorQuickAdapterV35(BaseRegressionModel): study_name = f"hp-{dk.pair}" storage = self.optuna_storage(dk) pruner = optuna.pruners.HyperbandPruner() - previous_study = self.optuna_previous_study(study_name, storage) + previous_study = self.optuna_study_load_and_cleanup(study_name, storage) study = optuna.create_study( study_name=study_name, sampler=optuna.samplers.TPESampler( @@ -291,6 +295,8 @@ class XGBoostRegressorQuickAdapterV35(BaseRegressionModel): ) if self.optuna_study_has_best_params(previous_study): study.enqueue_trial(previous_study.best_params) + elif self.__optuna_hp_params.get(dk.pair): + study.enqueue_trial(self.__optuna_hp_params[dk.pair]) start = time.time() try: study.optimize( @@ -335,7 +341,7 @@ class XGBoostRegressorQuickAdapterV35(BaseRegressionModel): study_name = f"period-{dk.pair}" storage = self.optuna_storage(dk) pruner = optuna.pruners.HyperbandPruner() - previous_study = self.optuna_previous_study(study_name, storage) + previous_study = self.optuna_study_load_and_cleanup(study_name, storage) study = optuna.create_study( study_name=study_name, sampler=optuna.samplers.TPESampler( @@ -348,6 +354,10 @@ class XGBoostRegressorQuickAdapterV35(BaseRegressionModel): ) if self.optuna_study_has_best_params(previous_study): study.enqueue_trial(previous_study.best_params) + elif self.__optuna_period_params.get(dk.pair): + previous_best_params = self.__optuna_period_params[dk.pair].copy() + del previous_best_params["rmse"] + study.enqueue_trial(previous_best_params) start = time.time() try: study.optimize( @@ -381,7 +391,7 @@ class XGBoostRegressorQuickAdapterV35(BaseRegressionModel): logger.info(f"Optuna period hyperopt | {key:>20s} : {value}") return params - def optuna_previous_study( + def optuna_study_load_and_cleanup( self, study_name: str, storage ) -> optuna.study.Study | None: try: -- 2.43.0