From 29a09e54555efa5d7bb3c27c575453a2f6e38b56 Mon Sep 17 00:00:00 2001 From: =?utf8?q?J=C3=A9r=C3=B4me=20Benoit?= Date: Mon, 15 Sep 2025 02:43:44 +0200 Subject: [PATCH] refactor(refactor): shape at first properly observations window MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit Signed-off-by: Jérôme Benoit --- ReforceXY/user_data/freqaimodels/ReforceXY.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/ReforceXY/user_data/freqaimodels/ReforceXY.py b/ReforceXY/user_data/freqaimodels/ReforceXY.py index b9efa73..838ae51 100644 --- a/ReforceXY/user_data/freqaimodels/ReforceXY.py +++ b/ReforceXY/user_data/freqaimodels/ReforceXY.py @@ -539,6 +539,14 @@ class ReforceXY(BaseReinforcementLearningModel): action_masks_param = {"action_masks": _action_masks(position)} np_observation = observation.to_numpy(dtype=np.float32) + shape = getattr(self, "shape", None) + if shape and np_observation.shape != shape: + logger.error( + "Frame shape mismatch: got %s expected %s", + np_observation.shape, + shape, + ) + raise ValueError("Frame shape mismatch") frame_stacking = self.frame_stacking if frame_stacking and frame_stacking > 1: @@ -553,15 +561,10 @@ class ReforceXY(BaseReinforcementLearningModel): fb_padded = [fb[0]] * pad_needed + fb else: fb_padded = fb - stacked_observations = np.concatenate(fb_padded, axis=0) - observations = stacked_observations.flatten() - else: - observations = np_observation.flatten() - - if observations.ndim == 1: - observations = observations.reshape(1, -1) + stacked_observations = np.stack(fb_padded, axis=0) + observations = stacked_observations.reshape(1, -1) else: - observations = observations + observations = np_observation.reshape(1, -1) action, _ = model.predict( observations, deterministic=True, **action_masks_param -- 2.43.0