]> Piment Noir Git Repositories - freqai-strategies.git/commitdiff
fix(qav3): ensure period hyperopt compute RMSE properly
authorJérôme Benoit <jerome.benoit@piment-noir.org>
Fri, 28 Feb 2025 18:36:07 +0000 (19:36 +0100)
committerJérôme Benoit <jerome.benoit@piment-noir.org>
Fri, 28 Feb 2025 18:36:07 +0000 (19:36 +0100)
Signed-off-by: Jérôme Benoit <jerome.benoit@piment-noir.org>
ReforceXY/user_data/strategies/RLAgentStrategy.py
quickadapter/user_data/freqaimodels/LightGBMRegressorQuickAdapterV35.py
quickadapter/user_data/freqaimodels/XGBoostRegressorQuickAdapterV35.py
quickadapter/user_data/strategies/QuickAdapterV3.py

index 405a5a71e3752c0ac6a7bc98913d9db68f1b3e7e..38053d72759ed4331cd9f1d86f32ff92ebbb6952 100644 (file)
@@ -23,7 +23,10 @@ class RLAgentStrategy(IStrategy):
     stoploss = -0.03
     use_exit_signal = True
     startup_candle_count: int = 300
-    can_short = False
+
+    @property
+    def can_short(self):
+        return self.is_short_allowed()
 
     # def feature_engineering_expand_all(
     #     self, dataframe: DataFrame, period: int, metadata: dict, **kwargs
@@ -92,3 +95,12 @@ class RLAgentStrategy(IStrategy):
             df.loc[reduce(lambda x, y: x & y, exit_short_conditions), "exit_short"] = 1
 
         return df
+
+    def is_short_allowed(self) -> bool:
+        trading_mode = self.config.get("trading_mode")
+        if trading_mode == "futures":
+            return True
+        elif trading_mode == "spot":
+            return False
+        else:
+            raise ValueError(f"Invalid trading_mode: {trading_mode}")
index 59ede95839f41458068fd521e5a29a21619e34ee..84f0b6411f1c3a4c1c157bc9d114bc59fe725caa 100644 (file)
@@ -612,8 +612,10 @@ def period_objective(
     y_pred_length = len(y_pred)
     y_test = y_test.tail(y_test_length - (y_test_length % label_period_candles))
     y_pred = y_pred[-(y_pred_length - (y_pred_length % label_period_candles)) :]
-    y_test.reshape(len(y_test) // label_period_candles, label_period_candles)
-    y_pred.reshape(len(y_pred) // label_period_candles, label_period_candles)
+    y_test = y_test.to_numpy().reshape(
+        len(y_test) // label_period_candles, label_period_candles
+    )
+    y_pred = y_pred.reshape(len(y_pred) // label_period_candles, label_period_candles)
 
     error = sklearn.metrics.root_mean_squared_error(y_test, y_pred)
 
index 3b02ebf6d6bced444598f2a85a4f85d034f1ebb7..e14bf098245ba18421deee5863ff4f6ab7b3c682 100644 (file)
@@ -616,8 +616,10 @@ def period_objective(
     y_pred_length = len(y_pred)
     y_test = y_test.tail(y_test_length - (y_test_length % label_period_candles))
     y_pred = y_pred[-(y_pred_length - (y_pred_length % label_period_candles)) :]
-    y_test.reshape(len(y_test) // label_period_candles, label_period_candles)
-    y_pred.reshape(len(y_pred) // label_period_candles, label_period_candles)
+    y_test = y_test.to_numpy().reshape(
+        len(y_test) // label_period_candles, label_period_candles
+    )
+    y_pred = y_pred.reshape(len(y_pred) // label_period_candles, label_period_candles)
 
     error = sklearn.metrics.root_mean_squared_error(y_test, y_pred)
 
index 52e59c63f49475635c7a792f6f43573126c17252..1c35c6827828df9d0f2fc31fe1d478715197243f 100644 (file)
@@ -65,7 +65,9 @@ class QuickAdapterV3(IStrategy):
 
     process_only_new_candles = True
 
-    can_short = True
+    @property
+    def can_short(self):
+        return self.is_short_allowed()
 
     @property
     def plot_config(self):