From d4b3be3c1e32ed4a4d49e1dba429a82305f3af80 Mon Sep 17 00:00:00 2001 From: =?utf8?q?J=C3=A9r=C3=B4me=20Benoit?= Date: Mon, 29 Sep 2025 19:46:05 +0200 Subject: [PATCH] refactor(reforcexy): cleanup spot support MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit Signed-off-by: Jérôme Benoit --- ReforceXY/user_data/freqaimodels/ReforceXY.py | 29 ++++++++----------- 1 file changed, 12 insertions(+), 17 deletions(-) diff --git a/ReforceXY/user_data/freqaimodels/ReforceXY.py b/ReforceXY/user_data/freqaimodels/ReforceXY.py index b3c3a6e..3e385e2 100644 --- a/ReforceXY/user_data/freqaimodels/ReforceXY.py +++ b/ReforceXY/user_data/freqaimodels/ReforceXY.py @@ -171,15 +171,6 @@ class ReforceXY(BaseReinforcementLearningModel): self._model_params_cache: Optional[Dict[str, Any]] = None self.unset_unsupported() - @staticmethod - def _is_short_allowed(trading_mode: str) -> bool: - if trading_mode in {"margin", "futures"}: - return True - elif trading_mode == "spot": - return False - else: - raise ValueError(f"Invalid trading_mode: {trading_mode}") - @staticmethod def _normalize_position(position: Any) -> Positions: if isinstance(position, Positions): @@ -196,22 +187,20 @@ class ReforceXY(BaseReinforcementLearningModel): @staticmethod def get_action_masks( - trading_mode: str, + can_short: bool, position: Positions, force_action: Optional[ForceActions] = None, ) -> NDArray[np.bool_]: position = ReforceXY._normalize_position(position) cache_key = ( - trading_mode, + can_short, position.value, force_action.value if force_action else None, ) if cache_key in ReforceXY._action_masks_cache: return ReforceXY._action_masks_cache[cache_key] - is_short_allowed = ReforceXY._is_short_allowed(trading_mode) - action_masks = np.zeros(len(Actions), dtype=np.bool_) if force_action is not None and position in (Positions.Long, Positions.Short): @@ -225,7 +214,7 @@ class ReforceXY(BaseReinforcementLearningModel): action_masks[Actions.Neutral.value] = True if position == Positions.Neutral: action_masks[Actions.Long_enter.value] = True - if is_short_allowed: + if can_short: action_masks[Actions.Short_enter.value] = True elif position == Positions.Long: action_masks[Actions.Long_exit.value] = True @@ -771,7 +760,7 @@ class ReforceXY(BaseReinforcementLearningModel): if self.action_masking and self.inference_masking: action_masks_param["action_masks"] = ReforceXY.get_action_masks( - self.config.get("trading_mode"), virtual_position + self.can_short, virtual_position ) action, _ = model.predict( @@ -1271,7 +1260,7 @@ class MyRLEnv(Base5ActionRLEnv): def _is_valid(self, action: int) -> bool: return ReforceXY.get_action_masks( - self.config.get("trading_mode"), self._position, self._force_action + self.can_short, self._position, self._force_action )[action] def reset_env( @@ -1616,6 +1605,12 @@ class MyRLEnv(Base5ActionRLEnv): if not self.is_tradesignal(action): return None + if ( + action in (Actions.Short_enter.value, Actions.Short_exit.value) + and not self.can_short + ): + return None + # Enter trade based on action if action in (Actions.Long_enter.value, Actions.Short_enter.value): self._enter_trade(action) @@ -1717,7 +1712,7 @@ class MyRLEnv(Base5ActionRLEnv): def action_masks(self) -> NDArray[np.bool_]: return ReforceXY.get_action_masks( - self.config.get("trading_mode"), self._position, self._force_action + self.can_short, self._position, self._force_action ) def get_feature_value( -- 2.43.0