]> Piment Noir Git Repositories - freqai-strategies.git/commitdiff
refactor(reforcexy): move trade signal validation logic in the same
authorJérôme Benoit <jerome.benoit@piment-noir.org>
Sat, 11 Oct 2025 19:01:47 +0000 (21:01 +0200)
committerJérôme Benoit <jerome.benoit@piment-noir.org>
Sat, 11 Oct 2025 19:01:47 +0000 (21:01 +0200)
     helper

Signed-off-by: Jérôme Benoit <jerome.benoit@piment-noir.org>
ReforceXY/user_data/freqaimodels/ReforceXY.py

index b66313ea79aa0f0a3d36bd70098e7d34ee9a1480..503551ecd34df8d7b8296c8f6b250609186dfabb 100644 (file)
@@ -1684,12 +1684,6 @@ class MyRLEnv(Base5ActionRLEnv):
         if not self.is_tradesignal(action):
             return None
 
-        if (
-            action in (Actions.Short_enter.value, Actions.Short_exit.value)
-            and not self.can_short
-        ):
-            return None
-
         # Enter trade based on action
         if action in (Actions.Long_enter.value, Actions.Short_enter.value):
             self._enter_trade(action)
@@ -1774,18 +1768,22 @@ class MyRLEnv(Base5ActionRLEnv):
 
     def is_tradesignal(self, action: int) -> bool:
         """
-        Determine if the action is entry or exit
+        Determine if the action is a valid entry or exit
         """
-        return (
-            (
-                action in (Actions.Short_enter.value, Actions.Long_enter.value)
-                and self._position == Positions.Neutral
-            )
-            or (action == Actions.Long_exit.value and self._position == Positions.Long)
-            or (
-                action == Actions.Short_exit.value and self._position == Positions.Short
-            )
-        )
+        position = self._position
+
+        action_rules = {
+            Actions.Long_enter.value: (Positions.Neutral, False),
+            Actions.Short_enter.value: (Positions.Neutral, True),
+            Actions.Long_exit.value: (Positions.Long, False),
+            Actions.Short_exit.value: (Positions.Short, True),
+        }
+
+        if action not in action_rules:
+            return False
+
+        required_position, requires_short = action_rules[action]
+        return position == required_position and (not requires_short or self.can_short)
 
     def action_masks(self) -> NDArray[np.bool_]:
         return ReforceXY.get_action_masks(self.can_short, self._position)