From 5de92b120cfa972712a89f7817556a8e462ebc20 Mon Sep 17 00:00:00 2001 From: =?utf8?q?J=C3=A9r=C3=B4me=20Benoit?= Date: Thu, 20 Feb 2025 16:16:16 +0100 Subject: [PATCH] refactor(qav3): factor out optuna storage handling ops MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit Signed-off-by: Jérôme Benoit --- .../LightGBMRegressorQuickAdapterV35.py | 20 ++++++++++--------- .../XGBoostRegressorQuickAdapterV35.py | 20 ++++++++++--------- 2 files changed, 22 insertions(+), 18 deletions(-) diff --git a/quickadapter/user_data/freqaimodels/LightGBMRegressorQuickAdapterV35.py b/quickadapter/user_data/freqaimodels/LightGBMRegressorQuickAdapterV35.py index 34a3ace..dee1983 100644 --- a/quickadapter/user_data/freqaimodels/LightGBMRegressorQuickAdapterV35.py +++ b/quickadapter/user_data/freqaimodels/LightGBMRegressorQuickAdapterV35.py @@ -305,7 +305,8 @@ class LightGBMRegressorQuickAdapterV35(BaseRegressionModel): study_name = f"hp-{dk.pair}" storage = self.optuna_storage(dk) pruner = optuna.pruners.HyperbandPruner() - previous_study = self.optuna_study_load_and_cleanup(study_name, storage) + previous_study = self.optuna_study_load(study_name, storage) + self.optuna_study_delete(study_name, storage) study = optuna.create_study( study_name=study_name, sampler=optuna.samplers.TPESampler( @@ -375,7 +376,8 @@ class LightGBMRegressorQuickAdapterV35(BaseRegressionModel): study_name = f"period-{dk.pair}" storage = self.optuna_storage(dk) pruner = optuna.pruners.HyperbandPruner() - previous_study = self.optuna_study_load_and_cleanup(study_name, storage) + previous_study = self.optuna_study_load(study_name, storage) + self.optuna_study_delete(study_name, storage) study = optuna.create_study( study_name=study_name, sampler=optuna.samplers.TPESampler( @@ -443,17 +445,17 @@ class LightGBMRegressorQuickAdapterV35(BaseRegressionModel): return json.load(read_file) return None - def optuna_study_load_and_cleanup( - self, study_name: str, storage - ) -> optuna.study.Study | None: - try: - study = optuna.load_study(study_name=study_name, storage=storage) - except Exception: - study = None + def optuna_study_delete(self, study_name: str, storage) -> None: try: optuna.delete_study(study_name=study_name, storage=storage) except Exception: pass + + def optuna_study_load(self, study_name: str, storage) -> optuna.study.Study | None: + try: + study = optuna.load_study(study_name=study_name, storage=storage) + except Exception: + study = None return study def optuna_study_has_best_params(self, study: optuna.study.Study | None) -> bool: diff --git a/quickadapter/user_data/freqaimodels/XGBoostRegressorQuickAdapterV35.py b/quickadapter/user_data/freqaimodels/XGBoostRegressorQuickAdapterV35.py index b22cbe1..2568a5b 100644 --- a/quickadapter/user_data/freqaimodels/XGBoostRegressorQuickAdapterV35.py +++ b/quickadapter/user_data/freqaimodels/XGBoostRegressorQuickAdapterV35.py @@ -306,7 +306,8 @@ class XGBoostRegressorQuickAdapterV35(BaseRegressionModel): study_name = f"hp-{dk.pair}" storage = self.optuna_storage(dk) pruner = optuna.pruners.HyperbandPruner() - previous_study = self.optuna_study_load_and_cleanup(study_name, storage) + previous_study = self.optuna_study_load(study_name, storage) + self.optuna_study_delete(study_name, storage) study = optuna.create_study( study_name=study_name, sampler=optuna.samplers.TPESampler( @@ -376,7 +377,8 @@ class XGBoostRegressorQuickAdapterV35(BaseRegressionModel): study_name = f"period-{dk.pair}" storage = self.optuna_storage(dk) pruner = optuna.pruners.HyperbandPruner() - previous_study = self.optuna_study_load_and_cleanup(study_name, storage) + previous_study = self.optuna_study_load(study_name, storage) + self.optuna_study_delete(study_name, storage) study = optuna.create_study( study_name=study_name, sampler=optuna.samplers.TPESampler( @@ -444,17 +446,17 @@ class XGBoostRegressorQuickAdapterV35(BaseRegressionModel): return json.load(read_file) return None - def optuna_study_load_and_cleanup( - self, study_name: str, storage - ) -> optuna.study.Study | None: - try: - study = optuna.load_study(study_name=study_name, storage=storage) - except Exception: - study = None + def optuna_study_delete(self, study_name: str, storage) -> None: try: optuna.delete_study(study_name=study_name, storage=storage) except Exception: pass + + def optuna_study_load(self, study_name: str, storage) -> optuna.study.Study | None: + try: + study = optuna.load_study(study_name=study_name, storage=storage) + except Exception: + study = None return study def optuna_study_has_best_params(self, study: optuna.study.Study | None) -> bool: -- 2.43.0