From 11715546eebc8ebdec6902f042dc69b8a52cf332 Mon Sep 17 00:00:00 2001 From: =?utf8?q?J=C3=A9r=C3=B4me=20Benoit?= Date: Thu, 20 Feb 2025 17:59:17 +0100 Subject: [PATCH] refactor(qav3): move attributes init to constructor MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit Signed-off-by: Jérôme Benoit --- .../LightGBMRegressorQuickAdapterV35.py | 14 ++++++-------- .../XGBoostRegressorQuickAdapterV35.py | 14 ++++++-------- 2 files changed, 12 insertions(+), 16 deletions(-) diff --git a/quickadapter/user_data/freqaimodels/LightGBMRegressorQuickAdapterV35.py b/quickadapter/user_data/freqaimodels/LightGBMRegressorQuickAdapterV35.py index dee1983..a5e12c7 100644 --- a/quickadapter/user_data/freqaimodels/LightGBMRegressorQuickAdapterV35.py +++ b/quickadapter/user_data/freqaimodels/LightGBMRegressorQuickAdapterV35.py @@ -45,6 +45,7 @@ class LightGBMRegressorQuickAdapterV35(BaseRegressionModel): def __init__(self, **kwargs): super().__init__(**kwargs) + self.pairs = self.config.get("exchange", {}).get("pair_whitelist") self.__optuna_config = self.freqai_info.get("optuna_hyperopt", {}) self.__optuna_hyperopt: bool = ( self.freqai_info.get("enabled", False) @@ -55,6 +56,11 @@ class LightGBMRegressorQuickAdapterV35(BaseRegressionModel): self.__optuna_period_rmse: dict[str, float] = {} self.__optuna_hp_params: dict[str, dict] = {} self.__optuna_period_params: dict[str, dict] = {} + for pair in self.pairs: + self.__optuna_hp_rmse[pair] = -1 + self.__optuna_period_rmse[pair] = -1 + self.__optuna_hp_params[pair] = {} + self.__optuna_period_params[pair] = {} def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any: """ @@ -81,16 +87,12 @@ class LightGBMRegressorQuickAdapterV35(BaseRegressionModel): dk, X, y, train_weights, X_test, y_test, test_weights ) if optuna_hp_params: - if dk.pair not in self.__optuna_hp_params: - self.__optuna_hp_params[dk.pair] = {} self.__optuna_hp_params[dk.pair] = optuna_hp_params model_training_parameters = { **model_training_parameters, **self.__optuna_hp_params[dk.pair], } if optuna_hp_rmse: - if dk.pair not in self.__optuna_hp_rmse: - self.__optuna_hp_rmse[dk.pair] = -1 self.__optuna_hp_rmse[dk.pair] = optuna_hp_rmse optuna_period_params, optuna_period_rmse = self.optuna_period_optimize( @@ -104,12 +106,8 @@ class LightGBMRegressorQuickAdapterV35(BaseRegressionModel): model_training_parameters, ) if optuna_period_params: - if dk.pair not in self.__optuna_period_params: - self.__optuna_period_params[dk.pair] = {} self.__optuna_period_params[dk.pair] = optuna_period_params if optuna_period_rmse: - if dk.pair not in self.__optuna_period_rmse: - self.__optuna_period_rmse[dk.pair] = -1 self.__optuna_period_rmse[dk.pair] = optuna_period_rmse if self.__optuna_period_params.get(dk.pair): diff --git a/quickadapter/user_data/freqaimodels/XGBoostRegressorQuickAdapterV35.py b/quickadapter/user_data/freqaimodels/XGBoostRegressorQuickAdapterV35.py index 2568a5b..924790f 100644 --- a/quickadapter/user_data/freqaimodels/XGBoostRegressorQuickAdapterV35.py +++ b/quickadapter/user_data/freqaimodels/XGBoostRegressorQuickAdapterV35.py @@ -45,6 +45,7 @@ class XGBoostRegressorQuickAdapterV35(BaseRegressionModel): def __init__(self, **kwargs): super().__init__(**kwargs) + self.pairs = self.config.get("exchange", {}).get("pair_whitelist") self.__optuna_config = self.freqai_info.get("optuna_hyperopt", {}) self.__optuna_hyperopt: bool = ( self.freqai_info.get("enabled", False) @@ -55,6 +56,11 @@ class XGBoostRegressorQuickAdapterV35(BaseRegressionModel): self.__optuna_period_rmse: dict[str, float] = {} self.__optuna_hp_params: dict[str, dict] = {} self.__optuna_period_params: dict[str, dict] = {} + for pair in self.pairs: + self.__optuna_hp_rmse[pair] = -1 + self.__optuna_period_rmse[pair] = -1 + self.__optuna_hp_params[pair] = {} + self.__optuna_period_params[pair] = {} def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any: """ @@ -81,16 +87,12 @@ class XGBoostRegressorQuickAdapterV35(BaseRegressionModel): dk, X, y, train_weights, X_test, y_test, test_weights ) if optuna_hp_params: - if dk.pair not in self.__optuna_hp_params: - self.__optuna_hp_params[dk.pair] = {} self.__optuna_hp_params[dk.pair] = optuna_hp_params model_training_parameters = { **model_training_parameters, **self.__optuna_hp_params[dk.pair], } if optuna_hp_rmse: - if dk.pair not in self.__optuna_hp_rmse: - self.__optuna_hp_rmse[dk.pair] = -1 self.__optuna_hp_rmse[dk.pair] = optuna_hp_rmse optuna_period_params, optuna_period_rmse = self.optuna_period_optimize( @@ -104,12 +106,8 @@ class XGBoostRegressorQuickAdapterV35(BaseRegressionModel): model_training_parameters, ) if optuna_period_params: - if dk.pair not in self.__optuna_period_params: - self.__optuna_period_params[dk.pair] = {} self.__optuna_period_params[dk.pair] = optuna_period_params if optuna_period_rmse: - if dk.pair not in self.__optuna_period_rmse: - self.__optuna_period_rmse[dk.pair] = -1 self.__optuna_period_rmse[dk.pair] = optuna_period_rmse if self.__optuna_period_params.get(dk.pair): -- 2.43.0