From 925019b1409b2caf0bf5501152d29f26ff0782a0 Mon Sep 17 00:00:00 2001 From: =?utf8?q?J=C3=A9r=C3=B4me=20Benoit?= Date: Sat, 13 Sep 2025 02:18:06 +0200 Subject: [PATCH] perf(reforcexy): fine tune reward calculation MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit Signed-off-by: Jérôme Benoit --- ReforceXY/user_data/freqaimodels/ReforceXY.py | 23 +++++++++---------- 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/ReforceXY/user_data/freqaimodels/ReforceXY.py b/ReforceXY/user_data/freqaimodels/ReforceXY.py index 003e2b2..730ef81 100644 --- a/ReforceXY/user_data/freqaimodels/ReforceXY.py +++ b/ReforceXY/user_data/freqaimodels/ReforceXY.py @@ -933,7 +933,7 @@ class ReforceXY(BaseReinforcementLearningModel): # mrr = self.get_most_recent_return() # mrp = self.get_most_recent_profit() - max_trade_duration = self.timeout + max_trade_duration = max(1, self.timeout) trade_duration = self.get_trade_duration() factor = 100.0 @@ -974,7 +974,7 @@ class ReforceXY(BaseReinforcementLearningModel): and self._position == Positions.Neutral ): return self.rl_config.get("model_reward_parameters", {}).get( - "enter_action", 25.0 + "enter_action", 1.0 ) # discourage agent from not entering trades @@ -990,27 +990,26 @@ class ReforceXY(BaseReinforcementLearningModel): self._position in (Positions.Short, Positions.Long) and action == Actions.Neutral.value ): + duration_fraction = trade_duration / max_trade_duration max_pnl = max(self.get_most_recent_max_pnl(), pnl) if max_pnl > 0: - drawdown_penalty = 0.01 * factor * (max_pnl - pnl) + drawdown_penalty = 0.0025 * factor * (max_pnl - pnl) * duration_fraction else: drawdown_penalty = 0.0 lambda1 = 0.05 lambda2 = 0.1 if pnl >= 0: - return ( - factor - * pnl - * np.exp(-lambda1 * (trade_duration / max_trade_duration)) - - lambda2 * (trade_duration / max_trade_duration) - - drawdown_penalty - ) + if duration_fraction < 0.75: + duration_penalty_factor = 1.0 + else: + duration_penalty_factor = 1.0 / (1.0 + lambda1 * duration_fraction) + return factor * pnl * duration_penalty_factor - lambda2 * duration_fraction - drawdown_penalty else: return ( factor * pnl - * (1 + lambda1 * (trade_duration / max_trade_duration)) - - 2 * lambda2 * (trade_duration / max_trade_duration) + * (1 + lambda1 * duration_fraction) + - 2 * lambda2 * duration_fraction - drawdown_penalty ) -- 2.43.0