From 1646767bdab8897f0e6f66324c9cd7d53593fa86 Mon Sep 17 00:00:00 2001 From: =?utf8?q?J=C3=A9r=C3=B4me=20Benoit?= Date: Wed, 24 Sep 2025 19:40:10 +0200 Subject: [PATCH] refactor(reforcexy): ensure trade efficiency logic is applied at holding MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit Signed-off-by: Jérôme Benoit --- ReforceXY/user_data/freqaimodels/ReforceXY.py | 66 +++++++++++-------- 1 file changed, 37 insertions(+), 29 deletions(-) diff --git a/ReforceXY/user_data/freqaimodels/ReforceXY.py b/ReforceXY/user_data/freqaimodels/ReforceXY.py index fd62ba7..d79e6a8 100644 --- a/ReforceXY/user_data/freqaimodels/ReforceXY.py +++ b/ReforceXY/user_data/freqaimodels/ReforceXY.py @@ -1202,37 +1202,45 @@ class MyRLEnv(Base5ActionRLEnv): elif exit_factor_mode == "linear": factor /= 1.0 + duration_ratio - factor *= self._get_pnl_factor(self.profit_aim * self.rr, pnl) - - eff_weight = float(model_reward_parameters.get("exit_efficiency_weight", 0.75)) - eff_center = float(model_reward_parameters.get("exit_efficiency_center", 0.5)) - if eff_weight != 0.0: - max_pnl = self.get_max_unrealized_profit() - min_pnl = self.get_min_unrealized_profit() - range_pnl = max_pnl - min_pnl - if not np.isclose(range_pnl, 0.0): - eff_ratio = (pnl - min_pnl) / range_pnl - factor *= 1.0 + eff_weight * (eff_ratio - eff_center) - factor = max(0.0, factor) + factor *= self._get_pnl_factor(pnl, self.profit_aim * self.rr) return factor - def _get_pnl_factor(self, profit_target: float, pnl: float) -> float: - if not np.isfinite(profit_target) or profit_target <= 0.0: - return 1.0 - if not np.isfinite(pnl) or pnl <= profit_target: - return 1.0 - + def _get_pnl_factor(self, pnl: float, pnl_target: float) -> float: model_reward_parameters = self.rl_config.get("model_reward_parameters", {}) - win_reward_factor = float(model_reward_parameters.get("win_reward_factor", 2.0)) - profit_factor_beta = float( - model_reward_parameters.get("profit_factor_beta", 0.5) - ) - profit_ratio = pnl / profit_target - return 1.0 + win_reward_factor * math.tanh( - profit_factor_beta * (profit_ratio - 1.0) + pnl_target_factor = 1.0 + if ( + np.isfinite(pnl_target) + and np.isfinite(pnl) + and pnl_target > 0.0 + and pnl > pnl_target + ): + win_reward_factor = float( + model_reward_parameters.get("win_reward_factor", 2.0) + ) + pnl_factor_beta = float(model_reward_parameters.get("pnl_factor_beta", 0.5)) + pnl_ratio = pnl / pnl_target + pnl_target_factor = 1.0 + win_reward_factor * math.tanh( + pnl_factor_beta * (pnl_ratio - 1.0) + ) + + efficiency_factor = 1.0 + efficiency_weight = float( + model_reward_parameters.get("efficiency_weight", 0.75) ) + efficiency_center = float(model_reward_parameters.get("efficiency_center", 0.5)) + if efficiency_weight != 0.0 and pnl >= 0.0: + max_pnl = max(self.get_max_unrealized_profit(), pnl) + min_pnl = min(self.get_min_unrealized_profit(), pnl) + range_pnl = max_pnl - min_pnl + if np.isfinite(range_pnl) and not np.isclose(range_pnl, 0.0): + efficiency_ratio = (pnl - min_pnl) / range_pnl + efficiency_factor = 1.0 + efficiency_weight * ( + efficiency_ratio - efficiency_center + ) + + return max(0.0, pnl_target_factor * efficiency_factor) def calculate_reward(self, action: int) -> float: """ @@ -1265,9 +1273,9 @@ class MyRLEnv(Base5ActionRLEnv): duration_ratio = trade_duration / max_trade_duration factor = 100.0 - profit_target = self.profit_aim * self.rr - idle_factor = factor * profit_target / 3.0 - holding_factor = idle_factor * self._get_pnl_factor(profit_target, pnl) + pnl_target = self.profit_aim * self.rr + idle_factor = factor * pnl_target / 3.0 + holding_factor = idle_factor * self._get_pnl_factor(pnl, pnl_target) # Force exits if self._force_action in ( @@ -1333,7 +1341,7 @@ class MyRLEnv(Base5ActionRLEnv): model_reward_parameters.get("holding_overage_power", 1.1) ) duration_overage_ratio = max(0.0, duration_ratio - holding_duration_grace) - if duration_overage_ratio > 0.0 or pnl > profit_target: + if duration_overage_ratio > 0.0 or pnl > pnl_target: return ( -holding_factor * holding_overage_scale -- 2.43.0