From 1c6c5df24b5d28fc551ea583e3d150f8d3b6029b Mon Sep 17 00:00:00 2001 From: =?utf8?q?J=C3=A9r=C3=B4me=20Benoit?= Date: Fri, 19 Sep 2025 19:44:05 +0200 Subject: [PATCH] refactor(reforcexy): factor out action validation and masking logic MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit Signed-off-by: Jérôme Benoit --- ReforceXY/user_data/freqaimodels/ReforceXY.py | 68 +++++++------------ 1 file changed, 25 insertions(+), 43 deletions(-) diff --git a/ReforceXY/user_data/freqaimodels/ReforceXY.py b/ReforceXY/user_data/freqaimodels/ReforceXY.py index c283257..ac3cdaf 100644 --- a/ReforceXY/user_data/freqaimodels/ReforceXY.py +++ b/ReforceXY/user_data/freqaimodels/ReforceXY.py @@ -154,40 +154,29 @@ class ReforceXY(BaseReinforcementLearningModel): self.unset_unsupported() @staticmethod - def build_action_mask( + @lru_cache(maxsize=8) + def get_action_masks( position: Positions, force_action: Optional[ForceActions] = None - ) -> np.ndarray: - action_mask = np.zeros(len(Actions), dtype=bool) + ) -> NDArray[bool]: + action_masks = np.zeros(len(Actions), dtype=bool) - action_mask[Actions.Neutral.value] = True + if force_action is not None and position in (Positions.Long, Positions.Short): + if position == Positions.Long: + action_masks[Actions.Long_exit.value] = True + else: + action_masks[Actions.Short_exit.value] = True + return action_masks + action_masks[Actions.Neutral.value] = True if position == Positions.Neutral: - action_mask[Actions.Long_enter.value] = True - action_mask[Actions.Short_enter.value] = True + action_masks[Actions.Long_enter.value] = True + action_masks[Actions.Short_enter.value] = True elif position == Positions.Long: - action_mask[Actions.Long_exit.value] = True + action_masks[Actions.Long_exit.value] = True elif position == Positions.Short: - action_mask[Actions.Short_exit.value] = True - - if force_action is not None and position in (Positions.Long, Positions.Short): - force_action_mask = np.zeros(len(Actions), dtype=bool) - try: - if position == Positions.Long: - force_action_mask[Actions.Long_exit.value] = True - elif position == Positions.Short: - force_action_mask[Actions.Short_exit.value] = True - except Exception: - return action_mask - if force_action_mask.any(): - return force_action_mask - return action_mask + action_masks[Actions.Short_exit.value] = True - if not action_mask.any(): - try: - action_mask[Actions.Neutral.value] = True - except Exception: - action_mask = np.ones_like(action_mask, dtype=bool) - return action_mask + return action_masks def unset_unsupported(self) -> None: """ @@ -668,7 +657,7 @@ class ReforceXY(BaseReinforcementLearningModel): observations = np_observation.reshape(1, -1) if self.action_masking and self.inference_masking: - action_masks_param["action_masks"] = ReforceXY.build_action_mask( + action_masks_param["action_masks"] = ReforceXY.get_action_masks( simulated_position, None ) @@ -1045,6 +1034,11 @@ class ReforceXY(BaseReinforcementLearningModel): low=-np.inf, high=np.inf, shape=self.shape, dtype=np.float32 ) + def _is_valid(self, action: int) -> bool: + return ReforceXY.get_action_masks(self._position, self._force_action)[ + action + ] + def reset_env( self, df: DataFrame, @@ -1109,11 +1103,7 @@ class ReforceXY(BaseReinforcementLearningModel): of weights in NN) """ # first, penalize if the action is not valid - if ( - not self.action_masking - and not self._force_action - and not self._is_valid(action) - ): + if not self.action_masking and not self._is_valid(action): self.tensorboard_log("invalid", category="actions") return self.rl_config.get("model_reward_parameters", {}).get( "invalid_action", -2.0 @@ -1388,16 +1378,8 @@ class ReforceXY(BaseReinforcementLearningModel): ) ) - def action_masks(self): - try: - return ReforceXY.build_action_mask(self._position, self._force_action) - except Exception: - action_mask = np.zeros(len(Actions), dtype=bool) - try: - action_mask[Actions.Neutral.value] = True - except Exception: - action_mask = np.ones(len(Actions), dtype=bool) - return action_mask + def action_masks(self) -> NDArray[bool]: + return ReforceXY.get_action_masks(self._position, self._force_action) def get_feature_value( self, -- 2.43.0