]> Piment Noir Git Repositories - freqai-strategies.git/commitdiff
fix(qav3): avoid train and test set disperancy
authorJérôme Benoit <jerome.benoit@piment-noir.org>
Thu, 6 Feb 2025 11:15:37 +0000 (12:15 +0100)
committerJérôme Benoit <jerome.benoit@piment-noir.org>
Thu, 6 Feb 2025 11:15:37 +0000 (12:15 +0100)
Signed-off-by: Jérôme Benoit <jerome.benoit@piment-noir.org>
quickadapter/user_data/freqaimodels/LightGBMRegressorQuickAdapterV35.py
quickadapter/user_data/freqaimodels/XGBoostRegressorQuickAdapterV35.py
quickadapter/user_data/strategies/QuickAdapterV3.py

index 65f6b166233501b46ba1fa37e68ed723f93dbc4a..c4f274f06d3a8757123ac4ec3f7131ea0e2a305b 100644 (file)
@@ -103,6 +103,12 @@ class LightGBMRegressorQuickAdapterV35(BaseRegressionModel):
             y_test = y_test.tail(test_window)
             test_weights = test_weights[-test_window:]
 
+            if dk.pair not in self.freqai_info["feature_parameters"]:
+                self.freqai_info["feature_parameters"][dk.pair] = {}
+            self.freqai_info["feature_parameters"][dk.pair]["label_period_candles"] = (
+                self.__optuna_hp.get("label_period_candles")
+            )
+
         eval_set, eval_weights = self.eval_set_and_weights(X_test, y_test, test_weights)
 
         model.fit(
@@ -146,9 +152,6 @@ class LightGBMRegressorQuickAdapterV35(BaseRegressionModel):
                 label_period_candles = self.__optuna_hp.get(
                     "label_period_candles", self.ft_params["label_period_candles"]
                 )
-                self.freqai_info["feature_parameters"]["label_period_candles"] = (
-                    label_period_candles
-                )
             else:
                 label_period_candles = self.ft_params["label_period_candles"]
             min_pred, max_pred = min_max_pred(
@@ -231,19 +234,25 @@ def objective(
     candles_step,
     params,
 ):
-    if (len(X) != len(y)) or (len(X) != len(train_weights)):
-        raise ValueError("Training sets must have the same length")
-    if (len(X_test) != len(y_test)) or (len(X_test) != len(test_weights)):
-        raise ValueError("Test sets must have the same length")
+    min_train_window: int = 10
+    max_train_window: int = (
+        len(X) if len(X) > min_train_window else (min_train_window + len(X))
+    )
     train_window = trial.suggest_int(
-        "train_period_candles", 0, len(X), step=candles_step
+        "train_period_candles", min_train_window, max_train_window, step=candles_step
     )
     X = X.tail(train_window)
     y = y.tail(train_window)
     train_weights = train_weights[-train_window:]
 
+    min_test_window: int = 10
+    max_test_window: int = (
+        len(X_test)
+        if len(X_test) > min_test_window
+        else (min_test_window + len(X_test))
+    )
     test_window = trial.suggest_int(
-        "test_period_candles", 0, len(X_test), step=candles_step
+        "test_period_candles", min_test_window, max_test_window, step=candles_step
     )
     X_test = X_test.tail(test_window)
     y_test = y_test.tail(test_window)
index 6a2e8203e8db43a0e93ceb9af3a97c20344c2dc0..d66ec073f999b7867aef2df692a84ab0f9cbc8b2 100644 (file)
@@ -103,6 +103,12 @@ class XGBoostRegressorQuickAdapterV35(BaseRegressionModel):
             y_test = y_test.tail(test_window)
             test_weights = test_weights[-test_window:]
 
+            if dk.pair not in self.freqai_info["feature_parameters"]:
+                self.freqai_info["feature_parameters"][dk.pair] = {}
+            self.freqai_info["feature_parameters"][dk.pair]["label_period_candles"] = (
+                self.__optuna_hp.get("label_period_candles")
+            )
+
         eval_set, eval_weights = self.eval_set_and_weights(X_test, y_test, test_weights)
 
         model.fit(
@@ -146,9 +152,6 @@ class XGBoostRegressorQuickAdapterV35(BaseRegressionModel):
                 label_period_candles = self.__optuna_hp.get(
                     "label_period_candles", self.ft_params["label_period_candles"]
                 )
-                self.freqai_info["feature_parameters"]["label_period_candles"] = (
-                    label_period_candles
-                )
             else:
                 label_period_candles = self.ft_params["label_period_candles"]
             min_pred, max_pred = min_max_pred(
@@ -231,19 +234,25 @@ def objective(
     candles_step,
     params,
 ):
-    if (len(X) != len(y)) or (len(X) != len(train_weights)):
-        raise ValueError("Training sets must have the same length")
-    if (len(X_test) != len(y_test)) or (len(X_test) != len(test_weights)):
-        raise ValueError("Test sets must have the same length")
+    min_train_window: int = 10
+    max_train_window: int = (
+        len(X) if len(X) > min_train_window else (min_train_window + len(X))
+    )
     train_window = trial.suggest_int(
-        "train_period_candles", 0, len(X), step=candles_step
+        "train_period_candles", min_train_window, max_train_window, step=candles_step
     )
     X = X.tail(train_window)
     y = y.tail(train_window)
     train_weights = train_weights[-train_window:]
 
+    min_test_window: int = 10
+    max_test_window: int = (
+        len(X_test)
+        if len(X_test) > min_test_window
+        else (min_test_window + len(X_test))
+    )
     test_window = trial.suggest_int(
-        "test_period_candles", 0, len(X_test), step=candles_step
+        "test_period_candles", min_test_window, max_test_window, step=candles_step
     )
     X_test = X_test.tail(test_window)
     y_test = y_test.tail(test_window)
index 30ba56a2f9074ca67cc852c3330997604016e9fb..a2c1f441135715b337ab40a34af94293697fb0a7 100644 (file)
@@ -225,17 +225,26 @@ class QuickAdapterV3(IStrategy):
         dataframe["%-hour_of_day"] = (dataframe["date"].dt.hour + 1) / 25
         return dataframe
 
-    def set_freqai_targets(self, dataframe, **kwargs):
+    def set_freqai_targets(self, dataframe, metadata, **kwargs):
+        pair = str(metadata.get("pair"))
+        label_period_candles = (
+            self.freqai_info["feature_parameters"]
+            .get(pair, {})
+            .get(
+                "label_period_candles",
+                self.freqai_info["feature_parameters"]["label_period_candles"],
+            )
+        )
         dataframe["&s-extrema"] = 0
         min_peaks = argrelextrema(
             dataframe["low"].values,
             np.less,
-            order=self.freqai_info["feature_parameters"]["label_period_candles"],
+            order=label_period_candles,
         )
         max_peaks = argrelextrema(
             dataframe["high"].values,
             np.greater,
-            order=self.freqai_info["feature_parameters"]["label_period_candles"],
+            order=label_period_candles,
         )
         for mp in min_peaks[0]:
             dataframe.at[mp, "&s-extrema"] = -1