]> Piment Noir Git Repositories - freqai-strategies.git/commitdiff
refactor(reforcexy): remove dead code
authorJérôme Benoit <jerome.benoit@piment-noir.org>
Fri, 19 Sep 2025 16:03:49 +0000 (18:03 +0200)
committerJérôme Benoit <jerome.benoit@piment-noir.org>
Fri, 19 Sep 2025 16:03:49 +0000 (18:03 +0200)
Signed-off-by: Jérôme Benoit <jerome.benoit@piment-noir.org>
ReforceXY/user_data/freqaimodels/ReforceXY.py

index 08e998279ae76067f787dd90dda8e4d025101f16..c283257355ca60b80d110a8b876852ba1b47cd79 100644 (file)
@@ -412,16 +412,27 @@ class ReforceXY(BaseReinforcementLearningModel):
     def get_eval_freq(
         self, train_timesteps: int, model_params: Optional[Dict[str, Any]] = None
     ) -> int:
+        if train_timesteps <= 0:
+            return 1
         if "PPO" in self.model_type:
+            eval_freq = None
             if model_params:
                 n_steps = model_params.get("n_steps")
                 if isinstance(n_steps, int) and n_steps > 0:
-                    return n_steps
-            for step in sorted(PPO_N_STEPS, reverse=True):
-                if step <= train_timesteps:
-                    return step
-            return PPO_N_STEPS[0]
-        return max(1, train_timesteps // max(1, self.n_envs))
+                    eval_freq = n_steps
+            if eval_freq is None:
+                eval_freq = next(
+                    (
+                        step
+                        for step in sorted(PPO_N_STEPS, reverse=True)
+                        if step <= train_timesteps
+                    ),
+                    PPO_N_STEPS[0],
+                )
+        else:
+            eval_freq = max(1, train_timesteps // self.n_envs)
+
+        return max(1, min(eval_freq, train_timesteps))
 
     def get_callbacks(
         self, eval_freq: int, data_path: str, trial: Optional[Trial] = None
@@ -605,23 +616,6 @@ class ReforceXY(BaseReinforcementLearningModel):
             except Exception:
                 return Positions.Neutral
 
-        def _is_valid(action: int, position: Any) -> bool:
-            """
-            Determine if the action is valid for the step
-            """
-            position = _normalize_position(position)
-            # Agent should only try to exit if it is in position
-            if action in (Actions.Short_exit.value, Actions.Long_exit.value):
-                if position not in (Positions.Short, Positions.Long):
-                    return False
-
-            # Agent should only try to enter if it is not in position
-            if action in (Actions.Short_enter.value, Actions.Long_enter.value):
-                if position != Positions.Neutral:
-                    return False
-
-            return True
-
         simulated_position: Positions = Positions.Neutral
 
         def _update_simulated_position(action: int, position: Positions) -> Positions:
@@ -1077,7 +1071,7 @@ class ReforceXY(BaseReinforcementLearningModel):
             self._last_closed_trade_tick: int = 0
             return observation, history
 
-        def _get_reward_factor_at_trade_exit(
+        def _get_exit_reward_factor(
             self,
             factor: float,
             pnl: float,
@@ -1140,7 +1134,7 @@ class ReforceXY(BaseReinforcementLearningModel):
                 ForceActions.Stop_loss,
                 ForceActions.Timeout,
             ):
-                return pnl * self._get_reward_factor_at_trade_exit(
+                return pnl * self._get_exit_reward_factor(
                     factor, pnl, trade_duration, max_trade_duration
                 )
 
@@ -1206,13 +1200,13 @@ class ReforceXY(BaseReinforcementLearningModel):
 
             # close long
             if action == Actions.Long_exit.value and self._position == Positions.Long:
-                return pnl * self._get_reward_factor_at_trade_exit(
+                return pnl * self._get_exit_reward_factor(
                     factor, pnl, trade_duration, max_trade_duration
                 )
 
             # close short
             if action == Actions.Short_exit.value and self._position == Positions.Short:
-                return pnl * self._get_reward_factor_at_trade_exit(
+                return pnl * self._get_exit_reward_factor(
                     factor, pnl, trade_duration, max_trade_duration
                 )