]> Piment Noir Git Repositories - freqai-strategies.git/commitdiff
refactor(qav3): cleanup label_period_candles handling
authorJérôme Benoit <jerome.benoit@piment-noir.org>
Thu, 13 Mar 2025 18:41:38 +0000 (19:41 +0100)
committerJérôme Benoit <jerome.benoit@piment-noir.org>
Thu, 13 Mar 2025 18:41:38 +0000 (19:41 +0100)
Signed-off-by: Jérôme Benoit <jerome.benoit@piment-noir.org>
quickadapter/user_data/freqaimodels/LightGBMRegressorQuickAdapterV35.py
quickadapter/user_data/freqaimodels/XGBoostRegressorQuickAdapterV35.py
quickadapter/user_data/strategies/QuickAdapterV3.py

index 35f40b6114552d76cb8c85e2bff6c644d363996c..291f3fefe03b58fdda30c5355984c7680a8a0904 100644 (file)
@@ -152,6 +152,11 @@ class LightGBMRegressorQuickAdapterV35(BaseRegressionModel):
 
         return model
 
+    def get_label_period_candles(self, pair: str) -> int:
+        if self.__optuna_period_params.get(pair, {}).get("label_period_candles"):
+            return self.__optuna_period_params[pair]["label_period_candles"]
+        return self.ft_params["label_period_candles"]
+
     def fit_live_predictions(self, dk: FreqaiDataKitchen, pair: str) -> None:
         warmed_up = True
 
@@ -176,9 +181,7 @@ class LightGBMRegressorQuickAdapterV35(BaseRegressionModel):
             dk.data["extra_returns_per_train"][MINIMA_THRESHOLD_COLUMN] = -2
             dk.data["extra_returns_per_train"][MAXIMA_THRESHOLD_COLUMN] = 2
         else:
-            label_period_candles = self.__optuna_period_params.get(pair, {}).get(
-                "label_period_candles", self.ft_params["label_period_candles"]
-            )
+            label_period_candles = self.get_label_period_candles(pair)
             min_pred, max_pred = self.min_max_pred(
                 pred_df_full,
                 num_candles,
@@ -217,9 +220,7 @@ class LightGBMRegressorQuickAdapterV35(BaseRegressionModel):
         dk.data["extra_returns_per_train"]["DI_cutoff"] = cutoff
 
         dk.data["extra_returns_per_train"]["label_period_candles"] = (
-            self.__optuna_period_params.get(pair, {}).get(
-                "label_period_candles", self.ft_params["label_period_candles"]
-            )
+            self.get_label_period_candles(pair)
         )
         dk.data["extra_returns_per_train"]["hp_rmse"] = self.__optuna_hp_rmse.get(
             pair, -1
index 29a5efa4cbeb8ed00837401a5f073e047f3df719..3aa9abb8b87ed8630a943f564ec9d2010f8e23d1 100644 (file)
@@ -155,6 +155,11 @@ class XGBoostRegressorQuickAdapterV35(BaseRegressionModel):
 
         return model
 
+    def get_label_period_candles(self, pair: str) -> int:
+        if self.__optuna_period_params.get(pair, {}).get("label_period_candles"):
+            return self.__optuna_period_params[pair]["label_period_candles"]
+        return self.ft_params["label_period_candles"]
+
     def fit_live_predictions(self, dk: FreqaiDataKitchen, pair: str) -> None:
         warmed_up = True
 
@@ -179,9 +184,7 @@ class XGBoostRegressorQuickAdapterV35(BaseRegressionModel):
             dk.data["extra_returns_per_train"][MINIMA_THRESHOLD_COLUMN] = -2
             dk.data["extra_returns_per_train"][MAXIMA_THRESHOLD_COLUMN] = 2
         else:
-            label_period_candles = self.__optuna_period_params.get(pair, {}).get(
-                "label_period_candles", self.ft_params["label_period_candles"]
-            )
+            label_period_candles = self.get_label_period_candles(pair)
             min_pred, max_pred = self.min_max_pred(
                 pred_df_full,
                 num_candles,
@@ -220,9 +223,7 @@ class XGBoostRegressorQuickAdapterV35(BaseRegressionModel):
         dk.data["extra_returns_per_train"]["DI_cutoff"] = cutoff
 
         dk.data["extra_returns_per_train"]["label_period_candles"] = (
-            self.__optuna_period_params.get(pair, {}).get(
-                "label_period_candles", self.ft_params["label_period_candles"]
-            )
+            self.get_label_period_candles(pair)
         )
         dk.data["extra_returns_per_train"]["hp_rmse"] = self.__optuna_hp_rmse.get(
             pair, -1
index eea84f80cea4f1d0fe79006cbb8b823784c3d90d..3b2d02afed49900a2e05809bb79f6df9df9f7b09 100644 (file)
@@ -265,24 +265,13 @@ class QuickAdapterV3(IStrategy):
         dataframe["%-hour_of_day"] = (dataframe["date"].dt.hour + 1) / 25
         return dataframe
 
-    def get_label_period_candles(self, metadata, **kwargs) -> int:
-        pair = str(metadata.get("pair"))
+    def get_label_period_candles(self, pair: str) -> int:
         if self.__period_params.get(pair, {}).get("label_period_candles"):
-            label_period_candles = self.__period_params.get(pair, {}).get(
-                "label_period_candles",
-            )
-        else:
-            label_period_candles = self.freqai_info["feature_parameters"][
-                "label_period_candles"
-            ]
-        if label_period_candles < 1:
-            raise ValueError(
-                f"label_period_candles must be greater than 0, got {label_period_candles}"
-            )
-        return label_period_candles
+            return self.__period_params[pair]["label_period_candles"]
+        return self.freqai_info["feature_parameters"]["label_period_candles"]
 
     def set_freqai_targets(self, dataframe, metadata, **kwargs):
-        label_period_candles = self.get_label_period_candles(metadata, **kwargs)
+        label_period_candles = self.get_label_period_candles(str(metadata.get("pair")))
         min_peaks = argrelmin(
             dataframe["low"].values,
             order=label_period_candles,