From: Jérôme Benoit Date: Fri, 19 Sep 2025 19:18:44 +0000 (+0200) Subject: refactor(reforcexy): ensure position arg is normalized X-Git-Url: https://git.piment-noir.org/?a=commitdiff_plain;h=94783272845974c64b83f444456a11bdd4d7b820;p=freqai-strategies.git refactor(reforcexy): ensure position arg is normalized Signed-off-by: Jérôme Benoit --- diff --git a/ReforceXY/user_data/freqaimodels/ReforceXY.py b/ReforceXY/user_data/freqaimodels/ReforceXY.py index 2370f75..baf195c 100644 --- a/ReforceXY/user_data/freqaimodels/ReforceXY.py +++ b/ReforceXY/user_data/freqaimodels/ReforceXY.py @@ -153,11 +153,27 @@ class ReforceXY(BaseReinforcementLearningModel): self._model_params_cache: Optional[Dict[str, Any]] = None self.unset_unsupported() + @staticmethod + def _normalize_position(position: Any) -> Positions: + if isinstance(position, Positions): + return position + try: + position = float(position) + if position == float(Positions.Long.value): + return Positions.Long + if position == float(Positions.Short.value): + return Positions.Short + return Positions.Neutral + except Exception: + return Positions.Neutral + @staticmethod @lru_cache(maxsize=8) def get_action_masks( position: Positions, force_action: Optional[ForceActions] = None ) -> NDArray[bool]: + position = ReforceXY._normalize_position(position) + action_masks = np.zeros(len(Actions), dtype=bool) if force_action is not None and position in (Positions.Long, Positions.Short): @@ -592,19 +608,6 @@ class ReforceXY(BaseReinforcementLearningModel): :param model: Any = the trained model used to inference the features. """ - def _normalize_position(position: Any) -> Positions: - if isinstance(position, Positions): - return position - try: - position = float(position) - if position == float(Positions.Long.value): - return Positions.Long - if position == float(Positions.Short.value): - return Positions.Short - return Positions.Neutral - except Exception: - return Positions.Neutral - simulated_position: Positions = Positions.Neutral def _update_simulated_position(action: int, position: Positions) -> Positions: