From 519faf5029b0333d4bf189d57ccb1dff45ec0a37 Mon Sep 17 00:00:00 2001 From: =?utf8?q?J=C3=A9r=C3=B4me=20Benoit?= Date: Fri, 26 Sep 2025 20:29:59 +0200 Subject: [PATCH] fix(reforcexy): spot support at model training MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit Signed-off-by: Jérôme Benoit --- ReforceXY/user_data/freqaimodels/ReforceXY.py | 66 ++++++++++--------- .../user_data/strategies/RLAgentStrategy.py | 2 +- .../user_data/strategies/QuickAdapterV3.py | 2 +- 3 files changed, 37 insertions(+), 33 deletions(-) diff --git a/ReforceXY/user_data/freqaimodels/ReforceXY.py b/ReforceXY/user_data/freqaimodels/ReforceXY.py index 9fa7d48..60cd86a 100644 --- a/ReforceXY/user_data/freqaimodels/ReforceXY.py +++ b/ReforceXY/user_data/freqaimodels/ReforceXY.py @@ -159,6 +159,15 @@ class ReforceXY(BaseReinforcementLearningModel): self._model_params_cache: Optional[Dict[str, Any]] = None self.unset_unsupported() + @staticmethod + def is_short_allowed(trading_mode: str) -> bool: + if trading_mode in {"margin", "futures"}: + return True + elif trading_mode == "spot": + return False + else: + raise ValueError(f"Invalid trading_mode: {trading_mode}") + @staticmethod def _normalize_position(position: Any) -> Positions: if isinstance(position, Positions): @@ -175,8 +184,11 @@ class ReforceXY(BaseReinforcementLearningModel): @staticmethod def get_action_masks( - position: Positions, force_action: Optional[ForceActions] = None + trading_mode: str, + position: Positions, + force_action: Optional[ForceActions] = None, ) -> NDArray[np.bool_]: + is_short_allowed = ReforceXY.is_short_allowed(trading_mode) position = ReforceXY._normalize_position(position) action_masks = np.zeros(len(Actions), dtype=np.bool_) @@ -191,7 +203,8 @@ class ReforceXY(BaseReinforcementLearningModel): action_masks[Actions.Neutral.value] = True if position == Positions.Neutral: action_masks[Actions.Long_enter.value] = True - action_masks[Actions.Short_enter.value] = True + if is_short_allowed: + action_masks[Actions.Short_enter.value] = True elif position == Positions.Long: action_masks[Actions.Long_exit.value] = True elif position == Positions.Short: @@ -615,9 +628,9 @@ class ReforceXY(BaseReinforcementLearningModel): :param model: Any = the trained model used to inference the features. """ - simulated_position: Positions = Positions.Neutral + virtual_position: Positions = Positions.Neutral - def _update_simulated_position(action: int, position: Positions) -> Positions: + def _update_virtual_position(action: int, position: Positions) -> Positions: if action == Actions.Long_enter.value and position == Positions.Neutral: return Positions.Long if action == Actions.Short_enter.value and position == Positions.Neutral: @@ -668,7 +681,7 @@ class ReforceXY(BaseReinforcementLearningModel): if self.action_masking and self.inference_masking: action_masks_param["action_masks"] = ReforceXY.get_action_masks( - simulated_position + self.config.get("trading_mode"), virtual_position ) action, _ = model.predict( @@ -681,7 +694,7 @@ class ReforceXY(BaseReinforcementLearningModel): window = dataframe.iloc[window_end - self.CONV_WIDTH : window_end] action = _predict(window) predicted_actions.append(action) - simulated_position = _update_simulated_position(action, simulated_position) + virtual_position = _update_virtual_position(action, virtual_position) pad = [np.nan] * (self.CONV_WIDTH - 1) actions_list = pad + predicted_actions @@ -740,7 +753,7 @@ class ReforceXY(BaseReinforcementLearningModel): else self.get_storage() ) if "PPO" in self.model_type: - resource_eval_freq = max(PPO_N_STEPS) + resource_eval_freq = min(PPO_N_STEPS) else: resource_eval_freq = self.get_eval_freq(total_timesteps, hyperopt=True) reduction_factor = 3 @@ -955,13 +968,13 @@ class ReforceXY(BaseReinforcementLearningModel): train_env = DummyVecEnv(train_fns) eval_env = DummyVecEnv(eval_fns) - train_env = VecMonitor(train_env) - eval_env = VecMonitor(eval_env) - if self.frame_stacking: train_env = VecFrameStack(train_env, n_stack=self.frame_stacking) eval_env = VecFrameStack(eval_env, n_stack=self.frame_stacking) + train_env = VecMonitor(train_env) + eval_env = VecMonitor(eval_env) + return train_env, eval_env def objective( @@ -1159,7 +1172,9 @@ class MyRLEnv(Base5ActionRLEnv): ) def _is_valid(self, action: int) -> bool: - return ReforceXY.get_action_masks(self._position, self._force_action)[action] + return ReforceXY.get_action_masks( + self.config.get("trading_mode"), self._position, self._force_action + )[action] def reset_env( self, @@ -1558,7 +1573,9 @@ class MyRLEnv(Base5ActionRLEnv): ) def action_masks(self) -> NDArray[np.bool_]: - return ReforceXY.get_action_masks(self._position, self._force_action) + return ReforceXY.get_action_masks( + self.config.get("trading_mode"), self._position, self._force_action + ) def get_feature_value( self, @@ -1719,25 +1736,12 @@ class MyRLEnv(Base5ActionRLEnv): right_index=True, how="left", ) - except Exception: - try: - _price_history = ( - self.prices.iloc[_rollout_history.tick] - .copy() - .reset_index(drop=True) - ) - history = merge( - _rollout_history, - _price_history, - left_index=True, - right_index=True, - ) - except Exception as e: - logger.error( - f"Failed to merge history with prices: {repr(e)}", - exc_info=True, - ) - return DataFrame() + except Exception as e: + logger.error( + f"Failed to merge history with prices: {repr(e)}", + exc_info=True, + ) + return DataFrame() return history def get_env_plot(self) -> plt.Figure: diff --git a/ReforceXY/user_data/strategies/RLAgentStrategy.py b/ReforceXY/user_data/strategies/RLAgentStrategy.py index 4139cb3..4555f1b 100644 --- a/ReforceXY/user_data/strategies/RLAgentStrategy.py +++ b/ReforceXY/user_data/strategies/RLAgentStrategy.py @@ -109,7 +109,7 @@ class RLAgentStrategy(IStrategy): def is_short_allowed(self) -> bool: trading_mode = self.config.get("trading_mode") - if trading_mode == "margin" or trading_mode == "futures": + if trading_mode in {"margin", "futures"}: return True elif trading_mode == "spot": return False diff --git a/quickadapter/user_data/strategies/QuickAdapterV3.py b/quickadapter/user_data/strategies/QuickAdapterV3.py index dbac8dd..e48c2a3 100644 --- a/quickadapter/user_data/strategies/QuickAdapterV3.py +++ b/quickadapter/user_data/strategies/QuickAdapterV3.py @@ -1574,7 +1574,7 @@ class QuickAdapterV3(IStrategy): def is_short_allowed(self) -> bool: trading_mode = self.config.get("trading_mode") - if trading_mode == "margin" or trading_mode == "futures": + if trading_mode in {"margin", "futures"}: return True elif trading_mode == "spot": return False -- 2.43.0