From c019c1635623882a17fae831e3285054b7de0200 Mon Sep 17 00:00:00 2001 From: =?utf8?q?J=C3=A9r=C3=B4me=20Benoit?= Date: Mon, 15 Sep 2025 02:15:24 +0200 Subject: [PATCH] fix(reforcexy): fix stacked observations shape MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit Signed-off-by: Jérôme Benoit --- ReforceXY/user_data/freqaimodels/ReforceXY.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/ReforceXY/user_data/freqaimodels/ReforceXY.py b/ReforceXY/user_data/freqaimodels/ReforceXY.py index a6deefc..b9efa73 100644 --- a/ReforceXY/user_data/freqaimodels/ReforceXY.py +++ b/ReforceXY/user_data/freqaimodels/ReforceXY.py @@ -558,6 +558,11 @@ class ReforceXY(BaseReinforcementLearningModel): else: observations = np_observation.flatten() + if observations.ndim == 1: + observations = observations.reshape(1, -1) + else: + observations = observations + action, _ = model.predict( observations, deterministic=True, **action_masks_param ) -- 2.43.0