]> Piment Noir Git Repositories - freqai-strategies.git/commitdiff
fix(reforcexy): ensure the same reward is applied with force_actions
authorJérôme Benoit <jerome.benoit@piment-noir.org>
Tue, 18 Feb 2025 22:17:50 +0000 (23:17 +0100)
committerJérôme Benoit <jerome.benoit@piment-noir.org>
Tue, 18 Feb 2025 22:17:50 +0000 (23:17 +0100)
Signed-off-by: Jérôme Benoit <jerome.benoit@piment-noir.org>
ReforceXY/user_data/freqaimodels/ReforceXY.py

index 8a7a1d80f915e7aeaa8d97262cef2a6cc13e7454..0da09cb5ee7abdd95e9d34c32a3d64e0f612b056 100644 (file)
@@ -674,6 +674,10 @@ class ReforceXY(BaseReinforcementLearningModel):
                 ForceActions.Stop_loss,
                 ForceActions.Timeout,
             ):
+                if pnl > self.profit_aim * self.rr:
+                    factor *= self.rl_config.get("model_reward_parameters", {}).get(
+                        "win_reward_factor", 2
+                    )
                 return pnl * factor
 
             # first, penalize if the action is not valid
@@ -736,7 +740,7 @@ class ReforceXY(BaseReinforcementLearningModel):
                     factor *= self.rl_config.get("model_reward_parameters", {}).get(
                         "win_reward_factor", 2
                     )
-                return float(pnl * factor)
+                return pnl * factor
 
             # close short
             if action == Actions.Short_exit.value and self._position == Positions.Short:
@@ -744,7 +748,7 @@ class ReforceXY(BaseReinforcementLearningModel):
                     factor *= self.rl_config.get("model_reward_parameters", {}).get(
                         "win_reward_factor", 2
                     )
-                return float(pnl * factor)
+                return pnl * factor
 
             return 0.0
 
@@ -850,7 +854,7 @@ class ReforceXY(BaseReinforcementLearningModel):
                 "tick": self._current_tick,
                 "position": self._position.value,
                 "action": action,
-                "force_action": self._get_force_action(),
+                "force_action": self._force_action,
                 "pnl": self.get_unrealized_profit(),
                 "reward": round(reward, 5),
                 "total_reward": round(self.total_reward, 5),