]> Piment Noir Git Repositories - freqai-strategies.git/commitdiff
fix(reforcexy): refine reward method logic
authorJérôme Benoit <jerome.benoit@piment-noir.org>
Fri, 21 Feb 2025 20:35:19 +0000 (21:35 +0100)
committerJérôme Benoit <jerome.benoit@piment-noir.org>
Fri, 21 Feb 2025 20:35:19 +0000 (21:35 +0100)
Signed-off-by: Jérôme Benoit <jerome.benoit@piment-noir.org>
ReforceXY/user_data/freqaimodels/ReforceXY.py

index 0d063da392a7560579cdac6be90fe478179ae124..e57270f212b6f6069ccb49e6e918abf017e13c5d 100644 (file)
@@ -370,6 +370,7 @@ class ReforceXY(BaseReinforcementLearningModel):
 
         model_filename = dk.model_filename if dk.model_filename else "best"
         if Path(dk.data_path / f"{model_filename}_model.zip").is_file():
+            logger.info("Callback found a best model.")
             best_model = self.MODELCLASS.load(dk.data_path / f"{model_filename}_model")
             return best_model
 
@@ -675,6 +676,26 @@ class ReforceXY(BaseReinforcementLearningModel):
             self._non_profit_steps: int = 0
             return self._get_observation(), history
 
+        def get_reward_factor_at_trade_exit(
+            self,
+            factor: float,
+            pnl: float,
+            trade_duration: int,
+            max_trade_duration: int,
+        ) -> float:
+            """
+            Compute the reward factor at trade exit
+            """
+            if trade_duration <= max_trade_duration:
+                factor *= 1.5
+            elif trade_duration > max_trade_duration:
+                factor *= 0.5
+            if pnl > self.profit_aim * self.rr:
+                factor *= self.rl_config.get("model_reward_parameters", {}).get(
+                    "win_reward_factor", 2
+                )
+            return factor
+
         def calculate_reward(self, action) -> float:
             """
             An example reward function. This is the one function that users will likely
@@ -690,43 +711,28 @@ class ReforceXY(BaseReinforcementLearningModel):
             float = the reward to give to the agent for current step (used for optimization
                 of weights in NN)
             """
-            # first penalize if the action is not valid
-            if (
-                self.force_actions
-                and self._force_action is not None
-                and self._force_action
-                not in (
-                    ForceActions.Take_profit,
-                    ForceActions.Stop_loss,
-                    ForceActions.Timeout,
-                )
-            ) or not self._is_valid(action):
+            # first, penalize if the action is not valid
+            if not self._force_action and not self._is_valid(action):
                 return -2
 
             pnl = self.get_unrealized_profit()
             # mrr = self.get_most_recent_return()
             # mrp = self.get_most_recent_profit()
 
-            factor = 100.0
-
-            max_trade_duration = self.rl_config.get("max_trade_duration_candles", 300)
+            max_trade_duration = self.timeout
             trade_duration = self.get_trade_duration()
-            if trade_duration <= max_trade_duration:
-                factor *= 1.5
-            elif trade_duration > max_trade_duration:
-                factor *= 0.5
+
+            factor = 100.0
 
             # Force exits
-            if self.force_actions and self._force_action in (
+            if self._force_action in (
                 ForceActions.Take_profit,
                 ForceActions.Stop_loss,
                 ForceActions.Timeout,
             ):
-                if pnl > self.profit_aim * self.rr:
-                    factor *= self.rl_config.get("model_reward_parameters", {}).get(
-                        "win_reward_factor", 2
-                    )
-                return pnl * factor
+                return pnl * self.get_reward_factor_at_trade_exit(
+                    factor, pnl, trade_duration, max_trade_duration
+                )
 
             # # you can use feature values from dataframe
             # rsi_now = self.get_feature_value(
@@ -737,7 +743,7 @@ class ReforceXY(BaseReinforcementLearningModel):
             #     raw=True
             # )
 
-            # # reward agent for entering trades
+            # # reward agent for entering trades when RSI is low
             # if (action in (Actions.Long_enter.value, Actions.Short_enter.value)
             #         and self._position == Positions.Neutral):
             #     if rsi_now < 40:
@@ -746,46 +752,47 @@ class ReforceXY(BaseReinforcementLearningModel):
             #         factor = 1
             #     return 25 * factor
 
+            # reward agent for entering trades
+            if (
+                action == Actions.Long_enter.value
+                and self._position == Positions.Neutral
+            ):
+                return 25
+            if (
+                action == Actions.Short_enter.value
+                and self._position == Positions.Neutral
+            ):
+                return 25
+
             # discourage agent from not entering trades
             if action == Actions.Neutral.value and self._position == Positions.Neutral:
                 return -1
 
-            # discourage sitting in non profitable position
+            # discourage sitting in position
             if (
                 self._position in (Positions.Short, Positions.Long)
                 and action == Actions.Neutral.value
             ):
                 if pnl < 0:
                     self._non_profit_steps += 1
-                else:
-                    self._non_profit_steps = 0
-                if self._non_profit_steps > 0:
                     return factor * (
                         pnl - (0.1 * (self._non_profit_steps**2) * abs(pnl))
                     )  # time aggressive (quadratic) and loss magnitude aware penalty
-
-            # discourage sitting in position
-            if (
-                self._position in (Positions.Short, Positions.Long)
-                and action == Actions.Neutral.value
-            ):
-                return -1 * trade_duration / max_trade_duration
+                else:
+                    self._non_profit_steps = 0
+                    return -1 * trade_duration / max_trade_duration
 
             # close long
             if action == Actions.Long_exit.value and self._position == Positions.Long:
-                if pnl > self.profit_aim * self.rr:
-                    factor *= self.rl_config.get("model_reward_parameters", {}).get(
-                        "win_reward_factor", 2
-                    )
-                return pnl * factor
+                return pnl * self.get_reward_factor_at_trade_exit(
+                    factor, pnl, trade_duration, max_trade_duration
+                )
 
             # close short
             if action == Actions.Short_exit.value and self._position == Positions.Short:
-                if pnl > self.profit_aim * self.rr:
-                    factor *= self.rl_config.get("model_reward_parameters", {}).get(
-                        "win_reward_factor", 2
-                    )
-                return pnl * factor
+                return pnl * self.get_reward_factor_at_trade_exit(
+                    factor, pnl, trade_duration, max_trade_duration
+                )
 
             return 0.0