if not self.is_tradesignal(action):
return None
- if (
- action in (Actions.Short_enter.value, Actions.Short_exit.value)
- and not self.can_short
- ):
- return None
-
# Enter trade based on action
if action in (Actions.Long_enter.value, Actions.Short_enter.value):
self._enter_trade(action)
def is_tradesignal(self, action: int) -> bool:
"""
- Determine if the action is entry or exit
+ Determine if the action is a valid entry or exit
"""
- return (
- (
- action in (Actions.Short_enter.value, Actions.Long_enter.value)
- and self._position == Positions.Neutral
- )
- or (action == Actions.Long_exit.value and self._position == Positions.Long)
- or (
- action == Actions.Short_exit.value and self._position == Positions.Short
- )
- )
+ position = self._position
+
+ action_rules = {
+ Actions.Long_enter.value: (Positions.Neutral, False),
+ Actions.Short_enter.value: (Positions.Neutral, True),
+ Actions.Long_exit.value: (Positions.Long, False),
+ Actions.Short_exit.value: (Positions.Short, True),
+ }
+
+ if action not in action_rules:
+ return False
+
+ required_position, requires_short = action_rules[action]
+ return position == required_position and (not requires_short or self.can_short)
def action_masks(self) -> NDArray[np.bool_]:
return ReforceXY.get_action_masks(self.can_short, self._position)