From 75b18ac83dc86fdae9a39c02e0418bc489dccd33 Mon Sep 17 00:00:00 2001 From: =?utf8?q?J=C3=A9r=C3=B4me=20Benoit?= Date: Fri, 28 Feb 2025 19:36:07 +0100 Subject: [PATCH] fix(qav3): ensure period hyperopt compute RMSE properly MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit Signed-off-by: Jérôme Benoit --- ReforceXY/user_data/strategies/RLAgentStrategy.py | 14 +++++++++++++- .../LightGBMRegressorQuickAdapterV35.py | 6 ++++-- .../XGBoostRegressorQuickAdapterV35.py | 6 ++++-- .../user_data/strategies/QuickAdapterV3.py | 4 +++- 4 files changed, 24 insertions(+), 6 deletions(-) diff --git a/ReforceXY/user_data/strategies/RLAgentStrategy.py b/ReforceXY/user_data/strategies/RLAgentStrategy.py index 405a5a7..38053d7 100644 --- a/ReforceXY/user_data/strategies/RLAgentStrategy.py +++ b/ReforceXY/user_data/strategies/RLAgentStrategy.py @@ -23,7 +23,10 @@ class RLAgentStrategy(IStrategy): stoploss = -0.03 use_exit_signal = True startup_candle_count: int = 300 - can_short = False + + @property + def can_short(self): + return self.is_short_allowed() # def feature_engineering_expand_all( # self, dataframe: DataFrame, period: int, metadata: dict, **kwargs @@ -92,3 +95,12 @@ class RLAgentStrategy(IStrategy): df.loc[reduce(lambda x, y: x & y, exit_short_conditions), "exit_short"] = 1 return df + + def is_short_allowed(self) -> bool: + trading_mode = self.config.get("trading_mode") + if trading_mode == "futures": + return True + elif trading_mode == "spot": + return False + else: + raise ValueError(f"Invalid trading_mode: {trading_mode}") diff --git a/quickadapter/user_data/freqaimodels/LightGBMRegressorQuickAdapterV35.py b/quickadapter/user_data/freqaimodels/LightGBMRegressorQuickAdapterV35.py index 59ede95..84f0b64 100644 --- a/quickadapter/user_data/freqaimodels/LightGBMRegressorQuickAdapterV35.py +++ b/quickadapter/user_data/freqaimodels/LightGBMRegressorQuickAdapterV35.py @@ -612,8 +612,10 @@ def period_objective( y_pred_length = len(y_pred) y_test = y_test.tail(y_test_length - (y_test_length % label_period_candles)) y_pred = y_pred[-(y_pred_length - (y_pred_length % label_period_candles)) :] - y_test.reshape(len(y_test) // label_period_candles, label_period_candles) - y_pred.reshape(len(y_pred) // label_period_candles, label_period_candles) + y_test = y_test.to_numpy().reshape( + len(y_test) // label_period_candles, label_period_candles + ) + y_pred = y_pred.reshape(len(y_pred) // label_period_candles, label_period_candles) error = sklearn.metrics.root_mean_squared_error(y_test, y_pred) diff --git a/quickadapter/user_data/freqaimodels/XGBoostRegressorQuickAdapterV35.py b/quickadapter/user_data/freqaimodels/XGBoostRegressorQuickAdapterV35.py index 3b02ebf..e14bf09 100644 --- a/quickadapter/user_data/freqaimodels/XGBoostRegressorQuickAdapterV35.py +++ b/quickadapter/user_data/freqaimodels/XGBoostRegressorQuickAdapterV35.py @@ -616,8 +616,10 @@ def period_objective( y_pred_length = len(y_pred) y_test = y_test.tail(y_test_length - (y_test_length % label_period_candles)) y_pred = y_pred[-(y_pred_length - (y_pred_length % label_period_candles)) :] - y_test.reshape(len(y_test) // label_period_candles, label_period_candles) - y_pred.reshape(len(y_pred) // label_period_candles, label_period_candles) + y_test = y_test.to_numpy().reshape( + len(y_test) // label_period_candles, label_period_candles + ) + y_pred = y_pred.reshape(len(y_pred) // label_period_candles, label_period_candles) error = sklearn.metrics.root_mean_squared_error(y_test, y_pred) diff --git a/quickadapter/user_data/strategies/QuickAdapterV3.py b/quickadapter/user_data/strategies/QuickAdapterV3.py index 52e59c6..1c35c68 100644 --- a/quickadapter/user_data/strategies/QuickAdapterV3.py +++ b/quickadapter/user_data/strategies/QuickAdapterV3.py @@ -65,7 +65,9 @@ class QuickAdapterV3(IStrategy): process_only_new_candles = True - can_short = True + @property + def can_short(self): + return self.is_short_allowed() @property def plot_config(self): -- 2.43.0