From 94783272845974c64b83f444456a11bdd4d7b820 Mon Sep 17 00:00:00 2001 From: =?utf8?q?J=C3=A9r=C3=B4me=20Benoit?= Date: Fri, 19 Sep 2025 21:18:44 +0200 Subject: [PATCH] refactor(reforcexy): ensure position arg is normalized MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit Signed-off-by: Jérôme Benoit --- ReforceXY/user_data/freqaimodels/ReforceXY.py | 29 ++++++++++--------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/ReforceXY/user_data/freqaimodels/ReforceXY.py b/ReforceXY/user_data/freqaimodels/ReforceXY.py index 2370f75..baf195c 100644 --- a/ReforceXY/user_data/freqaimodels/ReforceXY.py +++ b/ReforceXY/user_data/freqaimodels/ReforceXY.py @@ -153,11 +153,27 @@ class ReforceXY(BaseReinforcementLearningModel): self._model_params_cache: Optional[Dict[str, Any]] = None self.unset_unsupported() + @staticmethod + def _normalize_position(position: Any) -> Positions: + if isinstance(position, Positions): + return position + try: + position = float(position) + if position == float(Positions.Long.value): + return Positions.Long + if position == float(Positions.Short.value): + return Positions.Short + return Positions.Neutral + except Exception: + return Positions.Neutral + @staticmethod @lru_cache(maxsize=8) def get_action_masks( position: Positions, force_action: Optional[ForceActions] = None ) -> NDArray[bool]: + position = ReforceXY._normalize_position(position) + action_masks = np.zeros(len(Actions), dtype=bool) if force_action is not None and position in (Positions.Long, Positions.Short): @@ -592,19 +608,6 @@ class ReforceXY(BaseReinforcementLearningModel): :param model: Any = the trained model used to inference the features. """ - def _normalize_position(position: Any) -> Positions: - if isinstance(position, Positions): - return position - try: - position = float(position) - if position == float(Positions.Long.value): - return Positions.Long - if position == float(Positions.Short.value): - return Positions.Short - return Positions.Neutral - except Exception: - return Positions.Neutral - simulated_position: Positions = Positions.Neutral def _update_simulated_position(action: int, position: Positions) -> Positions: -- 2.43.0