From: Jérôme Benoit Date: Mon, 15 Sep 2025 00:43:44 +0000 (+0200) Subject: refactor(refactor): shape at first properly observations window X-Git-Url: https://git.piment-noir.org/?a=commitdiff_plain;h=29a09e54555efa5d7bb3c27c575453a2f6e38b56;p=freqai-strategies.git refactor(refactor): shape at first properly observations window Signed-off-by: Jérôme Benoit --- diff --git a/ReforceXY/user_data/freqaimodels/ReforceXY.py b/ReforceXY/user_data/freqaimodels/ReforceXY.py index b9efa73..838ae51 100644 --- a/ReforceXY/user_data/freqaimodels/ReforceXY.py +++ b/ReforceXY/user_data/freqaimodels/ReforceXY.py @@ -539,6 +539,14 @@ class ReforceXY(BaseReinforcementLearningModel): action_masks_param = {"action_masks": _action_masks(position)} np_observation = observation.to_numpy(dtype=np.float32) + shape = getattr(self, "shape", None) + if shape and np_observation.shape != shape: + logger.error( + "Frame shape mismatch: got %s expected %s", + np_observation.shape, + shape, + ) + raise ValueError("Frame shape mismatch") frame_stacking = self.frame_stacking if frame_stacking and frame_stacking > 1: @@ -553,15 +561,10 @@ class ReforceXY(BaseReinforcementLearningModel): fb_padded = [fb[0]] * pad_needed + fb else: fb_padded = fb - stacked_observations = np.concatenate(fb_padded, axis=0) - observations = stacked_observations.flatten() - else: - observations = np_observation.flatten() - - if observations.ndim == 1: - observations = observations.reshape(1, -1) + stacked_observations = np.stack(fb_padded, axis=0) + observations = stacked_observations.reshape(1, -1) else: - observations = observations + observations = np_observation.reshape(1, -1) action, _ = model.predict( observations, deterministic=True, **action_masks_param