]> Piment Noir Git Repositories - freqai-strategies.git/commitdiff
refactor(reforcexy): factor out action validation and masking logic
authorJérôme Benoit <jerome.benoit@piment-noir.org>
Fri, 19 Sep 2025 17:44:05 +0000 (19:44 +0200)
committerJérôme Benoit <jerome.benoit@piment-noir.org>
Fri, 19 Sep 2025 17:44:05 +0000 (19:44 +0200)
Signed-off-by: Jérôme Benoit <jerome.benoit@piment-noir.org>
ReforceXY/user_data/freqaimodels/ReforceXY.py

index c283257355ca60b80d110a8b876852ba1b47cd79..ac3cdaffa4a4597669d6f94eac3ea411ae897511 100644 (file)
@@ -154,40 +154,29 @@ class ReforceXY(BaseReinforcementLearningModel):
         self.unset_unsupported()
 
     @staticmethod
-    def build_action_mask(
+    @lru_cache(maxsize=8)
+    def get_action_masks(
         position: Positions, force_action: Optional[ForceActions] = None
-    ) -> np.ndarray:
-        action_mask = np.zeros(len(Actions), dtype=bool)
+    ) -> NDArray[bool]:
+        action_masks = np.zeros(len(Actions), dtype=bool)
 
-        action_mask[Actions.Neutral.value] = True
+        if force_action is not None and position in (Positions.Long, Positions.Short):
+            if position == Positions.Long:
+                action_masks[Actions.Long_exit.value] = True
+            else:
+                action_masks[Actions.Short_exit.value] = True
+            return action_masks
 
+        action_masks[Actions.Neutral.value] = True
         if position == Positions.Neutral:
-            action_mask[Actions.Long_enter.value] = True
-            action_mask[Actions.Short_enter.value] = True
+            action_masks[Actions.Long_enter.value] = True
+            action_masks[Actions.Short_enter.value] = True
         elif position == Positions.Long:
-            action_mask[Actions.Long_exit.value] = True
+            action_masks[Actions.Long_exit.value] = True
         elif position == Positions.Short:
-            action_mask[Actions.Short_exit.value] = True
-
-        if force_action is not None and position in (Positions.Long, Positions.Short):
-            force_action_mask = np.zeros(len(Actions), dtype=bool)
-            try:
-                if position == Positions.Long:
-                    force_action_mask[Actions.Long_exit.value] = True
-                elif position == Positions.Short:
-                    force_action_mask[Actions.Short_exit.value] = True
-            except Exception:
-                return action_mask
-            if force_action_mask.any():
-                return force_action_mask
-            return action_mask
+            action_masks[Actions.Short_exit.value] = True
 
-        if not action_mask.any():
-            try:
-                action_mask[Actions.Neutral.value] = True
-            except Exception:
-                action_mask = np.ones_like(action_mask, dtype=bool)
-        return action_mask
+        return action_masks
 
     def unset_unsupported(self) -> None:
         """
@@ -668,7 +657,7 @@ class ReforceXY(BaseReinforcementLearningModel):
                 observations = np_observation.reshape(1, -1)
 
             if self.action_masking and self.inference_masking:
-                action_masks_param["action_masks"] = ReforceXY.build_action_mask(
+                action_masks_param["action_masks"] = ReforceXY.get_action_masks(
                     simulated_position, None
                 )
 
@@ -1045,6 +1034,11 @@ class ReforceXY(BaseReinforcementLearningModel):
                 low=-np.inf, high=np.inf, shape=self.shape, dtype=np.float32
             )
 
+        def _is_valid(self, action: int) -> bool:
+            return ReforceXY.get_action_masks(self._position, self._force_action)[
+                action
+            ]
+
         def reset_env(
             self,
             df: DataFrame,
@@ -1109,11 +1103,7 @@ class ReforceXY(BaseReinforcementLearningModel):
                     of weights in NN)
             """
             # first, penalize if the action is not valid
-            if (
-                not self.action_masking
-                and not self._force_action
-                and not self._is_valid(action)
-            ):
+            if not self.action_masking and not self._is_valid(action):
                 self.tensorboard_log("invalid", category="actions")
                 return self.rl_config.get("model_reward_parameters", {}).get(
                     "invalid_action", -2.0
@@ -1388,16 +1378,8 @@ class ReforceXY(BaseReinforcementLearningModel):
                 )
             )
 
-        def action_masks(self):
-            try:
-                return ReforceXY.build_action_mask(self._position, self._force_action)
-            except Exception:
-                action_mask = np.zeros(len(Actions), dtype=bool)
-                try:
-                    action_mask[Actions.Neutral.value] = True
-                except Exception:
-                    action_mask = np.ones(len(Actions), dtype=bool)
-                return action_mask
+        def action_masks(self) -> NDArray[bool]:
+            return ReforceXY.get_action_masks(self._position, self._force_action)
 
         def get_feature_value(
             self,