From: Jérôme Benoit Date: Thu, 2 Oct 2025 09:05:28 +0000 (+0200) Subject: refactor(reforcexy): cleanup prediction code path X-Git-Url: https://git.piment-noir.org/?a=commitdiff_plain;h=9849bb19d59406dc588ae72c37d66608e42d7470;p=freqai-strategies.git refactor(reforcexy): cleanup prediction code path Signed-off-by: Jérôme Benoit --- diff --git a/ReforceXY/user_data/freqaimodels/ReforceXY.py b/ReforceXY/user_data/freqaimodels/ReforceXY.py index ea3be6a..903f720 100644 --- a/ReforceXY/user_data/freqaimodels/ReforceXY.py +++ b/ReforceXY/user_data/freqaimodels/ReforceXY.py @@ -287,7 +287,7 @@ class ReforceXY(BaseReinforcementLearningModel): tensorboard_throttle, ) self.rl_config["tensorboard_throttle"] = 1 - if self.continual_learning and self.frame_stacking: + if self.continual_learning and bool(self.frame_stacking): logger.warning( "User tried to use continual_learning with frame_stacking=%s. " "Deactivating continual_learning", @@ -728,6 +728,9 @@ class ReforceXY(BaseReinforcementLearningModel): ) n = int(np_dataframe.shape[0]) window_length = int(self.CONV_WIDTH) + frame_stacking = self.frame_stacking + frame_stacking_activated = bool(frame_stacking) and frame_stacking > 1 + inference_masking = self.action_masking and self.inference_masking add_state_info = self.rl_config.get("add_state_info", False) def _update_virtual_position(action: int, position: Positions) -> Positions: @@ -789,8 +792,7 @@ class ReforceXY(BaseReinforcementLearningModel): np_observation = np.concatenate([np_observation, state_block], axis=1) fb: List[NDArray[np.float32]] = frame_buffer - frame_stacking = self.frame_stacking - if frame_stacking and frame_stacking > 1: + if frame_stacking_activated: fb.append(np_observation) if len(fb) > frame_stacking: del fb[0 : len(fb) - frame_stacking] @@ -810,7 +812,7 @@ class ReforceXY(BaseReinforcementLearningModel): 1, np_observation.shape[0], np_observation.shape[1] ) - if self.action_masking and self.inference_masking: + if inference_masking: action_masks_param["action_masks"] = ReforceXY.get_action_masks( self.can_short, virtual_position ) @@ -1118,7 +1120,7 @@ class ReforceXY(BaseReinforcementLearningModel): else: eval_env = DummyVecEnv(eval_fns) - if self.frame_stacking: + if bool(self.frame_stacking): train_env = VecFrameStack(train_env, n_stack=self.frame_stacking) eval_env = VecFrameStack(eval_env, n_stack=self.frame_stacking)