]> Piment Noir Git Repositories - freqai-strategies.git/commitdiff
fix(reforcexy): proper frame stacking implementation
authorJérôme Benoit <jerome.benoit@piment-noir.org>
Mon, 15 Sep 2025 00:09:01 +0000 (02:09 +0200)
committerJérôme Benoit <jerome.benoit@piment-noir.org>
Mon, 15 Sep 2025 00:09:01 +0000 (02:09 +0200)
Signed-off-by: Jérôme Benoit <jerome.benoit@piment-noir.org>
ReforceXY/user_data/freqaimodels/ReforceXY.py

index 5bb3080ff37ecb5d9e3f0c04e0981f64ffcbe1b0..a6deefce0ad0b5a8e220f5f5e3584ae67a5ba8fa 100644 (file)
@@ -514,15 +514,22 @@ class ReforceXY(BaseReinforcementLearningModel):
 
             return True
 
-        def _action_masks(position: float):
+        def _action_masks(position: float) -> list[bool]:
             return [_is_valid(action.value, position) for action in Actions]
 
-        def _predict(window):
+        def _predict(window) -> int:
             observation: DataFrame = dataframe.iloc[window.index]
             action_masks_param: Dict[str, Any] = {}
 
-            if self.live and self.rl_config.get("add_state_info", False):
-                position, pnl, trade_duration = self.get_state_info(dk.pair)
+            if self.rl_config.get("add_state_info", False):
+                if self.live:
+                    position, pnl, trade_duration = self.get_state_info(dk.pair)
+                else:
+                    # TODO: state info handling for backtests
+                    position = Positions.Neutral
+                    pnl = 0.0
+                    trade_duration = 0
+
                 # STATE_INFO
                 observation["pnl"] = pnl
                 observation["position"] = position
@@ -531,23 +538,33 @@ class ReforceXY(BaseReinforcementLearningModel):
                 if self.is_maskable:
                     action_masks_param = {"action_masks": _action_masks(position)}
 
-            observation = observation.to_numpy(dtype=np.float32)
-
-            if self.frame_stacking:
-                # FIXME: proper observation stacking needs more work
-                observations = np.repeat(
-                    observation, axis=1, repeats=self.frame_stacking
-                )
+            np_observation = observation.to_numpy(dtype=np.float32)
+
+            frame_stacking = self.frame_stacking
+            if frame_stacking and frame_stacking > 1:
+                if not hasattr(_predict, "_frame_buffer"):
+                    _predict._frame_buffer = []
+                fb: list[np.ndarray] = getattr(_predict, "_frame_buffer")
+                fb.append(np_observation)
+                if len(fb) > frame_stacking:
+                    del fb[0 : len(fb) - frame_stacking]
+                if len(fb) < frame_stacking:
+                    pad_needed = frame_stacking - len(fb)
+                    fb_padded = [fb[0]] * pad_needed + fb
+                else:
+                    fb_padded = fb
+                stacked_observations = np.concatenate(fb_padded, axis=0)
+                observations = stacked_observations.flatten()
             else:
-                observations = observation
+                observations = np_observation.flatten()
 
             action, _ = model.predict(
                 observations, deterministic=True, **action_masks_param
             )
-            return action
+            return int(action)
 
-        output = DataFrame(np.zeros(len(dataframe)), columns=dk.label_list)
-        output = output.rolling(window=self.CONV_WIDTH).apply(_predict)
+        actions = dataframe.iloc[:, 0].rolling(window=self.CONV_WIDTH).apply(_predict)
+        output = DataFrame({label: actions for label in dk.label_list})
         return output
 
     def get_storage(self, pair: Optional[str] = None) -> BaseStorage: