From 9849bb19d59406dc588ae72c37d66608e42d7470 Mon Sep 17 00:00:00 2001 From: =?utf8?q?J=C3=A9r=C3=B4me=20Benoit?= Date: Thu, 2 Oct 2025 11:05:28 +0200 Subject: [PATCH] refactor(reforcexy): cleanup prediction code path MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit Signed-off-by: Jérôme Benoit --- ReforceXY/user_data/freqaimodels/ReforceXY.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/ReforceXY/user_data/freqaimodels/ReforceXY.py b/ReforceXY/user_data/freqaimodels/ReforceXY.py index ea3be6a..903f720 100644 --- a/ReforceXY/user_data/freqaimodels/ReforceXY.py +++ b/ReforceXY/user_data/freqaimodels/ReforceXY.py @@ -287,7 +287,7 @@ class ReforceXY(BaseReinforcementLearningModel): tensorboard_throttle, ) self.rl_config["tensorboard_throttle"] = 1 - if self.continual_learning and self.frame_stacking: + if self.continual_learning and bool(self.frame_stacking): logger.warning( "User tried to use continual_learning with frame_stacking=%s. " "Deactivating continual_learning", @@ -728,6 +728,9 @@ class ReforceXY(BaseReinforcementLearningModel): ) n = int(np_dataframe.shape[0]) window_length = int(self.CONV_WIDTH) + frame_stacking = self.frame_stacking + frame_stacking_activated = bool(frame_stacking) and frame_stacking > 1 + inference_masking = self.action_masking and self.inference_masking add_state_info = self.rl_config.get("add_state_info", False) def _update_virtual_position(action: int, position: Positions) -> Positions: @@ -789,8 +792,7 @@ class ReforceXY(BaseReinforcementLearningModel): np_observation = np.concatenate([np_observation, state_block], axis=1) fb: List[NDArray[np.float32]] = frame_buffer - frame_stacking = self.frame_stacking - if frame_stacking and frame_stacking > 1: + if frame_stacking_activated: fb.append(np_observation) if len(fb) > frame_stacking: del fb[0 : len(fb) - frame_stacking] @@ -810,7 +812,7 @@ class ReforceXY(BaseReinforcementLearningModel): 1, np_observation.shape[0], np_observation.shape[1] ) - if self.action_masking and self.inference_masking: + if inference_masking: action_masks_param["action_masks"] = ReforceXY.get_action_masks( self.can_short, virtual_position ) @@ -1118,7 +1120,7 @@ class ReforceXY(BaseReinforcementLearningModel): else: eval_env = DummyVecEnv(eval_fns) - if self.frame_stacking: + if bool(self.frame_stacking): train_env = VecFrameStack(train_env, n_stack=self.frame_stacking) eval_env = VecFrameStack(eval_env, n_stack=self.frame_stacking) -- 2.43.0