From: Jérôme Benoit Date: Fri, 19 Sep 2025 16:03:49 +0000 (+0200) Subject: refactor(reforcexy): remove dead code X-Git-Url: https://git.piment-noir.org/?a=commitdiff_plain;h=ba0f3dcdd11ab1efb816e1661b1965b9c3444ff6;p=freqai-strategies.git refactor(reforcexy): remove dead code Signed-off-by: Jérôme Benoit --- diff --git a/ReforceXY/user_data/freqaimodels/ReforceXY.py b/ReforceXY/user_data/freqaimodels/ReforceXY.py index 08e9982..c283257 100644 --- a/ReforceXY/user_data/freqaimodels/ReforceXY.py +++ b/ReforceXY/user_data/freqaimodels/ReforceXY.py @@ -412,16 +412,27 @@ class ReforceXY(BaseReinforcementLearningModel): def get_eval_freq( self, train_timesteps: int, model_params: Optional[Dict[str, Any]] = None ) -> int: + if train_timesteps <= 0: + return 1 if "PPO" in self.model_type: + eval_freq = None if model_params: n_steps = model_params.get("n_steps") if isinstance(n_steps, int) and n_steps > 0: - return n_steps - for step in sorted(PPO_N_STEPS, reverse=True): - if step <= train_timesteps: - return step - return PPO_N_STEPS[0] - return max(1, train_timesteps // max(1, self.n_envs)) + eval_freq = n_steps + if eval_freq is None: + eval_freq = next( + ( + step + for step in sorted(PPO_N_STEPS, reverse=True) + if step <= train_timesteps + ), + PPO_N_STEPS[0], + ) + else: + eval_freq = max(1, train_timesteps // self.n_envs) + + return max(1, min(eval_freq, train_timesteps)) def get_callbacks( self, eval_freq: int, data_path: str, trial: Optional[Trial] = None @@ -605,23 +616,6 @@ class ReforceXY(BaseReinforcementLearningModel): except Exception: return Positions.Neutral - def _is_valid(action: int, position: Any) -> bool: - """ - Determine if the action is valid for the step - """ - position = _normalize_position(position) - # Agent should only try to exit if it is in position - if action in (Actions.Short_exit.value, Actions.Long_exit.value): - if position not in (Positions.Short, Positions.Long): - return False - - # Agent should only try to enter if it is not in position - if action in (Actions.Short_enter.value, Actions.Long_enter.value): - if position != Positions.Neutral: - return False - - return True - simulated_position: Positions = Positions.Neutral def _update_simulated_position(action: int, position: Positions) -> Positions: @@ -1077,7 +1071,7 @@ class ReforceXY(BaseReinforcementLearningModel): self._last_closed_trade_tick: int = 0 return observation, history - def _get_reward_factor_at_trade_exit( + def _get_exit_reward_factor( self, factor: float, pnl: float, @@ -1140,7 +1134,7 @@ class ReforceXY(BaseReinforcementLearningModel): ForceActions.Stop_loss, ForceActions.Timeout, ): - return pnl * self._get_reward_factor_at_trade_exit( + return pnl * self._get_exit_reward_factor( factor, pnl, trade_duration, max_trade_duration ) @@ -1206,13 +1200,13 @@ class ReforceXY(BaseReinforcementLearningModel): # close long if action == Actions.Long_exit.value and self._position == Positions.Long: - return pnl * self._get_reward_factor_at_trade_exit( + return pnl * self._get_exit_reward_factor( factor, pnl, trade_duration, max_trade_duration ) # close short if action == Actions.Short_exit.value and self._position == Positions.Short: - return pnl * self._get_reward_factor_at_trade_exit( + return pnl * self._get_exit_reward_factor( factor, pnl, trade_duration, max_trade_duration )