]> Piment Noir Git Repositories - freqai-strategies.git/commitdiff
refactor(qav3): move attributes init to constructor
authorJérôme Benoit <jerome.benoit@piment-noir.org>
Thu, 20 Feb 2025 16:59:17 +0000 (17:59 +0100)
committerJérôme Benoit <jerome.benoit@piment-noir.org>
Thu, 20 Feb 2025 16:59:17 +0000 (17:59 +0100)
Signed-off-by: Jérôme Benoit <jerome.benoit@piment-noir.org>
quickadapter/user_data/freqaimodels/LightGBMRegressorQuickAdapterV35.py
quickadapter/user_data/freqaimodels/XGBoostRegressorQuickAdapterV35.py

index dee1983630246752d0cf5f12c7276ba019e95152..a5e12c798c11ef2ead5ea49220d53596cd310b0a 100644 (file)
@@ -45,6 +45,7 @@ class LightGBMRegressorQuickAdapterV35(BaseRegressionModel):
 
     def __init__(self, **kwargs):
         super().__init__(**kwargs)
+        self.pairs = self.config.get("exchange", {}).get("pair_whitelist")
         self.__optuna_config = self.freqai_info.get("optuna_hyperopt", {})
         self.__optuna_hyperopt: bool = (
             self.freqai_info.get("enabled", False)
@@ -55,6 +56,11 @@ class LightGBMRegressorQuickAdapterV35(BaseRegressionModel):
         self.__optuna_period_rmse: dict[str, float] = {}
         self.__optuna_hp_params: dict[str, dict] = {}
         self.__optuna_period_params: dict[str, dict] = {}
+        for pair in self.pairs:
+            self.__optuna_hp_rmse[pair] = -1
+            self.__optuna_period_rmse[pair] = -1
+            self.__optuna_hp_params[pair] = {}
+            self.__optuna_period_params[pair] = {}
 
     def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
         """
@@ -81,16 +87,12 @@ class LightGBMRegressorQuickAdapterV35(BaseRegressionModel):
                 dk, X, y, train_weights, X_test, y_test, test_weights
             )
             if optuna_hp_params:
-                if dk.pair not in self.__optuna_hp_params:
-                    self.__optuna_hp_params[dk.pair] = {}
                 self.__optuna_hp_params[dk.pair] = optuna_hp_params
                 model_training_parameters = {
                     **model_training_parameters,
                     **self.__optuna_hp_params[dk.pair],
                 }
             if optuna_hp_rmse:
-                if dk.pair not in self.__optuna_hp_rmse:
-                    self.__optuna_hp_rmse[dk.pair] = -1
                 self.__optuna_hp_rmse[dk.pair] = optuna_hp_rmse
 
             optuna_period_params, optuna_period_rmse = self.optuna_period_optimize(
@@ -104,12 +106,8 @@ class LightGBMRegressorQuickAdapterV35(BaseRegressionModel):
                 model_training_parameters,
             )
             if optuna_period_params:
-                if dk.pair not in self.__optuna_period_params:
-                    self.__optuna_period_params[dk.pair] = {}
                 self.__optuna_period_params[dk.pair] = optuna_period_params
             if optuna_period_rmse:
-                if dk.pair not in self.__optuna_period_rmse:
-                    self.__optuna_period_rmse[dk.pair] = -1
                 self.__optuna_period_rmse[dk.pair] = optuna_period_rmse
 
             if self.__optuna_period_params.get(dk.pair):
index 2568a5b10f08443de7b238696f6213c957789e45..924790fcb3edf9e5199c4074744a4638c00ee99d 100644 (file)
@@ -45,6 +45,7 @@ class XGBoostRegressorQuickAdapterV35(BaseRegressionModel):
 
     def __init__(self, **kwargs):
         super().__init__(**kwargs)
+        self.pairs = self.config.get("exchange", {}).get("pair_whitelist")
         self.__optuna_config = self.freqai_info.get("optuna_hyperopt", {})
         self.__optuna_hyperopt: bool = (
             self.freqai_info.get("enabled", False)
@@ -55,6 +56,11 @@ class XGBoostRegressorQuickAdapterV35(BaseRegressionModel):
         self.__optuna_period_rmse: dict[str, float] = {}
         self.__optuna_hp_params: dict[str, dict] = {}
         self.__optuna_period_params: dict[str, dict] = {}
+        for pair in self.pairs:
+            self.__optuna_hp_rmse[pair] = -1
+            self.__optuna_period_rmse[pair] = -1
+            self.__optuna_hp_params[pair] = {}
+            self.__optuna_period_params[pair] = {}
 
     def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
         """
@@ -81,16 +87,12 @@ class XGBoostRegressorQuickAdapterV35(BaseRegressionModel):
                 dk, X, y, train_weights, X_test, y_test, test_weights
             )
             if optuna_hp_params:
-                if dk.pair not in self.__optuna_hp_params:
-                    self.__optuna_hp_params[dk.pair] = {}
                 self.__optuna_hp_params[dk.pair] = optuna_hp_params
                 model_training_parameters = {
                     **model_training_parameters,
                     **self.__optuna_hp_params[dk.pair],
                 }
             if optuna_hp_rmse:
-                if dk.pair not in self.__optuna_hp_rmse:
-                    self.__optuna_hp_rmse[dk.pair] = -1
                 self.__optuna_hp_rmse[dk.pair] = optuna_hp_rmse
 
             optuna_period_params, optuna_period_rmse = self.optuna_period_optimize(
@@ -104,12 +106,8 @@ class XGBoostRegressorQuickAdapterV35(BaseRegressionModel):
                 model_training_parameters,
             )
             if optuna_period_params:
-                if dk.pair not in self.__optuna_period_params:
-                    self.__optuna_period_params[dk.pair] = {}
                 self.__optuna_period_params[dk.pair] = optuna_period_params
             if optuna_period_rmse:
-                if dk.pair not in self.__optuna_period_rmse:
-                    self.__optuna_period_rmse[dk.pair] = -1
                 self.__optuna_period_rmse[dk.pair] = optuna_period_rmse
 
             if self.__optuna_period_params.get(dk.pair):