]> Piment Noir Git Repositories - freqai-strategies.git/commitdiff
fix(reforcexy): uniformize rewarding behavior
authorJérôme Benoit <jerome.benoit@piment-noir.org>
Wed, 19 Feb 2025 00:46:45 +0000 (01:46 +0100)
committerJérôme Benoit <jerome.benoit@piment-noir.org>
Wed, 19 Feb 2025 00:46:45 +0000 (01:46 +0100)
Signed-off-by: Jérôme Benoit <jerome.benoit@piment-noir.org>
ReforceXY/user_data/freqaimodels/ReforceXY.py

index 0da09cb5ee7abdd95e9d34c32a3d64e0f612b056..0d9a5768bd185eca0acac95ada7d743d87b99208 100644 (file)
@@ -402,22 +402,6 @@ class ReforceXY(BaseReinforcementLearningModel):
         :param model: Any = the trained model used to inference the features.
         """
 
-        def _is_valid(action: int, position: float) -> bool:
-            return not (
-                (
-                    action in (Actions.Short_enter.value, Actions.Long_enter.value)
-                    and position != Positions.Neutral.value
-                )
-                or (
-                    action == Actions.Long_exit.value
-                    and position != Positions.Long.value
-                )
-                or (
-                    action == Actions.Short_exit.value
-                    and position != Positions.Short.value
-                )
-            )
-
         def _action_masks(position: float):
             return [_is_valid(action.value, position) for action in Actions]
 
@@ -451,12 +435,10 @@ class ReforceXY(BaseReinforcementLearningModel):
         output = output.rolling(window=self.CONV_WIDTH).apply(_predict)
         return output
 
-    def study(self, train_df, total_timesteps: int, dk: FreqaiDataKitchen) -> Dict:
+    def get_storage(self, dk: FreqaiDataKitchen):
         """
-        Runs hyperparameter optimization using Optuna and
-        returns the best hyperparameters found
+        Get the storage for Optuna
         """
-        study_name = str(dk.pair)
         storage_dir = str(dk.full_path)
         storage_backend = self.rl_config_optuna.get("storage", "sqlite")
         if storage_backend == "sqlite":
@@ -465,6 +447,15 @@ class ReforceXY(BaseReinforcementLearningModel):
             storage = JournalStorage(
                 JournalFileBackend(f"{storage_dir}/optuna-{dk.pair.split('/')[0]}.log")
             )
+        return storage
+
+    def study(self, train_df, total_timesteps: int, dk: FreqaiDataKitchen) -> Dict:
+        """
+        Runs hyperparameter optimization using Optuna and
+        returns the best hyperparameters found
+        """
+        study_name = str(dk.pair)
+        storage = self.get_storage(dk)
         study: Study = create_study(
             study_name=study_name,
             sampler=TPESampler(
@@ -662,14 +653,34 @@ class ReforceXY(BaseReinforcementLearningModel):
             float = the reward to give to the agent for current step (used for optimization
                 of weights in NN)
             """
+            # first penalize if the action is not valid
+            if (
+                self.force_actions
+                and self._force_action is not None
+                and self._force_action
+                not in (
+                    ForceActions.Take_profit,
+                    ForceActions.Stop_loss,
+                    ForceActions.Timeout,
+                )
+            ) or not self._is_valid(action):
+                return -2
+
             pnl = self.get_unrealized_profit()
             # mrr = self.get_most_recent_return()
             # mrp = self.get_most_recent_profit()
 
             factor = 100.0
 
+            max_trade_duration = self.rl_config.get("max_trade_duration_candles", 300)
+            trade_duration = self.get_trade_duration()
+            if trade_duration <= max_trade_duration:
+                factor *= 1.5
+            elif trade_duration > max_trade_duration:
+                factor *= 0.5
+
             # Force exits
-            if self._force_action in (
+            if self.force_actions and self._force_action in (
                 ForceActions.Take_profit,
                 ForceActions.Stop_loss,
                 ForceActions.Timeout,
@@ -680,10 +691,6 @@ class ReforceXY(BaseReinforcementLearningModel):
                     )
                 return pnl * factor
 
-            # first, penalize if the action is not valid
-            if not self._is_valid(action):
-                return -2
-
             # # you can use feature values from dataframe
             # rsi_now = self.get_feature_value(
             #     name="%-rsi",
@@ -715,19 +722,12 @@ class ReforceXY(BaseReinforcementLearningModel):
                     self._non_profit_steps += 1
                 else:
                     self._non_profit_steps = 0
-            if self._non_profit_steps > 0:
-                return pnl - (
-                    0.1 * (self._non_profit_steps**2) * max(0, pnl)
-                )  # time aggressive (quadratic) and loss magnitude aware penalty
+                if self._non_profit_steps > 0:
+                    return factor * (
+                        pnl - (0.1 * (self._non_profit_steps**2) * abs(pnl))
+                    )  # time aggressive (quadratic) and loss magnitude aware penalty
 
             # discourage sitting in position
-            max_trade_duration = self.rl_config.get("max_trade_duration_candles", 300)
-            trade_duration = self.get_trade_duration()
-            if trade_duration <= max_trade_duration:
-                factor *= 1.5
-            elif trade_duration > max_trade_duration:
-                factor *= 0.5
-
             if (
                 self._position in (Positions.Short, Positions.Long)
                 and action == Actions.Neutral.value
@@ -781,7 +781,7 @@ class ReforceXY(BaseReinforcementLearningModel):
             else:
                 return features_window.to_numpy(dtype=np.float32)
 
-        def _get_force_action(self):
+        def _get_force_action(self) -> Optional[ForceActions]:
             if not self.force_actions or self._position == Positions.Neutral:
                 return None
 
@@ -813,6 +813,7 @@ class ReforceXY(BaseReinforcementLearningModel):
             self._position = Positions.Neutral
             self._last_closed_trade_tick = self._last_trade_tick
             self._last_trade_tick = None
+            self._non_profit_steps = 0
 
         def execute_trade(self, action: int) -> None:
             """
@@ -913,25 +914,6 @@ class ReforceXY(BaseReinforcementLearningModel):
                 )
             )
 
-        def _is_valid(self, action: int) -> bool:
-            """
-            Determine if the action is valid for the step
-            """
-            return not (
-                (
-                    action in (Actions.Short_enter.value, Actions.Long_enter.value)
-                    and self._position != Positions.Neutral
-                )
-                or (
-                    action == Actions.Long_exit.value
-                    and self._position != Positions.Long
-                )
-                or (
-                    action == Actions.Short_exit.value
-                    and self._position != Positions.Short
-                )
-            )
-
         def get_feature_value(
             self,
             name: str,
@@ -1392,3 +1374,20 @@ def sample_params_qrdqn(trial: Trial) -> Dict[str, Any]:
     n_quantiles = trial.suggest_int("n_quantiles", 5, 200)
     hyperparams["policy_kwargs"].update({"n_quantiles": n_quantiles})
     return hyperparams
+
+
+def _is_valid(action: int, position: float) -> bool:
+    """
+    Determine if the action is valid for the step
+    """
+    # Agent should only try to exit if it is in position
+    if action in (Actions.Short_exit.value, Actions.Long_exit.value):
+        if position not in (Positions.Short, Positions.Long):
+            return False
+
+    # Agent should only try to enter if it is not in position
+    if action in (Actions.Short_enter.value, Actions.Long_enter.value):
+        if position != Positions.Neutral:
+            return False
+
+    return True