From: Jérôme Benoit Date: Fri, 28 Feb 2025 18:36:07 +0000 (+0100) Subject: fix(qav3): ensure period hyperopt compute RMSE properly X-Git-Url: https://git.piment-noir.org/?a=commitdiff_plain;h=75b18ac83dc86fdae9a39c02e0418bc489dccd33;p=freqai-strategies.git fix(qav3): ensure period hyperopt compute RMSE properly Signed-off-by: Jérôme Benoit --- diff --git a/ReforceXY/user_data/strategies/RLAgentStrategy.py b/ReforceXY/user_data/strategies/RLAgentStrategy.py index 405a5a7..38053d7 100644 --- a/ReforceXY/user_data/strategies/RLAgentStrategy.py +++ b/ReforceXY/user_data/strategies/RLAgentStrategy.py @@ -23,7 +23,10 @@ class RLAgentStrategy(IStrategy): stoploss = -0.03 use_exit_signal = True startup_candle_count: int = 300 - can_short = False + + @property + def can_short(self): + return self.is_short_allowed() # def feature_engineering_expand_all( # self, dataframe: DataFrame, period: int, metadata: dict, **kwargs @@ -92,3 +95,12 @@ class RLAgentStrategy(IStrategy): df.loc[reduce(lambda x, y: x & y, exit_short_conditions), "exit_short"] = 1 return df + + def is_short_allowed(self) -> bool: + trading_mode = self.config.get("trading_mode") + if trading_mode == "futures": + return True + elif trading_mode == "spot": + return False + else: + raise ValueError(f"Invalid trading_mode: {trading_mode}") diff --git a/quickadapter/user_data/freqaimodels/LightGBMRegressorQuickAdapterV35.py b/quickadapter/user_data/freqaimodels/LightGBMRegressorQuickAdapterV35.py index 59ede95..84f0b64 100644 --- a/quickadapter/user_data/freqaimodels/LightGBMRegressorQuickAdapterV35.py +++ b/quickadapter/user_data/freqaimodels/LightGBMRegressorQuickAdapterV35.py @@ -612,8 +612,10 @@ def period_objective( y_pred_length = len(y_pred) y_test = y_test.tail(y_test_length - (y_test_length % label_period_candles)) y_pred = y_pred[-(y_pred_length - (y_pred_length % label_period_candles)) :] - y_test.reshape(len(y_test) // label_period_candles, label_period_candles) - y_pred.reshape(len(y_pred) // label_period_candles, label_period_candles) + y_test = y_test.to_numpy().reshape( + len(y_test) // label_period_candles, label_period_candles + ) + y_pred = y_pred.reshape(len(y_pred) // label_period_candles, label_period_candles) error = sklearn.metrics.root_mean_squared_error(y_test, y_pred) diff --git a/quickadapter/user_data/freqaimodels/XGBoostRegressorQuickAdapterV35.py b/quickadapter/user_data/freqaimodels/XGBoostRegressorQuickAdapterV35.py index 3b02ebf..e14bf09 100644 --- a/quickadapter/user_data/freqaimodels/XGBoostRegressorQuickAdapterV35.py +++ b/quickadapter/user_data/freqaimodels/XGBoostRegressorQuickAdapterV35.py @@ -616,8 +616,10 @@ def period_objective( y_pred_length = len(y_pred) y_test = y_test.tail(y_test_length - (y_test_length % label_period_candles)) y_pred = y_pred[-(y_pred_length - (y_pred_length % label_period_candles)) :] - y_test.reshape(len(y_test) // label_period_candles, label_period_candles) - y_pred.reshape(len(y_pred) // label_period_candles, label_period_candles) + y_test = y_test.to_numpy().reshape( + len(y_test) // label_period_candles, label_period_candles + ) + y_pred = y_pred.reshape(len(y_pred) // label_period_candles, label_period_candles) error = sklearn.metrics.root_mean_squared_error(y_test, y_pred) diff --git a/quickadapter/user_data/strategies/QuickAdapterV3.py b/quickadapter/user_data/strategies/QuickAdapterV3.py index 52e59c6..1c35c68 100644 --- a/quickadapter/user_data/strategies/QuickAdapterV3.py +++ b/quickadapter/user_data/strategies/QuickAdapterV3.py @@ -65,7 +65,9 @@ class QuickAdapterV3(IStrategy): process_only_new_candles = True - can_short = True + @property + def can_short(self): + return self.is_short_allowed() @property def plot_config(self):