def get_eval_freq(
self, train_timesteps: int, model_params: Optional[Dict[str, Any]] = None
) -> int:
+ if train_timesteps <= 0:
+ return 1
if "PPO" in self.model_type:
+ eval_freq = None
if model_params:
n_steps = model_params.get("n_steps")
if isinstance(n_steps, int) and n_steps > 0:
- return n_steps
- for step in sorted(PPO_N_STEPS, reverse=True):
- if step <= train_timesteps:
- return step
- return PPO_N_STEPS[0]
- return max(1, train_timesteps // max(1, self.n_envs))
+ eval_freq = n_steps
+ if eval_freq is None:
+ eval_freq = next(
+ (
+ step
+ for step in sorted(PPO_N_STEPS, reverse=True)
+ if step <= train_timesteps
+ ),
+ PPO_N_STEPS[0],
+ )
+ else:
+ eval_freq = max(1, train_timesteps // self.n_envs)
+
+ return max(1, min(eval_freq, train_timesteps))
def get_callbacks(
self, eval_freq: int, data_path: str, trial: Optional[Trial] = None
except Exception:
return Positions.Neutral
- def _is_valid(action: int, position: Any) -> bool:
- """
- Determine if the action is valid for the step
- """
- position = _normalize_position(position)
- # Agent should only try to exit if it is in position
- if action in (Actions.Short_exit.value, Actions.Long_exit.value):
- if position not in (Positions.Short, Positions.Long):
- return False
-
- # Agent should only try to enter if it is not in position
- if action in (Actions.Short_enter.value, Actions.Long_enter.value):
- if position != Positions.Neutral:
- return False
-
- return True
-
simulated_position: Positions = Positions.Neutral
def _update_simulated_position(action: int, position: Positions) -> Positions:
self._last_closed_trade_tick: int = 0
return observation, history
- def _get_reward_factor_at_trade_exit(
+ def _get_exit_reward_factor(
self,
factor: float,
pnl: float,
ForceActions.Stop_loss,
ForceActions.Timeout,
):
- return pnl * self._get_reward_factor_at_trade_exit(
+ return pnl * self._get_exit_reward_factor(
factor, pnl, trade_duration, max_trade_duration
)
# close long
if action == Actions.Long_exit.value and self._position == Positions.Long:
- return pnl * self._get_reward_factor_at_trade_exit(
+ return pnl * self._get_exit_reward_factor(
factor, pnl, trade_duration, max_trade_duration
)
# close short
if action == Actions.Short_exit.value and self._position == Positions.Short:
- return pnl * self._get_reward_factor_at_trade_exit(
+ return pnl * self._get_exit_reward_factor(
factor, pnl, trade_duration, max_trade_duration
)