]> Piment Noir Git Repositories - freqai-strategies.git/commitdiff
refactor(reforcexy): cleanup spot support
authorJérôme Benoit <jerome.benoit@piment-noir.org>
Mon, 29 Sep 2025 17:46:05 +0000 (19:46 +0200)
committerJérôme Benoit <jerome.benoit@piment-noir.org>
Mon, 29 Sep 2025 17:46:05 +0000 (19:46 +0200)
Signed-off-by: Jérôme Benoit <jerome.benoit@piment-noir.org>
ReforceXY/user_data/freqaimodels/ReforceXY.py

index b3c3a6e58165a105c5fc3d9a225cd8e8d245d49e..3e385e2f305146a3d017f8c00ec7ef033f670d16 100644 (file)
@@ -171,15 +171,6 @@ class ReforceXY(BaseReinforcementLearningModel):
         self._model_params_cache: Optional[Dict[str, Any]] = None
         self.unset_unsupported()
 
-    @staticmethod
-    def _is_short_allowed(trading_mode: str) -> bool:
-        if trading_mode in {"margin", "futures"}:
-            return True
-        elif trading_mode == "spot":
-            return False
-        else:
-            raise ValueError(f"Invalid trading_mode: {trading_mode}")
-
     @staticmethod
     def _normalize_position(position: Any) -> Positions:
         if isinstance(position, Positions):
@@ -196,22 +187,20 @@ class ReforceXY(BaseReinforcementLearningModel):
 
     @staticmethod
     def get_action_masks(
-        trading_mode: str,
+        can_short: bool,
         position: Positions,
         force_action: Optional[ForceActions] = None,
     ) -> NDArray[np.bool_]:
         position = ReforceXY._normalize_position(position)
 
         cache_key = (
-            trading_mode,
+            can_short,
             position.value,
             force_action.value if force_action else None,
         )
         if cache_key in ReforceXY._action_masks_cache:
             return ReforceXY._action_masks_cache[cache_key]
 
-        is_short_allowed = ReforceXY._is_short_allowed(trading_mode)
-
         action_masks = np.zeros(len(Actions), dtype=np.bool_)
 
         if force_action is not None and position in (Positions.Long, Positions.Short):
@@ -225,7 +214,7 @@ class ReforceXY(BaseReinforcementLearningModel):
         action_masks[Actions.Neutral.value] = True
         if position == Positions.Neutral:
             action_masks[Actions.Long_enter.value] = True
-            if is_short_allowed:
+            if can_short:
                 action_masks[Actions.Short_enter.value] = True
         elif position == Positions.Long:
             action_masks[Actions.Long_exit.value] = True
@@ -771,7 +760,7 @@ class ReforceXY(BaseReinforcementLearningModel):
 
             if self.action_masking and self.inference_masking:
                 action_masks_param["action_masks"] = ReforceXY.get_action_masks(
-                    self.config.get("trading_mode"), virtual_position
+                    self.can_short, virtual_position
                 )
 
             action, _ = model.predict(
@@ -1271,7 +1260,7 @@ class MyRLEnv(Base5ActionRLEnv):
 
     def _is_valid(self, action: int) -> bool:
         return ReforceXY.get_action_masks(
-            self.config.get("trading_mode"), self._position, self._force_action
+            self.can_short, self._position, self._force_action
         )[action]
 
     def reset_env(
@@ -1616,6 +1605,12 @@ class MyRLEnv(Base5ActionRLEnv):
         if not self.is_tradesignal(action):
             return None
 
+        if (
+            action in (Actions.Short_enter.value, Actions.Short_exit.value)
+            and not self.can_short
+        ):
+            return None
+
         # Enter trade based on action
         if action in (Actions.Long_enter.value, Actions.Short_enter.value):
             self._enter_trade(action)
@@ -1717,7 +1712,7 @@ class MyRLEnv(Base5ActionRLEnv):
 
     def action_masks(self) -> NDArray[np.bool_]:
         return ReforceXY.get_action_masks(
-            self.config.get("trading_mode"), self._position, self._force_action
+            self.can_short, self._position, self._force_action
         )
 
     def get_feature_value(