]> Piment Noir Git Repositories - freqai-strategies.git/commitdiff
fix(reforcexy): ensure virtual position in sync in live state
authorJérôme Benoit <jerome.benoit@piment-noir.org>
Thu, 2 Oct 2025 11:27:35 +0000 (13:27 +0200)
committerJérôme Benoit <jerome.benoit@piment-noir.org>
Thu, 2 Oct 2025 11:27:35 +0000 (13:27 +0200)
Signed-off-by: Jérôme Benoit <jerome.benoit@piment-noir.org>
ReforceXY/user_data/freqaimodels/ReforceXY.py

index ae74d82fafb5e3540bde102a4453abeba7269c8c..240f150218f40861f4840c957bba61b709a9e93f 100644 (file)
@@ -39,7 +39,7 @@ from optuna.storages import (
 )
 from optuna.storages.journal import JournalFileBackend
 from optuna.study import Study, StudyDirection
-from pandas import DataFrame, concat, merge
+from pandas import DataFrame, merge
 from sb3_contrib.common.maskable.callbacks import MaskableEvalCallback
 from sb3_contrib.common.maskable.utils import is_masking_supported
 from stable_baselines3.common.callbacks import (
@@ -721,8 +721,13 @@ class ReforceXY(BaseReinforcementLearningModel):
         :param dk: FreqaiDatakitchen = data kitchen for the current pair
         :param model: Any = the trained model used to inference the features.
         """
+        add_state_info = self.rl_config.get("add_state_info", False)
         virtual_position: Positions = Positions.Neutral
         virtual_trade_duration: int = 0
+        if add_state_info and self.live:
+            position, _, trade_duration = self.get_state_info(dk.pair)
+            virtual_position = ReforceXY._normalize_position(position)
+            virtual_trade_duration = trade_duration
         np_dataframe: NDArray[np.float32] = dataframe.to_numpy(
             dtype=np.float32, copy=False
         )
@@ -731,7 +736,6 @@ class ReforceXY(BaseReinforcementLearningModel):
         frame_stacking = self.frame_stacking
         frame_stacking_activated = bool(frame_stacking) and frame_stacking > 1
         inference_masking = self.action_masking and self.inference_masking
-        add_state_info = self.rl_config.get("add_state_info", False)
 
         def _update_virtual_position(action: int, position: Positions) -> Positions:
             if action == Actions.Long_enter.value and position == Positions.Neutral:
@@ -745,7 +749,6 @@ class ReforceXY(BaseReinforcementLearningModel):
             return position
 
         def _update_virtual_trade_duration(
-            action: int,
             virtual_position: Positions,
             previous_virtual_position: Positions,
             current_virtual_trade_duration: int,
@@ -829,7 +832,6 @@ class ReforceXY(BaseReinforcementLearningModel):
             previous_virtual_position = virtual_position
             virtual_position = _update_virtual_position(action, virtual_position)
             virtual_trade_duration = _update_virtual_trade_duration(
-                action,
                 virtual_position,
                 previous_virtual_position,
                 virtual_trade_duration,
@@ -1616,29 +1618,28 @@ class MyRLEnv(Base5ActionRLEnv):
         This may or may not be independent of action types, user can inherit
         this in their custom "MyRLEnv"
         """
-        start_idx = max(0, self._current_tick - self.window_size)
-        end_idx = self._current_tick
+        start_idx = max(self._start_tick, self._current_tick - self.window_size)
+        end_idx = min(self._current_tick, len(self.signal_features))
         features_window = self.signal_features.iloc[start_idx:end_idx]
-        if len(features_window) < self.window_size:
-            pad_size = self.window_size - len(features_window)
-            pad_df = DataFrame(
-                np.zeros((pad_size, features_window.shape[1]), dtype=np.float32),
-                columns=features_window.columns,
+        features_window_array = features_window.to_numpy(dtype=np.float32, copy=False)
+        if features_window_array.shape[0] < self.window_size:
+            pad_size = self.window_size - features_window_array.shape[0]
+            pad_array = np.zeros(
+                (pad_size, features_window_array.shape[1]), dtype=np.float32
             )
-            features_window = concat(
-                [pad_df, features_window], axis=0, ignore_index=True
+            features_window_array = np.concatenate(
+                [pad_array, features_window_array], axis=0
             )
-        features_window_array = features_window.to_numpy(dtype=np.float32)
         if self.add_state_info:
-            return np.concatenate(
+            observations = np.concatenate(
                 [
                     features_window_array,
                     np.tile(
                         np.array(
                             [
-                                self.get_unrealized_profit(),
-                                self._position.value,
-                                self.get_trade_duration(),
+                                float(self.get_unrealized_profit()),
+                                float(self._position.value),
+                                float(self.get_trade_duration()),
                             ],
                             dtype=np.float32,
                         ),
@@ -1648,7 +1649,9 @@ class MyRLEnv(Base5ActionRLEnv):
                 axis=1,
             )
         else:
-            return features_window_array
+            observations = features_window_array
+
+        return np.ascontiguousarray(observations)
 
     def _get_force_action(self) -> Optional[ForceActions]:
         if not self.force_actions or self._position == Positions.Neutral: