]> Piment Noir Git Repositories - freqai-strategies.git/commitdiff
refactor(reforcexy): cleanup prediction code path
authorJérôme Benoit <jerome.benoit@piment-noir.org>
Thu, 2 Oct 2025 09:05:28 +0000 (11:05 +0200)
committerJérôme Benoit <jerome.benoit@piment-noir.org>
Thu, 2 Oct 2025 09:05:28 +0000 (11:05 +0200)
Signed-off-by: Jérôme Benoit <jerome.benoit@piment-noir.org>
ReforceXY/user_data/freqaimodels/ReforceXY.py

index ea3be6a9ecdb57f528f5a0330cc7a69cba798967..903f7200a323e11bbd952d6d7df037f1093c30c8 100644 (file)
@@ -287,7 +287,7 @@ class ReforceXY(BaseReinforcementLearningModel):
                 tensorboard_throttle,
             )
             self.rl_config["tensorboard_throttle"] = 1
-        if self.continual_learning and self.frame_stacking:
+        if self.continual_learning and bool(self.frame_stacking):
             logger.warning(
                 "User tried to use continual_learning with frame_stacking=%s. "
                 "Deactivating continual_learning",
@@ -728,6 +728,9 @@ class ReforceXY(BaseReinforcementLearningModel):
         )
         n = int(np_dataframe.shape[0])
         window_length = int(self.CONV_WIDTH)
+        frame_stacking = self.frame_stacking
+        frame_stacking_activated = bool(frame_stacking) and frame_stacking > 1
+        inference_masking = self.action_masking and self.inference_masking
         add_state_info = self.rl_config.get("add_state_info", False)
 
         def _update_virtual_position(action: int, position: Positions) -> Positions:
@@ -789,8 +792,7 @@ class ReforceXY(BaseReinforcementLearningModel):
                 np_observation = np.concatenate([np_observation, state_block], axis=1)
 
             fb: List[NDArray[np.float32]] = frame_buffer
-            frame_stacking = self.frame_stacking
-            if frame_stacking and frame_stacking > 1:
+            if frame_stacking_activated:
                 fb.append(np_observation)
                 if len(fb) > frame_stacking:
                     del fb[0 : len(fb) - frame_stacking]
@@ -810,7 +812,7 @@ class ReforceXY(BaseReinforcementLearningModel):
                     1, np_observation.shape[0], np_observation.shape[1]
                 )
 
-            if self.action_masking and self.inference_masking:
+            if inference_masking:
                 action_masks_param["action_masks"] = ReforceXY.get_action_masks(
                     self.can_short, virtual_position
                 )
@@ -1118,7 +1120,7 @@ class ReforceXY(BaseReinforcementLearningModel):
         else:
             eval_env = DummyVecEnv(eval_fns)
 
-        if self.frame_stacking:
+        if bool(self.frame_stacking):
             train_env = VecFrameStack(train_env, n_stack=self.frame_stacking)
             eval_env = VecFrameStack(eval_env, n_stack=self.frame_stacking)