From: Jérôme Benoit Date: Thu, 2 Oct 2025 11:27:35 +0000 (+0200) Subject: fix(reforcexy): ensure virtual position in sync in live state X-Git-Url: https://git.piment-noir.org/?a=commitdiff_plain;h=b6af9ef7fa614eef043a6f85faef00a1de8e3474;p=freqai-strategies.git fix(reforcexy): ensure virtual position in sync in live state Signed-off-by: Jérôme Benoit --- diff --git a/ReforceXY/user_data/freqaimodels/ReforceXY.py b/ReforceXY/user_data/freqaimodels/ReforceXY.py index ae74d82..240f150 100644 --- a/ReforceXY/user_data/freqaimodels/ReforceXY.py +++ b/ReforceXY/user_data/freqaimodels/ReforceXY.py @@ -39,7 +39,7 @@ from optuna.storages import ( ) from optuna.storages.journal import JournalFileBackend from optuna.study import Study, StudyDirection -from pandas import DataFrame, concat, merge +from pandas import DataFrame, merge from sb3_contrib.common.maskable.callbacks import MaskableEvalCallback from sb3_contrib.common.maskable.utils import is_masking_supported from stable_baselines3.common.callbacks import ( @@ -721,8 +721,13 @@ class ReforceXY(BaseReinforcementLearningModel): :param dk: FreqaiDatakitchen = data kitchen for the current pair :param model: Any = the trained model used to inference the features. """ + add_state_info = self.rl_config.get("add_state_info", False) virtual_position: Positions = Positions.Neutral virtual_trade_duration: int = 0 + if add_state_info and self.live: + position, _, trade_duration = self.get_state_info(dk.pair) + virtual_position = ReforceXY._normalize_position(position) + virtual_trade_duration = trade_duration np_dataframe: NDArray[np.float32] = dataframe.to_numpy( dtype=np.float32, copy=False ) @@ -731,7 +736,6 @@ class ReforceXY(BaseReinforcementLearningModel): frame_stacking = self.frame_stacking frame_stacking_activated = bool(frame_stacking) and frame_stacking > 1 inference_masking = self.action_masking and self.inference_masking - add_state_info = self.rl_config.get("add_state_info", False) def _update_virtual_position(action: int, position: Positions) -> Positions: if action == Actions.Long_enter.value and position == Positions.Neutral: @@ -745,7 +749,6 @@ class ReforceXY(BaseReinforcementLearningModel): return position def _update_virtual_trade_duration( - action: int, virtual_position: Positions, previous_virtual_position: Positions, current_virtual_trade_duration: int, @@ -829,7 +832,6 @@ class ReforceXY(BaseReinforcementLearningModel): previous_virtual_position = virtual_position virtual_position = _update_virtual_position(action, virtual_position) virtual_trade_duration = _update_virtual_trade_duration( - action, virtual_position, previous_virtual_position, virtual_trade_duration, @@ -1616,29 +1618,28 @@ class MyRLEnv(Base5ActionRLEnv): This may or may not be independent of action types, user can inherit this in their custom "MyRLEnv" """ - start_idx = max(0, self._current_tick - self.window_size) - end_idx = self._current_tick + start_idx = max(self._start_tick, self._current_tick - self.window_size) + end_idx = min(self._current_tick, len(self.signal_features)) features_window = self.signal_features.iloc[start_idx:end_idx] - if len(features_window) < self.window_size: - pad_size = self.window_size - len(features_window) - pad_df = DataFrame( - np.zeros((pad_size, features_window.shape[1]), dtype=np.float32), - columns=features_window.columns, + features_window_array = features_window.to_numpy(dtype=np.float32, copy=False) + if features_window_array.shape[0] < self.window_size: + pad_size = self.window_size - features_window_array.shape[0] + pad_array = np.zeros( + (pad_size, features_window_array.shape[1]), dtype=np.float32 ) - features_window = concat( - [pad_df, features_window], axis=0, ignore_index=True + features_window_array = np.concatenate( + [pad_array, features_window_array], axis=0 ) - features_window_array = features_window.to_numpy(dtype=np.float32) if self.add_state_info: - return np.concatenate( + observations = np.concatenate( [ features_window_array, np.tile( np.array( [ - self.get_unrealized_profit(), - self._position.value, - self.get_trade_duration(), + float(self.get_unrealized_profit()), + float(self._position.value), + float(self.get_trade_duration()), ], dtype=np.float32, ), @@ -1648,7 +1649,9 @@ class MyRLEnv(Base5ActionRLEnv): axis=1, ) else: - return features_window_array + observations = features_window_array + + return np.ascontiguousarray(observations) def _get_force_action(self) -> Optional[ForceActions]: if not self.force_actions or self._position == Positions.Neutral: