From 9fa2ec3850d2635dcc5845e328e418e1f6368af2 Mon Sep 17 00:00:00 2001 From: =?utf8?q?J=C3=A9r=C3=B4me=20Benoit?= Date: Tue, 11 Feb 2025 14:51:20 +0100 Subject: [PATCH] refactor(qav3): factor out optuna storage building code MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit Signed-off-by: Jérôme Benoit --- .../LightGBMRegressorQuickAdapterV35.py | 36 +++++++++++-------- .../XGBoostRegressorQuickAdapterV35.py | 36 +++++++++++-------- 2 files changed, 42 insertions(+), 30 deletions(-) diff --git a/quickadapter/user_data/freqaimodels/LightGBMRegressorQuickAdapterV35.py b/quickadapter/user_data/freqaimodels/LightGBMRegressorQuickAdapterV35.py index f1eaffa..8889ebb 100644 --- a/quickadapter/user_data/freqaimodels/LightGBMRegressorQuickAdapterV35.py +++ b/quickadapter/user_data/freqaimodels/LightGBMRegressorQuickAdapterV35.py @@ -70,18 +70,7 @@ class LightGBMRegressorQuickAdapterV35(BaseRegressionModel): start = time.time() if self.__optuna_hyperopt: study_name = str(dk.pair) - storage_dir = str(dk.full_path) - storage_type = self.__optuna_config.get("storage_type", "sqlite") - if storage_type == "sqlite": - storage = ( - f"sqlite:///{storage_dir}/optuna-{sanitize_path(study_name)}.sqlite" - ) - elif storage_type == "file": - storage = optuna.storages.JournalStorage( - optuna.storages.journal.JournalFileBackend( - f"{storage_dir}/optuna-{sanitize_path(study_name)}.log" - ) - ) + storage = self.get_optuna_storage(dk) pruner = optuna.pruners.HyperbandPruner() study = optuna.create_study( study_name=study_name, @@ -235,6 +224,21 @@ class LightGBMRegressorQuickAdapterV35(BaseRegressionModel): return eval_set, eval_weights + def get_optuna_storage(self, dk: FreqaiDataKitchen): + storage_dir = str(dk.full_path) + storage_type = self.__optuna_config.get("storage_type", "sqlite") + if storage_type == "sqlite": + storage = ( + f"sqlite:///{storage_dir}/optuna-{sanitize_path(str(dk.pair))}.sqlite" + ) + elif storage_type == "file": + storage = optuna.storages.JournalStorage( + optuna.storages.journal.JournalFileBackend( + f"{storage_dir}/optuna-{sanitize_path(str(dk.pair))}.log" + ) + ) + return storage + def min_max_pred( pred_df: pd.DataFrame, fit_live_predictions_candles: int, label_period_candles: int @@ -256,9 +260,11 @@ def __min_max_pred( .apply(lambda col: col.sort_values(ascending=False, ignore_index=True)) ) - frequency = fit_live_predictions_candles / label_period_candles - min_pred = pred_df_sorted.iloc[-int(frequency) :].median() - max_pred = pred_df_sorted.iloc[: int(frequency)].median() + label_period_frequency: int = int( + fit_live_predictions_candles / label_period_candles + ) + min_pred = pred_df_sorted.iloc[-label_period_frequency:].median() + max_pred = pred_df_sorted.iloc[:label_period_frequency].median() return min_pred["&s-extrema"], max_pred["&s-extrema"] diff --git a/quickadapter/user_data/freqaimodels/XGBoostRegressorQuickAdapterV35.py b/quickadapter/user_data/freqaimodels/XGBoostRegressorQuickAdapterV35.py index e910125..0ba2305 100644 --- a/quickadapter/user_data/freqaimodels/XGBoostRegressorQuickAdapterV35.py +++ b/quickadapter/user_data/freqaimodels/XGBoostRegressorQuickAdapterV35.py @@ -74,18 +74,7 @@ class XGBoostRegressorQuickAdapterV35(BaseRegressionModel): start = time.time() if self.__optuna_hyperopt: study_name = str(dk.pair) - storage_dir = str(dk.full_path) - storage_type = self.__optuna_config.get("storage_type", "sqlite") - if storage_type == "sqlite": - storage = ( - f"sqlite:///{storage_dir}/optuna-{sanitize_path(study_name)}.sqlite" - ) - elif storage_type == "file": - storage = optuna.storages.JournalStorage( - optuna.storages.journal.JournalFileBackend( - f"{storage_dir}/optuna-{sanitize_path(study_name)}.log" - ) - ) + storage = self.get_optuna_storage(dk) pruner = optuna.pruners.HyperbandPruner() study = optuna.create_study( study_name=study_name, @@ -238,6 +227,21 @@ class XGBoostRegressorQuickAdapterV35(BaseRegressionModel): return eval_set, eval_weights + def get_optuna_storage(self, dk: FreqaiDataKitchen): + storage_dir = str(dk.full_path) + storage_type = self.__optuna_config.get("storage_type", "sqlite") + if storage_type == "sqlite": + storage = ( + f"sqlite:///{storage_dir}/optuna-{sanitize_path(str(dk.pair))}.sqlite" + ) + elif storage_type == "file": + storage = optuna.storages.JournalStorage( + optuna.storages.journal.JournalFileBackend( + f"{storage_dir}/optuna-{sanitize_path(str(dk.pair))}.log" + ) + ) + return storage + def min_max_pred( pred_df: pd.DataFrame, fit_live_predictions_candles: int, label_period_candles: int @@ -259,9 +263,11 @@ def __min_max_pred( .apply(lambda col: col.sort_values(ascending=False, ignore_index=True)) ) - frequency = fit_live_predictions_candles / label_period_candles - min_pred = pred_df_sorted.iloc[-int(frequency) :].median() - max_pred = pred_df_sorted.iloc[: int(frequency)].median() + label_period_frequency: int = int( + fit_live_predictions_candles / label_period_candles + ) + min_pred = pred_df_sorted.iloc[-label_period_frequency:].median() + max_pred = pred_df_sorted.iloc[:label_period_frequency].median() return min_pred["&s-extrema"], max_pred["&s-extrema"] -- 2.43.0