]> Piment Noir Git Repositories - freqai-strategies.git/commitdiff
refactor(reforcexy): cleanup logging
authorJérôme Benoit <jerome.benoit@piment-noir.org>
Tue, 16 Sep 2025 00:12:27 +0000 (02:12 +0200)
committerJérôme Benoit <jerome.benoit@piment-noir.org>
Tue, 16 Sep 2025 00:12:27 +0000 (02:12 +0200)
Signed-off-by: Jérôme Benoit <jerome.benoit@piment-noir.org>
ReforceXY/user_data/freqaimodels/ReforceXY.py

index 6289765869126bbb95f925a55e3d1d1415dd7f04..bc2d93d07da71505912f50b8946ceece36846ab4 100644 (file)
@@ -146,9 +146,23 @@ class ReforceXY(BaseReinforcementLearningModel):
 
     def unset_unsupported(self) -> None:
         """
-        If user has activated any custom function that may conflict, this
-        function will set them to false and warn them
+        If user has activated any features that may conflict, this
+        function will set them to proper values and warn them
         """
+        if not isinstance(self.n_envs, int) or self.n_envs < 1:
+            logger.warning("Invalid n_envs=%s. Forcing n_envs=1", self.n_envs)
+            self.n_envs = 1
+        if not isinstance(self.frame_stacking, int) or self.frame_stacking < 0:
+            logger.warning(
+                "Invalid frame_stacking=%s. Forcing frame_stacking=0",
+                self.frame_stacking,
+            )
+            self.frame_stacking = 0
+        if self.frame_stacking == 1:
+            logger.warning(
+                "Setting frame_stacking=%s is equivalent to no stacking; use >=2 or 0",
+                self.frame_stacking,
+            )
         if self.continual_learning and self.frame_stacking:
             logger.warning(
                 "User tried to use continual_learning with frame_stacking. \
@@ -168,25 +182,14 @@ class ReforceXY(BaseReinforcementLearningModel):
         """
         self.close_envs()
 
-        if not isinstance(self.n_envs, int) or self.n_envs < 1:
-            logger.warning("Invalid n_envs=%s. Forcing n_envs=1", self.n_envs)
-            self.n_envs = 1
-        if not isinstance(self.frame_stacking, int) or self.frame_stacking < 0:
-            logger.warning(
-                "Invalid frame_stacking=%s. Forcing frame_stacking=0",
-                self.frame_stacking,
-            )
-            self.frame_stacking = 0
-
         train_df = data_dictionary.get("train_features")
         test_df = data_dictionary.get("test_features")
         env_dict = self.pack_env_dict(dk.pair)
         seed = self.get_model_params().get("seed", 42)
         set_random_seed(seed)
-        logger.info("Seeding RNGs with seed=%s (train), %s (eval)", seed, seed + 10_000)
 
         if self.check_envs:
-            logger.info("Checking environments...")
+            logger.info("Checking environments")
             _train_env_check = self.MyRLEnv(
                 id="train_env_check", df=train_df, prices=prices_train, **env_dict
             )
@@ -231,23 +234,11 @@ class ReforceXY(BaseReinforcementLearningModel):
                 for i in range(self.n_envs)
             ]
         )
-        if self.frame_stacking == 1:
-            logger.warning(
-                "frame_stacking=%s is equivalent to no stacking; use >=2 or 0",
-                self.frame_stacking,
-            )
+
         if self.frame_stacking:
-            logger.info(
-                "Observation space shape pre-stacking: %s",
-                train_env.observation_space.shape,
-            )
             logger.info("Frame stacking: %s", self.frame_stacking)
             train_env = VecFrameStack(train_env, n_stack=self.frame_stacking)
             eval_env = VecFrameStack(eval_env, n_stack=self.frame_stacking)
-            logger.info(
-                "Observation space shape post-stacking: %s",
-                train_env.observation_space.shape,
-            )
 
         self.train_env = VecMonitor(train_env)
         self.eval_env = VecMonitor(eval_env)
@@ -264,7 +255,7 @@ class ReforceXY(BaseReinforcementLearningModel):
         if model_params.get("seed") is None:
             model_params["seed"] = 42
 
-        if self.lr_schedule:
+        if not self.hyperopt and self.lr_schedule:
             lr = model_params.get("learning_rate", 0.0003)
             if isinstance(lr, (int, float)):
                 lr = float(lr)
@@ -273,7 +264,7 @@ class ReforceXY(BaseReinforcementLearningModel):
                     "Learning rate linear schedule enabled, initial value: %s", lr
                 )
 
-        if "PPO" in self.model_type and self.cr_schedule:
+        if not self.hyperopt and "PPO" in self.model_type and self.cr_schedule:
             cr = model_params.get("clip_range", 0.2)
             if isinstance(cr, (int, float)):
                 cr = float(cr)
@@ -427,11 +418,18 @@ class ReforceXY(BaseReinforcementLearningModel):
         test_days = steps_to_days(test_timesteps, self.config.get("timeframe"))
         total_days = steps_to_days(total_timesteps, self.config.get("timeframe"))
 
-        logger.info("Action masking: %s", self.is_maskable)
-        logger.info("Train: %s steps (%s days)", train_timesteps, train_days)
-        logger.info("Train cycles: %s", train_cycles)
-        logger.info("Train total: %s steps (%s days)", total_timesteps, total_days)
+        logger.info("Model: %s", self.model_type)
+        logger.info(
+            "Train: %s steps (%s days), %s cycles, %s envs -> total %s steps (%s days)",
+            train_timesteps,
+            train_days,
+            train_cycles,
+            self.n_envs,
+            total_timesteps,
+            total_days,
+        )
         logger.info("Test: %s steps (%s days)", test_timesteps, test_days)
+        logger.info("Action masking: %s", self.is_maskable)
         logger.info("Hyperopt: %s", self.hyperopt)
 
         start_time = time.time()