]> Piment Noir Git Repositories - freqai-strategies.git/commitdiff
refactor(refactor): shape at first properly observations window
authorJérôme Benoit <jerome.benoit@piment-noir.org>
Mon, 15 Sep 2025 00:43:44 +0000 (02:43 +0200)
committerJérôme Benoit <jerome.benoit@piment-noir.org>
Mon, 15 Sep 2025 00:44:49 +0000 (02:44 +0200)
Signed-off-by: Jérôme Benoit <jerome.benoit@piment-noir.org>
ReforceXY/user_data/freqaimodels/ReforceXY.py

index b9efa731bd36b8d31fff6df4cbc3a47818207cfb..838ae51e2fee0524f43f7aafb1ec6ba369ae4412 100644 (file)
@@ -539,6 +539,14 @@ class ReforceXY(BaseReinforcementLearningModel):
                     action_masks_param = {"action_masks": _action_masks(position)}
 
             np_observation = observation.to_numpy(dtype=np.float32)
+            shape = getattr(self, "shape", None)
+            if shape and np_observation.shape != shape:
+                logger.error(
+                    "Frame shape mismatch: got %s expected %s",
+                    np_observation.shape,
+                    shape,
+                )
+                raise ValueError("Frame shape mismatch")
 
             frame_stacking = self.frame_stacking
             if frame_stacking and frame_stacking > 1:
@@ -553,15 +561,10 @@ class ReforceXY(BaseReinforcementLearningModel):
                     fb_padded = [fb[0]] * pad_needed + fb
                 else:
                     fb_padded = fb
-                stacked_observations = np.concatenate(fb_padded, axis=0)
-                observations = stacked_observations.flatten()
-            else:
-                observations = np_observation.flatten()
-
-            if observations.ndim == 1:
-                observations = observations.reshape(1, -1)
+                stacked_observations = np.stack(fb_padded, axis=0)
+                observations = stacked_observations.reshape(1, -1)
             else:
-                observations = observations
+                observations = np_observation.reshape(1, -1)
 
             action, _ = model.predict(
                 observations, deterministic=True, **action_masks_param