]> Piment Noir Git Repositories - freqai-strategies.git/commitdiff
refactor(reforcexy): ensure position arg is normalized
authorJérôme Benoit <jerome.benoit@piment-noir.org>
Fri, 19 Sep 2025 19:18:44 +0000 (21:18 +0200)
committerJérôme Benoit <jerome.benoit@piment-noir.org>
Fri, 19 Sep 2025 19:18:44 +0000 (21:18 +0200)
Signed-off-by: Jérôme Benoit <jerome.benoit@piment-noir.org>
ReforceXY/user_data/freqaimodels/ReforceXY.py

index 2370f75f13711f8ed37b0a4bc517d4ec87b116e6..baf195cbd184b2d20952de7df707a05b8fc83068 100644 (file)
@@ -153,11 +153,27 @@ class ReforceXY(BaseReinforcementLearningModel):
         self._model_params_cache: Optional[Dict[str, Any]] = None
         self.unset_unsupported()
 
+    @staticmethod
+    def _normalize_position(position: Any) -> Positions:
+        if isinstance(position, Positions):
+            return position
+        try:
+            position = float(position)
+            if position == float(Positions.Long.value):
+                return Positions.Long
+            if position == float(Positions.Short.value):
+                return Positions.Short
+            return Positions.Neutral
+        except Exception:
+            return Positions.Neutral
+
     @staticmethod
     @lru_cache(maxsize=8)
     def get_action_masks(
         position: Positions, force_action: Optional[ForceActions] = None
     ) -> NDArray[bool]:
+        position = ReforceXY._normalize_position(position)
+
         action_masks = np.zeros(len(Actions), dtype=bool)
 
         if force_action is not None and position in (Positions.Long, Positions.Short):
@@ -592,19 +608,6 @@ class ReforceXY(BaseReinforcementLearningModel):
         :param model: Any = the trained model used to inference the features.
         """
 
-        def _normalize_position(position: Any) -> Positions:
-            if isinstance(position, Positions):
-                return position
-            try:
-                position = float(position)
-                if position == float(Positions.Long.value):
-                    return Positions.Long
-                if position == float(Positions.Short.value):
-                    return Positions.Short
-                return Positions.Neutral
-            except Exception:
-                return Positions.Neutral
-
         simulated_position: Positions = Positions.Neutral
 
         def _update_simulated_position(action: int, position: Positions) -> Positions: